use serde::{Deserialize, Serialize};
use std::collections::HashMap;
pub trait ConditionProtocol: Send + Sync {
fn evaluate(&self, context: &HashMap<String, serde_json::Value>) -> bool;
}
pub trait RoutingConditionProtocol: ConditionProtocol {
fn get_target(&self, context: &HashMap<String, serde_json::Value>) -> Vec<String>;
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpressionCondition {
pub expression: String,
}
impl ExpressionCondition {
pub fn new(expression: impl Into<String>) -> Self {
Self {
expression: expression.into(),
}
}
fn evaluate_expression(&self, context: &HashMap<String, serde_json::Value>) -> bool {
let expr = self.expression.trim();
if !expr.contains('>') && !expr.contains('<') && !expr.contains('=') && !expr.contains("!=") {
return context
.get(expr)
.and_then(|v| v.as_bool())
.unwrap_or(false);
}
let operators = [">=", "<=", "!=", "==", ">", "<"];
for op in operators {
if let Some(pos) = expr.find(op) {
let left = expr[..pos].trim();
let right = expr[pos + op.len()..].trim();
return self.compare(left, op, right, context);
}
}
false
}
fn compare(
&self,
left: &str,
op: &str,
right: &str,
context: &HashMap<String, serde_json::Value>,
) -> bool {
let left_val = context.get(left);
let right_val = if right.starts_with('\'') || right.starts_with('"') {
let s = right.trim_matches(|c| c == '\'' || c == '"');
Some(serde_json::json!(s))
} else if let Ok(n) = right.parse::<f64>() {
Some(serde_json::json!(n))
} else if right == "true" {
Some(serde_json::json!(true))
} else if right == "false" {
Some(serde_json::json!(false))
} else {
context.get(right).cloned()
};
match (left_val, right_val) {
(Some(l), Some(r)) => {
if let (Some(ln), Some(rn)) = (l.as_f64(), r.as_f64()) {
return match op {
">" => ln > rn,
"<" => ln < rn,
">=" => ln >= rn,
"<=" => ln <= rn,
"==" => (ln - rn).abs() < f64::EPSILON,
"!=" => (ln - rn).abs() >= f64::EPSILON,
_ => false,
};
}
if let (Some(ls), Some(rs)) = (l.as_str(), r.as_str()) {
return match op {
"==" => ls == rs,
"!=" => ls != rs,
">" => ls > rs,
"<" => ls < rs,
">=" => ls >= rs,
"<=" => ls <= rs,
_ => false,
};
}
if let (Some(lb), Some(rb)) = (l.as_bool(), r.as_bool()) {
return match op {
"==" => lb == rb,
"!=" => lb != rb,
_ => false,
};
}
match op {
"==" => l == &r,
"!=" => l != &r,
_ => false,
}
}
_ => false,
}
}
}
impl ConditionProtocol for ExpressionCondition {
fn evaluate(&self, context: &HashMap<String, serde_json::Value>) -> bool {
self.evaluate_expression(context)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct DictCondition {
pub variable: String,
pub routes: HashMap<String, Vec<String>>,
pub default: Vec<String>,
}
impl DictCondition {
pub fn new(variable: impl Into<String>) -> Self {
Self {
variable: variable.into(),
routes: HashMap::new(),
default: Vec::new(),
}
}
pub fn when(mut self, value: impl Into<String>, targets: Vec<String>) -> Self {
self.routes.insert(value.into(), targets);
self
}
pub fn default_targets(mut self, targets: Vec<String>) -> Self {
self.default = targets;
self
}
}
impl ConditionProtocol for DictCondition {
fn evaluate(&self, context: &HashMap<String, serde_json::Value>) -> bool {
if let Some(value) = context.get(&self.variable) {
let key = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Number(n) => n.to_string(),
_ => return !self.default.is_empty(),
};
self.routes.contains_key(&key) || !self.default.is_empty()
} else {
!self.default.is_empty()
}
}
}
impl RoutingConditionProtocol for DictCondition {
fn get_target(&self, context: &HashMap<String, serde_json::Value>) -> Vec<String> {
if let Some(value) = context.get(&self.variable) {
let key = match value {
serde_json::Value::String(s) => s.clone(),
serde_json::Value::Bool(b) => b.to_string(),
serde_json::Value::Number(n) => n.to_string(),
_ => return self.default.clone(),
};
self.routes.get(&key).cloned().unwrap_or_else(|| self.default.clone())
} else {
self.default.clone()
}
}
}
pub struct ClosureCondition<F>
where
F: Fn(&HashMap<String, serde_json::Value>) -> bool + Send + Sync,
{
condition: F,
}
impl<F> ClosureCondition<F>
where
F: Fn(&HashMap<String, serde_json::Value>) -> bool + Send + Sync,
{
pub fn new(condition: F) -> Self {
Self { condition }
}
}
impl<F> ConditionProtocol for ClosureCondition<F>
where
F: Fn(&HashMap<String, serde_json::Value>) -> bool + Send + Sync,
{
fn evaluate(&self, context: &HashMap<String, serde_json::Value>) -> bool {
(self.condition)(context)
}
}
pub fn evaluate_condition(
expression: &str,
context: &HashMap<String, serde_json::Value>,
) -> bool {
ExpressionCondition::new(expression).evaluate(context)
}
pub struct If;
impl If {
pub fn expr(expression: impl Into<String>) -> ExpressionCondition {
ExpressionCondition::new(expression)
}
pub fn dict(variable: impl Into<String>) -> DictCondition {
DictCondition::new(variable)
}
pub fn closure<F>(f: F) -> ClosureCondition<F>
where
F: Fn(&HashMap<String, serde_json::Value>) -> bool + Send + Sync,
{
ClosureCondition::new(f)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_context(pairs: &[(&str, serde_json::Value)]) -> HashMap<String, serde_json::Value> {
pairs.iter().map(|(k, v)| (k.to_string(), v.clone())).collect()
}
#[test]
fn test_expression_condition_numeric_greater() {
let cond = ExpressionCondition::new("score > 80");
let ctx = make_context(&[("score", serde_json::json!(90))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("score", serde_json::json!(70))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_numeric_equal() {
let cond = ExpressionCondition::new("count == 5");
let ctx = make_context(&[("count", serde_json::json!(5))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("count", serde_json::json!(3))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_string_equal() {
let cond = ExpressionCondition::new("status == 'approved'");
let ctx = make_context(&[("status", serde_json::json!("approved"))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("status", serde_json::json!("rejected"))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_not_equal() {
let cond = ExpressionCondition::new("status != 'pending'");
let ctx = make_context(&[("status", serde_json::json!("approved"))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("status", serde_json::json!("pending"))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_boolean() {
let cond = ExpressionCondition::new("is_valid");
let ctx = make_context(&[("is_valid", serde_json::json!(true))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("is_valid", serde_json::json!(false))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_missing_variable() {
let cond = ExpressionCondition::new("score > 80");
let ctx = make_context(&[]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_dict_condition_basic() {
let cond = DictCondition::new("decision")
.when("approved", vec!["process_task".to_string()])
.when("rejected", vec!["revision_task".to_string()]);
let ctx = make_context(&[("decision", serde_json::json!("approved"))]);
assert!(cond.evaluate(&ctx));
assert_eq!(cond.get_target(&ctx), vec!["process_task"]);
let ctx = make_context(&[("decision", serde_json::json!("rejected"))]);
assert!(cond.evaluate(&ctx));
assert_eq!(cond.get_target(&ctx), vec!["revision_task"]);
}
#[test]
fn test_dict_condition_default() {
let cond = DictCondition::new("decision")
.when("approved", vec!["process_task".to_string()])
.default_targets(vec!["fallback_task".to_string()]);
let ctx = make_context(&[("decision", serde_json::json!("unknown"))]);
assert!(cond.evaluate(&ctx));
assert_eq!(cond.get_target(&ctx), vec!["fallback_task"]);
}
#[test]
fn test_dict_condition_no_match_no_default() {
let cond = DictCondition::new("decision")
.when("approved", vec!["process_task".to_string()]);
let ctx = make_context(&[("decision", serde_json::json!("unknown"))]);
assert!(!cond.evaluate(&ctx));
assert!(cond.get_target(&ctx).is_empty());
}
#[test]
fn test_closure_condition() {
let cond = ClosureCondition::new(|ctx| {
ctx.get("score")
.and_then(|v| v.as_f64())
.map(|s| s > 80.0)
.unwrap_or(false)
});
let ctx = make_context(&[("score", serde_json::json!(90))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("score", serde_json::json!(70))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_evaluate_condition_function() {
let ctx = make_context(&[("score", serde_json::json!(90))]);
assert!(evaluate_condition("score > 80", &ctx));
assert!(!evaluate_condition("score < 80", &ctx));
}
#[test]
fn test_if_builder() {
let expr_cond = If::expr("score > 80");
let ctx = make_context(&[("score", serde_json::json!(90))]);
assert!(expr_cond.evaluate(&ctx));
let dict_cond = If::dict("status")
.when("ok", vec!["proceed".to_string()]);
let ctx = make_context(&[("status", serde_json::json!("ok"))]);
assert!(dict_cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_greater_equal() {
let cond = ExpressionCondition::new("score >= 80");
let ctx = make_context(&[("score", serde_json::json!(80))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("score", serde_json::json!(79))]);
assert!(!cond.evaluate(&ctx));
}
#[test]
fn test_expression_condition_less_than() {
let cond = ExpressionCondition::new("score < 50");
let ctx = make_context(&[("score", serde_json::json!(30))]);
assert!(cond.evaluate(&ctx));
let ctx = make_context(&[("score", serde_json::json!(60))]);
assert!(!cond.evaluate(&ctx));
}
}