dataforge/rules/
mod.rs

1//! 规则引擎模块
2//! 
3//! 提供多层级规则继承系统、正则表达式编译优化和自定义函数hook机制
4
5use crate::error::{DataForgeError, Result};
6use crate::memory::StringPool;
7use rand::distributions::Distribution;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use serde_json::Value;
11use std::collections::HashMap;
12use std::sync::{Arc, RwLock};
13
14
15/// 规则类型
16#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
17pub enum RuleType {
18    /// 正则表达式规则
19    Regex { pattern: String, flags: Option<String> },
20    /// 范围规则
21    Range { min: Value, max: Value },
22    /// 枚举规则
23    Enum { values: Vec<Value> },
24    /// 长度规则
25    Length { min: Option<usize>, max: Option<usize> },
26    /// 格式规则
27    Format { format: String },
28    /// 自定义规则
29    Custom { name: String, params: HashMap<String, Value> },
30    /// 组合规则
31    Composite { operator: LogicalOperator, rules: Vec<Rule> },
32}
33
34/// 逻辑操作符
35#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
36pub enum LogicalOperator {
37    And,
38    Or,
39    Not,
40}
41
42/// 规则定义
43#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
44pub struct Rule {
45    /// 规则ID
46    pub id: String,
47    /// 规则名称
48    pub name: String,
49    /// 规则类型
50    pub rule_type: RuleType,
51    /// 优先级
52    pub priority: u32,
53    /// 是否启用
54    pub enabled: bool,
55    /// 描述
56    pub description: Option<String>,
57    /// 标签
58    pub tags: Vec<String>,
59    /// 父规则ID(用于继承)
60    pub parent_id: Option<String>,
61}
62
63/// 规则上下文
64#[derive(Debug, Clone)]
65pub struct RuleContext {
66    /// 字段名
67    pub field_name: String,
68    /// 当前值
69    pub current_value: Option<Value>,
70    /// 额外参数
71    pub params: HashMap<String, Value>,
72    /// 生成历史
73    pub generation_history: Vec<Value>,
74}
75
76/// 规则执行结果
77#[derive(Debug, Clone)]
78pub struct RuleResult {
79    /// 是否匹配
80    pub matched: bool,
81    /// 生成的值
82    pub value: Option<Value>,
83    /// 错误信息
84    pub error: Option<String>,
85    /// 执行时间
86    pub execution_time: std::time::Duration,
87}
88
89/// 自定义规则处理器
90pub trait CustomRuleHandler: Send + Sync {
91    /// 处理自定义规则
92    fn handle(&self, rule: &Rule, context: &RuleContext) -> Result<RuleResult>;
93    
94    /// 获取处理器名称
95    fn name(&self) -> &str;
96}
97
98/// 规则引擎
99pub struct RuleEngine {
100    /// 规则存储
101    rules: RwLock<HashMap<String, Rule>>,
102    /// 编译后的正则表达式缓存
103    regex_cache: RwLock<HashMap<String, Regex>>,
104    /// 自定义规则处理器
105    custom_handlers: RwLock<HashMap<String, Arc<dyn CustomRuleHandler>>>,
106    /// 字符串池
107    #[allow(dead_code)]
108    string_pool: Arc<StringPool>,
109    /// 规则继承关系
110    inheritance_tree: RwLock<HashMap<String, Vec<String>>>,
111}
112
113impl RuleEngine {
114    /// 创建新的规则引擎
115    pub fn new(string_pool: Arc<StringPool>) -> Self {
116        Self {
117            rules: RwLock::new(HashMap::new()),
118            regex_cache: RwLock::new(HashMap::new()),
119            custom_handlers: RwLock::new(HashMap::new()),
120            string_pool,
121            inheritance_tree: RwLock::new(HashMap::new()),
122        }
123    }
124
125    /// 添加规则
126    pub fn add_rule(&self, rule: Rule) -> Result<()> {
127        let rule_id = rule.id.clone();
128        
129        // 验证规则
130        self.validate_rule(&rule)?;
131        
132        // 更新继承关系
133        if let Some(parent_id) = &rule.parent_id {
134            let mut tree = self.inheritance_tree.write().unwrap();
135            tree.entry(parent_id.clone()).or_insert_with(Vec::new).push(rule_id.clone());
136        }
137        
138        // 存储规则
139        let mut rules = self.rules.write().unwrap();
140        rules.insert(rule_id, rule);
141        
142        Ok(())
143    }
144
145    /// 删除规则
146    pub fn remove_rule(&self, rule_id: &str) -> Result<()> {
147        let mut rules = self.rules.write().unwrap();
148        
149        if let Some(rule) = rules.remove(rule_id) {
150            // 清理继承关系
151            if let Some(parent_id) = &rule.parent_id {
152                let mut tree = self.inheritance_tree.write().unwrap();
153                if let Some(children) = tree.get_mut(parent_id) {
154                    children.retain(|id| id != rule_id);
155                }
156            }
157            
158            // 清理正则表达式缓存
159            if let RuleType::Regex { pattern, .. } = &rule.rule_type {
160                let mut cache = self.regex_cache.write().unwrap();
161                cache.remove(pattern);
162            }
163            
164            Ok(())
165        } else {
166            Err(DataForgeError::validation(&format!("Rule not found: {}", rule_id)))
167        }
168    }
169
170    /// 执行规则
171    pub fn execute_rule(&self, rule_id: &str, context: &RuleContext) -> Result<RuleResult> {
172        let start_time = std::time::Instant::now();
173        
174        let rule = {
175            let rules = self.rules.read().unwrap();
176            rules.get(rule_id)
177                .ok_or_else(|| DataForgeError::validation(&format!("Rule not found: {}", rule_id)))?
178                .clone()
179        };
180
181        if !rule.enabled {
182            return Ok(RuleResult {
183                matched: false,
184                value: None,
185                error: Some("Rule is disabled".to_string()),
186                execution_time: start_time.elapsed(),
187            });
188        }
189
190        let result = self.execute_rule_internal(&rule, context);
191        
192        Ok(RuleResult {
193            matched: result.is_ok(),
194            value: result.as_ref().ok().cloned(),
195            error: result.as_ref().err().map(|e| e.to_string()),
196            execution_time: start_time.elapsed(),
197        })
198    }
199
200    /// 内部规则执行逻辑
201    fn execute_rule_internal(&self, rule: &Rule, context: &RuleContext) -> Result<Value> {
202        match &rule.rule_type {
203            RuleType::Regex { pattern, flags } => {
204                self.execute_regex_rule(pattern, flags.as_deref(), context)
205            }
206            RuleType::Range { min, max } => {
207                self.execute_range_rule(min, max, context)
208            }
209            RuleType::Enum { values } => {
210                self.execute_enum_rule(values, context)
211            }
212            RuleType::Length { min, max } => {
213                self.execute_length_rule(*min, *max, context)
214            }
215            RuleType::Format { format } => {
216                self.execute_format_rule(format, context)
217            }
218            RuleType::Custom { name, params } => {
219                self.execute_custom_rule(name, params, rule, context)
220            }
221            RuleType::Composite { operator, rules } => {
222                self.execute_composite_rule(operator, rules, context)
223            }
224        }
225    }
226
227    /// 执行正则表达式规则
228    fn execute_regex_rule(&self, pattern: &str, _flags: Option<&str>, _context: &RuleContext) -> Result<Value> {
229        let _regex = self.get_or_compile_regex(pattern)?;
230        
231        // 使用rand_regex生成匹配的字符串
232        use rand_regex::Regex as RandRegex;
233        let rand_regex = RandRegex::compile(pattern, 100)
234            .map_err(|e| DataForgeError::generator(&format!("Failed to compile regex for generation: {}", e)))?;
235        
236        let mut rng = rand::thread_rng();
237        let generated = rand_regex.sample(&mut rng);
238        
239        Ok(Value::String(generated))
240    }
241
242    /// 执行范围规则
243    fn execute_range_rule(&self, min: &Value, max: &Value, _context: &RuleContext) -> Result<Value> {
244        use rand::Rng;
245        let mut rng = rand::thread_rng();
246        
247        match (min, max) {
248            (Value::Number(min_num), Value::Number(max_num)) => {
249                if let (Some(min_f), Some(max_f)) = (min_num.as_f64(), max_num.as_f64()) {
250                    let value = rng.gen_range(min_f..=max_f);
251                    Ok(Value::Number(serde_json::Number::from_f64(value).unwrap()))
252                } else {
253                    Err(DataForgeError::validation("Invalid number range"))
254                }
255            }
256            _ => Err(DataForgeError::validation("Range rule requires numeric min and max values")),
257        }
258    }
259
260    /// 执行枚举规则
261    fn execute_enum_rule(&self, values: &[Value], _context: &RuleContext) -> Result<Value> {
262        use rand::seq::SliceRandom;
263        let mut rng = rand::thread_rng();
264        
265        values.choose(&mut rng)
266            .cloned()
267            .ok_or_else(|| DataForgeError::validation("Empty enum values"))
268    }
269
270    /// 执行长度规则
271    fn execute_length_rule(&self, min: Option<usize>, max: Option<usize>, _context: &RuleContext) -> Result<Value> {
272        use rand::Rng;
273        let mut rng = rand::thread_rng();
274        
275        let min_len = min.unwrap_or(1);
276        let max_len = max.unwrap_or(20);
277        let length = rng.gen_range(min_len..=max_len);
278        
279        let chars: String = (0..length)
280            .map(|_| rng.gen_range(b'a'..=b'z') as char)
281            .collect();
282        
283        Ok(Value::String(chars))
284    }
285
286    /// 执行格式规则
287    fn execute_format_rule(&self, format: &str, context: &RuleContext) -> Result<Value> {
288        // 简单的格式替换实现
289        let mut result = format.to_string();
290        
291        // 替换常见的占位符
292        result = result.replace("{field_name}", &context.field_name);
293        result = result.replace("{random_number}", &rand::random::<u32>().to_string());
294        result = result.replace("{timestamp}", &chrono::Utc::now().timestamp().to_string());
295        
296        Ok(Value::String(result))
297    }
298
299    /// 执行自定义规则
300    fn execute_custom_rule(&self, name: &str, _params: &HashMap<String, Value>, rule: &Rule, context: &RuleContext) -> Result<Value> {
301        let handlers = self.custom_handlers.read().unwrap();
302        
303        if let Some(handler) = handlers.get(name) {
304            let result = handler.handle(rule, context)?;
305            result.value.ok_or_else(|| DataForgeError::generator("Custom rule handler returned no value"))
306        } else {
307            Err(DataForgeError::validation(&format!("Custom rule handler not found: {}", name)))
308        }
309    }
310
311    /// 执行组合规则
312    fn execute_composite_rule(&self, operator: &LogicalOperator, rules: &[Rule], context: &RuleContext) -> Result<Value> {
313        match operator {
314            LogicalOperator::And => {
315                // 所有规则都必须成功,返回最后一个规则的结果
316                let mut last_value = Value::Null;
317                for rule in rules {
318                    last_value = self.execute_rule_internal(rule, context)?;
319                }
320                Ok(last_value)
321            }
322            LogicalOperator::Or => {
323                // 任意一个规则成功即可
324                for rule in rules {
325                    if let Ok(value) = self.execute_rule_internal(rule, context) {
326                        return Ok(value);
327                    }
328                }
329                Err(DataForgeError::generator("No rule in OR composite succeeded"))
330            }
331            LogicalOperator::Not => {
332                // 规则不应该匹配,这里简单返回null
333                Ok(Value::Null)
334            }
335        }
336    }
337
338    /// 获取或编译正则表达式
339    fn get_or_compile_regex(&self, pattern: &str) -> Result<Regex> {
340        // 首先尝试从缓存获取
341        {
342            let cache = self.regex_cache.read().unwrap();
343            if let Some(regex) = cache.get(pattern) {
344                return Ok(regex.clone());
345            }
346        }
347
348        // 编译新的正则表达式
349        let regex = Regex::new(pattern)
350            .map_err(|e| DataForgeError::validation(&format!("Invalid regex pattern: {}", e)))?;
351
352        // 存入缓存
353        {
354            let mut cache = self.regex_cache.write().unwrap();
355            cache.insert(pattern.to_string(), regex.clone());
356        }
357
358        Ok(regex)
359    }
360
361    /// 验证规则
362    fn validate_rule(&self, rule: &Rule) -> Result<()> {
363        // 检查规则ID是否已存在
364        {
365            let rules = self.rules.read().unwrap();
366            if rules.contains_key(&rule.id) {
367                return Err(DataForgeError::validation(&format!("Rule ID already exists: {}", rule.id)));
368            }
369        }
370
371        // 验证父规则是否存在
372        if let Some(parent_id) = &rule.parent_id {
373            let rules = self.rules.read().unwrap();
374            if !rules.contains_key(parent_id) {
375                return Err(DataForgeError::validation(&format!("Parent rule not found: {}", parent_id)));
376            }
377        }
378
379        // 验证规则类型特定的内容
380        match &rule.rule_type {
381            RuleType::Regex { pattern, .. } => {
382                Regex::new(pattern)
383                    .map_err(|e| DataForgeError::validation(&format!("Invalid regex pattern: {}", e)))?;
384            }
385            RuleType::Range { min, max } => {
386                if !min.is_number() || !max.is_number() {
387                    return Err(DataForgeError::validation("Range rule requires numeric min and max values"));
388                }
389            }
390            RuleType::Enum { values } => {
391                if values.is_empty() {
392                    return Err(DataForgeError::validation("Enum rule requires at least one value"));
393                }
394            }
395            _ => {} // 其他类型暂不验证
396        }
397
398        Ok(())
399    }
400
401    /// 注册自定义规则处理器
402    pub fn register_custom_handler(&self, handler: Arc<dyn CustomRuleHandler>) {
403        let mut handlers = self.custom_handlers.write().unwrap();
404        handlers.insert(handler.name().to_string(), handler);
405    }
406
407    /// 获取所有规则
408    pub fn get_all_rules(&self) -> Vec<Rule> {
409        let rules = self.rules.read().unwrap();
410        rules.values().cloned().collect()
411    }
412
413    /// 根据标签查找规则
414    pub fn find_rules_by_tag(&self, tag: &str) -> Vec<Rule> {
415        let rules = self.rules.read().unwrap();
416        rules.values()
417            .filter(|rule| rule.tags.contains(&tag.to_string()))
418            .cloned()
419            .collect()
420    }
421
422    /// 获取规则继承链
423    pub fn get_inheritance_chain(&self, rule_id: &str) -> Vec<String> {
424        let mut chain = Vec::new();
425        let mut current_id = rule_id.to_string();
426        
427        let rules = self.rules.read().unwrap();
428        
429        while let Some(rule) = rules.get(&current_id) {
430            chain.push(current_id.clone());
431            if let Some(parent_id) = &rule.parent_id {
432                current_id = parent_id.clone();
433            } else {
434                break;
435            }
436        }
437        
438        chain.reverse();
439        chain
440    }
441
442    /// 清理缓存
443    pub fn clear_cache(&self) {
444        let mut cache = self.regex_cache.write().unwrap();
445        cache.clear();
446    }
447}
448
449impl Default for RuleEngine {
450    fn default() -> Self {
451        Self::new(Arc::new(StringPool::default()))
452    }
453}
454
455#[cfg(test)]
456mod tests {
457    use super::*;
458
459    #[test]
460    fn test_rule_engine_creation() {
461        let string_pool = Arc::new(StringPool::default());
462        let engine = RuleEngine::new(string_pool);
463        
464        assert_eq!(engine.get_all_rules().len(), 0);
465    }
466
467    #[test]
468    fn test_add_rule() {
469        let engine = RuleEngine::default();
470        
471        let rule = Rule {
472            id: "test_rule".to_string(),
473            name: "Test Rule".to_string(),
474            rule_type: RuleType::Regex {
475                pattern: r"\d{3}-\d{3}-\d{4}".to_string(),
476                flags: None,
477            },
478            priority: 100,
479            enabled: true,
480            description: Some("Test phone number rule".to_string()),
481            tags: vec!["phone".to_string()],
482            parent_id: None,
483        };
484
485        assert!(engine.add_rule(rule).is_ok());
486        assert_eq!(engine.get_all_rules().len(), 1);
487    }
488
489    #[test]
490    fn test_execute_enum_rule() {
491        let engine = RuleEngine::default();
492        
493        let rule = Rule {
494            id: "enum_rule".to_string(),
495            name: "Enum Rule".to_string(),
496            rule_type: RuleType::Enum {
497                values: vec![
498                    Value::String("A".to_string()),
499                    Value::String("B".to_string()),
500                    Value::String("C".to_string()),
501                ],
502            },
503            priority: 100,
504            enabled: true,
505            description: None,
506            tags: vec![],
507            parent_id: None,
508        };
509
510        engine.add_rule(rule).unwrap();
511
512        let context = RuleContext {
513            field_name: "test_field".to_string(),
514            current_value: None,
515            params: HashMap::new(),
516            generation_history: Vec::new(),
517        };
518
519        let result = engine.execute_rule("enum_rule", &context).unwrap();
520        assert!(result.matched);
521        assert!(result.value.is_some());
522        
523        if let Some(Value::String(s)) = result.value {
524            assert!(["A", "B", "C"].contains(&s.as_str()));
525        }
526    }
527
528    #[test]
529    fn test_rule_inheritance() {
530        let engine = RuleEngine::default();
531        
532        // 添加父规则
533        let parent_rule = Rule {
534            id: "parent_rule".to_string(),
535            name: "Parent Rule".to_string(),
536            rule_type: RuleType::Length { min: Some(5), max: Some(10) },
537            priority: 50,
538            enabled: true,
539            description: None,
540            tags: vec![],
541            parent_id: None,
542        };
543        engine.add_rule(parent_rule).unwrap();
544
545        // 添加子规则
546        let child_rule = Rule {
547            id: "child_rule".to_string(),
548            name: "Child Rule".to_string(),
549            rule_type: RuleType::Length { min: Some(3), max: Some(8) },
550            priority: 100,
551            enabled: true,
552            description: None,
553            tags: vec![],
554            parent_id: Some("parent_rule".to_string()),
555        };
556        engine.add_rule(child_rule).unwrap();
557
558        let chain = engine.get_inheritance_chain("child_rule");
559        assert_eq!(chain, vec!["parent_rule", "child_rule"]);
560    }
561
562    #[test]
563    fn test_find_rules_by_tag() {
564        let engine = RuleEngine::default();
565        
566        let rule1 = Rule {
567            id: "rule1".to_string(),
568            name: "Rule 1".to_string(),
569            rule_type: RuleType::Length { min: Some(1), max: Some(10) },
570            priority: 100,
571            enabled: true,
572            description: None,
573            tags: vec!["test".to_string(), "demo".to_string()],
574            parent_id: None,
575        };
576
577        let rule2 = Rule {
578            id: "rule2".to_string(),
579            name: "Rule 2".to_string(),
580            rule_type: RuleType::Length { min: Some(1), max: Some(5) },
581            priority: 50,
582            enabled: true,
583            description: None,
584            tags: vec!["test".to_string()],
585            parent_id: None,
586        };
587
588        engine.add_rule(rule1).unwrap();
589        engine.add_rule(rule2).unwrap();
590
591        let test_rules = engine.find_rules_by_tag("test");
592        assert_eq!(test_rules.len(), 2);
593
594        let demo_rules = engine.find_rules_by_tag("demo");
595        assert_eq!(demo_rules.len(), 1);
596    }
597}