1use anyhow::{Context, Result};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum Rule {
14 Contains {
16 field: String,
17 value: String,
18 #[serde(default)]
19 case_sensitive: bool,
20 },
21 LengthGte { field: String, min: usize },
23 LengthLte { field: String, max: usize },
25 Matches { field: String, pattern: String },
27 Equals {
29 field: String,
30 value: serde_json::Value,
31 },
32 NotEquals {
34 field: String,
35 value: serde_json::Value,
36 },
37 GreaterThan { field: String, value: f64 },
39 LessThan { field: String, value: f64 },
41 Exists { field: String },
43 NotExists { field: String },
45 All { rules: Vec<Rule> },
47 Any { rules: Vec<Rule> },
49 Not { rule: Box<Rule> },
51}
52
53#[derive(Debug, Clone)]
55pub struct ValidationResult {
56 pub passed: bool,
58 pub errors: Vec<String>,
60}
61
62impl ValidationResult {
63 pub fn success() -> Self {
64 Self {
65 passed: true,
66 errors: Vec::new(),
67 }
68 }
69
70 pub fn failure(error: String) -> Self {
71 Self {
72 passed: false,
73 errors: vec![error],
74 }
75 }
76
77 pub fn merge(mut self, other: Self) -> Self {
78 self.passed = self.passed && other.passed;
79 self.errors.extend(other.errors);
80 self
81 }
82}
83
84#[derive(Debug, Default)]
86pub struct RuleEngine {
87 regex_cache: HashMap<String, Regex>,
89}
90
91impl RuleEngine {
92 pub fn new() -> Self {
93 Self {
94 regex_cache: HashMap::new(),
95 }
96 }
97
98 pub fn validate(
100 &mut self,
101 rule: &Rule,
102 context: &HashMap<String, serde_json::Value>,
103 ) -> Result<ValidationResult> {
104 match rule {
105 Rule::Contains {
106 field,
107 value,
108 case_sensitive,
109 } => self.validate_contains(context, field, value, *case_sensitive),
110 Rule::LengthGte { field, min } => self.validate_length_gte(context, field, *min),
111 Rule::LengthLte { field, max } => self.validate_length_lte(context, field, *max),
112 Rule::Matches { field, pattern } => self.validate_matches(context, field, pattern),
113 Rule::Equals { field, value } => self.validate_equals(context, field, value),
114 Rule::NotEquals { field, value } => self.validate_not_equals(context, field, value),
115 Rule::GreaterThan { field, value } => {
116 self.validate_greater_than(context, field, *value)
117 }
118 Rule::LessThan { field, value } => self.validate_less_than(context, field, *value),
119 Rule::Exists { field } => self.validate_exists(context, field),
120 Rule::NotExists { field } => self.validate_not_exists(context, field),
121 Rule::All { rules } => self.validate_all(rules, context),
122 Rule::Any { rules } => self.validate_any(rules, context),
123 Rule::Not { rule } => {
124 let result = self.validate(rule, context)?;
125 if result.passed {
126 Ok(ValidationResult::failure(
127 "Condition should not be met".to_string(),
128 ))
129 } else {
130 Ok(ValidationResult::success())
131 }
132 }
133 }
134 }
135
136 fn validate_contains(
137 &self,
138 context: &HashMap<String, serde_json::Value>,
139 field: &str,
140 value: &str,
141 case_sensitive: bool,
142 ) -> Result<ValidationResult> {
143 match context.get(field) {
144 Some(serde_json::Value::String(s)) => {
145 let contains = if case_sensitive {
146 s.contains(value)
147 } else {
148 s.to_lowercase().contains(&value.to_lowercase())
149 };
150 if contains {
151 Ok(ValidationResult::success())
152 } else {
153 Ok(ValidationResult::failure(format!(
154 "Field '{}' does not contain '{}'",
155 field, value
156 )))
157 }
158 }
159 Some(_) => Ok(ValidationResult::failure(format!(
160 "Field '{}' is not a string",
161 field
162 ))),
163 None => Ok(ValidationResult::failure(format!(
164 "Field '{}' not found",
165 field
166 ))),
167 }
168 }
169
170 fn validate_length_gte(
171 &self,
172 context: &HashMap<String, serde_json::Value>,
173 field: &str,
174 min: usize,
175 ) -> Result<ValidationResult> {
176 match context.get(field) {
177 Some(serde_json::Value::String(s)) => {
178 if s.len() >= min {
179 Ok(ValidationResult::success())
180 } else {
181 Ok(ValidationResult::failure(format!(
182 "Field '{}' length {} is less than {}",
183 field,
184 s.len(),
185 min
186 )))
187 }
188 }
189 Some(serde_json::Value::Array(arr)) => {
190 if arr.len() >= min {
191 Ok(ValidationResult::success())
192 } else {
193 Ok(ValidationResult::failure(format!(
194 "Field '{}' array length {} is less than {}",
195 field,
196 arr.len(),
197 min
198 )))
199 }
200 }
201 Some(_) => Ok(ValidationResult::failure(format!(
202 "Field '{}' is not a string or array",
203 field
204 ))),
205 None => Ok(ValidationResult::failure(format!(
206 "Field '{}' not found",
207 field
208 ))),
209 }
210 }
211
212 fn validate_length_lte(
213 &self,
214 context: &HashMap<String, serde_json::Value>,
215 field: &str,
216 max: usize,
217 ) -> Result<ValidationResult> {
218 match context.get(field) {
219 Some(serde_json::Value::String(s)) => {
220 if s.len() <= max {
221 Ok(ValidationResult::success())
222 } else {
223 Ok(ValidationResult::failure(format!(
224 "Field '{}' length {} is greater than {}",
225 field,
226 s.len(),
227 max
228 )))
229 }
230 }
231 Some(serde_json::Value::Array(arr)) => {
232 if arr.len() <= max {
233 Ok(ValidationResult::success())
234 } else {
235 Ok(ValidationResult::failure(format!(
236 "Field '{}' array length {} is greater than {}",
237 field,
238 arr.len(),
239 max
240 )))
241 }
242 }
243 Some(_) => Ok(ValidationResult::failure(format!(
244 "Field '{}' is not a string or array",
245 field
246 ))),
247 None => Ok(ValidationResult::failure(format!(
248 "Field '{}' not found",
249 field
250 ))),
251 }
252 }
253
254 fn validate_matches(
255 &mut self,
256 context: &HashMap<String, serde_json::Value>,
257 field: &str,
258 pattern: &str,
259 ) -> Result<ValidationResult> {
260 let regex = self
261 .regex_cache
262 .entry(pattern.to_string())
263 .or_insert_with(|| {
264 Regex::new(pattern).unwrap_or_else(|_| Regex::new("^(?:)$").unwrap())
265 });
266
267 match context.get(field) {
268 Some(serde_json::Value::String(s)) => {
269 if regex.is_match(s) {
270 Ok(ValidationResult::success())
271 } else {
272 Ok(ValidationResult::failure(format!(
273 "Field '{}' does not match pattern '{}'",
274 field, pattern
275 )))
276 }
277 }
278 Some(_) => Ok(ValidationResult::failure(format!(
279 "Field '{}' is not a string",
280 field
281 ))),
282 None => Ok(ValidationResult::failure(format!(
283 "Field '{}' not found",
284 field
285 ))),
286 }
287 }
288
289 fn validate_equals(
290 &self,
291 context: &HashMap<String, serde_json::Value>,
292 field: &str,
293 value: &serde_json::Value,
294 ) -> Result<ValidationResult> {
295 match context.get(field) {
296 Some(v) => {
297 if v == value {
298 Ok(ValidationResult::success())
299 } else {
300 Ok(ValidationResult::failure(format!(
301 "Field '{}' value {:?} does not equal {:?}",
302 field, v, value
303 )))
304 }
305 }
306 None => Ok(ValidationResult::failure(format!(
307 "Field '{}' not found",
308 field
309 ))),
310 }
311 }
312
313 fn validate_not_equals(
314 &self,
315 context: &HashMap<String, serde_json::Value>,
316 field: &str,
317 value: &serde_json::Value,
318 ) -> Result<ValidationResult> {
319 match context.get(field) {
320 Some(v) => {
321 if v != value {
322 Ok(ValidationResult::success())
323 } else {
324 Ok(ValidationResult::failure(format!(
325 "Field '{}' value {:?} equals {:?}",
326 field, v, value
327 )))
328 }
329 }
330 None => Ok(ValidationResult::success()),
331 }
332 }
333
334 fn validate_greater_than(
335 &self,
336 context: &HashMap<String, serde_json::Value>,
337 field: &str,
338 value: f64,
339 ) -> Result<ValidationResult> {
340 match context.get(field) {
341 Some(serde_json::Value::Number(n)) => {
342 if let Some(f) = n.as_f64() {
343 if f > value {
344 Ok(ValidationResult::success())
345 } else {
346 Ok(ValidationResult::failure(format!(
347 "Field '{}' value {} is not greater than {}",
348 field, f, value
349 )))
350 }
351 } else {
352 Ok(ValidationResult::failure(format!(
353 "Field '{}' is not a valid number",
354 field
355 )))
356 }
357 }
358 Some(_) => Ok(ValidationResult::failure(format!(
359 "Field '{}' is not a number",
360 field
361 ))),
362 None => Ok(ValidationResult::failure(format!(
363 "Field '{}' not found",
364 field
365 ))),
366 }
367 }
368
369 fn validate_less_than(
370 &self,
371 context: &HashMap<String, serde_json::Value>,
372 field: &str,
373 value: f64,
374 ) -> Result<ValidationResult> {
375 match context.get(field) {
376 Some(serde_json::Value::Number(n)) => {
377 if let Some(f) = n.as_f64() {
378 if f < value {
379 Ok(ValidationResult::success())
380 } else {
381 Ok(ValidationResult::failure(format!(
382 "Field '{}' value {} is not less than {}",
383 field, f, value
384 )))
385 }
386 } else {
387 Ok(ValidationResult::failure(format!(
388 "Field '{}' is not a valid number",
389 field
390 )))
391 }
392 }
393 Some(_) => Ok(ValidationResult::failure(format!(
394 "Field '{}' is not a number",
395 field
396 ))),
397 None => Ok(ValidationResult::failure(format!(
398 "Field '{}' not found",
399 field
400 ))),
401 }
402 }
403
404 fn validate_exists(
405 &self,
406 context: &HashMap<String, serde_json::Value>,
407 field: &str,
408 ) -> Result<ValidationResult> {
409 if context.contains_key(field) {
410 Ok(ValidationResult::success())
411 } else {
412 Ok(ValidationResult::failure(format!(
413 "Field '{}' not found",
414 field
415 )))
416 }
417 }
418
419 fn validate_not_exists(
420 &self,
421 context: &HashMap<String, serde_json::Value>,
422 field: &str,
423 ) -> Result<ValidationResult> {
424 if !context.contains_key(field) {
425 Ok(ValidationResult::success())
426 } else {
427 Ok(ValidationResult::failure(format!(
428 "Field '{}' exists",
429 field
430 )))
431 }
432 }
433
434 fn validate_all(
435 &mut self,
436 rules: &[Rule],
437 context: &HashMap<String, serde_json::Value>,
438 ) -> Result<ValidationResult> {
439 let mut result = ValidationResult::success();
440 for rule in rules {
441 result = result.merge(self.validate(rule, context)?);
442 }
443 Ok(result)
444 }
445
446 fn validate_any(
447 &mut self,
448 rules: &[Rule],
449 context: &HashMap<String, serde_json::Value>,
450 ) -> Result<ValidationResult> {
451 let mut errors = Vec::new();
452 for rule in rules {
453 let result = self.validate(rule, context)?;
454 if result.passed {
455 return Ok(ValidationResult::success());
456 }
457 errors.extend(result.errors);
458 }
459 Ok(ValidationResult::failure(format!(
460 "None of the conditions met: {}",
461 errors.join("; ")
462 )))
463 }
464}
465
466pub fn evaluate_expression(
468 expr: &str,
469 context: &HashMap<String, serde_json::Value>,
470) -> Result<bool> {
471 let expr = expr.trim();
472
473 if expr.contains(" && ") {
475 let parts: Vec<&str> = expr.split(" && ").collect();
476 for part in parts {
477 if !evaluate_expression(part, context)? {
478 return Ok(false);
479 }
480 }
481 return Ok(true);
482 }
483
484 if expr.contains(" || ") {
486 let parts: Vec<&str> = expr.split(" || ").collect();
487 for part in parts {
488 if evaluate_expression(part, context)? {
489 return Ok(true);
490 }
491 }
492 return Ok(false);
493 }
494
495 if let Some(eq_pos) = expr.find("==") {
498 let left = expr[..eq_pos].trim();
499 let right = expr[eq_pos + 2..].trim();
500 return evaluate_comparison(left, right, context, true);
501 }
502
503 if let Some(ne_pos) = expr.find("!=") {
505 let left = expr[..ne_pos].trim();
506 let right = expr[ne_pos + 2..].trim();
507 return evaluate_comparison(left, right, context, false);
508 }
509
510 if let Some(ge_pos) = expr.find(">=") {
512 let left = expr[..ge_pos].trim();
513 let right = expr[ge_pos + 2..].trim();
514 return evaluate_numeric_comparison(left, right, context, ">=");
515 }
516
517 if let Some(le_pos) = expr.find("<=") {
519 let left = expr[..le_pos].trim();
520 let right = expr[le_pos + 2..].trim();
521 return evaluate_numeric_comparison(left, right, context, "<=");
522 }
523
524 if let Some(gt_pos) = expr.find('>') {
526 let left = expr[..gt_pos].trim();
527 let right = expr[gt_pos + 1..].trim();
528 return evaluate_numeric_comparison(left, right, context, ">");
529 }
530
531 if let Some(lt_pos) = expr.find('<') {
533 let left = expr[..lt_pos].trim();
534 let right = expr[lt_pos + 1..].trim();
535 return evaluate_numeric_comparison(left, right, context, "<");
536 }
537
538 match expr {
540 "true" => Ok(true),
541 "false" => Ok(false),
542 _ => {
543 if let Some(value) = context.get(expr) {
545 Ok(value.as_bool().unwrap_or(false))
546 } else {
547 Ok(false)
548 }
549 }
550 }
551}
552
553fn evaluate_comparison(
554 left: &str,
555 right: &str,
556 context: &HashMap<String, serde_json::Value>,
557 equals: bool,
558) -> Result<bool> {
559 let left_val = resolve_value(left, context)?;
560 let right_val = resolve_value(right, context)?;
561
562 let result = left_val == right_val;
563 Ok(if equals { result } else { !result })
564}
565
566fn evaluate_numeric_comparison(
567 left: &str,
568 right: &str,
569 context: &HashMap<String, serde_json::Value>,
570 op: &str,
571) -> Result<bool> {
572 let left_val = resolve_numeric(left, context)
573 .with_context(|| format!("Failed to resolve left operand: {}", left))?;
574 let right_val = resolve_numeric(right, context)
575 .with_context(|| format!("Failed to resolve right operand: {}", right))?;
576
577 let result = match op {
578 ">" => left_val > right_val,
579 "<" => left_val < right_val,
580 ">=" => left_val >= right_val,
581 "<=" => left_val <= right_val,
582 _ => false,
583 };
584
585 Ok(result)
586}
587
588fn resolve_value(
589 expr: &str,
590 context: &HashMap<String, serde_json::Value>,
591) -> Result<serde_json::Value> {
592 if expr.starts_with('"') && expr.ends_with('"') {
594 return Ok(serde_json::Value::String(
595 expr[1..expr.len() - 1].to_string(),
596 ));
597 }
598
599 if let Ok(n) = expr.parse::<i64>() {
601 return Ok(serde_json::Value::Number(n.into()));
602 }
603 if let Ok(n) = expr.parse::<f64>()
604 && let Some(num) = serde_json::Number::from_f64(n)
605 {
606 return Ok(serde_json::Value::Number(num));
607 }
608
609 if expr == "true" {
611 return Ok(serde_json::Value::Bool(true));
612 }
613 if expr == "false" {
614 return Ok(serde_json::Value::Bool(false));
615 }
616
617 if expr == "null" {
619 return Ok(serde_json::Value::Null);
620 }
621
622 if let Some(value) = context.get(expr) {
624 return Ok(value.clone());
625 }
626
627 anyhow::bail!("Unknown value: {}", expr)
628}
629
630fn resolve_numeric(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<f64> {
631 if let Ok(n) = expr.parse::<f64>() {
633 return Ok(n);
634 }
635
636 if let Some(value) = context.get(expr)
638 && let Some(n) = value.as_f64()
639 {
640 return Ok(n);
641 }
642
643 anyhow::bail!("Not a numeric value: {}", expr)
644}
645
646#[cfg(test)]
647mod tests {
648 use super::*;
649 use serde_json::json;
650
651 #[test]
652 fn test_rule_contains() {
653 let mut engine = RuleEngine::new();
654 let mut context = HashMap::new();
655 context.insert("text".to_string(), json!("Hello, World!"));
656
657 let rule = Rule::Contains {
658 field: "text".to_string(),
659 value: "World".to_string(),
660 case_sensitive: true,
661 };
662
663 let result = engine.validate(&rule, &context).unwrap();
664 assert!(result.passed);
665 }
666
667 #[test]
668 fn test_rule_length_gte() {
669 let mut engine = RuleEngine::new();
670 let mut context = HashMap::new();
671 context.insert("name".to_string(), json!("Alice"));
672
673 let rule = Rule::LengthGte {
674 field: "name".to_string(),
675 min: 3,
676 };
677
678 let result = engine.validate(&rule, &context).unwrap();
679 assert!(result.passed);
680 }
681
682 #[test]
683 fn test_rule_matches() {
684 let mut engine = RuleEngine::new();
685 let mut context = HashMap::new();
686 context.insert("email".to_string(), json!("test@example.com"));
687
688 let rule = Rule::Matches {
689 field: "email".to_string(),
690 pattern: r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$".to_string(),
691 };
692
693 let result = engine.validate(&rule, &context).unwrap();
694 assert!(result.passed);
695 }
696
697 #[test]
698 fn test_rule_all() {
699 let mut engine = RuleEngine::new();
700 let mut context = HashMap::new();
701 context.insert("name".to_string(), json!("Alice"));
702 context.insert("age".to_string(), json!(25));
703
704 let rule = Rule::All {
705 rules: vec![
706 Rule::LengthGte {
707 field: "name".to_string(),
708 min: 3,
709 },
710 Rule::GreaterThan {
711 field: "age".to_string(),
712 value: 18.0,
713 },
714 ],
715 };
716
717 let result = engine.validate(&rule, &context).unwrap();
718 assert!(result.passed);
719 }
720
721 #[test]
722 fn test_evaluate_expression() {
723 let mut context = HashMap::new();
724 context.insert("count".to_string(), json!(10));
725 context.insert("enabled".to_string(), json!(true));
726
727 assert!(evaluate_expression("count == 10", &context).unwrap());
728 assert!(evaluate_expression("count > 5", &context).unwrap());
729 assert!(evaluate_expression("count < 20 && enabled == true", &context).unwrap());
730 assert!(evaluate_expression("count < 5 || enabled == true", &context).unwrap());
731 }
732}