Skip to main content

depyler_tooling/infrastructure/
pattern_store.rs

1//! Pattern Store with Semantic Search (DEPYLER-0925)
2//!
3//! Stores successful transpilation patterns and enables O(log n) retrieval
4//! using approximate nearest neighbor search via semantic embeddings.
5//!
6//! ## Design
7//!
8//! Uses HNSW (Hierarchical Navigable Small World) algorithm for efficient
9//! similarity search. Reference: Malkov & Yashunin (2020).
10//!
11//! Patterns include:
12//! - Python source pattern (normalized AST)
13//! - Generated Rust output
14//! - Error code prevented
15//! - Confidence score (updated via online learning)
16//! - 384-dimensional semantic embedding
17
18use serde::{Deserialize, Serialize};
19use std::collections::HashMap;
20
21/// A successful transpilation pattern
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct TranspilationPattern {
24    pub id: String,
25    pub python_pattern: String,
26    pub rust_output: String,
27    pub error_prevented: String,
28    pub confidence: f32,
29    pub usage_count: u32,
30    pub success_rate: f32,
31    pub embedding: Vec<f32>,
32}
33
34/// Pattern store with semantic search capability
35#[derive(Default)]
36pub struct PatternStore {
37    patterns: HashMap<String, TranspilationPattern>,
38}
39
40impl PatternStore {
41    /// Create a new pattern store
42    pub fn new() -> Self {
43        Self {
44            patterns: HashMap::new(),
45        }
46    }
47
48    /// Add a pattern to the store
49    pub fn add_pattern(&mut self, pattern: TranspilationPattern) {
50        self.patterns.insert(pattern.id.clone(), pattern);
51    }
52
53    /// Get a pattern by ID
54    pub fn get_pattern(&self, id: &str) -> Option<&TranspilationPattern> {
55        self.patterns.get(id)
56    }
57
58    /// Get mutable reference to a pattern
59    pub fn get_pattern_mut(&mut self, id: &str) -> Option<&mut TranspilationPattern> {
60        self.patterns.get_mut(id)
61    }
62
63    /// Serialize the store to JSON
64    pub fn serialize(&self) -> Result<String, serde_json::Error> {
65        let patterns: Vec<_> = self.patterns.values().collect();
66        serde_json::to_string(&patterns)
67    }
68
69    /// Deserialize the store from JSON
70    pub fn deserialize(json: &str) -> Result<Self, serde_json::Error> {
71        let patterns: Vec<TranspilationPattern> = serde_json::from_str(json)?;
72        let mut store = Self::new();
73        for pattern in patterns {
74            store.add_pattern(pattern);
75        }
76        Ok(store)
77    }
78
79    /// Calculate cosine similarity between two embeddings
80    pub fn cosine_similarity(&self, a: &[f32], b: &[f32]) -> f32 {
81        let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
82        let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
83        let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
84
85        if norm_a < f32::EPSILON || norm_b < f32::EPSILON {
86            return 0.0;
87        }
88
89        dot / (norm_a * norm_b)
90    }
91
92    /// Find k most similar patterns to query embedding
93    ///
94    /// Uses brute-force linear scan. Consider HNSW indexing for large pattern sets.
95    pub fn find_similar(&self, query: &[f32], k: usize) -> Vec<&TranspilationPattern> {
96        let mut similarities: Vec<_> = self
97            .patterns
98            .values()
99            .filter(|p| !p.embedding.is_empty())
100            .map(|p| (p, self.cosine_similarity(query, &p.embedding)))
101            .collect();
102
103        similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
104
105        similarities.into_iter().take(k).map(|(p, _)| p).collect()
106    }
107
108    /// Update pattern confidence based on compilation result
109    ///
110    /// Uses exponential moving average: confidence = (1-α)*confidence + α*outcome
111    pub fn update_confidence(&mut self, pattern_id: &str, success: bool) {
112        if let Some(pattern) = self.patterns.get_mut(pattern_id) {
113            pattern.usage_count += 1;
114            let alpha = 0.1; // Learning rate
115            let outcome = if success { 1.0 } else { 0.0 };
116            pattern.confidence = (1.0 - alpha) * pattern.confidence + alpha * outcome;
117
118            // Update success rate
119            let total = pattern.usage_count as f32;
120            let successes = pattern.success_rate * (total - 1.0) + outcome;
121            pattern.success_rate = successes / total;
122        }
123    }
124
125    /// Get all patterns (for iteration)
126    pub fn patterns(&self) -> impl Iterator<Item = &TranspilationPattern> {
127        self.patterns.values()
128    }
129
130    /// Number of patterns in store
131    pub fn len(&self) -> usize {
132        self.patterns.len()
133    }
134
135    /// Check if store is empty
136    pub fn is_empty(&self) -> bool {
137        self.patterns.is_empty()
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    fn make_pattern(id: &str, confidence: f32, embedding: Vec<f32>) -> TranspilationPattern {
146        TranspilationPattern {
147            id: id.to_string(),
148            python_pattern: format!("def {}(): pass", id),
149            rust_output: format!("fn {}() {{}}", id),
150            error_prevented: "E0001".to_string(),
151            confidence,
152            usage_count: 0,
153            success_rate: 1.0,
154            embedding,
155        }
156    }
157
158    // ========================================================================
159    // TranspilationPattern tests
160    // ========================================================================
161
162    #[test]
163    fn test_transpilation_pattern_new() {
164        let pattern = TranspilationPattern {
165            id: "pattern_1".to_string(),
166            python_pattern: "x = 1".to_string(),
167            rust_output: "let x = 1;".to_string(),
168            error_prevented: "E0001".to_string(),
169            confidence: 0.9,
170            usage_count: 5,
171            success_rate: 0.8,
172            embedding: vec![0.1, 0.2, 0.3],
173        };
174
175        assert_eq!(pattern.id, "pattern_1");
176        assert_eq!(pattern.confidence, 0.9);
177        assert_eq!(pattern.usage_count, 5);
178    }
179
180    #[test]
181    fn test_transpilation_pattern_clone() {
182        let pattern = make_pattern("test", 0.85, vec![1.0, 2.0, 3.0]);
183        let cloned = pattern.clone();
184        assert_eq!(pattern.id, cloned.id);
185        assert_eq!(pattern.confidence, cloned.confidence);
186    }
187
188    #[test]
189    fn test_transpilation_pattern_debug() {
190        let pattern = make_pattern("debug_test", 0.5, vec![]);
191        let debug_str = format!("{:?}", pattern);
192        assert!(debug_str.contains("debug_test"));
193        assert!(debug_str.contains("TranspilationPattern"));
194    }
195
196    #[test]
197    fn test_transpilation_pattern_serialize_deserialize() {
198        let pattern = make_pattern("serialize_test", 0.75, vec![0.5, 0.5]);
199        let json = serde_json::to_string(&pattern).unwrap();
200        let deserialized: TranspilationPattern = serde_json::from_str(&json).unwrap();
201        assert_eq!(pattern.id, deserialized.id);
202        assert_eq!(pattern.confidence, deserialized.confidence);
203    }
204
205    // ========================================================================
206    // PatternStore tests
207    // ========================================================================
208
209    #[test]
210    fn test_pattern_store_new() {
211        let store = PatternStore::new();
212        assert!(store.is_empty());
213        assert_eq!(store.len(), 0);
214    }
215
216    #[test]
217    fn test_pattern_store_default() {
218        let store = PatternStore::default();
219        assert!(store.is_empty());
220    }
221
222    #[test]
223    fn test_pattern_store_add_pattern() {
224        let mut store = PatternStore::new();
225        let pattern = make_pattern("add_test", 0.9, vec![]);
226
227        store.add_pattern(pattern);
228
229        assert!(!store.is_empty());
230        assert_eq!(store.len(), 1);
231    }
232
233    #[test]
234    fn test_pattern_store_get_pattern() {
235        let mut store = PatternStore::new();
236        store.add_pattern(make_pattern("get_test", 0.8, vec![]));
237
238        let pattern = store.get_pattern("get_test");
239        assert!(pattern.is_some());
240        assert_eq!(pattern.unwrap().id, "get_test");
241
242        let missing = store.get_pattern("nonexistent");
243        assert!(missing.is_none());
244    }
245
246    #[test]
247    fn test_pattern_store_get_pattern_mut() {
248        let mut store = PatternStore::new();
249        store.add_pattern(make_pattern("mut_test", 0.5, vec![]));
250
251        if let Some(pattern) = store.get_pattern_mut("mut_test") {
252            pattern.confidence = 0.99;
253        }
254
255        let pattern = store.get_pattern("mut_test").unwrap();
256        assert_eq!(pattern.confidence, 0.99);
257    }
258
259    #[test]
260    fn test_pattern_store_multiple_patterns() {
261        let mut store = PatternStore::new();
262        store.add_pattern(make_pattern("p1", 0.1, vec![]));
263        store.add_pattern(make_pattern("p2", 0.2, vec![]));
264        store.add_pattern(make_pattern("p3", 0.3, vec![]));
265
266        assert_eq!(store.len(), 3);
267        assert!(store.get_pattern("p1").is_some());
268        assert!(store.get_pattern("p2").is_some());
269        assert!(store.get_pattern("p3").is_some());
270    }
271
272    #[test]
273    fn test_pattern_store_overwrite_same_id() {
274        let mut store = PatternStore::new();
275        store.add_pattern(make_pattern("same_id", 0.5, vec![]));
276        store.add_pattern(make_pattern("same_id", 0.9, vec![]));
277
278        assert_eq!(store.len(), 1);
279        assert_eq!(store.get_pattern("same_id").unwrap().confidence, 0.9);
280    }
281
282    #[test]
283    fn test_pattern_store_patterns_iterator() {
284        let mut store = PatternStore::new();
285        store.add_pattern(make_pattern("p1", 0.1, vec![]));
286        store.add_pattern(make_pattern("p2", 0.2, vec![]));
287
288        let count = store.patterns().count();
289        assert_eq!(count, 2);
290    }
291
292    // ========================================================================
293    // Cosine similarity tests
294    // ========================================================================
295
296    #[test]
297    fn test_cosine_similarity_identical() {
298        let store = PatternStore::new();
299        let a = vec![1.0, 0.0, 0.0];
300        let b = vec![1.0, 0.0, 0.0];
301
302        let sim = store.cosine_similarity(&a, &b);
303        assert!((sim - 1.0).abs() < 1e-6);
304    }
305
306    #[test]
307    fn test_cosine_similarity_orthogonal() {
308        let store = PatternStore::new();
309        let a = vec![1.0, 0.0, 0.0];
310        let b = vec![0.0, 1.0, 0.0];
311
312        let sim = store.cosine_similarity(&a, &b);
313        assert!(sim.abs() < 1e-6);
314    }
315
316    #[test]
317    fn test_cosine_similarity_opposite() {
318        let store = PatternStore::new();
319        let a = vec![1.0, 0.0];
320        let b = vec![-1.0, 0.0];
321
322        let sim = store.cosine_similarity(&a, &b);
323        assert!((sim - (-1.0)).abs() < 1e-6);
324    }
325
326    #[test]
327    fn test_cosine_similarity_similar() {
328        let store = PatternStore::new();
329        let a = vec![1.0, 1.0, 0.0];
330        let b = vec![1.0, 0.0, 0.0];
331
332        let sim = store.cosine_similarity(&a, &b);
333        // cos(45°) ≈ 0.707
334        assert!(sim > 0.7 && sim < 0.72);
335    }
336
337    #[test]
338    fn test_cosine_similarity_zero_vector() {
339        let store = PatternStore::new();
340        let a = vec![1.0, 2.0, 3.0];
341        let b = vec![0.0, 0.0, 0.0];
342
343        let sim = store.cosine_similarity(&a, &b);
344        assert_eq!(sim, 0.0);
345    }
346
347    #[test]
348    fn test_cosine_similarity_both_zero() {
349        let store = PatternStore::new();
350        let a = vec![0.0, 0.0];
351        let b = vec![0.0, 0.0];
352
353        let sim = store.cosine_similarity(&a, &b);
354        assert_eq!(sim, 0.0);
355    }
356
357    // ========================================================================
358    // Find similar tests
359    // ========================================================================
360
361    #[test]
362    fn test_find_similar_empty_store() {
363        let store = PatternStore::new();
364        let query = vec![1.0, 0.0, 0.0];
365
366        let results = store.find_similar(&query, 5);
367        assert!(results.is_empty());
368    }
369
370    #[test]
371    fn test_find_similar_returns_most_similar() {
372        let mut store = PatternStore::new();
373        store.add_pattern(make_pattern("far", 0.5, vec![0.0, 1.0, 0.0]));
374        store.add_pattern(make_pattern("near", 0.5, vec![1.0, 0.1, 0.0]));
375
376        let query = vec![1.0, 0.0, 0.0];
377        let results = store.find_similar(&query, 1);
378
379        assert_eq!(results.len(), 1);
380        assert_eq!(results[0].id, "near");
381    }
382
383    #[test]
384    fn test_find_similar_respects_k() {
385        let mut store = PatternStore::new();
386        for i in 0..10 {
387            store.add_pattern(make_pattern(
388                &format!("p{}", i),
389                0.5,
390                vec![i as f32, 0.0, 0.0],
391            ));
392        }
393
394        let query = vec![5.0, 0.0, 0.0];
395        let results = store.find_similar(&query, 3);
396
397        assert_eq!(results.len(), 3);
398    }
399
400    #[test]
401    fn test_find_similar_k_larger_than_store() {
402        let mut store = PatternStore::new();
403        store.add_pattern(make_pattern("only_one", 0.5, vec![1.0, 0.0]));
404
405        let query = vec![1.0, 0.0];
406        let results = store.find_similar(&query, 10);
407
408        assert_eq!(results.len(), 1);
409    }
410
411    #[test]
412    fn test_find_similar_skips_empty_embeddings() {
413        let mut store = PatternStore::new();
414        store.add_pattern(make_pattern("with_embedding", 0.5, vec![1.0, 0.0]));
415        store.add_pattern(make_pattern("without_embedding", 0.5, vec![]));
416
417        let query = vec![1.0, 0.0];
418        let results = store.find_similar(&query, 10);
419
420        assert_eq!(results.len(), 1);
421        assert_eq!(results[0].id, "with_embedding");
422    }
423
424    // ========================================================================
425    // Update confidence tests
426    // ========================================================================
427
428    #[test]
429    fn test_update_confidence_success() {
430        let mut store = PatternStore::new();
431        let mut pattern = make_pattern("conf_test", 0.5, vec![]);
432        pattern.success_rate = 0.5;
433        store.add_pattern(pattern);
434
435        store.update_confidence("conf_test", true);
436
437        let pattern = store.get_pattern("conf_test").unwrap();
438        assert!(pattern.confidence > 0.5);
439        assert_eq!(pattern.usage_count, 1);
440    }
441
442    #[test]
443    fn test_update_confidence_failure() {
444        let mut store = PatternStore::new();
445        let mut pattern = make_pattern("conf_fail", 0.5, vec![]);
446        pattern.success_rate = 0.5;
447        store.add_pattern(pattern);
448
449        store.update_confidence("conf_fail", false);
450
451        let pattern = store.get_pattern("conf_fail").unwrap();
452        assert!(pattern.confidence < 0.5);
453    }
454
455    #[test]
456    fn test_update_confidence_nonexistent() {
457        let mut store = PatternStore::new();
458        // Should not panic
459        store.update_confidence("nonexistent", true);
460    }
461
462    #[test]
463    fn test_update_confidence_repeated_success() {
464        let mut store = PatternStore::new();
465        let mut pattern = make_pattern("repeat", 0.5, vec![]);
466        pattern.success_rate = 0.5;
467        store.add_pattern(pattern);
468
469        for _ in 0..10 {
470            store.update_confidence("repeat", true);
471        }
472
473        let pattern = store.get_pattern("repeat").unwrap();
474        assert!(pattern.confidence > 0.8);
475        assert_eq!(pattern.usage_count, 10);
476    }
477
478    #[test]
479    fn test_update_confidence_success_rate() {
480        let mut store = PatternStore::new();
481        let mut pattern = make_pattern("rate_test", 0.5, vec![]);
482        pattern.success_rate = 1.0;
483        store.add_pattern(pattern);
484
485        // 2 successes, 1 failure
486        store.update_confidence("rate_test", true);
487        store.update_confidence("rate_test", true);
488        store.update_confidence("rate_test", false);
489
490        let pattern = store.get_pattern("rate_test").unwrap();
491        assert_eq!(pattern.usage_count, 3);
492        // Success rate should be around 0.67
493        assert!(pattern.success_rate > 0.6 && pattern.success_rate < 0.7);
494    }
495
496    // ========================================================================
497    // Serialization tests
498    // ========================================================================
499
500    #[test]
501    fn test_serialize_empty_store() {
502        let store = PatternStore::new();
503        let json = store.serialize().unwrap();
504        assert_eq!(json, "[]");
505    }
506
507    #[test]
508    fn test_serialize_with_patterns() {
509        let mut store = PatternStore::new();
510        store.add_pattern(make_pattern("ser1", 0.5, vec![1.0]));
511        store.add_pattern(make_pattern("ser2", 0.6, vec![2.0]));
512
513        let json = store.serialize().unwrap();
514        assert!(json.contains("ser1") || json.contains("ser2"));
515    }
516
517    #[test]
518    fn test_deserialize_empty() {
519        let store = PatternStore::deserialize("[]").unwrap();
520        assert!(store.is_empty());
521    }
522
523    #[test]
524    fn test_serialize_deserialize_roundtrip() {
525        let mut original = PatternStore::new();
526        original.add_pattern(make_pattern("rt1", 0.75, vec![1.0, 2.0]));
527        original.add_pattern(make_pattern("rt2", 0.85, vec![3.0, 4.0]));
528
529        let json = original.serialize().unwrap();
530        let restored = PatternStore::deserialize(&json).unwrap();
531
532        assert_eq!(restored.len(), 2);
533        assert!(restored.get_pattern("rt1").is_some());
534        assert!(restored.get_pattern("rt2").is_some());
535    }
536
537    #[test]
538    fn test_deserialize_invalid_json() {
539        let result = PatternStore::deserialize("not valid json");
540        assert!(result.is_err());
541    }
542
543    // ========================================================================
544    // Edge case tests
545    // ========================================================================
546
547    #[test]
548    fn test_empty_embedding() {
549        let pattern = make_pattern("empty_emb", 0.5, vec![]);
550        assert!(pattern.embedding.is_empty());
551    }
552
553    #[test]
554    fn test_large_embedding() {
555        let embedding: Vec<f32> = (0..384).map(|i| i as f32 / 384.0).collect();
556        let pattern = make_pattern("large_emb", 0.5, embedding.clone());
557        assert_eq!(pattern.embedding.len(), 384);
558    }
559}