Skip to main content

memscope_rs/analysis/classification/
pattern_matcher.rs

1use crate::core::{MemScopeError, MemScopeResult};
2use regex::Regex;
3use std::collections::HashMap;
4
5/// Advanced pattern matcher for type names with fuzzy matching capabilities
6pub struct PatternMatcher {
7    patterns: Vec<CompiledPattern>,
8    fuzzy_threshold: f64,
9    cache: std::sync::Mutex<HashMap<String, Vec<PatternMatch>>>,
10}
11
12/// A compiled pattern with metadata
13#[derive(Debug, Clone)]
14pub struct CompiledPattern {
15    id: String,
16    regex: Regex,
17    weight: f64,
18    tags: Vec<String>,
19    description: String,
20}
21
22impl CompiledPattern {
23    pub fn description(&self) -> &str {
24        &self.description
25    }
26
27    pub fn tags(&self) -> &[String] {
28        &self.tags
29    }
30}
31
32/// Result of a pattern match
33#[derive(Debug, Clone)]
34pub struct PatternMatch {
35    pub pattern_id: String,
36    pub score: f64,
37    pub match_type: MatchType,
38    pub captured_groups: Vec<String>,
39    pub position: (usize, usize),
40}
41
42/// Type of match that occurred
43#[derive(Debug, Clone, PartialEq)]
44pub enum MatchType {
45    Exact,
46    Partial,
47    Fuzzy,
48    Substring,
49    Prefix,
50    Suffix,
51}
52
53impl PatternMatcher {
54    /// Create a new pattern matcher
55    pub fn new() -> Self {
56        Self {
57            patterns: Vec::new(),
58            fuzzy_threshold: 0.7,
59            cache: std::sync::Mutex::new(HashMap::new()),
60        }
61    }
62
63    /// Add a pattern to the matcher
64    pub fn add_pattern(
65        &mut self,
66        id: &str,
67        pattern: &str,
68        weight: f64,
69        description: &str,
70    ) -> MemScopeResult<()> {
71        let regex = Regex::new(pattern).map_err(|e| {
72            MemScopeError::error(
73                "pattern_matcher",
74                "add_pattern",
75                format!("Invalid pattern '{}': {}", pattern, e),
76            )
77        })?;
78
79        let compiled = CompiledPattern {
80            id: id.to_string(),
81            regex,
82            weight,
83            tags: Vec::new(),
84            description: description.to_string(),
85        };
86
87        self.patterns.push(compiled);
88        self.clear_cache();
89        Ok(())
90    }
91
92    /// Add a pattern with tags
93    pub fn add_pattern_with_tags(
94        &mut self,
95        id: &str,
96        pattern: &str,
97        weight: f64,
98        description: &str,
99        tags: Vec<String>,
100    ) -> MemScopeResult<()> {
101        let regex = Regex::new(pattern).map_err(|e| {
102            MemScopeError::error(
103                "pattern_matcher",
104                "add_pattern_with_tags",
105                format!("Invalid pattern '{}': {}", pattern, e),
106            )
107        })?;
108
109        let compiled = CompiledPattern {
110            id: id.to_string(),
111            regex,
112            weight,
113            tags,
114            description: description.to_string(),
115        };
116
117        self.patterns.push(compiled);
118        self.clear_cache();
119        Ok(())
120    }
121
122    /// Find all matches for a given input
123    pub fn find_matches(&self, input: &str) -> Vec<PatternMatch> {
124        // Check cache first
125        if let Ok(cache) = self.cache.lock() {
126            if let Some(cached_matches) = cache.get(input) {
127                return cached_matches.clone();
128            }
129        }
130
131        let mut matches = Vec::new();
132
133        // Test each pattern
134        for pattern in &self.patterns {
135            if let Some(pattern_match) = self.test_pattern(pattern, input) {
136                matches.push(pattern_match);
137            }
138        }
139
140        // Sort by score (descending)
141        matches.sort_by(|a, b| {
142            b.score
143                .partial_cmp(&a.score)
144                .unwrap_or(std::cmp::Ordering::Equal)
145        });
146
147        // Cache the results
148        if let Ok(mut cache) = self.cache.lock() {
149            cache.insert(input.to_string(), matches.clone());
150        }
151
152        matches
153    }
154
155    /// Find the best match for a given input
156    pub fn find_best_match(&self, input: &str) -> Option<PatternMatch> {
157        self.find_matches(input).into_iter().next()
158    }
159
160    /// Find matches by tags
161    pub fn find_matches_by_tag(&self, input: &str, tag: &str) -> Vec<PatternMatch> {
162        let all_matches = self.find_matches(input);
163        all_matches
164            .into_iter()
165            .filter(|m| {
166                if let Some(pattern) = self.patterns.iter().find(|p| p.id == m.pattern_id) {
167                    pattern.tags.contains(&tag.to_string())
168                } else {
169                    false
170                }
171            })
172            .collect()
173    }
174
175    /// Test a single pattern against input
176    fn test_pattern(&self, pattern: &CompiledPattern, input: &str) -> Option<PatternMatch> {
177        // Try exact regex match first
178        if let Some(regex_match) = pattern.regex.find(input) {
179            let captured_groups = pattern
180                .regex
181                .captures(input)
182                .map(|caps| {
183                    caps.iter()
184                        .skip(1)
185                        .filter_map(|m| m.map(|m| m.as_str().to_string()))
186                        .collect()
187                })
188                .unwrap_or_default();
189
190            let match_type = if regex_match.start() == 0 && regex_match.end() == input.len() {
191                MatchType::Exact
192            } else if regex_match.start() == 0 {
193                MatchType::Prefix
194            } else if regex_match.end() == input.len() {
195                MatchType::Suffix
196            } else {
197                MatchType::Partial
198            };
199
200            let score = self.calculate_score(pattern, input, &regex_match, &match_type);
201
202            return Some(PatternMatch {
203                pattern_id: pattern.id.clone(),
204                score,
205                match_type,
206                captured_groups,
207                position: (regex_match.start(), regex_match.end()),
208            });
209        }
210
211        // Try fuzzy matching if enabled
212        if self.fuzzy_threshold > 0.0 {
213            if let Some(fuzzy_match) = self.fuzzy_match(pattern, input) {
214                return Some(fuzzy_match);
215            }
216        }
217
218        None
219    }
220
221    /// Perform fuzzy matching
222    fn fuzzy_match(&self, pattern: &CompiledPattern, input: &str) -> Option<PatternMatch> {
223        // Simple fuzzy matching based on edit distance
224        let pattern_str = pattern.regex.as_str();
225
226        // Remove regex special characters for fuzzy matching
227        let clean_pattern = self.clean_pattern_for_fuzzy(pattern_str);
228        let similarity = self.calculate_similarity(&clean_pattern, input);
229
230        if similarity >= self.fuzzy_threshold {
231            Some(PatternMatch {
232                pattern_id: pattern.id.clone(),
233                score: similarity * pattern.weight * 0.8, // Fuzzy matches get lower score
234                match_type: MatchType::Fuzzy,
235                captured_groups: Vec::new(),
236                position: (0, input.len()),
237            })
238        } else {
239            None
240        }
241    }
242
243    /// Calculate match score
244    fn calculate_score(
245        &self,
246        pattern: &CompiledPattern,
247        input: &str,
248        regex_match: &regex::Match,
249        match_type: &MatchType,
250    ) -> f64 {
251        let mut score = pattern.weight;
252
253        // Bonus for match type
254        let type_bonus = match match_type {
255            MatchType::Exact => 1.0,
256            MatchType::Prefix => 0.9,
257            MatchType::Suffix => 0.8,
258            MatchType::Partial => 0.7,
259            MatchType::Substring => 0.6,
260            MatchType::Fuzzy => 0.5,
261        };
262        score *= type_bonus;
263
264        // Bonus for coverage
265        let coverage = regex_match.len() as f64 / input.len() as f64;
266        score *= 0.5 + coverage * 0.5;
267
268        // Bonus for position (earlier matches are better)
269        let position_bonus = 1.0 - (regex_match.start() as f64 / input.len() as f64) * 0.1;
270        score *= position_bonus;
271
272        score.min(1.0)
273    }
274
275    /// Clean regex pattern for fuzzy matching
276    fn clean_pattern_for_fuzzy(&self, pattern: &str) -> String {
277        // Remove common regex special characters
278        pattern
279            .replace("^", "")
280            .replace("$", "")
281            .replace("\\", "")
282            .replace(".*", "")
283            .replace(".+", "")
284            .replace("?", "")
285            .replace("*", "")
286            .replace("+", "")
287            .replace("(", "")
288            .replace(")", "")
289            .replace("[", "")
290            .replace("]", "")
291            .replace("{", "")
292            .replace("}", "")
293            .replace("|", "")
294    }
295
296    /// Calculate string similarity using Levenshtein distance
297    fn calculate_similarity(&self, s1: &str, s2: &str) -> f64 {
298        let len1 = s1.chars().count();
299        let len2 = s2.chars().count();
300
301        if len1 == 0 {
302            return if len2 == 0 { 1.0 } else { 0.0 };
303        }
304        if len2 == 0 {
305            return 0.0;
306        }
307
308        let s1_chars: Vec<char> = s1.chars().collect();
309        let s2_chars: Vec<char> = s2.chars().collect();
310
311        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
312
313        // Initialize first row and column
314        for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
315            row[0] = i;
316        }
317        for (j, row) in matrix[0].iter_mut().enumerate().take(len2 + 1) {
318            *row = j;
319        }
320
321        // Fill the matrix
322        for i in 1..=len1 {
323            for j in 1..=len2 {
324                let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
325                    0
326                } else {
327                    1
328                };
329                matrix[i][j] = std::cmp::min(
330                    std::cmp::min(
331                        matrix[i - 1][j] + 1, // deletion
332                        matrix[i][j - 1] + 1, // insertion
333                    ),
334                    matrix[i - 1][j - 1] + cost, // substitution
335                );
336            }
337        }
338
339        let distance = matrix[len1][len2];
340        let max_len = std::cmp::max(len1, len2);
341        1.0 - (distance as f64 / max_len as f64)
342    }
343
344    /// Set fuzzy matching threshold
345    pub fn set_fuzzy_threshold(&mut self, threshold: f64) {
346        self.fuzzy_threshold = threshold.clamp(0.0, 1.0);
347        self.clear_cache();
348    }
349
350    /// Get fuzzy matching threshold
351    pub fn get_fuzzy_threshold(&self) -> f64 {
352        self.fuzzy_threshold
353    }
354
355    /// Clear the match cache
356    pub fn clear_cache(&self) {
357        if let Ok(mut cache) = self.cache.lock() {
358            cache.clear();
359        }
360    }
361
362    /// Get pattern statistics
363    pub fn get_stats(&self) -> MemScopeResult<PatternMatcherStats> {
364        let cache = self.cache.lock().map_err(|e| {
365            MemScopeError::system(
366                crate::core::error::SystemErrorType::Locking,
367                format!("Failed to acquire pattern cache lock: {}", e),
368            )
369        })?;
370        let total_patterns = self.patterns.len();
371        let cached_inputs = cache.len();
372
373        let mut tag_distribution = HashMap::new();
374        for pattern in &self.patterns {
375            for tag in &pattern.tags {
376                *tag_distribution.entry(tag.clone()).or_insert(0) += 1;
377            }
378        }
379
380        Ok(PatternMatcherStats {
381            total_patterns,
382            cached_inputs,
383            fuzzy_threshold: self.fuzzy_threshold,
384            tag_distribution,
385        })
386    }
387
388    /// Get all pattern IDs
389    pub fn get_pattern_ids(&self) -> Vec<String> {
390        self.patterns.iter().map(|p| p.id.clone()).collect()
391    }
392
393    /// Get pattern by ID
394    pub fn get_pattern(&self, id: &str) -> Option<&CompiledPattern> {
395        self.patterns.iter().find(|p| p.id == id)
396    }
397
398    /// Remove pattern by ID
399    pub fn remove_pattern(&mut self, id: &str) -> bool {
400        let initial_len = self.patterns.len();
401        self.patterns.retain(|p| p.id != id);
402        let removed = self.patterns.len() != initial_len;
403        if removed {
404            self.clear_cache();
405        }
406        removed
407    }
408}
409
410impl Default for PatternMatcher {
411    fn default() -> Self {
412        Self::new()
413    }
414}
415
416/// Statistics about the pattern matcher
417#[derive(Debug, Clone)]
418pub struct PatternMatcherStats {
419    pub total_patterns: usize,
420    pub cached_inputs: usize,
421    pub fuzzy_threshold: f64,
422    pub tag_distribution: HashMap<String, usize>,
423}
424
425/// Builder for creating pattern matchers with common patterns
426pub struct PatternMatcherBuilder {
427    matcher: PatternMatcher,
428}
429
430impl PatternMatcherBuilder {
431    pub fn new() -> Self {
432        Self {
433            matcher: PatternMatcher::new(),
434        }
435    }
436
437    /// Add common Rust type patterns
438    pub fn with_rust_patterns(mut self) -> MemScopeResult<Self> {
439        // Primitive types
440        self.matcher.add_pattern_with_tags(
441            "primitives",
442            r"^(i8|i16|i32|i64|i128|isize|u8|u16|u32|u64|u128|usize|f32|f64|bool|char)$",
443            1.0,
444            "Rust primitive types",
445            vec!["rust".to_string(), "primitive".to_string()],
446        )?;
447
448        // String types
449        self.matcher.add_pattern_with_tags(
450            "strings",
451            r"^(String|&str|str)$",
452            1.0,
453            "Rust string types",
454            vec!["rust".to_string(), "string".to_string()],
455        )?;
456
457        // Collections
458        self.matcher.add_pattern_with_tags(
459            "collections",
460            r"^(Vec|HashMap|BTreeMap|HashSet|BTreeSet|VecDeque|LinkedList)<",
461            0.9,
462            "Rust collection types",
463            vec!["rust".to_string(), "collection".to_string()],
464        )?;
465
466        // Smart pointers
467        self.matcher.add_pattern_with_tags(
468            "smart_pointers",
469            r"^(Box|Arc|Rc|Weak)<",
470            0.9,
471            "Rust smart pointer types",
472            vec!["rust".to_string(), "smart_pointer".to_string()],
473        )?;
474
475        Ok(self)
476    }
477
478    /// Set fuzzy threshold
479    pub fn fuzzy_threshold(mut self, threshold: f64) -> Self {
480        self.matcher.set_fuzzy_threshold(threshold);
481        self
482    }
483
484    /// Build the pattern matcher
485    pub fn build(self) -> PatternMatcher {
486        self.matcher
487    }
488}
489
490impl Default for PatternMatcherBuilder {
491    fn default() -> Self {
492        Self::new()
493    }
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_exact_match() {
502        let mut matcher = PatternMatcher::new();
503        matcher
504            .add_pattern("vec", r"^Vec<", 1.0, "Vector pattern")
505            .unwrap();
506
507        let matches = matcher.find_matches("Vec<i32>");
508        assert_eq!(matches.len(), 1);
509        assert_eq!(matches[0].match_type, MatchType::Prefix);
510    }
511
512    #[test]
513    fn test_fuzzy_matching() {
514        let mut matcher = PatternMatcher::new();
515        matcher.set_fuzzy_threshold(0.6);
516        matcher
517            .add_pattern("vector", r"Vector", 1.0, "Vector pattern")
518            .unwrap();
519
520        let matches = matcher.find_matches("Vektor"); // Typo
521        assert_eq!(matches.len(), 1);
522        assert_eq!(matches[0].match_type, MatchType::Fuzzy);
523    }
524
525    #[test]
526    fn test_pattern_with_tags() {
527        let mut matcher = PatternMatcher::new();
528        matcher
529            .add_pattern_with_tags(
530                "rust_vec",
531                r"^Vec<",
532                1.0,
533                "Rust vector",
534                vec!["rust".to_string(), "collection".to_string()],
535            )
536            .unwrap();
537
538        let matches = matcher.find_matches_by_tag("Vec<i32>", "rust");
539        assert_eq!(matches.len(), 1);
540
541        let matches = matcher.find_matches_by_tag("Vec<i32>", "java");
542        assert_eq!(matches.len(), 0);
543    }
544
545    #[test]
546    fn test_builder_with_rust_patterns() {
547        let matcher = PatternMatcherBuilder::new()
548            .with_rust_patterns()
549            .unwrap()
550            .fuzzy_threshold(0.8)
551            .build();
552
553        let matches = matcher.find_matches("Vec<i32>");
554        assert!(!matches.is_empty());
555
556        let matches = matcher.find_matches("i32");
557        assert!(!matches.is_empty());
558    }
559
560    #[test]
561    fn test_similarity_calculation() {
562        let matcher = PatternMatcher::new();
563
564        assert_eq!(matcher.calculate_similarity("hello", "hello"), 1.0);
565        assert_eq!(matcher.calculate_similarity("hello", ""), 0.0);
566        assert_eq!(matcher.calculate_similarity("", "hello"), 0.0);
567        assert_eq!(matcher.calculate_similarity("", ""), 1.0);
568
569        let sim = matcher.calculate_similarity("hello", "hallo");
570        assert!(sim > 0.5 && sim < 1.0);
571    }
572
573    #[test]
574    fn test_cache_functionality() {
575        let mut matcher = PatternMatcher::new();
576        matcher
577            .add_pattern("test", r"test", 1.0, "Test pattern")
578            .unwrap();
579
580        // First call
581        let matches1 = matcher.find_matches("test");
582
583        // Second call should use cache
584        let matches2 = matcher.find_matches("test");
585
586        assert_eq!(matches1.len(), matches2.len());
587        assert_eq!(matches1[0].pattern_id, matches2[0].pattern_id);
588    }
589
590    #[test]
591    fn test_pattern_management() {
592        let mut matcher = PatternMatcher::new();
593
594        matcher
595            .add_pattern("test1", r"test1", 1.0, "Test pattern 1")
596            .unwrap();
597        matcher
598            .add_pattern("test2", r"test2", 1.0, "Test pattern 2")
599            .unwrap();
600
601        assert_eq!(matcher.get_pattern_ids().len(), 2);
602
603        assert!(matcher.remove_pattern("test1"));
604        assert_eq!(matcher.get_pattern_ids().len(), 1);
605
606        assert!(!matcher.remove_pattern("nonexistent"));
607    }
608}