Skip to main content

matrixcode_core/workflow/
rule_engine.rs

1//! Rule Engine for Workflow Validation
2//!
3//! 验证规则引擎,支持条件表达式解析和验证。
4
5use anyhow::{Context, Result};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10/// 验证规则
11#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum Rule {
14    /// 包含检查
15    Contains {
16        field: String,
17        value: String,
18        #[serde(default)]
19        case_sensitive: bool,
20    },
21    /// 长度大于等于
22    LengthGte { field: String, min: usize },
23    /// 长度小于等于
24    LengthLte { field: String, max: usize },
25    /// 正则匹配
26    Matches { field: String, pattern: String },
27    /// 相等
28    Equals {
29        field: String,
30        value: serde_json::Value,
31    },
32    /// 不相等
33    NotEquals {
34        field: String,
35        value: serde_json::Value,
36    },
37    /// 大于
38    GreaterThan { field: String, value: f64 },
39    /// 小于
40    LessThan { field: String, value: f64 },
41    /// 存在
42    Exists { field: String },
43    /// 不存在
44    NotExists { field: String },
45    /// 组合规则:全部满足
46    All { rules: Vec<Rule> },
47    /// 组合规则:任一满足
48    Any { rules: Vec<Rule> },
49    /// 取反
50    Not { rule: Box<Rule> },
51}
52
53/// 验证结果
54#[derive(Debug, Clone)]
55pub struct ValidationResult {
56    /// 是否通过
57    pub passed: bool,
58    /// 错误消息列表
59    pub errors: Vec<String>,
60}
61
62impl ValidationResult {
63    pub fn success() -> Self {
64        Self {
65            passed: true,
66            errors: Vec::new(),
67        }
68    }
69
70    pub fn failure(error: String) -> Self {
71        Self {
72            passed: false,
73            errors: vec![error],
74        }
75    }
76
77    pub fn merge(mut self, other: Self) -> Self {
78        self.passed = self.passed && other.passed;
79        self.errors.extend(other.errors);
80        self
81    }
82}
83
84/// 规则引擎
85#[derive(Debug, Default)]
86pub struct RuleEngine {
87    /// 缓存的正则表达式
88    regex_cache: HashMap<String, Regex>,
89}
90
91impl RuleEngine {
92    pub fn new() -> Self {
93        Self {
94            regex_cache: HashMap::new(),
95        }
96    }
97
98    /// 验证单个规则
99    pub fn validate(
100        &mut self,
101        rule: &Rule,
102        context: &HashMap<String, serde_json::Value>,
103    ) -> Result<ValidationResult> {
104        match rule {
105            Rule::Contains {
106                field,
107                value,
108                case_sensitive,
109            } => self.validate_contains(context, field, value, *case_sensitive),
110            Rule::LengthGte { field, min } => self.validate_length_gte(context, field, *min),
111            Rule::LengthLte { field, max } => self.validate_length_lte(context, field, *max),
112            Rule::Matches { field, pattern } => self.validate_matches(context, field, pattern),
113            Rule::Equals { field, value } => self.validate_equals(context, field, value),
114            Rule::NotEquals { field, value } => self.validate_not_equals(context, field, value),
115            Rule::GreaterThan { field, value } => {
116                self.validate_greater_than(context, field, *value)
117            }
118            Rule::LessThan { field, value } => self.validate_less_than(context, field, *value),
119            Rule::Exists { field } => self.validate_exists(context, field),
120            Rule::NotExists { field } => self.validate_not_exists(context, field),
121            Rule::All { rules } => self.validate_all(rules, context),
122            Rule::Any { rules } => self.validate_any(rules, context),
123            Rule::Not { rule } => {
124                let result = self.validate(rule, context)?;
125                if result.passed {
126                    Ok(ValidationResult::failure(
127                        "Condition should not be met".to_string(),
128                    ))
129                } else {
130                    Ok(ValidationResult::success())
131                }
132            }
133        }
134    }
135
136    fn validate_contains(
137        &self,
138        context: &HashMap<String, serde_json::Value>,
139        field: &str,
140        value: &str,
141        case_sensitive: bool,
142    ) -> Result<ValidationResult> {
143        match context.get(field) {
144            Some(serde_json::Value::String(s)) => {
145                let contains = if case_sensitive {
146                    s.contains(value)
147                } else {
148                    s.to_lowercase().contains(&value.to_lowercase())
149                };
150                if contains {
151                    Ok(ValidationResult::success())
152                } else {
153                    Ok(ValidationResult::failure(format!(
154                        "Field '{}' does not contain '{}'",
155                        field, value
156                    )))
157                }
158            }
159            Some(_) => Ok(ValidationResult::failure(format!(
160                "Field '{}' is not a string",
161                field
162            ))),
163            None => Ok(ValidationResult::failure(format!(
164                "Field '{}' not found",
165                field
166            ))),
167        }
168    }
169
170    fn validate_length_gte(
171        &self,
172        context: &HashMap<String, serde_json::Value>,
173        field: &str,
174        min: usize,
175    ) -> Result<ValidationResult> {
176        match context.get(field) {
177            Some(serde_json::Value::String(s)) => {
178                if s.len() >= min {
179                    Ok(ValidationResult::success())
180                } else {
181                    Ok(ValidationResult::failure(format!(
182                        "Field '{}' length {} is less than {}",
183                        field,
184                        s.len(),
185                        min
186                    )))
187                }
188            }
189            Some(serde_json::Value::Array(arr)) => {
190                if arr.len() >= min {
191                    Ok(ValidationResult::success())
192                } else {
193                    Ok(ValidationResult::failure(format!(
194                        "Field '{}' array length {} is less than {}",
195                        field,
196                        arr.len(),
197                        min
198                    )))
199                }
200            }
201            Some(_) => Ok(ValidationResult::failure(format!(
202                "Field '{}' is not a string or array",
203                field
204            ))),
205            None => Ok(ValidationResult::failure(format!(
206                "Field '{}' not found",
207                field
208            ))),
209        }
210    }
211
212    fn validate_length_lte(
213        &self,
214        context: &HashMap<String, serde_json::Value>,
215        field: &str,
216        max: usize,
217    ) -> Result<ValidationResult> {
218        match context.get(field) {
219            Some(serde_json::Value::String(s)) => {
220                if s.len() <= max {
221                    Ok(ValidationResult::success())
222                } else {
223                    Ok(ValidationResult::failure(format!(
224                        "Field '{}' length {} is greater than {}",
225                        field,
226                        s.len(),
227                        max
228                    )))
229                }
230            }
231            Some(serde_json::Value::Array(arr)) => {
232                if arr.len() <= max {
233                    Ok(ValidationResult::success())
234                } else {
235                    Ok(ValidationResult::failure(format!(
236                        "Field '{}' array length {} is greater than {}",
237                        field,
238                        arr.len(),
239                        max
240                    )))
241                }
242            }
243            Some(_) => Ok(ValidationResult::failure(format!(
244                "Field '{}' is not a string or array",
245                field
246            ))),
247            None => Ok(ValidationResult::failure(format!(
248                "Field '{}' not found",
249                field
250            ))),
251        }
252    }
253
254    fn validate_matches(
255        &mut self,
256        context: &HashMap<String, serde_json::Value>,
257        field: &str,
258        pattern: &str,
259    ) -> Result<ValidationResult> {
260        let regex = self
261            .regex_cache
262            .entry(pattern.to_string())
263            .or_insert_with(|| {
264                Regex::new(pattern).unwrap_or_else(|_| Regex::new("^(?:)$").unwrap())
265            });
266
267        match context.get(field) {
268            Some(serde_json::Value::String(s)) => {
269                if regex.is_match(s) {
270                    Ok(ValidationResult::success())
271                } else {
272                    Ok(ValidationResult::failure(format!(
273                        "Field '{}' does not match pattern '{}'",
274                        field, pattern
275                    )))
276                }
277            }
278            Some(_) => Ok(ValidationResult::failure(format!(
279                "Field '{}' is not a string",
280                field
281            ))),
282            None => Ok(ValidationResult::failure(format!(
283                "Field '{}' not found",
284                field
285            ))),
286        }
287    }
288
289    fn validate_equals(
290        &self,
291        context: &HashMap<String, serde_json::Value>,
292        field: &str,
293        value: &serde_json::Value,
294    ) -> Result<ValidationResult> {
295        match context.get(field) {
296            Some(v) => {
297                if v == value {
298                    Ok(ValidationResult::success())
299                } else {
300                    Ok(ValidationResult::failure(format!(
301                        "Field '{}' value {:?} does not equal {:?}",
302                        field, v, value
303                    )))
304                }
305            }
306            None => Ok(ValidationResult::failure(format!(
307                "Field '{}' not found",
308                field
309            ))),
310        }
311    }
312
313    fn validate_not_equals(
314        &self,
315        context: &HashMap<String, serde_json::Value>,
316        field: &str,
317        value: &serde_json::Value,
318    ) -> Result<ValidationResult> {
319        match context.get(field) {
320            Some(v) => {
321                if v != value {
322                    Ok(ValidationResult::success())
323                } else {
324                    Ok(ValidationResult::failure(format!(
325                        "Field '{}' value {:?} equals {:?}",
326                        field, v, value
327                    )))
328                }
329            }
330            None => Ok(ValidationResult::success()),
331        }
332    }
333
334    fn validate_greater_than(
335        &self,
336        context: &HashMap<String, serde_json::Value>,
337        field: &str,
338        value: f64,
339    ) -> Result<ValidationResult> {
340        match context.get(field) {
341            Some(serde_json::Value::Number(n)) => {
342                if let Some(f) = n.as_f64() {
343                    if f > value {
344                        Ok(ValidationResult::success())
345                    } else {
346                        Ok(ValidationResult::failure(format!(
347                            "Field '{}' value {} is not greater than {}",
348                            field, f, value
349                        )))
350                    }
351                } else {
352                    Ok(ValidationResult::failure(format!(
353                        "Field '{}' is not a valid number",
354                        field
355                    )))
356                }
357            }
358            Some(_) => Ok(ValidationResult::failure(format!(
359                "Field '{}' is not a number",
360                field
361            ))),
362            None => Ok(ValidationResult::failure(format!(
363                "Field '{}' not found",
364                field
365            ))),
366        }
367    }
368
369    fn validate_less_than(
370        &self,
371        context: &HashMap<String, serde_json::Value>,
372        field: &str,
373        value: f64,
374    ) -> Result<ValidationResult> {
375        match context.get(field) {
376            Some(serde_json::Value::Number(n)) => {
377                if let Some(f) = n.as_f64() {
378                    if f < value {
379                        Ok(ValidationResult::success())
380                    } else {
381                        Ok(ValidationResult::failure(format!(
382                            "Field '{}' value {} is not less than {}",
383                            field, f, value
384                        )))
385                    }
386                } else {
387                    Ok(ValidationResult::failure(format!(
388                        "Field '{}' is not a valid number",
389                        field
390                    )))
391                }
392            }
393            Some(_) => Ok(ValidationResult::failure(format!(
394                "Field '{}' is not a number",
395                field
396            ))),
397            None => Ok(ValidationResult::failure(format!(
398                "Field '{}' not found",
399                field
400            ))),
401        }
402    }
403
404    fn validate_exists(
405        &self,
406        context: &HashMap<String, serde_json::Value>,
407        field: &str,
408    ) -> Result<ValidationResult> {
409        if context.contains_key(field) {
410            Ok(ValidationResult::success())
411        } else {
412            Ok(ValidationResult::failure(format!(
413                "Field '{}' not found",
414                field
415            )))
416        }
417    }
418
419    fn validate_not_exists(
420        &self,
421        context: &HashMap<String, serde_json::Value>,
422        field: &str,
423    ) -> Result<ValidationResult> {
424        if !context.contains_key(field) {
425            Ok(ValidationResult::success())
426        } else {
427            Ok(ValidationResult::failure(format!(
428                "Field '{}' exists",
429                field
430            )))
431        }
432    }
433
434    fn validate_all(
435        &mut self,
436        rules: &[Rule],
437        context: &HashMap<String, serde_json::Value>,
438    ) -> Result<ValidationResult> {
439        let mut result = ValidationResult::success();
440        for rule in rules {
441            result = result.merge(self.validate(rule, context)?);
442        }
443        Ok(result)
444    }
445
446    fn validate_any(
447        &mut self,
448        rules: &[Rule],
449        context: &HashMap<String, serde_json::Value>,
450    ) -> Result<ValidationResult> {
451        let mut errors = Vec::new();
452        for rule in rules {
453            let result = self.validate(rule, context)?;
454            if result.passed {
455                return Ok(ValidationResult::success());
456            }
457            errors.extend(result.errors);
458        }
459        Ok(ValidationResult::failure(format!(
460            "None of the conditions met: {}",
461            errors.join("; ")
462        )))
463    }
464}
465
466/// 表达式求值(简单比较和逻辑运算)
467pub fn evaluate_expression(
468    expr: &str,
469    context: &HashMap<String, serde_json::Value>,
470) -> Result<bool> {
471    let expr = expr.trim();
472
473    // 处理 AND 逻辑
474    if expr.contains(" && ") {
475        let parts: Vec<&str> = expr.split(" && ").collect();
476        for part in parts {
477            if !evaluate_expression(part, context)? {
478                return Ok(false);
479            }
480        }
481        return Ok(true);
482    }
483
484    // 处理 OR 逻辑
485    if expr.contains(" || ") {
486        let parts: Vec<&str> = expr.split(" || ").collect();
487        for part in parts {
488            if evaluate_expression(part, context)? {
489                return Ok(true);
490            }
491        }
492        return Ok(false);
493    }
494
495    // 处理比较运算
496    // 相等检查
497    if let Some(eq_pos) = expr.find("==") {
498        let left = expr[..eq_pos].trim();
499        let right = expr[eq_pos + 2..].trim();
500        return evaluate_comparison(left, right, context, true);
501    }
502
503    // 不等检查
504    if let Some(ne_pos) = expr.find("!=") {
505        let left = expr[..ne_pos].trim();
506        let right = expr[ne_pos + 2..].trim();
507        return evaluate_comparison(left, right, context, false);
508    }
509
510    // 大于等于
511    if let Some(ge_pos) = expr.find(">=") {
512        let left = expr[..ge_pos].trim();
513        let right = expr[ge_pos + 2..].trim();
514        return evaluate_numeric_comparison(left, right, context, ">=");
515    }
516
517    // 小于等于
518    if let Some(le_pos) = expr.find("<=") {
519        let left = expr[..le_pos].trim();
520        let right = expr[le_pos + 2..].trim();
521        return evaluate_numeric_comparison(left, right, context, "<=");
522    }
523
524    // 大于
525    if let Some(gt_pos) = expr.find('>') {
526        let left = expr[..gt_pos].trim();
527        let right = expr[gt_pos + 1..].trim();
528        return evaluate_numeric_comparison(left, right, context, ">");
529    }
530
531    // 小于
532    if let Some(lt_pos) = expr.find('<') {
533        let left = expr[..lt_pos].trim();
534        let right = expr[lt_pos + 1..].trim();
535        return evaluate_numeric_comparison(left, right, context, "<");
536    }
537
538    // 布尔值检查
539    match expr {
540        "true" => Ok(true),
541        "false" => Ok(false),
542        _ => {
543            // 检查变量是否存在且为真
544            if let Some(value) = context.get(expr) {
545                Ok(value.as_bool().unwrap_or(false))
546            } else {
547                Ok(false)
548            }
549        }
550    }
551}
552
553fn evaluate_comparison(
554    left: &str,
555    right: &str,
556    context: &HashMap<String, serde_json::Value>,
557    equals: bool,
558) -> Result<bool> {
559    let left_val = resolve_value(left, context)?;
560    let right_val = resolve_value(right, context)?;
561
562    let result = left_val == right_val;
563    Ok(if equals { result } else { !result })
564}
565
566fn evaluate_numeric_comparison(
567    left: &str,
568    right: &str,
569    context: &HashMap<String, serde_json::Value>,
570    op: &str,
571) -> Result<bool> {
572    let left_val = resolve_numeric(left, context)
573        .with_context(|| format!("Failed to resolve left operand: {}", left))?;
574    let right_val = resolve_numeric(right, context)
575        .with_context(|| format!("Failed to resolve right operand: {}", right))?;
576
577    let result = match op {
578        ">" => left_val > right_val,
579        "<" => left_val < right_val,
580        ">=" => left_val >= right_val,
581        "<=" => left_val <= right_val,
582        _ => false,
583    };
584
585    Ok(result)
586}
587
588fn resolve_value(
589    expr: &str,
590    context: &HashMap<String, serde_json::Value>,
591) -> Result<serde_json::Value> {
592    // 字符串字面量
593    if expr.starts_with('"') && expr.ends_with('"') {
594        return Ok(serde_json::Value::String(
595            expr[1..expr.len() - 1].to_string(),
596        ));
597    }
598
599    // 数字字面量
600    if let Ok(n) = expr.parse::<i64>() {
601        return Ok(serde_json::Value::Number(n.into()));
602    }
603    if let Ok(n) = expr.parse::<f64>()
604        && let Some(num) = serde_json::Number::from_f64(n)
605    {
606        return Ok(serde_json::Value::Number(num));
607    }
608
609    // 布尔字面量
610    if expr == "true" {
611        return Ok(serde_json::Value::Bool(true));
612    }
613    if expr == "false" {
614        return Ok(serde_json::Value::Bool(false));
615    }
616
617    // 空值
618    if expr == "null" {
619        return Ok(serde_json::Value::Null);
620    }
621
622    // 变量引用
623    if let Some(value) = context.get(expr) {
624        return Ok(value.clone());
625    }
626
627    anyhow::bail!("Unknown value: {}", expr)
628}
629
630fn resolve_numeric(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<f64> {
631    // 数字字面量
632    if let Ok(n) = expr.parse::<f64>() {
633        return Ok(n);
634    }
635
636    // 变量引用
637    if let Some(value) = context.get(expr)
638        && let Some(n) = value.as_f64()
639    {
640        return Ok(n);
641    }
642
643    anyhow::bail!("Not a numeric value: {}", expr)
644}
645
646#[cfg(test)]
647mod tests {
648    use super::*;
649    use serde_json::json;
650
651    #[test]
652    fn test_rule_contains() {
653        let mut engine = RuleEngine::new();
654        let mut context = HashMap::new();
655        context.insert("text".to_string(), json!("Hello, World!"));
656
657        let rule = Rule::Contains {
658            field: "text".to_string(),
659            value: "World".to_string(),
660            case_sensitive: true,
661        };
662
663        let result = engine.validate(&rule, &context).unwrap();
664        assert!(result.passed);
665    }
666
667    #[test]
668    fn test_rule_length_gte() {
669        let mut engine = RuleEngine::new();
670        let mut context = HashMap::new();
671        context.insert("name".to_string(), json!("Alice"));
672
673        let rule = Rule::LengthGte {
674            field: "name".to_string(),
675            min: 3,
676        };
677
678        let result = engine.validate(&rule, &context).unwrap();
679        assert!(result.passed);
680    }
681
682    #[test]
683    fn test_rule_matches() {
684        let mut engine = RuleEngine::new();
685        let mut context = HashMap::new();
686        context.insert("email".to_string(), json!("test@example.com"));
687
688        let rule = Rule::Matches {
689            field: "email".to_string(),
690            pattern: r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$".to_string(),
691        };
692
693        let result = engine.validate(&rule, &context).unwrap();
694        assert!(result.passed);
695    }
696
697    #[test]
698    fn test_rule_all() {
699        let mut engine = RuleEngine::new();
700        let mut context = HashMap::new();
701        context.insert("name".to_string(), json!("Alice"));
702        context.insert("age".to_string(), json!(25));
703
704        let rule = Rule::All {
705            rules: vec![
706                Rule::LengthGte {
707                    field: "name".to_string(),
708                    min: 3,
709                },
710                Rule::GreaterThan {
711                    field: "age".to_string(),
712                    value: 18.0,
713                },
714            ],
715        };
716
717        let result = engine.validate(&rule, &context).unwrap();
718        assert!(result.passed);
719    }
720
721    #[test]
722    fn test_evaluate_expression() {
723        let mut context = HashMap::new();
724        context.insert("count".to_string(), json!(10));
725        context.insert("enabled".to_string(), json!(true));
726
727        assert!(evaluate_expression("count == 10", &context).unwrap());
728        assert!(evaluate_expression("count > 5", &context).unwrap());
729        assert!(evaluate_expression("count < 20 && enabled == true", &context).unwrap());
730        assert!(evaluate_expression("count < 5 || enabled == true", &context).unwrap());
731    }
732}