1use anyhow::{Context, Result};
6use regex::Regex;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12#[serde(tag = "type", rename_all = "snake_case")]
13pub enum Rule {
14 Contains {
16 field: String,
17 value: String,
18 #[serde(default)]
19 case_sensitive: bool,
20 },
21 LengthGte {
23 field: String,
24 min: usize,
25 },
26 LengthLte {
28 field: String,
29 max: usize,
30 },
31 Matches {
33 field: String,
34 pattern: String,
35 },
36 Equals {
38 field: String,
39 value: serde_json::Value,
40 },
41 NotEquals {
43 field: String,
44 value: serde_json::Value,
45 },
46 GreaterThan {
48 field: String,
49 value: f64,
50 },
51 LessThan {
53 field: String,
54 value: f64,
55 },
56 Exists {
58 field: String,
59 },
60 NotExists {
62 field: String,
63 },
64 All {
66 rules: Vec<Rule>,
67 },
68 Any {
70 rules: Vec<Rule>,
71 },
72 Not {
74 rule: Box<Rule>,
75 },
76}
77
78#[derive(Debug, Clone)]
80pub struct ValidationResult {
81 pub passed: bool,
83 pub errors: Vec<String>,
85}
86
87impl ValidationResult {
88 pub fn success() -> Self {
89 Self {
90 passed: true,
91 errors: Vec::new(),
92 }
93 }
94
95 pub fn failure(error: String) -> Self {
96 Self {
97 passed: false,
98 errors: vec![error],
99 }
100 }
101
102 pub fn merge(mut self, other: Self) -> Self {
103 self.passed = self.passed && other.passed;
104 self.errors.extend(other.errors);
105 self
106 }
107}
108
109#[derive(Debug, Default)]
111pub struct RuleEngine {
112 regex_cache: HashMap<String, Regex>,
114}
115
116impl RuleEngine {
117 pub fn new() -> Self {
118 Self {
119 regex_cache: HashMap::new(),
120 }
121 }
122
123 pub fn validate(&mut self, rule: &Rule, context: &HashMap<String, serde_json::Value>) -> Result<ValidationResult> {
125 match rule {
126 Rule::Contains { field, value, case_sensitive } => {
127 self.validate_contains(context, field, value, *case_sensitive)
128 }
129 Rule::LengthGte { field, min } => {
130 self.validate_length_gte(context, field, *min)
131 }
132 Rule::LengthLte { field, max } => {
133 self.validate_length_lte(context, field, *max)
134 }
135 Rule::Matches { field, pattern } => {
136 self.validate_matches(context, field, pattern)
137 }
138 Rule::Equals { field, value } => {
139 self.validate_equals(context, field, value)
140 }
141 Rule::NotEquals { field, value } => {
142 self.validate_not_equals(context, field, value)
143 }
144 Rule::GreaterThan { field, value } => {
145 self.validate_greater_than(context, field, *value)
146 }
147 Rule::LessThan { field, value } => {
148 self.validate_less_than(context, field, *value)
149 }
150 Rule::Exists { field } => {
151 self.validate_exists(context, field)
152 }
153 Rule::NotExists { field } => {
154 self.validate_not_exists(context, field)
155 }
156 Rule::All { rules } => {
157 self.validate_all(rules, context)
158 }
159 Rule::Any { rules } => {
160 self.validate_any(rules, context)
161 }
162 Rule::Not { rule } => {
163 let result = self.validate(rule, context)?;
164 if result.passed {
165 Ok(ValidationResult::failure("Condition should not be met".to_string()))
166 } else {
167 Ok(ValidationResult::success())
168 }
169 }
170 }
171 }
172
173 fn validate_contains(
174 &self,
175 context: &HashMap<String, serde_json::Value>,
176 field: &str,
177 value: &str,
178 case_sensitive: bool,
179 ) -> Result<ValidationResult> {
180 match context.get(field) {
181 Some(serde_json::Value::String(s)) => {
182 let contains = if case_sensitive {
183 s.contains(value)
184 } else {
185 s.to_lowercase().contains(&value.to_lowercase())
186 };
187 if contains {
188 Ok(ValidationResult::success())
189 } else {
190 Ok(ValidationResult::failure(format!(
191 "Field '{}' does not contain '{}'",
192 field, value
193 )))
194 }
195 }
196 Some(_) => Ok(ValidationResult::failure(format!(
197 "Field '{}' is not a string",
198 field
199 ))),
200 None => Ok(ValidationResult::failure(format!(
201 "Field '{}' not found",
202 field
203 ))),
204 }
205 }
206
207 fn validate_length_gte(
208 &self,
209 context: &HashMap<String, serde_json::Value>,
210 field: &str,
211 min: usize,
212 ) -> Result<ValidationResult> {
213 match context.get(field) {
214 Some(serde_json::Value::String(s)) => {
215 if s.len() >= min {
216 Ok(ValidationResult::success())
217 } else {
218 Ok(ValidationResult::failure(format!(
219 "Field '{}' length {} is less than {}",
220 field,
221 s.len(),
222 min
223 )))
224 }
225 }
226 Some(serde_json::Value::Array(arr)) => {
227 if arr.len() >= min {
228 Ok(ValidationResult::success())
229 } else {
230 Ok(ValidationResult::failure(format!(
231 "Field '{}' array length {} is less than {}",
232 field,
233 arr.len(),
234 min
235 )))
236 }
237 }
238 Some(_) => Ok(ValidationResult::failure(format!(
239 "Field '{}' is not a string or array",
240 field
241 ))),
242 None => Ok(ValidationResult::failure(format!(
243 "Field '{}' not found",
244 field
245 ))),
246 }
247 }
248
249 fn validate_length_lte(
250 &self,
251 context: &HashMap<String, serde_json::Value>,
252 field: &str,
253 max: usize,
254 ) -> Result<ValidationResult> {
255 match context.get(field) {
256 Some(serde_json::Value::String(s)) => {
257 if s.len() <= max {
258 Ok(ValidationResult::success())
259 } else {
260 Ok(ValidationResult::failure(format!(
261 "Field '{}' length {} is greater than {}",
262 field,
263 s.len(),
264 max
265 )))
266 }
267 }
268 Some(serde_json::Value::Array(arr)) => {
269 if arr.len() <= max {
270 Ok(ValidationResult::success())
271 } else {
272 Ok(ValidationResult::failure(format!(
273 "Field '{}' array length {} is greater than {}",
274 field,
275 arr.len(),
276 max
277 )))
278 }
279 }
280 Some(_) => Ok(ValidationResult::failure(format!(
281 "Field '{}' is not a string or array",
282 field
283 ))),
284 None => Ok(ValidationResult::failure(format!(
285 "Field '{}' not found",
286 field
287 ))),
288 }
289 }
290
291 fn validate_matches(
292 &mut self,
293 context: &HashMap<String, serde_json::Value>,
294 field: &str,
295 pattern: &str,
296 ) -> Result<ValidationResult> {
297 let regex = self.regex_cache
298 .entry(pattern.to_string())
299 .or_insert_with(|| {
300 Regex::new(pattern).unwrap_or_else(|_| Regex::new("^(?:)$").unwrap())
301 });
302
303 match context.get(field) {
304 Some(serde_json::Value::String(s)) => {
305 if regex.is_match(s) {
306 Ok(ValidationResult::success())
307 } else {
308 Ok(ValidationResult::failure(format!(
309 "Field '{}' does not match pattern '{}'",
310 field, pattern
311 )))
312 }
313 }
314 Some(_) => Ok(ValidationResult::failure(format!(
315 "Field '{}' is not a string",
316 field
317 ))),
318 None => Ok(ValidationResult::failure(format!(
319 "Field '{}' not found",
320 field
321 ))),
322 }
323 }
324
325 fn validate_equals(
326 &self,
327 context: &HashMap<String, serde_json::Value>,
328 field: &str,
329 value: &serde_json::Value,
330 ) -> Result<ValidationResult> {
331 match context.get(field) {
332 Some(v) => {
333 if v == value {
334 Ok(ValidationResult::success())
335 } else {
336 Ok(ValidationResult::failure(format!(
337 "Field '{}' value {:?} does not equal {:?}",
338 field, v, value
339 )))
340 }
341 }
342 None => Ok(ValidationResult::failure(format!(
343 "Field '{}' not found",
344 field
345 ))),
346 }
347 }
348
349 fn validate_not_equals(
350 &self,
351 context: &HashMap<String, serde_json::Value>,
352 field: &str,
353 value: &serde_json::Value,
354 ) -> Result<ValidationResult> {
355 match context.get(field) {
356 Some(v) => {
357 if v != value {
358 Ok(ValidationResult::success())
359 } else {
360 Ok(ValidationResult::failure(format!(
361 "Field '{}' value {:?} equals {:?}",
362 field, v, value
363 )))
364 }
365 }
366 None => Ok(ValidationResult::success()),
367 }
368 }
369
370 fn validate_greater_than(
371 &self,
372 context: &HashMap<String, serde_json::Value>,
373 field: &str,
374 value: f64,
375 ) -> Result<ValidationResult> {
376 match context.get(field) {
377 Some(serde_json::Value::Number(n)) => {
378 if let Some(f) = n.as_f64() {
379 if f > value {
380 Ok(ValidationResult::success())
381 } else {
382 Ok(ValidationResult::failure(format!(
383 "Field '{}' value {} is not greater than {}",
384 field, f, value
385 )))
386 }
387 } else {
388 Ok(ValidationResult::failure(format!(
389 "Field '{}' is not a valid number",
390 field
391 )))
392 }
393 }
394 Some(_) => Ok(ValidationResult::failure(format!(
395 "Field '{}' is not a number",
396 field
397 ))),
398 None => Ok(ValidationResult::failure(format!(
399 "Field '{}' not found",
400 field
401 ))),
402 }
403 }
404
405 fn validate_less_than(
406 &self,
407 context: &HashMap<String, serde_json::Value>,
408 field: &str,
409 value: f64,
410 ) -> Result<ValidationResult> {
411 match context.get(field) {
412 Some(serde_json::Value::Number(n)) => {
413 if let Some(f) = n.as_f64() {
414 if f < value {
415 Ok(ValidationResult::success())
416 } else {
417 Ok(ValidationResult::failure(format!(
418 "Field '{}' value {} is not less than {}",
419 field, f, value
420 )))
421 }
422 } else {
423 Ok(ValidationResult::failure(format!(
424 "Field '{}' is not a valid number",
425 field
426 )))
427 }
428 }
429 Some(_) => Ok(ValidationResult::failure(format!(
430 "Field '{}' is not a number",
431 field
432 ))),
433 None => Ok(ValidationResult::failure(format!(
434 "Field '{}' not found",
435 field
436 ))),
437 }
438 }
439
440 fn validate_exists(
441 &self,
442 context: &HashMap<String, serde_json::Value>,
443 field: &str,
444 ) -> Result<ValidationResult> {
445 if context.contains_key(field) {
446 Ok(ValidationResult::success())
447 } else {
448 Ok(ValidationResult::failure(format!(
449 "Field '{}' not found",
450 field
451 )))
452 }
453 }
454
455 fn validate_not_exists(
456 &self,
457 context: &HashMap<String, serde_json::Value>,
458 field: &str,
459 ) -> Result<ValidationResult> {
460 if !context.contains_key(field) {
461 Ok(ValidationResult::success())
462 } else {
463 Ok(ValidationResult::failure(format!(
464 "Field '{}' exists",
465 field
466 )))
467 }
468 }
469
470 fn validate_all(
471 &mut self,
472 rules: &[Rule],
473 context: &HashMap<String, serde_json::Value>,
474 ) -> Result<ValidationResult> {
475 let mut result = ValidationResult::success();
476 for rule in rules {
477 result = result.merge(self.validate(rule, context)?);
478 }
479 Ok(result)
480 }
481
482 fn validate_any(
483 &mut self,
484 rules: &[Rule],
485 context: &HashMap<String, serde_json::Value>,
486 ) -> Result<ValidationResult> {
487 let mut errors = Vec::new();
488 for rule in rules {
489 let result = self.validate(rule, context)?;
490 if result.passed {
491 return Ok(ValidationResult::success());
492 }
493 errors.extend(result.errors);
494 }
495 Ok(ValidationResult::failure(format!(
496 "None of the conditions met: {}",
497 errors.join("; ")
498 )))
499 }
500}
501
502pub fn evaluate_expression(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<bool> {
504 let expr = expr.trim();
505
506 if expr.contains(" && ") {
508 let parts: Vec<&str> = expr.split(" && ").collect();
509 for part in parts {
510 if !evaluate_expression(part, context)? {
511 return Ok(false);
512 }
513 }
514 return Ok(true);
515 }
516
517 if expr.contains(" || ") {
519 let parts: Vec<&str> = expr.split(" || ").collect();
520 for part in parts {
521 if evaluate_expression(part, context)? {
522 return Ok(true);
523 }
524 }
525 return Ok(false);
526 }
527
528 if let Some(eq_pos) = expr.find("==") {
531 let left = expr[..eq_pos].trim();
532 let right = expr[eq_pos + 2..].trim();
533 return evaluate_comparison(left, right, context, true);
534 }
535
536 if let Some(ne_pos) = expr.find("!=") {
538 let left = expr[..ne_pos].trim();
539 let right = expr[ne_pos + 2..].trim();
540 return evaluate_comparison(left, right, context, false);
541 }
542
543 if let Some(ge_pos) = expr.find(">=") {
545 let left = expr[..ge_pos].trim();
546 let right = expr[ge_pos + 2..].trim();
547 return evaluate_numeric_comparison(left, right, context, ">=");
548 }
549
550 if let Some(le_pos) = expr.find("<=") {
552 let left = expr[..le_pos].trim();
553 let right = expr[le_pos + 2..].trim();
554 return evaluate_numeric_comparison(left, right, context, "<=");
555 }
556
557 if let Some(gt_pos) = expr.find('>') {
559 let left = expr[..gt_pos].trim();
560 let right = expr[gt_pos + 1..].trim();
561 return evaluate_numeric_comparison(left, right, context, ">");
562 }
563
564 if let Some(lt_pos) = expr.find('<') {
566 let left = expr[..lt_pos].trim();
567 let right = expr[lt_pos + 1..].trim();
568 return evaluate_numeric_comparison(left, right, context, "<");
569 }
570
571 match expr {
573 "true" => Ok(true),
574 "false" => Ok(false),
575 _ => {
576 if let Some(value) = context.get(expr) {
578 Ok(value.as_bool().unwrap_or(false))
579 } else {
580 Ok(false)
581 }
582 }
583 }
584}
585
586fn evaluate_comparison(
587 left: &str,
588 right: &str,
589 context: &HashMap<String, serde_json::Value>,
590 equals: bool,
591) -> Result<bool> {
592 let left_val = resolve_value(left, context)?;
593 let right_val = resolve_value(right, context)?;
594
595 let result = left_val == right_val;
596 Ok(if equals { result } else { !result })
597}
598
599fn evaluate_numeric_comparison(
600 left: &str,
601 right: &str,
602 context: &HashMap<String, serde_json::Value>,
603 op: &str,
604) -> Result<bool> {
605 let left_val = resolve_numeric(left, context)
606 .with_context(|| format!("Failed to resolve left operand: {}", left))?;
607 let right_val = resolve_numeric(right, context)
608 .with_context(|| format!("Failed to resolve right operand: {}", right))?;
609
610 let result = match op {
611 ">" => left_val > right_val,
612 "<" => left_val < right_val,
613 ">=" => left_val >= right_val,
614 "<=" => left_val <= right_val,
615 _ => false,
616 };
617
618 Ok(result)
619}
620
621fn resolve_value(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<serde_json::Value> {
622 if expr.starts_with('"') && expr.ends_with('"') {
624 return Ok(serde_json::Value::String(expr[1..expr.len()-1].to_string()));
625 }
626
627 if let Ok(n) = expr.parse::<i64>() {
629 return Ok(serde_json::Value::Number(n.into()));
630 }
631 if let Ok(n) = expr.parse::<f64>()
632 && let Some(num) = serde_json::Number::from_f64(n) {
633 return Ok(serde_json::Value::Number(num));
634 }
635
636 if expr == "true" {
638 return Ok(serde_json::Value::Bool(true));
639 }
640 if expr == "false" {
641 return Ok(serde_json::Value::Bool(false));
642 }
643
644 if expr == "null" {
646 return Ok(serde_json::Value::Null);
647 }
648
649 if let Some(value) = context.get(expr) {
651 return Ok(value.clone());
652 }
653
654 anyhow::bail!("Unknown value: {}", expr)
655}
656
657fn resolve_numeric(expr: &str, context: &HashMap<String, serde_json::Value>) -> Result<f64> {
658 if let Ok(n) = expr.parse::<f64>() {
660 return Ok(n);
661 }
662
663 if let Some(value) = context.get(expr)
665 && let Some(n) = value.as_f64() {
666 return Ok(n);
667 }
668
669 anyhow::bail!("Not a numeric value: {}", expr)
670}
671
672#[cfg(test)]
673mod tests {
674 use super::*;
675 use serde_json::json;
676
677 #[test]
678 fn test_rule_contains() {
679 let mut engine = RuleEngine::new();
680 let mut context = HashMap::new();
681 context.insert("text".to_string(), json!("Hello, World!"));
682
683 let rule = Rule::Contains {
684 field: "text".to_string(),
685 value: "World".to_string(),
686 case_sensitive: true,
687 };
688
689 let result = engine.validate(&rule, &context).unwrap();
690 assert!(result.passed);
691 }
692
693 #[test]
694 fn test_rule_length_gte() {
695 let mut engine = RuleEngine::new();
696 let mut context = HashMap::new();
697 context.insert("name".to_string(), json!("Alice"));
698
699 let rule = Rule::LengthGte {
700 field: "name".to_string(),
701 min: 3,
702 };
703
704 let result = engine.validate(&rule, &context).unwrap();
705 assert!(result.passed);
706 }
707
708 #[test]
709 fn test_rule_matches() {
710 let mut engine = RuleEngine::new();
711 let mut context = HashMap::new();
712 context.insert("email".to_string(), json!("test@example.com"));
713
714 let rule = Rule::Matches {
715 field: "email".to_string(),
716 pattern: r"^[a-zA-Z0-9._%+-]+@[a-zA-Z0-9.-]+\.[a-zA-Z]{2,}$".to_string(),
717 };
718
719 let result = engine.validate(&rule, &context).unwrap();
720 assert!(result.passed);
721 }
722
723 #[test]
724 fn test_rule_all() {
725 let mut engine = RuleEngine::new();
726 let mut context = HashMap::new();
727 context.insert("name".to_string(), json!("Alice"));
728 context.insert("age".to_string(), json!(25));
729
730 let rule = Rule::All {
731 rules: vec![
732 Rule::LengthGte { field: "name".to_string(), min: 3 },
733 Rule::GreaterThan { field: "age".to_string(), value: 18.0 },
734 ],
735 };
736
737 let result = engine.validate(&rule, &context).unwrap();
738 assert!(result.passed);
739 }
740
741 #[test]
742 fn test_evaluate_expression() {
743 let mut context = HashMap::new();
744 context.insert("count".to_string(), json!(10));
745 context.insert("enabled".to_string(), json!(true));
746
747 assert!(evaluate_expression("count == 10", &context).unwrap());
748 assert!(evaluate_expression("count > 5", &context).unwrap());
749 assert!(evaluate_expression("count < 20 && enabled == true", &context).unwrap());
750 assert!(evaluate_expression("count < 5 || enabled == true", &context).unwrap());
751 }
752}