ricecoder_refactoring/patterns/
store.rs

1//! Pattern storage and retrieval
2
3use crate::error::{RefactoringError, Result};
4use super::{RefactoringPattern, PatternScope};
5use std::collections::HashMap;
6use std::sync::Arc;
7use tokio::sync::RwLock;
8
9/// Stores and manages refactoring patterns
10pub struct PatternStore {
11    global_patterns: Arc<RwLock<HashMap<String, RefactoringPattern>>>,
12    project_patterns: Arc<RwLock<HashMap<String, RefactoringPattern>>>,
13}
14
15impl PatternStore {
16    /// Create a new pattern store
17    pub fn new() -> Self {
18        Self {
19            global_patterns: Arc::new(RwLock::new(HashMap::new())),
20            project_patterns: Arc::new(RwLock::new(HashMap::new())),
21        }
22    }
23
24    /// Add a pattern to the store
25    pub async fn add_pattern(&self, pattern: RefactoringPattern) -> Result<()> {
26        match pattern.scope {
27            PatternScope::Global => {
28                let mut patterns = self.global_patterns.write().await;
29                patterns.insert(pattern.name.clone(), pattern);
30            }
31            PatternScope::Project => {
32                let mut patterns = self.project_patterns.write().await;
33                patterns.insert(pattern.name.clone(), pattern);
34            }
35        }
36        Ok(())
37    }
38
39    /// Get a pattern by name (project patterns take precedence)
40    pub async fn get_pattern(&self, name: &str) -> Result<Option<RefactoringPattern>> {
41        // Check project patterns first
42        {
43            let patterns = self.project_patterns.read().await;
44            if let Some(pattern) = patterns.get(name) {
45                return Ok(Some(pattern.clone()));
46            }
47        }
48
49        // Fall back to global patterns
50        let patterns = self.global_patterns.read().await;
51        Ok(patterns.get(name).cloned())
52    }
53
54    /// List all patterns
55    pub async fn list_patterns(&self) -> Result<Vec<RefactoringPattern>> {
56        let mut patterns = vec![];
57
58        // Add global patterns
59        {
60            let global = self.global_patterns.read().await;
61            patterns.extend(global.values().cloned());
62        }
63
64        // Add project patterns
65        {
66            let project = self.project_patterns.read().await;
67            patterns.extend(project.values().cloned());
68        }
69
70        Ok(patterns)
71    }
72
73    /// Remove a pattern
74    pub async fn remove_pattern(&self, name: &str) -> Result<()> {
75        let mut project = self.project_patterns.write().await;
76        if project.remove(name).is_some() {
77            return Ok(());
78        }
79
80        let mut global = self.global_patterns.write().await;
81        if global.remove(name).is_some() {
82            return Ok(());
83        }
84
85        Err(RefactoringError::Other(format!("Pattern not found: {}", name)))
86    }
87
88    /// Clear all patterns
89    pub async fn clear(&self) -> Result<()> {
90        self.global_patterns.write().await.clear();
91        self.project_patterns.write().await.clear();
92        Ok(())
93    }
94
95    /// Get pattern count
96    pub async fn pattern_count(&self) -> Result<usize> {
97        let global_count = self.global_patterns.read().await.len();
98        let project_count = self.project_patterns.read().await.len();
99        Ok(global_count + project_count)
100    }
101}
102
103impl Default for PatternStore {
104    fn default() -> Self {
105        Self::new()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use super::*;
112
113    #[tokio::test]
114    async fn test_add_and_get_pattern() -> Result<()> {
115        let store = PatternStore::new();
116        let pattern = RefactoringPattern {
117            name: "test".to_string(),
118            description: "Test pattern".to_string(),
119            template: "template".to_string(),
120            parameters: vec![],
121            scope: PatternScope::Global,
122        };
123
124        store.add_pattern(pattern.clone()).await?;
125        let retrieved = store.get_pattern("test").await?;
126
127        assert!(retrieved.is_some());
128        assert_eq!(retrieved.unwrap().name, "test");
129
130        Ok(())
131    }
132
133    #[tokio::test]
134    async fn test_project_patterns_take_precedence() -> Result<()> {
135        let store = PatternStore::new();
136
137        let global_pattern = RefactoringPattern {
138            name: "test".to_string(),
139            description: "Global pattern".to_string(),
140            template: "global".to_string(),
141            parameters: vec![],
142            scope: PatternScope::Global,
143        };
144
145        let project_pattern = RefactoringPattern {
146            name: "test".to_string(),
147            description: "Project pattern".to_string(),
148            template: "project".to_string(),
149            parameters: vec![],
150            scope: PatternScope::Project,
151        };
152
153        store.add_pattern(global_pattern).await?;
154        store.add_pattern(project_pattern).await?;
155
156        let retrieved = store.get_pattern("test").await?;
157        assert!(retrieved.is_some());
158        assert_eq!(retrieved.unwrap().template, "project");
159
160        Ok(())
161    }
162
163    #[tokio::test]
164    async fn test_list_patterns() -> Result<()> {
165        let store = PatternStore::new();
166
167        let pattern1 = RefactoringPattern {
168            name: "pattern1".to_string(),
169            description: "Pattern 1".to_string(),
170            template: "template1".to_string(),
171            parameters: vec![],
172            scope: PatternScope::Global,
173        };
174
175        let pattern2 = RefactoringPattern {
176            name: "pattern2".to_string(),
177            description: "Pattern 2".to_string(),
178            template: "template2".to_string(),
179            parameters: vec![],
180            scope: PatternScope::Project,
181        };
182
183        store.add_pattern(pattern1).await?;
184        store.add_pattern(pattern2).await?;
185
186        let patterns = store.list_patterns().await?;
187        assert_eq!(patterns.len(), 2);
188
189        Ok(())
190    }
191
192    #[tokio::test]
193    async fn test_remove_pattern() -> Result<()> {
194        let store = PatternStore::new();
195        let pattern = RefactoringPattern {
196            name: "test".to_string(),
197            description: "Test pattern".to_string(),
198            template: "template".to_string(),
199            parameters: vec![],
200            scope: PatternScope::Global,
201        };
202
203        store.add_pattern(pattern).await?;
204        store.remove_pattern("test").await?;
205
206        let retrieved = store.get_pattern("test").await?;
207        assert!(retrieved.is_none());
208
209        Ok(())
210    }
211}