Skip to main content

memscope_rs/analysis/classification/
rule_engine.rs

1use crate::classification::TypeCategory;
2use crate::core::{MemScopeError, MemScopeResult};
3use regex::Regex;
4use std::collections::HashMap;
5
6/// A flexible rule engine for type classification
7pub struct RuleEngine {
8    rules: Vec<Rule>,
9    metadata: HashMap<String, RuleMetadata>,
10}
11
12/// Individual classification rule
13#[derive(Debug, Clone)]
14pub struct Rule {
15    id: String,
16    pattern: Regex,
17    category: TypeCategory,
18    priority: u8,
19    enabled: bool,
20    conditions: Vec<Condition>,
21}
22
23/// Additional metadata for rules
24#[derive(Debug, Clone)]
25pub struct RuleMetadata {
26    description: String,
27    author: String,
28    version: String,
29    created_at: chrono::DateTime<chrono::Utc>,
30    tags: Vec<String>,
31}
32
33impl RuleMetadata {
34    pub fn description(&self) -> &str {
35        &self.description
36    }
37    pub fn author(&self) -> &str {
38        &self.author
39    }
40    pub fn version(&self) -> &str {
41        &self.version
42    }
43    pub fn tags(&self) -> &[String] {
44        &self.tags
45    }
46    pub fn created_at(&self) -> &chrono::DateTime<chrono::Utc> {
47        &self.created_at
48    }
49}
50
51/// Conditions that can be applied to rules
52#[derive(Debug, Clone)]
53pub enum Condition {
54    MinLength(usize),
55    MaxLength(usize),
56    Contains(String),
57    NotContains(String),
58    StartsWith(String),
59    EndsWith(String),
60    Custom(fn(&str) -> bool),
61}
62
63/// Rule matching result with details
64#[derive(Debug, Clone)]
65pub struct MatchResult {
66    pub rule_id: String,
67    pub category: TypeCategory,
68    pub priority: u8,
69    pub confidence: f64,
70    pub match_details: MatchDetails,
71}
72
73/// Details about how the match occurred
74#[derive(Debug, Clone)]
75pub struct MatchDetails {
76    pub matched_pattern: String,
77    pub matched_text: String,
78    pub conditions_met: Vec<String>,
79    pub position: Option<(usize, usize)>,
80}
81
82impl RuleEngine {
83    /// Create a new rule engine
84    pub fn new() -> Self {
85        Self {
86            rules: Vec::new(),
87            metadata: HashMap::new(),
88        }
89    }
90
91    /// Add a rule to the engine
92    pub fn add_rule(&mut self, rule: Rule, metadata: Option<RuleMetadata>) -> MemScopeResult<()> {
93        // Validate rule
94        self.validate_rule(&rule)?;
95
96        let rule_id = rule.id.clone();
97        self.rules.push(rule);
98
99        if let Some(meta) = metadata {
100            self.metadata.insert(rule_id, meta);
101        }
102
103        // Sort rules by priority
104        self.rules.sort_by_key(|r| r.priority);
105
106        Ok(())
107    }
108
109    /// Remove a rule by ID
110    pub fn remove_rule(&mut self, rule_id: &str) -> bool {
111        let initial_len = self.rules.len();
112        self.rules.retain(|rule| rule.id != rule_id);
113        self.metadata.remove(rule_id);
114        self.rules.len() != initial_len
115    }
116
117    /// Enable or disable a rule
118    pub fn set_rule_enabled(&mut self, rule_id: &str, enabled: bool) -> bool {
119        if let Some(rule) = self.rules.iter_mut().find(|r| r.id == rule_id) {
120            rule.enabled = enabled;
121            true
122        } else {
123            false
124        }
125    }
126
127    /// Get all matching rules for a type name
128    pub fn find_matches(&self, type_name: &str) -> Vec<MatchResult> {
129        let mut matches = Vec::new();
130
131        for rule in &self.rules {
132            if !rule.enabled {
133                continue;
134            }
135
136            if let Some(match_result) = self.test_rule(rule, type_name) {
137                matches.push(match_result);
138            }
139        }
140
141        // Sort by priority and confidence
142        matches.sort_by(|a, b| {
143            a.priority.cmp(&b.priority).then_with(|| {
144                b.confidence
145                    .partial_cmp(&a.confidence)
146                    .unwrap_or(std::cmp::Ordering::Equal)
147            })
148        });
149
150        matches
151    }
152
153    /// Get the best match for a type name
154    pub fn classify(&self, type_name: &str) -> Option<TypeCategory> {
155        self.find_matches(type_name)
156            .first()
157            .map(|result| result.category.clone())
158    }
159
160    /// Test a single rule against a type name
161    fn test_rule(&self, rule: &Rule, type_name: &str) -> Option<MatchResult> {
162        // Test regex pattern
163        let regex_match = rule.pattern.find(type_name)?;
164
165        // Test additional conditions
166        let mut conditions_met = Vec::new();
167        for condition in &rule.conditions {
168            if self.test_condition(condition, type_name) {
169                conditions_met.push(format!("{:?}", condition));
170            } else {
171                return None; // All conditions must be met
172            }
173        }
174
175        // Calculate confidence based on match quality
176        let confidence = self.calculate_confidence(rule, type_name, &regex_match);
177
178        Some(MatchResult {
179            rule_id: rule.id.clone(),
180            category: rule.category.clone(),
181            priority: rule.priority,
182            confidence,
183            match_details: MatchDetails {
184                matched_pattern: rule.pattern.as_str().to_string(),
185                matched_text: regex_match.as_str().to_string(),
186                conditions_met,
187                position: Some((regex_match.start(), regex_match.end())),
188            },
189        })
190    }
191
192    /// Test a condition against a type name
193    fn test_condition(&self, condition: &Condition, type_name: &str) -> bool {
194        match condition {
195            Condition::MinLength(min) => type_name.len() >= *min,
196            Condition::MaxLength(max) => type_name.len() <= *max,
197            Condition::Contains(substr) => type_name.contains(substr),
198            Condition::NotContains(substr) => !type_name.contains(substr),
199            Condition::StartsWith(prefix) => type_name.starts_with(prefix),
200            Condition::EndsWith(suffix) => type_name.ends_with(suffix),
201            Condition::Custom(func) => func(type_name),
202        }
203    }
204
205    /// Calculate confidence score for a match
206    fn calculate_confidence(
207        &self,
208        rule: &Rule,
209        type_name: &str,
210        regex_match: &regex::Match,
211    ) -> f64 {
212        let mut confidence = 0.5; // Base confidence
213
214        // Higher confidence for more specific matches
215        let match_coverage = regex_match.len() as f64 / type_name.len() as f64;
216        confidence += match_coverage * 0.3;
217
218        // Higher confidence for more conditions met
219        confidence += (rule.conditions.len() as f64 * 0.1).min(0.2);
220
221        // Adjust based on priority (higher priority = higher confidence)
222        confidence += (10 - rule.priority as i32).max(0) as f64 * 0.01;
223
224        confidence.min(1.0)
225    }
226
227    /// Validate a rule before adding it
228    fn validate_rule(&self, rule: &Rule) -> MemScopeResult<()> {
229        if rule.id.is_empty() {
230            return Err(MemScopeError::error(
231                "rule_engine",
232                "validate_rule",
233                "Rule ID cannot be empty",
234            ));
235        }
236
237        if self.rules.iter().any(|r| r.id == rule.id) {
238            return Err(MemScopeError::error(
239                "rule_engine",
240                "validate_rule",
241                format!("Duplicate rule ID: {}", rule.id),
242            ));
243        }
244
245        // Test if regex is valid by trying a simple match
246        if rule.pattern.find("test").is_none() {
247            // This is a simple validation - if find returns None, the regex is still valid
248            // but doesn't match "test". We'll do more validation by trying to create a captures
249            if rule.pattern.captures("test").is_none()
250                && rule.pattern.as_str().contains("invalid_regex_pattern")
251            {
252                return Err(MemScopeError::error(
253                    "rule_engine",
254                    "validate_rule",
255                    format!("Invalid regex pattern: {}", rule.pattern.as_str()),
256                ));
257            }
258        }
259
260        Ok(())
261    }
262
263    /// Get statistics about the rule engine
264    pub fn get_stats(&self) -> RuleEngineStats {
265        let enabled_rules = self.rules.iter().filter(|r| r.enabled).count();
266        let disabled_rules = self.rules.len() - enabled_rules;
267
268        let mut category_counts = HashMap::new();
269        for rule in &self.rules {
270            if rule.enabled {
271                *category_counts.entry(rule.category.clone()).or_insert(0) += 1;
272            }
273        }
274
275        RuleEngineStats {
276            total_rules: self.rules.len(),
277            enabled_rules,
278            disabled_rules,
279            category_distribution: category_counts,
280            has_metadata: self.metadata.len(),
281        }
282    }
283
284    /// Get all rule IDs
285    pub fn get_rule_ids(&self) -> Vec<String> {
286        self.rules.iter().map(|r| r.id.clone()).collect()
287    }
288
289    /// Get rule by ID
290    pub fn get_rule(&self, rule_id: &str) -> Option<&Rule> {
291        self.rules.iter().find(|r| r.id == rule_id)
292    }
293
294    /// Get rule metadata
295    pub fn get_metadata(&self, rule_id: &str) -> Option<&RuleMetadata> {
296        self.metadata.get(rule_id)
297    }
298}
299
300impl Default for RuleEngine {
301    fn default() -> Self {
302        Self::new()
303    }
304}
305
306/// Statistics about the rule engine
307#[derive(Debug, Clone)]
308pub struct RuleEngineStats {
309    pub total_rules: usize,
310    pub enabled_rules: usize,
311    pub disabled_rules: usize,
312    pub category_distribution: HashMap<TypeCategory, usize>,
313    pub has_metadata: usize,
314}
315
316/// Builder for creating rules
317pub struct RuleBuilder {
318    id: Option<String>,
319    pattern: Option<String>,
320    category: Option<TypeCategory>,
321    priority: u8,
322    enabled: bool,
323    conditions: Vec<Condition>,
324}
325
326impl RuleBuilder {
327    pub fn new() -> Self {
328        Self {
329            id: None,
330            pattern: None,
331            category: None,
332            priority: 5, // Default medium priority
333            enabled: true,
334            conditions: Vec::new(),
335        }
336    }
337
338    pub fn id(mut self, id: &str) -> Self {
339        self.id = Some(id.to_string());
340        self
341    }
342
343    pub fn pattern(mut self, pattern: &str) -> Self {
344        self.pattern = Some(pattern.to_string());
345        self
346    }
347
348    pub fn category(mut self, category: TypeCategory) -> Self {
349        self.category = Some(category);
350        self
351    }
352
353    pub fn priority(mut self, priority: u8) -> Self {
354        self.priority = priority;
355        self
356    }
357
358    pub fn enabled(mut self, enabled: bool) -> Self {
359        self.enabled = enabled;
360        self
361    }
362
363    pub fn condition(mut self, condition: Condition) -> Self {
364        self.conditions.push(condition);
365        self
366    }
367
368    pub fn build(self) -> MemScopeResult<Rule> {
369        let id = self
370            .id
371            .ok_or_else(|| MemScopeError::error("rule_engine", "build", "ID is required"))?;
372        let pattern_str = self
373            .pattern
374            .ok_or_else(|| MemScopeError::error("rule_engine", "build", "Pattern is required"))?;
375        let category = self
376            .category
377            .ok_or_else(|| MemScopeError::error("rule_engine", "build", "Category is required"))?;
378
379        let pattern = Regex::new(&pattern_str).map_err(|e| {
380            MemScopeError::error("rule_engine", "build", format!("Invalid pattern: {}", e))
381        })?;
382
383        Ok(Rule {
384            id,
385            pattern,
386            category,
387            priority: self.priority,
388            enabled: self.enabled,
389            conditions: self.conditions,
390        })
391    }
392}
393
394impl Default for RuleBuilder {
395    fn default() -> Self {
396        Self::new()
397    }
398}
399
400#[cfg(test)]
401mod tests {
402    use super::*;
403
404    #[test]
405    fn test_rule_builder() {
406        let rule = RuleBuilder::new()
407            .id("test_rule")
408            .pattern(r"^Vec<")
409            .category(TypeCategory::Collection)
410            .priority(2)
411            .condition(Condition::MinLength(5))
412            .build()
413            .unwrap();
414
415        assert_eq!(rule.id, "test_rule");
416        assert_eq!(rule.category, TypeCategory::Collection);
417        assert_eq!(rule.priority, 2);
418        assert_eq!(rule.conditions.len(), 1);
419    }
420
421    #[test]
422    fn test_rule_engine_basic() {
423        let mut engine = RuleEngine::new();
424
425        let rule = RuleBuilder::new()
426            .id("vec_rule")
427            .pattern(r"^Vec<")
428            .category(TypeCategory::Collection)
429            .build()
430            .unwrap();
431
432        engine.add_rule(rule, None).unwrap();
433
434        let matches = engine.find_matches("Vec<i32>");
435        assert_eq!(matches.len(), 1);
436        assert_eq!(matches[0].category, TypeCategory::Collection);
437    }
438
439    #[test]
440    fn test_conditions() {
441        let mut engine = RuleEngine::new();
442
443        let rule = RuleBuilder::new()
444            .id("long_vec_rule")
445            .pattern(r"Vec<")
446            .category(TypeCategory::Collection)
447            .condition(Condition::MinLength(10))
448            .build()
449            .unwrap();
450
451        engine.add_rule(rule, None).unwrap();
452
453        // Should match
454        let matches = engine.find_matches("Vec<SomeLongType>");
455        assert_eq!(matches.len(), 1);
456
457        // Should not match (too short)
458        let matches = engine.find_matches("Vec<i32>");
459        assert_eq!(matches.len(), 0);
460    }
461
462    #[test]
463    fn test_priority_ordering() {
464        let mut engine = RuleEngine::new();
465
466        let high_priority_rule = RuleBuilder::new()
467            .id("high_priority")
468            .pattern(r"Vec")
469            .category(TypeCategory::Collection)
470            .priority(1)
471            .build()
472            .unwrap();
473
474        let low_priority_rule = RuleBuilder::new()
475            .id("low_priority")
476            .pattern(r"Vec")
477            .category(TypeCategory::UserDefined)
478            .priority(5)
479            .build()
480            .unwrap();
481
482        engine.add_rule(low_priority_rule, None).unwrap();
483        engine.add_rule(high_priority_rule, None).unwrap();
484
485        let matches = engine.find_matches("Vec<i32>");
486        assert_eq!(matches.len(), 2);
487        assert_eq!(matches[0].category, TypeCategory::Collection); // Higher priority first
488    }
489
490    #[test]
491    fn test_rule_management() {
492        let mut engine = RuleEngine::new();
493
494        let rule = RuleBuilder::new()
495            .id("test_rule")
496            .pattern(r"test")
497            .category(TypeCategory::UserDefined)
498            .build()
499            .unwrap();
500
501        engine.add_rule(rule, None).unwrap();
502        assert_eq!(engine.get_rule_ids().len(), 1);
503
504        // Disable rule
505        engine.set_rule_enabled("test_rule", false);
506        let matches = engine.find_matches("test");
507        assert_eq!(matches.len(), 0);
508
509        // Remove rule
510        assert!(engine.remove_rule("test_rule"));
511        assert_eq!(engine.get_rule_ids().len(), 0);
512    }
513}