use serde_json::Value;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Debug, Error)]
pub enum ConditionError {
#[error("Expression syntax error: {0}")]
SyntaxError(String),
#[error("Evaluation error: {0}")]
EvaluationError(String),
#[error("Type error: {0}")]
TypeError(String),
#[error("Variable not found: {0}")]
VariableNotFound(String),
}
pub type ConditionResult<T> = Result<T, ConditionError>;
pub struct ConditionEvaluator {
context: HashMap<String, Value>,
}
impl ConditionEvaluator {
pub fn new() -> Self {
Self {
context: HashMap::new(),
}
}
pub fn with_context(context: HashMap<String, Value>) -> Self {
Self { context }
}
pub fn set_variable(&mut self, name: impl Into<String>, value: Value) {
self.context.insert(name.into(), value);
}
pub fn get_variable(&self, name: &str) -> Option<&Value> {
self.context.get(name)
}
pub fn evaluate(&self, expression: &str) -> ConditionResult<bool> {
self.evaluate_simple(expression)
}
fn evaluate_simple(&self, expression: &str) -> ConditionResult<bool> {
let expr = expression.trim();
if expr == "true" {
return Ok(true);
}
if expr == "false" {
return Ok(false);
}
if let Some(result) = self.evaluate_comparison(expr)? {
return Ok(result);
}
if let Some(result) = self.evaluate_logical(expr)? {
return Ok(result);
}
if let Some(value) = self.get_variable_value(expr)? {
return self.value_to_bool(&value);
}
Err(ConditionError::SyntaxError(format!("Unable to evaluate expression: {}", expr)))
}
fn evaluate_comparison(&self, expr: &str) -> ConditionResult<Option<bool>> {
if let Some((left, right)) = expr.split_once("==") {
let left_val = self.evaluate_value(left.trim())?;
let right_val = self.evaluate_value(right.trim())?;
return Ok(Some(left_val == right_val));
}
if let Some((left, right)) = expr.split_once("!=") {
let left_val = self.evaluate_value(left.trim())?;
let right_val = self.evaluate_value(right.trim())?;
return Ok(Some(left_val != right_val));
}
for op in [">=", "<=", ">", "<"] {
if let Some((left, right)) = expr.split_once(op) {
let left_val = self.evaluate_value(left.trim())?;
let right_val = self.evaluate_value(right.trim())?;
if let (Some(a), Some(b)) = (
left_val.as_f64().or_else(|| left_val.as_i64().map(|i| i as f64)),
right_val.as_f64().or_else(|| right_val.as_i64().map(|i| i as f64)),
) {
let result = match op {
">=" => a >= b,
"<=" => a <= b,
">" => a > b,
"<" => a < b,
_ => false,
};
return Ok(Some(result));
}
}
}
Ok(None)
}
fn evaluate_logical(&self, expr: &str) -> ConditionResult<Option<bool>> {
if let Some(stripped) = expr.strip_prefix('!') {
let inner = stripped.trim();
let inner_result = self.evaluate(inner)?;
return Ok(Some(!inner_result));
}
if let Some((left, right)) = expr.split_once("&&") {
let left_result = self.evaluate(left.trim())?;
if !left_result {
return Ok(Some(false));
}
return Ok(Some(self.evaluate(right.trim())?));
}
if let Some((left, right)) = expr.split_once("||") {
let left_result = self.evaluate(left.trim())?;
if left_result {
return Ok(Some(true));
}
return Ok(Some(self.evaluate(right.trim())?));
}
Ok(None)
}
fn evaluate_value(&self, expr: &str) -> ConditionResult<Value> {
if let Some(value) = self.get_variable_value(expr)? {
return Ok(value.clone());
}
if let Ok(value) = serde_json::from_str::<Value>(expr) {
return Ok(value);
}
if let Ok(num) = expr.parse::<f64>() {
return Ok(Value::Number(
serde_json::Number::from_f64(num).unwrap_or_else(|| serde_json::Number::from(0)),
));
}
if expr == "true" {
return Ok(Value::Bool(true));
}
if expr == "false" {
return Ok(Value::Bool(false));
}
Ok(Value::String(expr.to_string()))
}
fn get_variable_value(&self, path: &str) -> ConditionResult<Option<Value>> {
let parts: Vec<&str> = path.split('.').collect();
if parts.is_empty() {
return Ok(None);
}
let root = self.context.get(parts[0]);
if root.is_none() {
return Ok(None);
}
let mut value = root.unwrap().clone();
for part in parts.iter().skip(1) {
match value {
Value::Object(ref obj) => {
value = obj
.get(*part)
.ok_or_else(|| {
ConditionError::VariableNotFound(format!("{}.{}", parts[0], part))
})?
.clone();
}
_ => {
return Err(ConditionError::TypeError(format!(
"Cannot access property '{}' on non-object",
part
)));
}
}
}
Ok(Some(value))
}
fn value_to_bool(&self, value: &Value) -> ConditionResult<bool> {
match value {
Value::Bool(b) => Ok(*b),
Value::Number(n) => Ok(n.as_f64().unwrap_or(0.0) != 0.0),
Value::String(s) => Ok(!s.is_empty()),
Value::Array(arr) => Ok(!arr.is_empty()),
Value::Object(obj) => Ok(!obj.is_empty()),
Value::Null => Ok(false),
}
}
}
impl Default for ConditionEvaluator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_boolean() {
let evaluator = ConditionEvaluator::new();
assert!(evaluator.evaluate("true").unwrap());
assert!(!evaluator.evaluate("false").unwrap());
}
#[test]
fn test_comparison_operators() {
let evaluator = ConditionEvaluator::new();
assert!(evaluator.evaluate("5 > 3").unwrap());
assert!(evaluator.evaluate("3 < 5").unwrap());
assert!(evaluator.evaluate("5 == 5").unwrap());
assert!(evaluator.evaluate("5 != 3").unwrap());
}
#[test]
fn test_variable_access() {
let mut context = HashMap::new();
context.insert("status".to_string(), Value::String("active".to_string()));
context.insert("count".to_string(), Value::Number(5.into()));
let evaluator = ConditionEvaluator::with_context(context);
assert!(evaluator.evaluate("count > 3").unwrap());
}
#[test]
fn test_logical_operators() {
let evaluator = ConditionEvaluator::new();
assert!(evaluator.evaluate("true && true").unwrap());
assert!(!evaluator.evaluate("true && false").unwrap());
assert!(evaluator.evaluate("true || false").unwrap());
assert!(!evaluator.evaluate("!true").unwrap());
}
}