mur_core/policy/
patterns.rs1use anyhow::{Context, Result};
5use serde::{Deserialize, Serialize};
6use std::path::{Path, PathBuf};
7use std::sync::Mutex;
8use std::time::Instant;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct Pattern {
13 pub id: String,
15 pub name: String,
17 pub description: String,
19 #[serde(default)]
21 pub tags: Vec<String>,
22 #[serde(default = "default_inject")]
24 pub inject: String,
25 #[serde(default)]
27 pub match_actions: Vec<String>,
28}
29
30fn default_inject() -> String {
31 "on_match".to_string()
32}
33
34pub struct PatternStore {
36 patterns_dir: PathBuf,
37 cache: Mutex<Option<(Instant, Vec<Pattern>)>>,
38}
39
40impl PatternStore {
41 pub fn new(patterns_dir: &Path) -> Self {
43 Self {
44 patterns_dir: patterns_dir.to_path_buf(),
45 cache: Mutex::new(None),
46 }
47 }
48
49 fn cached_load(&self) -> Result<Vec<Pattern>> {
51 let mut cache = self.cache.lock().unwrap_or_else(|e| e.into_inner());
52 if let Some((loaded_at, ref patterns)) = *cache {
53 if loaded_at.elapsed() < std::time::Duration::from_secs(30) {
54 return Ok(patterns.clone());
55 }
56 }
57 let patterns = self.load_all()?;
58 *cache = Some((Instant::now(), patterns.clone()));
59 Ok(patterns)
60 }
61
62 pub fn default_dir() -> PathBuf {
64 directories::BaseDirs::new()
65 .map(|d| d.home_dir().join(".mur").join("patterns"))
66 .unwrap_or_else(|| PathBuf::from(".mur/patterns"))
67 }
68
69 pub fn load_all(&self) -> Result<Vec<Pattern>> {
71 if !self.patterns_dir.exists() {
72 return Ok(Vec::new());
73 }
74
75 let mut patterns = Vec::new();
76 let glob_pattern = self
77 .patterns_dir
78 .join("*.yaml")
79 .to_string_lossy()
80 .to_string();
81
82 for entry in glob::glob(&glob_pattern).context("Invalid glob pattern")? {
83 match entry {
84 Ok(path) => match self.load_pattern(&path) {
85 Ok(p) => patterns.push(p),
86 Err(e) => {
87 tracing::warn!("Failed to load pattern {:?}: {}", path, e);
88 }
89 },
90 Err(e) => {
91 tracing::warn!("Glob error: {}", e);
92 }
93 }
94 }
95
96 let yml_pattern = self
98 .patterns_dir
99 .join("*.yml")
100 .to_string_lossy()
101 .to_string();
102
103 for entry in glob::glob(&yml_pattern).context("Invalid glob pattern")? {
104 match entry {
105 Ok(path) => match self.load_pattern(&path) {
106 Ok(p) => {
107 if !patterns.iter().any(|existing| existing.id == p.id) {
108 patterns.push(p);
109 }
110 }
111 Err(e) => {
112 tracing::warn!("Failed to load pattern {:?}: {}", path, e);
113 }
114 },
115 Err(e) => {
116 tracing::warn!("Glob error: {}", e);
117 }
118 }
119 }
120
121 Ok(patterns)
122 }
123
124 pub fn load_pattern(&self, path: &Path) -> Result<Pattern> {
126 let content =
127 std::fs::read_to_string(path).context(format!("Reading pattern {:?}", path))?;
128 let pattern: Pattern =
129 serde_yaml::from_str(&content).context(format!("Parsing pattern {:?}", path))?;
130 Ok(pattern)
131 }
132
133 pub fn get_matching_patterns(
140 &self,
141 action_type: &str,
142 action_command: &str,
143 ) -> Result<Vec<Pattern>> {
144 let all = self.cached_load()?;
145 let mut matching = Vec::new();
146
147 for pattern in all {
148 if pattern.inject == "always" {
149 matching.push(pattern);
150 continue;
151 }
152
153 if pattern.inject == "on_match" {
154 let matches = pattern.match_actions.iter().any(|matcher| {
156 super::rules::pattern_matches(action_type, matcher) || super::rules::pattern_matches(action_command, matcher)
157 });
158 if matches
160 || (pattern.match_actions.is_empty()
161 && pattern.tags.iter().any(|t| {
162 super::rules::pattern_matches(action_type, t) || super::rules::pattern_matches(action_command, t)
163 }))
164 {
165 matching.push(pattern);
166 }
167 }
168 }
169
170 Ok(matching)
171 }
172
173 pub fn format_context(
175 &self,
176 action_type: &str,
177 action_command: &str,
178 ) -> Result<Option<String>> {
179 let patterns = self.get_matching_patterns(action_type, action_command)?;
180 if patterns.is_empty() {
181 return Ok(None);
182 }
183
184 let mut context = String::from("## Relevant Patterns\n\n");
185 for pattern in &patterns {
186 context.push_str(&format!("### {}\n", pattern.name));
187 context.push_str(&pattern.description);
188 context.push_str("\n\n");
189 }
190
191 Ok(Some(context))
192 }
193
194 pub fn save_pattern(&self, pattern: &Pattern) -> Result<()> {
196 validate_pattern_id(&pattern.id)?;
197 std::fs::create_dir_all(&self.patterns_dir)?;
198 let path = self.patterns_dir.join(format!("{}.yaml", pattern.id));
199 let yaml = serde_yaml::to_string(pattern)?;
200 std::fs::write(path, yaml)?;
201 Ok(())
202 }
203
204 pub fn remove_pattern(&self, id: &str) -> Result<bool> {
206 validate_pattern_id(id)?;
207 let path = self.patterns_dir.join(format!("{}.yaml", id));
208 if path.exists() {
209 std::fs::remove_file(&path)?;
210 return Ok(true);
211 }
212 let path_yml = self.patterns_dir.join(format!("{}.yml", id));
213 if path_yml.exists() {
214 std::fs::remove_file(&path_yml)?;
215 return Ok(true);
216 }
217 Ok(false)
218 }
219}
220
221fn validate_pattern_id(id: &str) -> Result<()> {
223 if id.is_empty() {
224 anyhow::bail!("Pattern ID cannot be empty");
225 }
226 if !id
227 .chars()
228 .all(|c| c.is_alphanumeric() || c == '_' || c == '-')
229 {
230 anyhow::bail!(
231 "Pattern ID must only contain alphanumeric characters, underscores, and hyphens"
232 );
233 }
234 Ok(())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_pattern_parse() {
243 let yaml = r#"
244id: error-handling
245name: Error Handling Pattern
246description: |
247 Always use Result<T> for fallible operations.
248 Log errors with tracing before propagating.
249tags: [rust, error-handling]
250inject: on_match
251match_actions: ["code", "fix", "refactor"]
252"#;
253 let pattern: Pattern = serde_yaml::from_str(yaml).unwrap();
254 assert_eq!(pattern.id, "error-handling");
255 assert_eq!(pattern.tags.len(), 2);
256 assert_eq!(pattern.match_actions.len(), 3);
257 }
258
259 #[test]
260 fn test_load_all_empty() {
261 let dir = tempfile::TempDir::new().unwrap();
262 let store = PatternStore::new(dir.path());
263 let patterns = store.load_all().unwrap();
264 assert!(patterns.is_empty());
265 }
266
267 #[test]
268 fn test_load_all_from_files() {
269 let dir = tempfile::TempDir::new().unwrap();
270
271 let yaml = r#"
273id: test-pattern
274name: Test Pattern
275description: A test pattern for unit tests
276tags: [test]
277inject: always
278"#;
279 std::fs::write(dir.path().join("test-pattern.yaml"), yaml).unwrap();
280
281 let store = PatternStore::new(dir.path());
282 let patterns = store.load_all().unwrap();
283 assert_eq!(patterns.len(), 1);
284 assert_eq!(patterns[0].id, "test-pattern");
285 }
286
287 #[test]
288 fn test_get_matching_always() {
289 let dir = tempfile::TempDir::new().unwrap();
290 let yaml = r#"
291id: always-inject
292name: Always Inject
293description: This always injects
294tags: []
295inject: always
296"#;
297 std::fs::write(dir.path().join("always-inject.yaml"), yaml).unwrap();
298
299 let store = PatternStore::new(dir.path());
300 let matching = store.get_matching_patterns("execute", "cargo build").unwrap();
301 assert_eq!(matching.len(), 1);
302 }
303
304 #[test]
305 fn test_get_matching_on_match() {
306 let dir = tempfile::TempDir::new().unwrap();
307 let yaml = r#"
308id: deploy-pattern
309name: Deploy Safety
310description: Check all services before deploying
311tags: [deploy]
312inject: on_match
313match_actions: ["deploy*"]
314"#;
315 std::fs::write(dir.path().join("deploy-pattern.yaml"), yaml).unwrap();
316
317 let store = PatternStore::new(dir.path());
318
319 let matching = store.get_matching_patterns("execute", "deploy production").unwrap();
320 assert_eq!(matching.len(), 1);
321
322 let no_match = store.get_matching_patterns("read", "git log").unwrap();
323 assert!(no_match.is_empty());
324 }
325
326 #[test]
327 fn test_format_context() {
328 let dir = tempfile::TempDir::new().unwrap();
329 let yaml = r#"
330id: ctx-pattern
331name: Context Pattern
332description: Use this context when coding
333inject: always
334"#;
335 std::fs::write(dir.path().join("ctx-pattern.yaml"), yaml).unwrap();
336
337 let store = PatternStore::new(dir.path());
338 let ctx = store.format_context("code", "write code").unwrap();
339 assert!(ctx.is_some());
340 let text = ctx.unwrap();
341 assert!(text.contains("Context Pattern"));
342 assert!(text.contains("Use this context when coding"));
343 }
344
345 #[test]
346 fn test_save_and_remove_pattern() {
347 let dir = tempfile::TempDir::new().unwrap();
348 let store = PatternStore::new(dir.path());
349
350 let pattern = Pattern {
351 id: "saveable".into(),
352 name: "Saveable".into(),
353 description: "test".into(),
354 tags: vec![],
355 inject: "always".into(),
356 match_actions: vec![],
357 };
358
359 store.save_pattern(&pattern).unwrap();
360 let loaded = store.load_all().unwrap();
361 assert_eq!(loaded.len(), 1);
362
363 assert!(store.remove_pattern("saveable").unwrap());
364 let loaded = store.load_all().unwrap();
365 assert!(loaded.is_empty());
366 }
367
368 #[test]
369 fn test_remove_nonexistent() {
370 let dir = tempfile::TempDir::new().unwrap();
371 let store = PatternStore::new(dir.path());
372 assert!(!store.remove_pattern("nope").unwrap());
373 }
374
375 #[test]
376 fn test_pattern_matches() {
377 use crate::policy::rules::pattern_matches;
378 assert!(pattern_matches("deploy production", "deploy*"));
379 assert!(pattern_matches("run tests", "test"));
380 assert!(!pattern_matches("read files", "deploy"));
381 }
382
383 #[test]
384 fn test_nonexistent_dir() {
385 let store = PatternStore::new(Path::new("/nonexistent/path"));
386 let patterns = store.load_all().unwrap();
387 assert!(patterns.is_empty());
388 }
389}