Skip to main content

do_memory_core/pattern/
mod.rs

1//! Pattern extraction and management
2//!
3//! This module provides types and functions for working with patterns extracted from episodes.
4
5mod heuristic;
6mod similarity;
7mod types;
8
9pub use heuristic::Heuristic;
10pub use types::{Pattern, PatternEffectiveness};
11
12use crate::types::TaskContext;
13use chrono::Duration;
14
15impl Pattern {
16    /// Check if this pattern is relevant to a given context
17    #[must_use]
18    pub fn is_relevant_to(&self, query_context: &TaskContext) -> bool {
19        if let Some(pattern_context) = self.context() {
20            // Match on domain
21            if pattern_context.domain == query_context.domain {
22                return true;
23            }
24
25            // Match on language
26            if pattern_context.language == query_context.language
27                && pattern_context.language.is_some()
28            {
29                return true;
30            }
31
32            // Match on tags
33            let common_tags: Vec<_> = pattern_context
34                .tags
35                .iter()
36                .filter(|t| query_context.tags.contains(t))
37                .collect();
38
39            if !common_tags.is_empty() {
40                return true;
41            }
42        }
43
44        false
45    }
46
47    /// Get a similarity key for pattern deduplication
48    /// Patterns with identical keys are considered duplicates
49    #[must_use]
50    pub fn similarity_key(&self) -> String {
51        match self {
52            Pattern::ToolSequence { tools, context, .. } => {
53                format!("tool_seq:{}:{}", tools.join(","), context.domain)
54            }
55            Pattern::DecisionPoint {
56                condition,
57                action,
58                context,
59                ..
60            } => {
61                format!("decision:{}:{}:{}", condition, action, context.domain)
62            }
63            Pattern::ErrorRecovery {
64                error_type,
65                recovery_steps,
66                context,
67                ..
68            } => {
69                format!(
70                    "error_recovery:{}:{}:{}",
71                    error_type,
72                    recovery_steps.join(","),
73                    context.domain
74                )
75            }
76            Pattern::ContextPattern {
77                context_features,
78                recommended_approach,
79                ..
80            } => {
81                format!(
82                    "context:{}:{}",
83                    context_features.join(","),
84                    recommended_approach
85                )
86            }
87        }
88    }
89
90    /// Calculate similarity score between this pattern and another (0.0 to 1.0)
91    /// Uses edit distance for sequences and context matching
92    #[must_use]
93    pub fn similarity_score(&self, other: &Self) -> f32 {
94        // Different pattern types have zero similarity
95        if std::mem::discriminant(self) != std::mem::discriminant(other) {
96            return 0.0;
97        }
98
99        match (self, other) {
100            (
101                Pattern::ToolSequence {
102                    tools: tools1,
103                    context: ctx1,
104                    ..
105                },
106                Pattern::ToolSequence {
107                    tools: tools2,
108                    context: ctx2,
109                    ..
110                },
111            ) => {
112                let sequence_similarity = similarity::sequence_similarity(tools1, tools2);
113                let context_similarity = similarity::context_similarity(ctx1, ctx2);
114                // Weight: 70% sequence, 30% context
115                sequence_similarity * 0.7 + context_similarity * 0.3
116            }
117            (
118                Pattern::DecisionPoint {
119                    condition: cond1,
120                    action: act1,
121                    context: ctx1,
122                    ..
123                },
124                Pattern::DecisionPoint {
125                    condition: cond2,
126                    action: act2,
127                    context: ctx2,
128                    ..
129                },
130            ) => {
131                let condition_sim = similarity::string_similarity(cond1, cond2);
132                let action_sim = similarity::string_similarity(act1, act2);
133                let context_sim = similarity::context_similarity(ctx1, ctx2);
134                // Weight: 40% condition, 40% action, 20% context
135                condition_sim * 0.4 + action_sim * 0.4 + context_sim * 0.2
136            }
137            (
138                Pattern::ErrorRecovery {
139                    error_type: err1,
140                    recovery_steps: steps1,
141                    context: ctx1,
142                    ..
143                },
144                Pattern::ErrorRecovery {
145                    error_type: err2,
146                    recovery_steps: steps2,
147                    context: ctx2,
148                    ..
149                },
150            ) => {
151                let error_sim = similarity::string_similarity(err1, err2);
152                let steps_sim = similarity::sequence_similarity(steps1, steps2);
153                let context_sim = similarity::context_similarity(ctx1, ctx2);
154                // Weight: 40% error type, 40% recovery steps, 20% context
155                error_sim * 0.4 + steps_sim * 0.4 + context_sim * 0.2
156            }
157            (
158                Pattern::ContextPattern {
159                    context_features: feat1,
160                    recommended_approach: rec1,
161                    ..
162                },
163                Pattern::ContextPattern {
164                    context_features: feat2,
165                    recommended_approach: rec2,
166                    ..
167                },
168            ) => {
169                let features_sim = similarity::sequence_similarity(feat1, feat2);
170                let approach_sim = similarity::string_similarity(rec1, rec2);
171                // Weight: 60% features, 40% approach
172                features_sim * 0.6 + approach_sim * 0.4
173            }
174            _ => 0.0,
175        }
176    }
177
178    /// Calculate confidence score for this pattern
179    /// Confidence = `success_rate` * `sqrt(sample_size)`
180    #[must_use]
181    pub fn confidence(&self) -> f32 {
182        let success_rate = self.success_rate();
183        let sample_size = self.sample_size() as f32;
184
185        if sample_size == 0.0 {
186            return 0.0;
187        }
188
189        success_rate * sample_size.sqrt()
190    }
191
192    /// Merge this pattern with another similar pattern
193    /// Combines evidence and updates statistics
194    pub fn merge_with(&mut self, other: &Self) {
195        // Can only merge patterns of the same type
196        if std::mem::discriminant(self) != std::mem::discriminant(other) {
197            return;
198        }
199
200        match (self, other) {
201            (
202                Pattern::ToolSequence {
203                    success_rate: sr1,
204                    occurrence_count: oc1,
205                    avg_latency: lat1,
206                    ..
207                },
208                Pattern::ToolSequence {
209                    success_rate: sr2,
210                    occurrence_count: oc2,
211                    avg_latency: lat2,
212                    ..
213                },
214            ) => {
215                let total_count = *oc1 + *oc2;
216                // Weighted average of success rates
217                *sr1 = (*sr1 * *oc1 as f32 + *sr2 * *oc2 as f32) / total_count as f32;
218                // Weighted average of latencies
219                *lat1 = Duration::milliseconds(
220                    (lat1.num_milliseconds() * *oc1 as i64 + lat2.num_milliseconds() * *oc2 as i64)
221                        / total_count as i64,
222                );
223                *oc1 = total_count;
224            }
225            (
226                Pattern::DecisionPoint {
227                    outcome_stats: stats1,
228                    ..
229                },
230                Pattern::DecisionPoint {
231                    outcome_stats: stats2,
232                    ..
233                },
234            ) => {
235                stats1.success_count += stats2.success_count;
236                stats1.failure_count += stats2.failure_count;
237                stats1.total_count += stats2.total_count;
238                // Weighted average of durations
239                stats1.avg_duration_secs = (stats1.avg_duration_secs
240                    * (stats1.total_count - stats2.total_count) as f32
241                    + stats2.avg_duration_secs * stats2.total_count as f32)
242                    / stats1.total_count as f32;
243            }
244            (
245                Pattern::ErrorRecovery {
246                    success_rate: sr1, ..
247                },
248                Pattern::ErrorRecovery {
249                    success_rate: sr2, ..
250                },
251            ) => {
252                // Simple average for error recovery patterns
253                *sr1 = (*sr1 + *sr2) / 2.0;
254                // Keep the richer context (more tags)
255                // Context is already part of self
256            }
257            (
258                Pattern::ContextPattern {
259                    evidence: ev1,
260                    success_rate: sr1,
261                    ..
262                },
263                Pattern::ContextPattern {
264                    evidence: ev2,
265                    success_rate: sr2,
266                    ..
267                },
268            ) => {
269                let size1 = ev1.len();
270                let size2 = ev2.len();
271                // Combine evidence
272                ev1.extend_from_slice(ev2);
273                // Weighted average of success rates
274                *sr1 = (*sr1 * size1 as f32 + *sr2 * size2 as f32) / (size1 + size2) as f32;
275            }
276            _ => {}
277        }
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284    use crate::types::ComplexityLevel;
285    use uuid::Uuid;
286
287    #[test]
288    fn test_pattern_id() {
289        let pattern = Pattern::ToolSequence {
290            id: Uuid::new_v4(),
291            tools: vec!["tool1".to_string(), "tool2".to_string()],
292            context: TaskContext::default(),
293            success_rate: 0.9,
294            avg_latency: Duration::milliseconds(100),
295            occurrence_count: 5,
296            effectiveness: PatternEffectiveness::new(),
297        };
298
299        assert!(pattern.success_rate() > 0.8);
300        assert!(pattern.context().is_some());
301    }
302
303    #[test]
304    fn test_pattern_similarity_key() {
305        let pattern1 = Pattern::ToolSequence {
306            id: Uuid::new_v4(),
307            tools: vec!["read".to_string(), "write".to_string()],
308            context: TaskContext {
309                domain: "web-api".to_string(),
310                ..Default::default()
311            },
312            success_rate: 0.9,
313            avg_latency: Duration::milliseconds(100),
314            occurrence_count: 5,
315            effectiveness: PatternEffectiveness::new(),
316        };
317
318        let pattern2 = Pattern::ToolSequence {
319            id: Uuid::new_v4(),
320            tools: vec!["read".to_string(), "write".to_string()],
321            context: TaskContext {
322                domain: "web-api".to_string(),
323                ..Default::default()
324            },
325            success_rate: 0.8,
326            avg_latency: Duration::milliseconds(120),
327            occurrence_count: 3,
328            effectiveness: PatternEffectiveness::new(),
329        };
330
331        // Same tools and domain = same key
332        assert_eq!(pattern1.similarity_key(), pattern2.similarity_key());
333    }
334
335    #[test]
336    fn test_pattern_similarity_score() {
337        let pattern1 = Pattern::ToolSequence {
338            id: Uuid::new_v4(),
339            tools: vec!["read".to_string(), "write".to_string()],
340            context: TaskContext {
341                domain: "web-api".to_string(),
342                language: Some("rust".to_string()),
343                ..Default::default()
344            },
345            success_rate: 0.9,
346            avg_latency: Duration::milliseconds(100),
347            occurrence_count: 5,
348            effectiveness: PatternEffectiveness::new(),
349        };
350
351        let pattern2 = Pattern::ToolSequence {
352            id: Uuid::new_v4(),
353            tools: vec!["read".to_string(), "write".to_string()],
354            context: TaskContext {
355                domain: "web-api".to_string(),
356                language: Some("rust".to_string()),
357                ..Default::default()
358            },
359            success_rate: 0.8,
360            avg_latency: Duration::milliseconds(120),
361            occurrence_count: 3,
362            effectiveness: PatternEffectiveness::new(),
363        };
364
365        let similarity = pattern1.similarity_score(&pattern2);
366
367        // Identical tools and context should have high similarity
368        assert!(similarity > 0.9);
369    }
370
371    #[test]
372    fn test_pattern_confidence() {
373        let pattern = Pattern::ToolSequence {
374            id: Uuid::new_v4(),
375            tools: vec!["tool1".to_string()],
376            context: TaskContext::default(),
377            success_rate: 0.8,
378            avg_latency: Duration::milliseconds(100),
379            occurrence_count: 16, // sqrt(16) = 4
380            effectiveness: PatternEffectiveness::new(),
381        };
382
383        let confidence = pattern.confidence();
384
385        // 0.8 * sqrt(16) = 0.8 * 4 = 3.2
386        assert!((confidence - 3.2).abs() < 0.01);
387    }
388
389    #[test]
390    fn test_pattern_merge() {
391        let mut pattern1 = Pattern::ToolSequence {
392            id: Uuid::new_v4(),
393            tools: vec!["read".to_string(), "write".to_string()],
394            context: TaskContext::default(),
395            success_rate: 0.8,
396            avg_latency: Duration::milliseconds(100),
397            occurrence_count: 10,
398            effectiveness: PatternEffectiveness::new(),
399        };
400
401        let pattern2 = Pattern::ToolSequence {
402            id: Uuid::new_v4(),
403            tools: vec!["read".to_string(), "write".to_string()],
404            context: TaskContext::default(),
405            success_rate: 0.9,
406            avg_latency: Duration::milliseconds(200),
407            occurrence_count: 10,
408            effectiveness: PatternEffectiveness::new(),
409        };
410
411        pattern1.merge_with(&pattern2);
412
413        // Should have combined occurrence count
414        match pattern1 {
415            Pattern::ToolSequence {
416                occurrence_count,
417                success_rate,
418                ..
419            } => {
420                assert_eq!(occurrence_count, 20);
421                // Average: (0.8 * 10 + 0.9 * 10) / 20 = 0.85
422                assert!((success_rate - 0.85).abs() < 0.01);
423            }
424            _ => panic!("Expected ToolSequence"),
425        }
426    }
427
428    #[test]
429    fn test_pattern_relevance() {
430        let pattern_context = TaskContext {
431            language: Some("rust".to_string()),
432            framework: None,
433            complexity: ComplexityLevel::Moderate,
434            domain: "web-api".to_string(),
435            tags: vec!["async".to_string()],
436        };
437
438        let pattern = Pattern::ToolSequence {
439            id: Uuid::new_v4(),
440            tools: vec![],
441            context: pattern_context.clone(),
442            success_rate: 0.9,
443            avg_latency: Duration::milliseconds(100),
444            occurrence_count: 1,
445            effectiveness: PatternEffectiveness::new(),
446        };
447
448        // Should match on domain
449        let query_context = TaskContext {
450            domain: "web-api".to_string(),
451            ..Default::default()
452        };
453        assert!(pattern.is_relevant_to(&query_context));
454
455        // Should match on language
456        let query_context2 = TaskContext {
457            language: Some("rust".to_string()),
458            domain: "cli".to_string(),
459            ..Default::default()
460        };
461        assert!(pattern.is_relevant_to(&query_context2));
462
463        // Should not match
464        let query_context3 = TaskContext {
465            language: Some("python".to_string()),
466            domain: "data-science".to_string(),
467            ..Default::default()
468        };
469        assert!(!pattern.is_relevant_to(&query_context3));
470    }
471
472    #[test]
473    fn test_heuristic_evidence_update() {
474        let mut heuristic = Heuristic::new(
475            "When refactoring async code".to_string(),
476            "Use tokio::spawn for CPU-intensive tasks".to_string(),
477            0.7,
478        );
479
480        assert_eq!(heuristic.evidence.sample_size, 0);
481
482        // Add successful evidence
483        heuristic.update_evidence(Uuid::new_v4(), true);
484        assert_eq!(heuristic.evidence.sample_size, 1);
485        assert_eq!(heuristic.evidence.success_rate, 1.0);
486
487        // Add failed evidence
488        heuristic.update_evidence(Uuid::new_v4(), false);
489        assert_eq!(heuristic.evidence.sample_size, 2);
490        assert_eq!(heuristic.evidence.success_rate, 0.5);
491
492        // Add more successful evidence
493        heuristic.update_evidence(Uuid::new_v4(), true);
494        assert_eq!(heuristic.evidence.sample_size, 3);
495        assert!((heuristic.evidence.success_rate - 0.666).abs() < 0.01);
496    }
497}