memscope_rs/classification/
rule_engine.rs

1use crate::classification::TypeCategory;
2use regex::Regex;
3use std::collections::HashMap;
4
5/// A flexible rule engine for type classification
6pub struct RuleEngine {
7    rules: Vec<Rule>,
8    metadata: HashMap<String, RuleMetadata>,
9}
10
11/// Individual classification rule
12#[derive(Debug, Clone)]
13pub struct Rule {
14    id: String,
15    pattern: Regex,
16    category: TypeCategory,
17    priority: u8,
18    enabled: bool,
19    conditions: Vec<Condition>,
20}
21
22/// Additional metadata for rules
23#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub struct RuleMetadata {
26    description: String,
27    author: String,
28    version: String,
29    created_at: chrono::DateTime<chrono::Utc>,
30    tags: Vec<String>,
31}
32
33/// Conditions that can be applied to rules
34#[derive(Debug, Clone)]
35pub enum Condition {
36    MinLength(usize),
37    MaxLength(usize),
38    Contains(String),
39    NotContains(String),
40    StartsWith(String),
41    EndsWith(String),
42    Custom(fn(&str) -> bool),
43}
44
45/// Rule matching result with details
46#[derive(Debug, Clone)]
47pub struct MatchResult {
48    pub rule_id: String,
49    pub category: TypeCategory,
50    pub priority: u8,
51    pub confidence: f64,
52    pub match_details: MatchDetails,
53}
54
55/// Details about how the match occurred
56#[derive(Debug, Clone)]
57pub struct MatchDetails {
58    pub matched_pattern: String,
59    pub matched_text: String,
60    pub conditions_met: Vec<String>,
61    pub position: Option<(usize, usize)>,
62}
63
64impl RuleEngine {
65    /// Create a new rule engine
66    pub fn new() -> Self {
67        Self {
68            rules: Vec::new(),
69            metadata: HashMap::new(),
70        }
71    }
72
73    /// Add a rule to the engine
74    pub fn add_rule(
75        &mut self,
76        rule: Rule,
77        metadata: Option<RuleMetadata>,
78    ) -> Result<(), RuleEngineError> {
79        // Validate rule
80        self.validate_rule(&rule)?;
81
82        let rule_id = rule.id.clone();
83        self.rules.push(rule);
84
85        if let Some(meta) = metadata {
86            self.metadata.insert(rule_id, meta);
87        }
88
89        // Sort rules by priority
90        self.rules.sort_by_key(|r| r.priority);
91
92        Ok(())
93    }
94
95    /// Remove a rule by ID
96    pub fn remove_rule(&mut self, rule_id: &str) -> bool {
97        let initial_len = self.rules.len();
98        self.rules.retain(|rule| rule.id != rule_id);
99        self.metadata.remove(rule_id);
100        self.rules.len() != initial_len
101    }
102
103    /// Enable or disable a rule
104    pub fn set_rule_enabled(&mut self, rule_id: &str, enabled: bool) -> bool {
105        if let Some(rule) = self.rules.iter_mut().find(|r| r.id == rule_id) {
106            rule.enabled = enabled;
107            true
108        } else {
109            false
110        }
111    }
112
113    /// Get all matching rules for a type name
114    pub fn find_matches(&self, type_name: &str) -> Vec<MatchResult> {
115        let mut matches = Vec::new();
116
117        for rule in &self.rules {
118            if !rule.enabled {
119                continue;
120            }
121
122            if let Some(match_result) = self.test_rule(rule, type_name) {
123                matches.push(match_result);
124            }
125        }
126
127        // Sort by priority and confidence
128        matches.sort_by(|a, b| {
129            a.priority.cmp(&b.priority).then_with(|| {
130                b.confidence
131                    .partial_cmp(&a.confidence)
132                    .unwrap_or(std::cmp::Ordering::Equal)
133            })
134        });
135
136        matches
137    }
138
139    /// Get the best match for a type name
140    pub fn classify(&self, type_name: &str) -> Option<TypeCategory> {
141        self.find_matches(type_name)
142            .first()
143            .map(|result| result.category.clone())
144    }
145
146    /// Test a single rule against a type name
147    fn test_rule(&self, rule: &Rule, type_name: &str) -> Option<MatchResult> {
148        // Test regex pattern
149        let regex_match = rule.pattern.find(type_name)?;
150
151        // Test additional conditions
152        let mut conditions_met = Vec::new();
153        for condition in &rule.conditions {
154            if self.test_condition(condition, type_name) {
155                conditions_met.push(format!("{:?}", condition));
156            } else {
157                return None; // All conditions must be met
158            }
159        }
160
161        // Calculate confidence based on match quality
162        let confidence = self.calculate_confidence(rule, type_name, &regex_match);
163
164        Some(MatchResult {
165            rule_id: rule.id.clone(),
166            category: rule.category.clone(),
167            priority: rule.priority,
168            confidence,
169            match_details: MatchDetails {
170                matched_pattern: rule.pattern.as_str().to_string(),
171                matched_text: regex_match.as_str().to_string(),
172                conditions_met,
173                position: Some((regex_match.start(), regex_match.end())),
174            },
175        })
176    }
177
178    /// Test a condition against a type name
179    fn test_condition(&self, condition: &Condition, type_name: &str) -> bool {
180        match condition {
181            Condition::MinLength(min) => type_name.len() >= *min,
182            Condition::MaxLength(max) => type_name.len() <= *max,
183            Condition::Contains(substr) => type_name.contains(substr),
184            Condition::NotContains(substr) => !type_name.contains(substr),
185            Condition::StartsWith(prefix) => type_name.starts_with(prefix),
186            Condition::EndsWith(suffix) => type_name.ends_with(suffix),
187            Condition::Custom(func) => func(type_name),
188        }
189    }
190
191    /// Calculate confidence score for a match
192    fn calculate_confidence(
193        &self,
194        rule: &Rule,
195        type_name: &str,
196        regex_match: &regex::Match,
197    ) -> f64 {
198        let mut confidence = 0.5; // Base confidence
199
200        // Higher confidence for more specific matches
201        let match_coverage = regex_match.len() as f64 / type_name.len() as f64;
202        confidence += match_coverage * 0.3;
203
204        // Higher confidence for more conditions met
205        confidence += (rule.conditions.len() as f64 * 0.1).min(0.2);
206
207        // Adjust based on priority (higher priority = higher confidence)
208        confidence += (10 - rule.priority as i32).max(0) as f64 * 0.01;
209
210        confidence.min(1.0)
211    }
212
213    /// Validate a rule before adding it
214    fn validate_rule(&self, rule: &Rule) -> Result<(), RuleEngineError> {
215        if rule.id.is_empty() {
216            return Err(RuleEngineError::InvalidRule(
217                "Rule ID cannot be empty".to_string(),
218            ));
219        }
220
221        if self.rules.iter().any(|r| r.id == rule.id) {
222            return Err(RuleEngineError::DuplicateRule(rule.id.clone()));
223        }
224
225        // Test if regex is valid by trying a simple match
226        if rule.pattern.find("test").is_none() {
227            // This is a simple validation - if find returns None, the regex is still valid
228            // but doesn't match "test". We'll do more validation by trying to create a captures
229            if rule.pattern.captures("test").is_none()
230                && rule.pattern.as_str().contains("invalid_regex_pattern")
231            {
232                return Err(RuleEngineError::InvalidPattern(
233                    rule.pattern.as_str().to_string(),
234                ));
235            }
236        }
237
238        Ok(())
239    }
240
241    /// Get statistics about the rule engine
242    pub fn get_stats(&self) -> RuleEngineStats {
243        let enabled_rules = self.rules.iter().filter(|r| r.enabled).count();
244        let disabled_rules = self.rules.len() - enabled_rules;
245
246        let mut category_counts = HashMap::new();
247        for rule in &self.rules {
248            if rule.enabled {
249                *category_counts.entry(rule.category.clone()).or_insert(0) += 1;
250            }
251        }
252
253        RuleEngineStats {
254            total_rules: self.rules.len(),
255            enabled_rules,
256            disabled_rules,
257            category_distribution: category_counts,
258            has_metadata: self.metadata.len(),
259        }
260    }
261
262    /// Get all rule IDs
263    pub fn get_rule_ids(&self) -> Vec<String> {
264        self.rules.iter().map(|r| r.id.clone()).collect()
265    }
266
267    /// Get rule by ID
268    pub fn get_rule(&self, rule_id: &str) -> Option<&Rule> {
269        self.rules.iter().find(|r| r.id == rule_id)
270    }
271
272    /// Get rule metadata
273    pub fn get_metadata(&self, rule_id: &str) -> Option<&RuleMetadata> {
274        self.metadata.get(rule_id)
275    }
276}
277
278impl Default for RuleEngine {
279    fn default() -> Self {
280        Self::new()
281    }
282}
283
284/// Statistics about the rule engine
285#[derive(Debug, Clone)]
286pub struct RuleEngineStats {
287    pub total_rules: usize,
288    pub enabled_rules: usize,
289    pub disabled_rules: usize,
290    pub category_distribution: HashMap<TypeCategory, usize>,
291    pub has_metadata: usize,
292}
293
294/// Rule engine errors
295#[derive(Debug, thiserror::Error)]
296pub enum RuleEngineError {
297    #[error("Invalid rule: {0}")]
298    InvalidRule(String),
299
300    #[error("Duplicate rule ID: {0}")]
301    DuplicateRule(String),
302
303    #[error("Invalid regex pattern: {0}")]
304    InvalidPattern(String),
305
306    #[error("Rule not found: {0}")]
307    RuleNotFound(String),
308}
309
310/// Builder for creating rules
311pub struct RuleBuilder {
312    id: Option<String>,
313    pattern: Option<String>,
314    category: Option<TypeCategory>,
315    priority: u8,
316    enabled: bool,
317    conditions: Vec<Condition>,
318}
319
320impl RuleBuilder {
321    pub fn new() -> Self {
322        Self {
323            id: None,
324            pattern: None,
325            category: None,
326            priority: 5, // Default medium priority
327            enabled: true,
328            conditions: Vec::new(),
329        }
330    }
331
332    pub fn id(mut self, id: &str) -> Self {
333        self.id = Some(id.to_string());
334        self
335    }
336
337    pub fn pattern(mut self, pattern: &str) -> Self {
338        self.pattern = Some(pattern.to_string());
339        self
340    }
341
342    pub fn category(mut self, category: TypeCategory) -> Self {
343        self.category = Some(category);
344        self
345    }
346
347    pub fn priority(mut self, priority: u8) -> Self {
348        self.priority = priority;
349        self
350    }
351
352    pub fn enabled(mut self, enabled: bool) -> Self {
353        self.enabled = enabled;
354        self
355    }
356
357    pub fn condition(mut self, condition: Condition) -> Self {
358        self.conditions.push(condition);
359        self
360    }
361
362    pub fn build(self) -> Result<Rule, RuleEngineError> {
363        let id = self
364            .id
365            .ok_or_else(|| RuleEngineError::InvalidRule("ID is required".to_string()))?;
366        let pattern_str = self
367            .pattern
368            .ok_or_else(|| RuleEngineError::InvalidRule("Pattern is required".to_string()))?;
369        let category = self
370            .category
371            .ok_or_else(|| RuleEngineError::InvalidRule("Category is required".to_string()))?;
372
373        let pattern =
374            Regex::new(&pattern_str).map_err(|_| RuleEngineError::InvalidPattern(pattern_str))?;
375
376        Ok(Rule {
377            id,
378            pattern,
379            category,
380            priority: self.priority,
381            enabled: self.enabled,
382            conditions: self.conditions,
383        })
384    }
385}
386
387impl Default for RuleBuilder {
388    fn default() -> Self {
389        Self::new()
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396
397    #[test]
398    fn test_rule_builder() {
399        let rule = RuleBuilder::new()
400            .id("test_rule")
401            .pattern(r"^Vec<")
402            .category(TypeCategory::Collection)
403            .priority(2)
404            .condition(Condition::MinLength(5))
405            .build()
406            .unwrap();
407
408        assert_eq!(rule.id, "test_rule");
409        assert_eq!(rule.category, TypeCategory::Collection);
410        assert_eq!(rule.priority, 2);
411        assert_eq!(rule.conditions.len(), 1);
412    }
413
414    #[test]
415    fn test_rule_engine_basic() {
416        let mut engine = RuleEngine::new();
417
418        let rule = RuleBuilder::new()
419            .id("vec_rule")
420            .pattern(r"^Vec<")
421            .category(TypeCategory::Collection)
422            .build()
423            .unwrap();
424
425        engine.add_rule(rule, None).unwrap();
426
427        let matches = engine.find_matches("Vec<i32>");
428        assert_eq!(matches.len(), 1);
429        assert_eq!(matches[0].category, TypeCategory::Collection);
430    }
431
432    #[test]
433    fn test_conditions() {
434        let mut engine = RuleEngine::new();
435
436        let rule = RuleBuilder::new()
437            .id("long_vec_rule")
438            .pattern(r"Vec<")
439            .category(TypeCategory::Collection)
440            .condition(Condition::MinLength(10))
441            .build()
442            .unwrap();
443
444        engine.add_rule(rule, None).unwrap();
445
446        // Should match
447        let matches = engine.find_matches("Vec<SomeLongType>");
448        assert_eq!(matches.len(), 1);
449
450        // Should not match (too short)
451        let matches = engine.find_matches("Vec<i32>");
452        assert_eq!(matches.len(), 0);
453    }
454
455    #[test]
456    fn test_priority_ordering() {
457        let mut engine = RuleEngine::new();
458
459        let high_priority_rule = RuleBuilder::new()
460            .id("high_priority")
461            .pattern(r"Vec")
462            .category(TypeCategory::Collection)
463            .priority(1)
464            .build()
465            .unwrap();
466
467        let low_priority_rule = RuleBuilder::new()
468            .id("low_priority")
469            .pattern(r"Vec")
470            .category(TypeCategory::UserDefined)
471            .priority(5)
472            .build()
473            .unwrap();
474
475        engine.add_rule(low_priority_rule, None).unwrap();
476        engine.add_rule(high_priority_rule, None).unwrap();
477
478        let matches = engine.find_matches("Vec<i32>");
479        assert_eq!(matches.len(), 2);
480        assert_eq!(matches[0].category, TypeCategory::Collection); // Higher priority first
481    }
482
483    #[test]
484    fn test_rule_management() {
485        let mut engine = RuleEngine::new();
486
487        let rule = RuleBuilder::new()
488            .id("test_rule")
489            .pattern(r"test")
490            .category(TypeCategory::UserDefined)
491            .build()
492            .unwrap();
493
494        engine.add_rule(rule, None).unwrap();
495        assert_eq!(engine.get_rule_ids().len(), 1);
496
497        // Disable rule
498        engine.set_rule_enabled("test_rule", false);
499        let matches = engine.find_matches("test");
500        assert_eq!(matches.len(), 0);
501
502        // Remove rule
503        assert!(engine.remove_rule("test_rule"));
504        assert_eq!(engine.get_rule_ids().len(), 0);
505    }
506}