memscope_rs/classification/
pattern_matcher.rs

1use regex::Regex;
2use std::collections::HashMap;
3
4/// Advanced pattern matcher for type names with fuzzy matching capabilities
5pub struct PatternMatcher {
6    patterns: Vec<CompiledPattern>,
7    fuzzy_threshold: f64,
8    cache: std::sync::Mutex<HashMap<String, Vec<PatternMatch>>>,
9}
10
11/// A compiled pattern with metadata
12#[derive(Debug, Clone)]
13#[allow(dead_code)]
14pub struct CompiledPattern {
15    id: String,
16    regex: Regex,
17    weight: f64,
18    tags: Vec<String>,
19    description: String,
20}
21
22/// Result of a pattern match
23#[derive(Debug, Clone)]
24pub struct PatternMatch {
25    pub pattern_id: String,
26    pub score: f64,
27    pub match_type: MatchType,
28    pub captured_groups: Vec<String>,
29    pub position: (usize, usize),
30}
31
32/// Type of match that occurred
33#[derive(Debug, Clone, PartialEq)]
34pub enum MatchType {
35    Exact,
36    Partial,
37    Fuzzy,
38    Substring,
39    Prefix,
40    Suffix,
41}
42
43impl PatternMatcher {
44    /// Create a new pattern matcher
45    pub fn new() -> Self {
46        Self {
47            patterns: Vec::new(),
48            fuzzy_threshold: 0.7,
49            cache: std::sync::Mutex::new(HashMap::new()),
50        }
51    }
52
53    /// Add a pattern to the matcher
54    pub fn add_pattern(
55        &mut self,
56        id: &str,
57        pattern: &str,
58        weight: f64,
59        description: &str,
60    ) -> Result<(), PatternMatcherError> {
61        let regex = Regex::new(pattern)
62            .map_err(|e| PatternMatcherError::InvalidPattern(pattern.to_string(), e.to_string()))?;
63
64        let compiled = CompiledPattern {
65            id: id.to_string(),
66            regex,
67            weight,
68            tags: Vec::new(),
69            description: description.to_string(),
70        };
71
72        self.patterns.push(compiled);
73        self.clear_cache();
74        Ok(())
75    }
76
77    /// Add a pattern with tags
78    pub fn add_pattern_with_tags(
79        &mut self,
80        id: &str,
81        pattern: &str,
82        weight: f64,
83        description: &str,
84        tags: Vec<String>,
85    ) -> Result<(), PatternMatcherError> {
86        let regex = Regex::new(pattern)
87            .map_err(|e| PatternMatcherError::InvalidPattern(pattern.to_string(), e.to_string()))?;
88
89        let compiled = CompiledPattern {
90            id: id.to_string(),
91            regex,
92            weight,
93            tags,
94            description: description.to_string(),
95        };
96
97        self.patterns.push(compiled);
98        self.clear_cache();
99        Ok(())
100    }
101
102    /// Find all matches for a given input
103    pub fn find_matches(&self, input: &str) -> Vec<PatternMatch> {
104        // Check cache first
105        if let Ok(cache) = self.cache.lock() {
106            if let Some(cached_matches) = cache.get(input) {
107                return cached_matches.clone();
108            }
109        }
110
111        let mut matches = Vec::new();
112
113        // Test each pattern
114        for pattern in &self.patterns {
115            if let Some(pattern_match) = self.test_pattern(pattern, input) {
116                matches.push(pattern_match);
117            }
118        }
119
120        // Sort by score (descending)
121        matches.sort_by(|a, b| {
122            b.score
123                .partial_cmp(&a.score)
124                .unwrap_or(std::cmp::Ordering::Equal)
125        });
126
127        // Cache the results
128        if let Ok(mut cache) = self.cache.lock() {
129            cache.insert(input.to_string(), matches.clone());
130        }
131
132        matches
133    }
134
135    /// Find the best match for a given input
136    pub fn find_best_match(&self, input: &str) -> Option<PatternMatch> {
137        self.find_matches(input).into_iter().next()
138    }
139
140    /// Find matches by tags
141    pub fn find_matches_by_tag(&self, input: &str, tag: &str) -> Vec<PatternMatch> {
142        let all_matches = self.find_matches(input);
143        all_matches
144            .into_iter()
145            .filter(|m| {
146                if let Some(pattern) = self.patterns.iter().find(|p| p.id == m.pattern_id) {
147                    pattern.tags.contains(&tag.to_string())
148                } else {
149                    false
150                }
151            })
152            .collect()
153    }
154
155    /// Test a single pattern against input
156    fn test_pattern(&self, pattern: &CompiledPattern, input: &str) -> Option<PatternMatch> {
157        // Try exact regex match first
158        if let Some(regex_match) = pattern.regex.find(input) {
159            let captured_groups = pattern
160                .regex
161                .captures(input)
162                .map(|caps| {
163                    caps.iter()
164                        .skip(1)
165                        .filter_map(|m| m.map(|m| m.as_str().to_string()))
166                        .collect()
167                })
168                .unwrap_or_default();
169
170            let match_type = if regex_match.start() == 0 && regex_match.end() == input.len() {
171                MatchType::Exact
172            } else if regex_match.start() == 0 {
173                MatchType::Prefix
174            } else if regex_match.end() == input.len() {
175                MatchType::Suffix
176            } else {
177                MatchType::Partial
178            };
179
180            let score = self.calculate_score(pattern, input, &regex_match, &match_type);
181
182            return Some(PatternMatch {
183                pattern_id: pattern.id.clone(),
184                score,
185                match_type,
186                captured_groups,
187                position: (regex_match.start(), regex_match.end()),
188            });
189        }
190
191        // Try fuzzy matching if enabled
192        if self.fuzzy_threshold > 0.0 {
193            if let Some(fuzzy_match) = self.fuzzy_match(pattern, input) {
194                return Some(fuzzy_match);
195            }
196        }
197
198        None
199    }
200
201    /// Perform fuzzy matching
202    fn fuzzy_match(&self, pattern: &CompiledPattern, input: &str) -> Option<PatternMatch> {
203        // Simple fuzzy matching based on edit distance
204        let pattern_str = pattern.regex.as_str();
205
206        // Remove regex special characters for fuzzy matching
207        let clean_pattern = self.clean_pattern_for_fuzzy(pattern_str);
208        let similarity = self.calculate_similarity(&clean_pattern, input);
209
210        if similarity >= self.fuzzy_threshold {
211            Some(PatternMatch {
212                pattern_id: pattern.id.clone(),
213                score: similarity * pattern.weight * 0.8, // Fuzzy matches get lower score
214                match_type: MatchType::Fuzzy,
215                captured_groups: Vec::new(),
216                position: (0, input.len()),
217            })
218        } else {
219            None
220        }
221    }
222
223    /// Calculate match score
224    fn calculate_score(
225        &self,
226        pattern: &CompiledPattern,
227        input: &str,
228        regex_match: &regex::Match,
229        match_type: &MatchType,
230    ) -> f64 {
231        let mut score = pattern.weight;
232
233        // Bonus for match type
234        let type_bonus = match match_type {
235            MatchType::Exact => 1.0,
236            MatchType::Prefix => 0.9,
237            MatchType::Suffix => 0.8,
238            MatchType::Partial => 0.7,
239            MatchType::Substring => 0.6,
240            MatchType::Fuzzy => 0.5,
241        };
242        score *= type_bonus;
243
244        // Bonus for coverage
245        let coverage = regex_match.len() as f64 / input.len() as f64;
246        score *= 0.5 + coverage * 0.5;
247
248        // Bonus for position (earlier matches are better)
249        let position_bonus = 1.0 - (regex_match.start() as f64 / input.len() as f64) * 0.1;
250        score *= position_bonus;
251
252        score.min(1.0)
253    }
254
255    /// Clean regex pattern for fuzzy matching
256    fn clean_pattern_for_fuzzy(&self, pattern: &str) -> String {
257        // Remove common regex special characters
258        pattern
259            .replace("^", "")
260            .replace("$", "")
261            .replace("\\", "")
262            .replace(".*", "")
263            .replace(".+", "")
264            .replace("?", "")
265            .replace("*", "")
266            .replace("+", "")
267            .replace("(", "")
268            .replace(")", "")
269            .replace("[", "")
270            .replace("]", "")
271            .replace("{", "")
272            .replace("}", "")
273            .replace("|", "")
274    }
275
276    /// Calculate string similarity using Levenshtein distance
277    fn calculate_similarity(&self, s1: &str, s2: &str) -> f64 {
278        let len1 = s1.chars().count();
279        let len2 = s2.chars().count();
280
281        if len1 == 0 {
282            return if len2 == 0 { 1.0 } else { 0.0 };
283        }
284        if len2 == 0 {
285            return 0.0;
286        }
287
288        let s1_chars: Vec<char> = s1.chars().collect();
289        let s2_chars: Vec<char> = s2.chars().collect();
290
291        let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
292
293        // Initialize first row and column
294        for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
295            row[0] = i;
296        }
297        for j in 0..=len2 {
298            matrix[0][j] = j;
299        }
300
301        // Fill the matrix
302        for i in 1..=len1 {
303            for j in 1..=len2 {
304                let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
305                    0
306                } else {
307                    1
308                };
309                matrix[i][j] = std::cmp::min(
310                    std::cmp::min(
311                        matrix[i - 1][j] + 1, // deletion
312                        matrix[i][j - 1] + 1, // insertion
313                    ),
314                    matrix[i - 1][j - 1] + cost, // substitution
315                );
316            }
317        }
318
319        let distance = matrix[len1][len2];
320        let max_len = std::cmp::max(len1, len2);
321        1.0 - (distance as f64 / max_len as f64)
322    }
323
324    /// Set fuzzy matching threshold
325    pub fn set_fuzzy_threshold(&mut self, threshold: f64) {
326        self.fuzzy_threshold = threshold.clamp(0.0, 1.0);
327        self.clear_cache();
328    }
329
330    /// Get fuzzy matching threshold
331    pub fn get_fuzzy_threshold(&self) -> f64 {
332        self.fuzzy_threshold
333    }
334
335    /// Clear the match cache
336    pub fn clear_cache(&self) {
337        if let Ok(mut cache) = self.cache.lock() {
338            cache.clear();
339        }
340    }
341
342    /// Get pattern statistics
343    pub fn get_stats(&self) -> PatternMatcherStats {
344        let cache = self.cache.lock().unwrap();
345        let total_patterns = self.patterns.len();
346        let cached_inputs = cache.len();
347
348        let mut tag_distribution = HashMap::new();
349        for pattern in &self.patterns {
350            for tag in &pattern.tags {
351                *tag_distribution.entry(tag.clone()).or_insert(0) += 1;
352            }
353        }
354
355        PatternMatcherStats {
356            total_patterns,
357            cached_inputs,
358            fuzzy_threshold: self.fuzzy_threshold,
359            tag_distribution,
360        }
361    }
362
363    /// Get all pattern IDs
364    pub fn get_pattern_ids(&self) -> Vec<String> {
365        self.patterns.iter().map(|p| p.id.clone()).collect()
366    }
367
368    /// Get pattern by ID
369    pub fn get_pattern(&self, id: &str) -> Option<&CompiledPattern> {
370        self.patterns.iter().find(|p| p.id == id)
371    }
372
373    /// Remove pattern by ID
374    pub fn remove_pattern(&mut self, id: &str) -> bool {
375        let initial_len = self.patterns.len();
376        self.patterns.retain(|p| p.id != id);
377        let removed = self.patterns.len() != initial_len;
378        if removed {
379            self.clear_cache();
380        }
381        removed
382    }
383}
384
385impl Default for PatternMatcher {
386    fn default() -> Self {
387        Self::new()
388    }
389}
390
391/// Statistics about the pattern matcher
392#[derive(Debug, Clone)]
393pub struct PatternMatcherStats {
394    pub total_patterns: usize,
395    pub cached_inputs: usize,
396    pub fuzzy_threshold: f64,
397    pub tag_distribution: HashMap<String, usize>,
398}
399
400/// Pattern matcher errors
401#[derive(Debug, thiserror::Error)]
402pub enum PatternMatcherError {
403    #[error("Invalid pattern '{0}': {1}")]
404    InvalidPattern(String, String),
405
406    #[error("Pattern not found: {0}")]
407    PatternNotFound(String),
408
409    #[error("Cache error: {0}")]
410    CacheError(String),
411}
412
413/// Builder for creating pattern matchers with common patterns
414pub struct PatternMatcherBuilder {
415    matcher: PatternMatcher,
416}
417
418impl PatternMatcherBuilder {
419    pub fn new() -> Self {
420        Self {
421            matcher: PatternMatcher::new(),
422        }
423    }
424
425    /// Add common Rust type patterns
426    pub fn with_rust_patterns(mut self) -> Result<Self, PatternMatcherError> {
427        // Primitive types
428        self.matcher.add_pattern_with_tags(
429            "primitives",
430            r"^(i8|i16|i32|i64|i128|isize|u8|u16|u32|u64|u128|usize|f32|f64|bool|char)$",
431            1.0,
432            "Rust primitive types",
433            vec!["rust".to_string(), "primitive".to_string()],
434        )?;
435
436        // String types
437        self.matcher.add_pattern_with_tags(
438            "strings",
439            r"^(String|&str|str)$",
440            1.0,
441            "Rust string types",
442            vec!["rust".to_string(), "string".to_string()],
443        )?;
444
445        // Collections
446        self.matcher.add_pattern_with_tags(
447            "collections",
448            r"^(Vec|HashMap|BTreeMap|HashSet|BTreeSet|VecDeque|LinkedList)<",
449            0.9,
450            "Rust collection types",
451            vec!["rust".to_string(), "collection".to_string()],
452        )?;
453
454        // Smart pointers
455        self.matcher.add_pattern_with_tags(
456            "smart_pointers",
457            r"^(Box|Arc|Rc|Weak)<",
458            0.9,
459            "Rust smart pointer types",
460            vec!["rust".to_string(), "smart_pointer".to_string()],
461        )?;
462
463        Ok(self)
464    }
465
466    /// Set fuzzy threshold
467    pub fn fuzzy_threshold(mut self, threshold: f64) -> Self {
468        self.matcher.set_fuzzy_threshold(threshold);
469        self
470    }
471
472    /// Build the pattern matcher
473    pub fn build(self) -> PatternMatcher {
474        self.matcher
475    }
476}
477
478impl Default for PatternMatcherBuilder {
479    fn default() -> Self {
480        Self::new()
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_exact_match() {
490        let mut matcher = PatternMatcher::new();
491        matcher
492            .add_pattern("vec", r"^Vec<", 1.0, "Vector pattern")
493            .unwrap();
494
495        let matches = matcher.find_matches("Vec<i32>");
496        assert_eq!(matches.len(), 1);
497        assert_eq!(matches[0].match_type, MatchType::Prefix);
498    }
499
500    #[test]
501    fn test_fuzzy_matching() {
502        let mut matcher = PatternMatcher::new();
503        matcher.set_fuzzy_threshold(0.6);
504        matcher
505            .add_pattern("vector", r"Vector", 1.0, "Vector pattern")
506            .unwrap();
507
508        let matches = matcher.find_matches("Vektor"); // Typo
509        assert_eq!(matches.len(), 1);
510        assert_eq!(matches[0].match_type, MatchType::Fuzzy);
511    }
512
513    #[test]
514    fn test_pattern_with_tags() {
515        let mut matcher = PatternMatcher::new();
516        matcher
517            .add_pattern_with_tags(
518                "rust_vec",
519                r"^Vec<",
520                1.0,
521                "Rust vector",
522                vec!["rust".to_string(), "collection".to_string()],
523            )
524            .unwrap();
525
526        let matches = matcher.find_matches_by_tag("Vec<i32>", "rust");
527        assert_eq!(matches.len(), 1);
528
529        let matches = matcher.find_matches_by_tag("Vec<i32>", "java");
530        assert_eq!(matches.len(), 0);
531    }
532
533    #[test]
534    fn test_builder_with_rust_patterns() {
535        let matcher = PatternMatcherBuilder::new()
536            .with_rust_patterns()
537            .unwrap()
538            .fuzzy_threshold(0.8)
539            .build();
540
541        let matches = matcher.find_matches("Vec<i32>");
542        assert!(!matches.is_empty());
543
544        let matches = matcher.find_matches("i32");
545        assert!(!matches.is_empty());
546    }
547
548    #[test]
549    fn test_similarity_calculation() {
550        let matcher = PatternMatcher::new();
551
552        assert_eq!(matcher.calculate_similarity("hello", "hello"), 1.0);
553        assert_eq!(matcher.calculate_similarity("hello", ""), 0.0);
554        assert_eq!(matcher.calculate_similarity("", "hello"), 0.0);
555        assert_eq!(matcher.calculate_similarity("", ""), 1.0);
556
557        let sim = matcher.calculate_similarity("hello", "hallo");
558        assert!(sim > 0.5 && sim < 1.0);
559    }
560
561    #[test]
562    fn test_cache_functionality() {
563        let mut matcher = PatternMatcher::new();
564        matcher
565            .add_pattern("test", r"test", 1.0, "Test pattern")
566            .unwrap();
567
568        // First call
569        let matches1 = matcher.find_matches("test");
570
571        // Second call should use cache
572        let matches2 = matcher.find_matches("test");
573
574        assert_eq!(matches1.len(), matches2.len());
575        assert_eq!(matches1[0].pattern_id, matches2[0].pattern_id);
576    }
577
578    #[test]
579    fn test_pattern_management() {
580        let mut matcher = PatternMatcher::new();
581
582        matcher
583            .add_pattern("test1", r"test1", 1.0, "Test pattern 1")
584            .unwrap();
585        matcher
586            .add_pattern("test2", r"test2", 1.0, "Test pattern 2")
587            .unwrap();
588
589        assert_eq!(matcher.get_pattern_ids().len(), 2);
590
591        assert!(matcher.remove_pattern("test1"));
592        assert_eq!(matcher.get_pattern_ids().len(), 1);
593
594        assert!(!matcher.remove_pattern("nonexistent"));
595    }
596}