memscope_rs/classification/
rule_engine.rs1use crate::classification::TypeCategory;
2use regex::Regex;
3use std::collections::HashMap;
4
5pub struct RuleEngine {
7 rules: Vec<Rule>,
8 metadata: HashMap<String, RuleMetadata>,
9}
10
11#[derive(Debug, Clone)]
13pub struct Rule {
14 id: String,
15 pattern: Regex,
16 category: TypeCategory,
17 priority: u8,
18 enabled: bool,
19 conditions: Vec<Condition>,
20}
21
22#[derive(Debug, Clone)]
24#[allow(dead_code)]
25pub struct RuleMetadata {
26 description: String,
27 author: String,
28 version: String,
29 created_at: chrono::DateTime<chrono::Utc>,
30 tags: Vec<String>,
31}
32
33#[derive(Debug, Clone)]
35pub enum Condition {
36 MinLength(usize),
37 MaxLength(usize),
38 Contains(String),
39 NotContains(String),
40 StartsWith(String),
41 EndsWith(String),
42 Custom(fn(&str) -> bool),
43}
44
45#[derive(Debug, Clone)]
47pub struct MatchResult {
48 pub rule_id: String,
49 pub category: TypeCategory,
50 pub priority: u8,
51 pub confidence: f64,
52 pub match_details: MatchDetails,
53}
54
55#[derive(Debug, Clone)]
57pub struct MatchDetails {
58 pub matched_pattern: String,
59 pub matched_text: String,
60 pub conditions_met: Vec<String>,
61 pub position: Option<(usize, usize)>,
62}
63
64impl RuleEngine {
65 pub fn new() -> Self {
67 Self {
68 rules: Vec::new(),
69 metadata: HashMap::new(),
70 }
71 }
72
73 pub fn add_rule(
75 &mut self,
76 rule: Rule,
77 metadata: Option<RuleMetadata>,
78 ) -> Result<(), RuleEngineError> {
79 self.validate_rule(&rule)?;
81
82 let rule_id = rule.id.clone();
83 self.rules.push(rule);
84
85 if let Some(meta) = metadata {
86 self.metadata.insert(rule_id, meta);
87 }
88
89 self.rules.sort_by_key(|r| r.priority);
91
92 Ok(())
93 }
94
95 pub fn remove_rule(&mut self, rule_id: &str) -> bool {
97 let initial_len = self.rules.len();
98 self.rules.retain(|rule| rule.id != rule_id);
99 self.metadata.remove(rule_id);
100 self.rules.len() != initial_len
101 }
102
103 pub fn set_rule_enabled(&mut self, rule_id: &str, enabled: bool) -> bool {
105 if let Some(rule) = self.rules.iter_mut().find(|r| r.id == rule_id) {
106 rule.enabled = enabled;
107 true
108 } else {
109 false
110 }
111 }
112
113 pub fn find_matches(&self, type_name: &str) -> Vec<MatchResult> {
115 let mut matches = Vec::new();
116
117 for rule in &self.rules {
118 if !rule.enabled {
119 continue;
120 }
121
122 if let Some(match_result) = self.test_rule(rule, type_name) {
123 matches.push(match_result);
124 }
125 }
126
127 matches.sort_by(|a, b| {
129 a.priority.cmp(&b.priority).then_with(|| {
130 b.confidence
131 .partial_cmp(&a.confidence)
132 .unwrap_or(std::cmp::Ordering::Equal)
133 })
134 });
135
136 matches
137 }
138
139 pub fn classify(&self, type_name: &str) -> Option<TypeCategory> {
141 self.find_matches(type_name)
142 .first()
143 .map(|result| result.category.clone())
144 }
145
146 fn test_rule(&self, rule: &Rule, type_name: &str) -> Option<MatchResult> {
148 let regex_match = rule.pattern.find(type_name)?;
150
151 let mut conditions_met = Vec::new();
153 for condition in &rule.conditions {
154 if self.test_condition(condition, type_name) {
155 conditions_met.push(format!("{:?}", condition));
156 } else {
157 return None; }
159 }
160
161 let confidence = self.calculate_confidence(rule, type_name, ®ex_match);
163
164 Some(MatchResult {
165 rule_id: rule.id.clone(),
166 category: rule.category.clone(),
167 priority: rule.priority,
168 confidence,
169 match_details: MatchDetails {
170 matched_pattern: rule.pattern.as_str().to_string(),
171 matched_text: regex_match.as_str().to_string(),
172 conditions_met,
173 position: Some((regex_match.start(), regex_match.end())),
174 },
175 })
176 }
177
178 fn test_condition(&self, condition: &Condition, type_name: &str) -> bool {
180 match condition {
181 Condition::MinLength(min) => type_name.len() >= *min,
182 Condition::MaxLength(max) => type_name.len() <= *max,
183 Condition::Contains(substr) => type_name.contains(substr),
184 Condition::NotContains(substr) => !type_name.contains(substr),
185 Condition::StartsWith(prefix) => type_name.starts_with(prefix),
186 Condition::EndsWith(suffix) => type_name.ends_with(suffix),
187 Condition::Custom(func) => func(type_name),
188 }
189 }
190
191 fn calculate_confidence(
193 &self,
194 rule: &Rule,
195 type_name: &str,
196 regex_match: ®ex::Match,
197 ) -> f64 {
198 let mut confidence = 0.5; let match_coverage = regex_match.len() as f64 / type_name.len() as f64;
202 confidence += match_coverage * 0.3;
203
204 confidence += (rule.conditions.len() as f64 * 0.1).min(0.2);
206
207 confidence += (10 - rule.priority as i32).max(0) as f64 * 0.01;
209
210 confidence.min(1.0)
211 }
212
213 fn validate_rule(&self, rule: &Rule) -> Result<(), RuleEngineError> {
215 if rule.id.is_empty() {
216 return Err(RuleEngineError::InvalidRule(
217 "Rule ID cannot be empty".to_string(),
218 ));
219 }
220
221 if self.rules.iter().any(|r| r.id == rule.id) {
222 return Err(RuleEngineError::DuplicateRule(rule.id.clone()));
223 }
224
225 if rule.pattern.find("test").is_none() {
227 if rule.pattern.captures("test").is_none()
230 && rule.pattern.as_str().contains("invalid_regex_pattern")
231 {
232 return Err(RuleEngineError::InvalidPattern(
233 rule.pattern.as_str().to_string(),
234 ));
235 }
236 }
237
238 Ok(())
239 }
240
241 pub fn get_stats(&self) -> RuleEngineStats {
243 let enabled_rules = self.rules.iter().filter(|r| r.enabled).count();
244 let disabled_rules = self.rules.len() - enabled_rules;
245
246 let mut category_counts = HashMap::new();
247 for rule in &self.rules {
248 if rule.enabled {
249 *category_counts.entry(rule.category.clone()).or_insert(0) += 1;
250 }
251 }
252
253 RuleEngineStats {
254 total_rules: self.rules.len(),
255 enabled_rules,
256 disabled_rules,
257 category_distribution: category_counts,
258 has_metadata: self.metadata.len(),
259 }
260 }
261
262 pub fn get_rule_ids(&self) -> Vec<String> {
264 self.rules.iter().map(|r| r.id.clone()).collect()
265 }
266
267 pub fn get_rule(&self, rule_id: &str) -> Option<&Rule> {
269 self.rules.iter().find(|r| r.id == rule_id)
270 }
271
272 pub fn get_metadata(&self, rule_id: &str) -> Option<&RuleMetadata> {
274 self.metadata.get(rule_id)
275 }
276}
277
278impl Default for RuleEngine {
279 fn default() -> Self {
280 Self::new()
281 }
282}
283
284#[derive(Debug, Clone)]
286pub struct RuleEngineStats {
287 pub total_rules: usize,
288 pub enabled_rules: usize,
289 pub disabled_rules: usize,
290 pub category_distribution: HashMap<TypeCategory, usize>,
291 pub has_metadata: usize,
292}
293
294#[derive(Debug, thiserror::Error)]
296pub enum RuleEngineError {
297 #[error("Invalid rule: {0}")]
298 InvalidRule(String),
299
300 #[error("Duplicate rule ID: {0}")]
301 DuplicateRule(String),
302
303 #[error("Invalid regex pattern: {0}")]
304 InvalidPattern(String),
305
306 #[error("Rule not found: {0}")]
307 RuleNotFound(String),
308}
309
310pub struct RuleBuilder {
312 id: Option<String>,
313 pattern: Option<String>,
314 category: Option<TypeCategory>,
315 priority: u8,
316 enabled: bool,
317 conditions: Vec<Condition>,
318}
319
320impl RuleBuilder {
321 pub fn new() -> Self {
322 Self {
323 id: None,
324 pattern: None,
325 category: None,
326 priority: 5, enabled: true,
328 conditions: Vec::new(),
329 }
330 }
331
332 pub fn id(mut self, id: &str) -> Self {
333 self.id = Some(id.to_string());
334 self
335 }
336
337 pub fn pattern(mut self, pattern: &str) -> Self {
338 self.pattern = Some(pattern.to_string());
339 self
340 }
341
342 pub fn category(mut self, category: TypeCategory) -> Self {
343 self.category = Some(category);
344 self
345 }
346
347 pub fn priority(mut self, priority: u8) -> Self {
348 self.priority = priority;
349 self
350 }
351
352 pub fn enabled(mut self, enabled: bool) -> Self {
353 self.enabled = enabled;
354 self
355 }
356
357 pub fn condition(mut self, condition: Condition) -> Self {
358 self.conditions.push(condition);
359 self
360 }
361
362 pub fn build(self) -> Result<Rule, RuleEngineError> {
363 let id = self
364 .id
365 .ok_or_else(|| RuleEngineError::InvalidRule("ID is required".to_string()))?;
366 let pattern_str = self
367 .pattern
368 .ok_or_else(|| RuleEngineError::InvalidRule("Pattern is required".to_string()))?;
369 let category = self
370 .category
371 .ok_or_else(|| RuleEngineError::InvalidRule("Category is required".to_string()))?;
372
373 let pattern =
374 Regex::new(&pattern_str).map_err(|_| RuleEngineError::InvalidPattern(pattern_str))?;
375
376 Ok(Rule {
377 id,
378 pattern,
379 category,
380 priority: self.priority,
381 enabled: self.enabled,
382 conditions: self.conditions,
383 })
384 }
385}
386
387impl Default for RuleBuilder {
388 fn default() -> Self {
389 Self::new()
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396
397 #[test]
398 fn test_rule_builder() {
399 let rule = RuleBuilder::new()
400 .id("test_rule")
401 .pattern(r"^Vec<")
402 .category(TypeCategory::Collection)
403 .priority(2)
404 .condition(Condition::MinLength(5))
405 .build()
406 .unwrap();
407
408 assert_eq!(rule.id, "test_rule");
409 assert_eq!(rule.category, TypeCategory::Collection);
410 assert_eq!(rule.priority, 2);
411 assert_eq!(rule.conditions.len(), 1);
412 }
413
414 #[test]
415 fn test_rule_engine_basic() {
416 let mut engine = RuleEngine::new();
417
418 let rule = RuleBuilder::new()
419 .id("vec_rule")
420 .pattern(r"^Vec<")
421 .category(TypeCategory::Collection)
422 .build()
423 .unwrap();
424
425 engine.add_rule(rule, None).unwrap();
426
427 let matches = engine.find_matches("Vec<i32>");
428 assert_eq!(matches.len(), 1);
429 assert_eq!(matches[0].category, TypeCategory::Collection);
430 }
431
432 #[test]
433 fn test_conditions() {
434 let mut engine = RuleEngine::new();
435
436 let rule = RuleBuilder::new()
437 .id("long_vec_rule")
438 .pattern(r"Vec<")
439 .category(TypeCategory::Collection)
440 .condition(Condition::MinLength(10))
441 .build()
442 .unwrap();
443
444 engine.add_rule(rule, None).unwrap();
445
446 let matches = engine.find_matches("Vec<SomeLongType>");
448 assert_eq!(matches.len(), 1);
449
450 let matches = engine.find_matches("Vec<i32>");
452 assert_eq!(matches.len(), 0);
453 }
454
455 #[test]
456 fn test_priority_ordering() {
457 let mut engine = RuleEngine::new();
458
459 let high_priority_rule = RuleBuilder::new()
460 .id("high_priority")
461 .pattern(r"Vec")
462 .category(TypeCategory::Collection)
463 .priority(1)
464 .build()
465 .unwrap();
466
467 let low_priority_rule = RuleBuilder::new()
468 .id("low_priority")
469 .pattern(r"Vec")
470 .category(TypeCategory::UserDefined)
471 .priority(5)
472 .build()
473 .unwrap();
474
475 engine.add_rule(low_priority_rule, None).unwrap();
476 engine.add_rule(high_priority_rule, None).unwrap();
477
478 let matches = engine.find_matches("Vec<i32>");
479 assert_eq!(matches.len(), 2);
480 assert_eq!(matches[0].category, TypeCategory::Collection); }
482
483 #[test]
484 fn test_rule_management() {
485 let mut engine = RuleEngine::new();
486
487 let rule = RuleBuilder::new()
488 .id("test_rule")
489 .pattern(r"test")
490 .category(TypeCategory::UserDefined)
491 .build()
492 .unwrap();
493
494 engine.add_rule(rule, None).unwrap();
495 assert_eq!(engine.get_rule_ids().len(), 1);
496
497 engine.set_rule_enabled("test_rule", false);
499 let matches = engine.find_matches("test");
500 assert_eq!(matches.len(), 0);
501
502 assert!(engine.remove_rule("test_rule"));
504 assert_eq!(engine.get_rule_ids().len(), 0);
505 }
506}