Skip to main content

matrixcode_core/workflow/
rule_engine.rs

1//! Rule Engine for Workflow Validation
2//!
3//! 验证规则引擎,支持条件表达式解析和验证。
4
5use anyhow::{Context, Result};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// 验证规则
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum Rule {
14    /// 包含检查
15    Contains {
16        field: String,
17        value: String,
18        #[serde(default)]
19        case_sensitive: bool,
20    },
21    /// 长度大于等于
22    LengthGte {
23        field: String,
24        min: usize,
25    },
26    /// 长度小于等于
27    LengthLte {
28        field: String,
29        max: usize,
30    },
31    /// 正则匹配
32    Matches {
33        field: String,
34        pattern: String,
35    },
36    /// 相等
37    Equals {
38        field: String,
39        value: serde_json::Value,
40    },
41    /// 不相等
42    NotEquals {
43        field: String,
44        value: serde_json::Value,
45    },
46    /// 大于
47    GreaterThan {
48        field: String,
49        value: f64,
50    },
51    /// 小于
52    LessThan {
53        field: String,
54        value: f64,
55    },
56    /// 存在
57    Exists {
58        field: String,
59    },
60    /// 不存在
61    NotExists {
62        field: String,
63    },
64    /// 组合规则:全部满足
65    All {
66        rules: Vec<Rule>,
67    },
68    /// 组合规则:任一满足
69    Any {
70        rules: Vec<Rule>,
71    },
72    /// 取反
73    Not {
74        rule: Box<Rule>,
75    },
76}
77
78/// 验证结果
79#[derive(Debug, Clone)]
80pub struct ValidationResult {
81    /// 是否通过
82    pub passed: bool,
83    /// 错误消息列表
84    pub errors: Vec<String>,
85}
86
87impl ValidationResult {
88    pub fn success() -> Self {
89        Self {
90            passed: true,
91            errors: Vec::new(),
92        }
93    }
94
95    pub fn failure(error: String) -> Self {
96        Self {
97            passed: false,
98            errors: vec![error],
99        }
100    }
101
102    pub fn merge(mut self, other: Self) -> Self {
103        self.passed = self.passed && other.passed;
104        self.errors.extend(other.errors);
105        self
106    }
107}
108
109/// 规则引擎
110#[derive(Debug, Default)]
111pub struct RuleEngine {
112    /// 缓存的正则表达式
113    regex_cache: HashMap<String, Regex>,
114}
115
116impl RuleEngine {
117    pub fn new() -> Self {
118        Self {
119            regex_cache: HashMap::new(),
120        }
121    }
122
123    /// 验证单个规则
124    pub fn validate(&mut self, rule: &Rule, context: &HashMap<String, serde_json::Value>) -> Result<ValidationResult> {
125        match rule {
126            Rule::Contains { field, value, case_sensitive } => {
127                self.validate_contains(context, field, value, *case_sensitive)
128            }
129            Rule::LengthGte { field, min } => {
130                self.validate_length_gte(context, field, *min)
131            }
132            Rule::LengthLte { field, max } => {
133                self.validate_length_lte(context, field, *max)
134            }
135            Rule::Matches { field, pattern } => {
136                self.validate_matches(context, field, pattern)
137            }
138            Rule::Equals { field, value } => {
139                self.validate_equals(context, field, value)
140            }
141            Rule::NotEquals { field, value } => {
142                self.validate_not_equals(context, field, value)
143            }
144            Rule::GreaterThan { field, value } => {
145                self.validate_greater_than(context, field, *value)
146            }
147            Rule::LessThan { field, value } => {
148                self.validate_less_than(context, field, *value)
149            }
150            Rule::Exists { field } => {
151                self.validate_exists(context, field)
152            }
153            Rule::NotExists { field } => {
154                self.validate_not_exists(context, field)
155            }
156            Rule::All { rules } => {
157                self.validate_all(rules, context)
158            }
159            Rule::Any { rules } => {
160                self.validate_any(rules, context)
161            }
162            Rule::Not { rule } => {
163                let result = self.validate(rule, context)?;
164                if result.passed {
165                    Ok(ValidationResult::failure("Condition should not be met".to_string()))
166                } else {
167                    Ok(ValidationResult::success())
168                }
169            }
170        }
171    }
172
173    fn validate_contains(
174        &self,
175        context: &HashMap<String, serde_json::Value>,
176        field: &str,
177        value: &str,
178        case_sensitive: bool,
179    ) -> Result<ValidationResult> {
180        match context.get(field) {
181            Some(serde_json::Value::String(s)) => {
182                let contains = if case_sensitive {
183                    s.contains(value)
184                } else {
185                    s.to_lowercase().contains(&value.to_lowercase())
186                };
187                if contains {
188                    Ok(ValidationResult::success())
189                } else {
190                    Ok(ValidationResult::failure(format!(
191                        "Field '{}' does not contain '{}'",
192                        field, value
193                    )))
194                }
195            }
196            Some(_) => Ok(ValidationResult::failure(format!(
197                "Field '{}' is not a string",
198                field
199            ))),
200            None => Ok(ValidationResult::failure(format!(
201                "Field '{}' not found",
202                field
203            ))),
204        }
205    }
206
207    fn validate_length_gte(
208        &self,
209        context: &HashMap<String, serde_json::Value>,
210        field: &str,
211        min: usize,
212    ) -> Result<ValidationResult> {
213        match context.get(field) {
214            Some(serde_json::Value::String(s)) => {
215                if s.len() >= min {
216                    Ok(ValidationResult::success())
217                } else {
218                    Ok(ValidationResult::failure(format!(
219                        "Field '{}' length {} is less than {}",
220                        field,
221                        s.len(),
222                        min
223                    )))
224                }
225            }
226            Some(serde_json::Value::Array(arr)) => {
227                if arr.len() >= min {
228                    Ok(ValidationResult::success())
229                } else {
230                    Ok(ValidationResult::failure(format!(
231                        "Field '{}' array length {} is less than {}",
232                        field,
233                        arr.len(),
234                        min
235                    )))
236                }
237            }
238            Some(_) => Ok(ValidationResult::failure(format!(
239                "Field '{}' is not a string or array",
240                field
241            ))),
242            None => Ok(ValidationResult::failure(format!(
243                "Field '{}' not found",
244                field
245            ))),
246        }
247    }
248
249    fn validate_length_lte(
250        &self,
251        context: &HashMap<String, serde_json::Value>,
252        field: &str,
253        max: usize,
254    ) -> Result<ValidationResult> {
255        match context.get(field) {
256            Some(serde_json::Value::String(s)) => {
257                if s.len() <= max {
258                    Ok(ValidationResult::success())
259                } else {
260                    Ok(ValidationResult::failure(format!(
261                        "Field '{}' length {} is greater than {}",
262                        field,
263                        s.len(),
264                        max
265                    )))
266                }
267            }
268            Some(serde_json::Value::Array(arr)) => {
269                if arr.len() <= max {
270                    Ok(ValidationResult::success())
271                } else {
272                    Ok(ValidationResult::failure(format!(
273                        "Field '{}' array length {} is greater than {}",
274                        field,
275                        arr.len(),
276                        max
277                    )))
278                }
279            }
280            Some(_) => Ok(ValidationResult::failure(format!(
281                "Field '{}' is not a string or array",
282                field
283            ))),
284            None => Ok(ValidationResult::failure(format!(
285                "Field '{}' not found",
286                field
287            ))),
288        }
289    }
290
291    fn validate_matches(
292        &mut self,
293        context: &HashMap<String, serde_json::Value>,
294        field: &str,
295        pattern: &str,
296    ) -> Result<ValidationResult> {
297        let regex = self.regex_cache
298            .entry(pattern.to_string())
299            .or_insert_with(|| {
300                Regex::new(pattern).unwrap_or_else(|_| Regex::new("^(?:)$").unwrap())
301            });
302
303        match context.get(field) {
304            Some(serde_json::Value::String(s)) => {
305                if regex.is_match(s) {
306                    Ok(ValidationResult::success())
307                } else {
308                    Ok(ValidationResult::failure(format!(
309                        "Field '{}' does not match pattern '{}'",
310                        field, pattern
311                    )))
312                }
313            }
314            Some(_) => Ok(ValidationResult::failure(format!(
315                "Field '{}' is not a string",
316                field
317            ))),
318            None => Ok(ValidationResult::failure(format!(
319                "Field '{}' not found",
320                field
321            ))),
322        }
323    }
324
325    fn validate_equals(
326        &self,
327        context: &HashMap<String, serde_json::Value>,
328        field: &str,
329        value: &serde_json::Value,
330    ) -> Result<ValidationResult> {
331        match context.get(field) {
332            Some(v) => {
333                if v == value {
334                    Ok(ValidationResult::success())
335                } else {
336                    Ok(ValidationResult::failure(format!(
337                        "Field '{}' value {:?} does not equal {:?}",
338                        field, v, value
339                    )))
340                }
341            }
342            None => Ok(ValidationResult::failure(format!(
343                "Field '{}' not found",
344                field
345            ))),
346        }
347    }
348
349    fn validate_not_equals(
350        &self,
351        context: &HashMap<String, serde_json::Value>,
352        field: &str,
353        value: &serde_json::Value,
354    ) -> Result<ValidationResult> {
355        match context.get(field) {
356            Some(v) => {
357                if v != value {
358                    Ok(ValidationResult::success())
359                } else {
360                    Ok(ValidationResult::failure(format!(
361                        "Field '{}' value {:?} equals {:?}",
362                        field, v, value
363                    )))
364                }
365            }
366            None => Ok(ValidationResult::success()),
367        }
368    }
369
370    fn validate_greater_than(
371        &self,
372        context: &HashMap<String, serde_json::Value>,
373        field: &str,
374        value: f64,
375    ) -> Result<ValidationResult> {
376        match context.get(field) {
377            Some(serde_json::Value::Number(n)) => {
378                if let Some(f) = n.as_f64() {
379                    if f > value {
380                        Ok(ValidationResult::success())
381                    } else {
382                        Ok(ValidationResult::failure(format!(
383                            "Field '{}' value {} is not greater than {}",
384                            field, f, value
385                        )))
386                    }
387                } else {
388                    Ok(ValidationResult::failure(format!(
389                        "Field '{}' is not a valid number",
390                        field
391                    )))
392                }
393            }
394            Some(_) => Ok(ValidationResult::failure(format!(
395                "Field '{}' is not a number",
396                field
397            ))),
398            None => Ok(ValidationResult::failure(format!(
399                "Field '{}' not found",
400                field
401            ))),
402        }
403    }
404
405    fn validate_less_than(
406        &self,
407        context: &HashMap<String, serde_json::Value>,
408        field: &str,
409        value: f64,
410    ) -> Result<ValidationResult> {
411        match context.get(field) {
412            Some(serde_json::Value::Number(n)) => {
413                if let Some(f) = n.as_f64() {
414                    if f < value {
415                        Ok(ValidationResult::success())
416                    } else {
417                        Ok(ValidationResult::failure(format!(
418                            "Field '{}' value {} is not less than {}",
419                            field, f, value
420                        )))
421                    }
422                } else {
423                    Ok(ValidationResult::failure(format!(
424                        "Field '{}' is not a valid number",
425                        field
426                    )))
427                }
428            }
429            Some(_) => Ok(ValidationResult::failure(format!(
430                "Field '{}' is not a number",
431                field
432            ))),
433            None => Ok(ValidationResult::failure(format!(
434                "Field '{}' not found",
435                field
436            ))),
437        }
438    }
439
440    fn validate_exists(
441        &self,
442        context: &HashMap<String, serde_json::Value>,
443        field: &str,
444    ) -> Result<ValidationResult> {
445        if context.contains_key(field) {
446            Ok(ValidationResult::success())
447        } else {
448            Ok(ValidationResult::failure(format!(
449                "Field '{}' not found",
450                field
451            )))
452        }
453    }
454
455    fn validate_not_exists(
456        &self,
457        context: &HashMap<String, serde_json::Value>,
458        field: &str,
459    ) -> Result<ValidationResult> {
460        if !context.contains_key(field) {
461            Ok(ValidationResult::success())
462        } else {
463            Ok(ValidationResult::failure(format!(
464                "Field '{}' exists",
465                field
466            )))
467        }
468    }
469
470    fn validate_all(
471        &mut self,
472        rules: &[Rule],
473        context: &HashMap<String, serde_json::Value>,
474    ) -> Result<ValidationResult> {
475        let mut result = ValidationResult::success();
476        for rule in rules {
477            result = result.merge(self.validate(rule, context)?);
478        }
479        Ok(result)
480    }
481
482    fn validate_any(
483        &mut self,
484        rules: &[Rule],
485        context: &HashMap<String, serde_json::Value>,
486    ) -> Result<ValidationResult> {
487        let mut errors = Vec::new();
488        for rule in rules {
489            let result = self.validate(rule, context)?;
490            if result.passed {
491                return Ok(ValidationResult::success());
492            }
493            errors.extend(result.errors);
494        }
495        Ok(ValidationResult::failure(format!(
496            "None of the conditions met: {}",
497            errors.join("; ")
498        )))
499    }
500}
501
502/// 表达式求值(简单比较和逻辑运算)
503pub fn evaluate_expression(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<bool> {
504    let expr = expr.trim();
505
506    // 处理 AND 逻辑
507    if expr.contains(" && ") {
508        let parts: Vec<&str> = expr.split(" && ").collect();
509        for part in parts {
510            if !evaluate_expression(part, context)? {
511                return Ok(false);
512            }
513        }
514        return Ok(true);
515    }
516
517    // 处理 OR 逻辑
518    if expr.contains(" || ") {
519        let parts: Vec<&str> = expr.split(" || ").collect();
520        for part in parts {
521            if evaluate_expression(part, context)? {
522                return Ok(true);
523            }
524        }
525        return Ok(false);
526    }
527
528    // 处理比较运算
529    // 相等检查
530    if let Some(eq_pos) = expr.find("==") {
531        let left = expr[..eq_pos].trim();
532        let right = expr[eq_pos + 2..].trim();
533        return evaluate_comparison(left, right, context, true);
534    }
535
536    // 不等检查
537    if let Some(ne_pos) = expr.find("!=") {
538        let left = expr[..ne_pos].trim();
539        let right = expr[ne_pos + 2..].trim();
540        return evaluate_comparison(left, right, context, false);
541    }
542
543    // 大于等于
544    if let Some(ge_pos) = expr.find(">=") {
545        let left = expr[..ge_pos].trim();
546        let right = expr[ge_pos + 2..].trim();
547        return evaluate_numeric_comparison(left, right, context, ">=");
548    }
549
550    // 小于等于
551    if let Some(le_pos) = expr.find("<=") {
552        let left = expr[..le_pos].trim();
553        let right = expr[le_pos + 2..].trim();
554        return evaluate_numeric_comparison(left, right, context, "<=");
555    }
556
557    // 大于
558    if let Some(gt_pos) = expr.find('>') {
559        let left = expr[..gt_pos].trim();
560        let right = expr[gt_pos + 1..].trim();
561        return evaluate_numeric_comparison(left, right, context, ">");
562    }
563
564    // 小于
565    if let Some(lt_pos) = expr.find('<') {
566        let left = expr[..lt_pos].trim();
567        let right = expr[lt_pos + 1..].trim();
568        return evaluate_numeric_comparison(left, right, context, "<");
569    }
570
571    // 布尔值检查
572    match expr {
573        "true" => Ok(true),
574        "false" => Ok(false),
575        _ => {
576            // 检查变量是否存在且为真
577            if let Some(value) = context.get(expr) {
578                Ok(value.as_bool().unwrap_or(false))
579            } else {
580                Ok(false)
581            }
582        }
583    }
584}
585
586fn evaluate_comparison(
587    left: &str,
588    right: &str,
589    context: &HashMap<String, serde_json::Value>,
590    equals: bool,
591) -> Result<bool> {
592    let left_val = resolve_value(left, context)?;
593    let right_val = resolve_value(right, context)?;
594
595    let result = left_val == right_val;
596    Ok(if equals { result } else { !result })
597}
598
599fn evaluate_numeric_comparison(
600    left: &str,
601    right: &str,
602    context: &HashMap<String, serde_json::Value>,
603    op: &str,
604) -> Result<bool> {
605    let left_val = resolve_numeric(left, context)
606        .with_context(|| format!("Failed to resolve left operand: {}", left))?;
607    let right_val = resolve_numeric(right, context)
608        .with_context(|| format!("Failed to resolve right operand: {}", right))?;
609
610    let result = match op {
611        ">" => left_val > right_val,
612        "<" => left_val < right_val,
613        ">=" => left_val >= right_val,
614        "<=" => left_val <= right_val,
615        _ => false,
616    };
617
618    Ok(result)
619}
620
621fn resolve_value(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<serde_json::Value> {
622    // 字符串字面量
623    if expr.starts_with('"') && expr.ends_with('"') {
624        return Ok(serde_json::Value::String(expr[1..expr.len()-1].to_string()));
625    }
626
627    // 数字字面量
628    if let Ok(n) = expr.parse::<i64>() {
629        return Ok(serde_json::Value::Number(n.into()));
630    }
631    if let Ok(n) = expr.parse::<f64>()
632        && let Some(num) = serde_json::Number::from_f64(n) {
633            return Ok(serde_json::Value::Number(num));
634        }
635
636    // 布尔字面量
637    if expr == "true" {
638        return Ok(serde_json::Value::Bool(true));
639    }
640    if expr == "false" {
641        return Ok(serde_json::Value::Bool(false));
642    }
643
644    // 空值
645    if expr == "null" {
646        return Ok(serde_json::Value::Null);
647    }
648
649    // 变量引用
650    if let Some(value) = context.get(expr) {
651        return Ok(value.clone());
652    }
653
654    anyhow::bail!("Unknown value: {}", expr)
655}
656
657fn resolve_numeric(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<f64> {
658    // 数字字面量
659    if let Ok(n) = expr.parse::<f64>() {
660        return Ok(n);
661    }
662
663    // 变量引用
664    if let Some(value) = context.get(expr)
665        && let Some(n) = value.as_f64() {
666            return Ok(n);
667        }
668
669    anyhow::bail!("Not a numeric value: {}", expr)
670}
671
672#[cfg(test)]
673mod tests {
674    use super::*;
675    use serde_json::json;
676
677    #[test]
678    fn test_rule_contains() {
679        let mut engine = RuleEngine::new();
680        let mut context = HashMap::new();
681        context.insert("text".to_string(), json!("Hello, World!"));
682
683        let rule = Rule::Contains {
684            field: "text".to_string(),
685            value: "World".to_string(),
686            case_sensitive: true,
687        };
688
689        let result = engine.validate(&rule, &context).unwrap();
690        assert!(result.passed);
691    }
692
693    #[test]
694    fn test_rule_length_gte() {
695        let mut engine = RuleEngine::new();
696        let mut context = HashMap::new();
697        context.insert("name".to_string(), json!("Alice"));
698
699        let rule = Rule::LengthGte {
700            field: "name".to_string(),
701            min: 3,
702        };
703
704        let result = engine.validate(&rule, &context).unwrap();
705        assert!(result.passed);
706    }
707
708    #[test]
709    fn test_rule_matches() {
710        let mut engine = RuleEngine::new();
711        let mut context = HashMap::new();
712        context.insert("email".to_string(), json!("test@example.com"));
713
714        let rule = Rule::Matches {
715            field: "email".to_string(),
716            pattern: r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$".to_string(),
717        };
718
719        let result = engine.validate(&rule, &context).unwrap();
720        assert!(result.passed);
721    }
722
723    #[test]
724    fn test_rule_all() {
725        let mut engine = RuleEngine::new();
726        let mut context = HashMap::new();
727        context.insert("name".to_string(), json!("Alice"));
728        context.insert("age".to_string(), json!(25));
729
730        let rule = Rule::All {
731            rules: vec![
732                Rule::LengthGte { field: "name".to_string(), min: 3 },
733                Rule::GreaterThan { field: "age".to_string(), value: 18.0 },
734            ],
735        };
736
737        let result = engine.validate(&rule, &context).unwrap();
738        assert!(result.passed);
739    }
740
741    #[test]
742    fn test_evaluate_expression() {
743        let mut context = HashMap::new();
744        context.insert("count".to_string(), json!(10));
745        context.insert("enabled".to_string(), json!(true));
746
747        assert!(evaluate_expression("count == 10", &context).unwrap());
748        assert!(evaluate_expression("count > 5", &context).unwrap());
749        assert!(evaluate_expression("count < 20 && enabled == true", &context).unwrap());
750        assert!(evaluate_expression("count < 5 || enabled == true", &context).unwrap());
751    }
752}