Skip to main content

oxigdal_workflow/conditional/
expressions.rs

1//! Conditional expression evaluation.
2
3use crate::error::{Result, WorkflowError};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8/// A conditional expression.
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum Expression {
11    /// Literal value.
12    Literal(Value),
13    /// Variable reference.
14    Variable(String),
15    /// Binary operation.
16    Binary {
17        /// Left operand.
18        left: Box<Expression>,
19        /// Binary operator.
20        op: BinaryOperator,
21        /// Right operand.
22        right: Box<Expression>,
23    },
24    /// Unary operation.
25    Unary {
26        /// Unary operator.
27        op: UnaryOperator,
28        /// Expression to apply operator to.
29        expr: Box<Expression>,
30    },
31    /// Function call.
32    Function {
33        /// Function name.
34        name: String,
35        /// Function arguments.
36        args: Vec<Expression>,
37    },
38}
39
40/// Binary operators.
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum BinaryOperator {
43    /// Equality (==).
44    Eq,
45    /// Inequality (!=).
46    Ne,
47    /// Less than (<).
48    Lt,
49    /// Less than or equal (<=).
50    Le,
51    /// Greater than (>).
52    Gt,
53    /// Greater than or equal (>=).
54    Ge,
55    /// Logical AND (&&).
56    And,
57    /// Logical OR (||).
58    Or,
59    /// Addition (+).
60    Add,
61    /// Subtraction (-).
62    Sub,
63    /// Multiplication (*).
64    Mul,
65    /// Division (/).
66    Div,
67    /// Modulo (%).
68    Mod,
69}
70
71/// Unary operators.
72#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum UnaryOperator {
74    /// Logical NOT (!).
75    Not,
76    /// Negation (-).
77    Neg,
78}
79
80/// Expression context for variable lookup.
81pub type ExpressionContext = HashMap<String, Value>;
82
83impl Expression {
84    /// Create a literal expression.
85    pub fn literal(value: Value) -> Self {
86        Self::Literal(value)
87    }
88
89    /// Create a variable reference expression.
90    pub fn variable<S: Into<String>>(name: S) -> Self {
91        Self::Variable(name.into())
92    }
93
94    /// Create a binary expression.
95    pub fn binary(left: Expression, op: BinaryOperator, right: Expression) -> Self {
96        Self::Binary {
97            left: Box::new(left),
98            op,
99            right: Box::new(right),
100        }
101    }
102
103    /// Create an equality expression.
104    pub fn eq(left: Expression, right: Expression) -> Self {
105        Self::binary(left, BinaryOperator::Eq, right)
106    }
107
108    /// Create a logical AND expression.
109    pub fn and(left: Expression, right: Expression) -> Self {
110        Self::binary(left, BinaryOperator::And, right)
111    }
112
113    /// Create a logical OR expression.
114    pub fn or(left: Expression, right: Expression) -> Self {
115        Self::binary(left, BinaryOperator::Or, right)
116    }
117
118    /// Create a NOT expression.
119    pub fn logical_not(expr: Expression) -> Self {
120        Self::Unary {
121            op: UnaryOperator::Not,
122            expr: Box::new(expr),
123        }
124    }
125
126    /// Evaluate the expression.
127    pub fn evaluate(&self, context: &ExpressionContext) -> Result<Value> {
128        match self {
129            Expression::Literal(value) => Ok(value.clone()),
130
131            Expression::Variable(name) => context.get(name).cloned().ok_or_else(|| {
132                WorkflowError::conditional(format!("Variable '{}' not found", name))
133            }),
134
135            Expression::Binary { left, op, right } => {
136                let left_val = left.evaluate(context)?;
137                let right_val = right.evaluate(context)?;
138                self.evaluate_binary(*op, &left_val, &right_val)
139            }
140
141            Expression::Unary { op, expr } => {
142                let val = expr.evaluate(context)?;
143                self.evaluate_unary(*op, &val)
144            }
145
146            Expression::Function { name, args } => {
147                let arg_vals: Result<Vec<_>> =
148                    args.iter().map(|arg| arg.evaluate(context)).collect();
149                let arg_vals = arg_vals?;
150                self.evaluate_function(name, &arg_vals)
151            }
152        }
153    }
154
155    /// Evaluate a binary operation.
156    fn evaluate_binary(&self, op: BinaryOperator, left: &Value, right: &Value) -> Result<Value> {
157        match op {
158            BinaryOperator::Eq => Ok(Value::Bool(left == right)),
159            BinaryOperator::Ne => Ok(Value::Bool(left != right)),
160            BinaryOperator::Lt => self.compare_values(left, right, |cmp| cmp.is_lt()),
161            BinaryOperator::Le => self.compare_values(left, right, |cmp| cmp.is_le()),
162            BinaryOperator::Gt => self.compare_values(left, right, |cmp| cmp.is_gt()),
163            BinaryOperator::Ge => self.compare_values(left, right, |cmp| cmp.is_ge()),
164            BinaryOperator::And => self.logical_and(left, right),
165            BinaryOperator::Or => self.logical_or(left, right),
166            BinaryOperator::Add => self.arithmetic_op(left, right, |a, b| a + b),
167            BinaryOperator::Sub => self.arithmetic_op(left, right, |a, b| a - b),
168            BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
169            BinaryOperator::Div => {
170                self.arithmetic_op(left, right, |a, b| if b == 0.0 { f64::NAN } else { a / b })
171            }
172            BinaryOperator::Mod => self.arithmetic_op(left, right, |a, b| a % b),
173        }
174    }
175
176    /// Compare two values.
177    fn compare_values<F>(&self, left: &Value, right: &Value, pred: F) -> Result<Value>
178    where
179        F: FnOnce(std::cmp::Ordering) -> bool,
180    {
181        let cmp = match (left, right) {
182            (Value::Number(l), Value::Number(r)) => {
183                let l = l
184                    .as_f64()
185                    .ok_or_else(|| WorkflowError::conditional("Invalid number"))?;
186                let r = r
187                    .as_f64()
188                    .ok_or_else(|| WorkflowError::conditional("Invalid number"))?;
189                l.partial_cmp(&r)
190                    .ok_or_else(|| WorkflowError::conditional("NaN comparison"))?
191            }
192            (Value::String(l), Value::String(r)) => l.cmp(r),
193            _ => {
194                return Err(WorkflowError::conditional("Cannot compare these types"));
195            }
196        };
197
198        Ok(Value::Bool(pred(cmp)))
199    }
200
201    /// Logical AND operation.
202    fn logical_and(&self, left: &Value, right: &Value) -> Result<Value> {
203        let left_bool = left
204            .as_bool()
205            .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
206        let right_bool = right
207            .as_bool()
208            .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
209        Ok(Value::Bool(left_bool && right_bool))
210    }
211
212    /// Logical OR operation.
213    fn logical_or(&self, left: &Value, right: &Value) -> Result<Value> {
214        let left_bool = left
215            .as_bool()
216            .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
217        let right_bool = right
218            .as_bool()
219            .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
220        Ok(Value::Bool(left_bool || right_bool))
221    }
222
223    /// Arithmetic operation.
224    fn arithmetic_op<F>(&self, left: &Value, right: &Value, op: F) -> Result<Value>
225    where
226        F: FnOnce(f64, f64) -> f64,
227    {
228        let left_num = left
229            .as_f64()
230            .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
231        let right_num = right
232            .as_f64()
233            .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
234
235        let result = op(left_num, right_num);
236        Ok(serde_json::json!(result))
237    }
238
239    /// Evaluate a unary operation.
240    fn evaluate_unary(&self, op: UnaryOperator, val: &Value) -> Result<Value> {
241        match op {
242            UnaryOperator::Not => {
243                let bool_val = val
244                    .as_bool()
245                    .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
246                Ok(Value::Bool(!bool_val))
247            }
248            UnaryOperator::Neg => {
249                let num_val = val
250                    .as_f64()
251                    .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
252                Ok(serde_json::json!(-num_val))
253            }
254        }
255    }
256
257    /// Evaluate a function call.
258    fn evaluate_function(&self, name: &str, args: &[Value]) -> Result<Value> {
259        match name {
260            "len" => {
261                if args.len() != 1 {
262                    return Err(WorkflowError::conditional("len() expects 1 argument"));
263                }
264                match &args[0] {
265                    Value::String(s) => Ok(Value::Number(s.len().into())),
266                    Value::Array(a) => Ok(Value::Number(a.len().into())),
267                    _ => Err(WorkflowError::conditional("len() expects string or array")),
268                }
269            }
270            "upper" => {
271                if args.len() != 1 {
272                    return Err(WorkflowError::conditional("upper() expects 1 argument"));
273                }
274                match &args[0] {
275                    Value::String(s) => Ok(Value::String(s.to_uppercase())),
276                    _ => Err(WorkflowError::conditional("upper() expects string")),
277                }
278            }
279            "lower" => {
280                if args.len() != 1 {
281                    return Err(WorkflowError::conditional("lower() expects 1 argument"));
282                }
283                match &args[0] {
284                    Value::String(s) => Ok(Value::String(s.to_lowercase())),
285                    _ => Err(WorkflowError::conditional("lower() expects string")),
286                }
287            }
288            _ => Err(WorkflowError::conditional(format!(
289                "Unknown function '{}'",
290                name
291            ))),
292        }
293    }
294}
295
296/// Parse a simple conditional expression from a string.
297/// Format: "variable operator value" (e.g., "status == 'success'")
298pub fn parse_simple_expression(expr: &str) -> Result<Expression> {
299    let parts: Vec<&str> = expr.split_whitespace().collect();
300
301    if parts.len() != 3 {
302        return Err(WorkflowError::conditional(
303            "Invalid expression format. Expected: 'variable operator value'",
304        ));
305    }
306
307    let var = Expression::variable(parts[0]);
308    let value = parse_value(parts[2])?;
309
310    let op = match parts[1] {
311        "==" => BinaryOperator::Eq,
312        "!=" => BinaryOperator::Ne,
313        "<" => BinaryOperator::Lt,
314        "<=" => BinaryOperator::Le,
315        ">" => BinaryOperator::Gt,
316        ">=" => BinaryOperator::Ge,
317        _ => {
318            return Err(WorkflowError::conditional(format!(
319                "Unknown operator '{}'",
320                parts[1]
321            )));
322        }
323    };
324
325    Ok(Expression::binary(var, op, Expression::literal(value)))
326}
327
328/// Parse a value from a string.
329fn parse_value(s: &str) -> Result<Value> {
330    // Try to parse as number
331    if let Ok(num) = s.parse::<i64>() {
332        return Ok(Value::Number(num.into()));
333    }
334    if let Ok(num) = s.parse::<f64>() {
335        return Ok(serde_json::json!(num));
336    }
337
338    // Try to parse as boolean
339    if let Ok(b) = s.parse::<bool>() {
340        return Ok(Value::Bool(b));
341    }
342
343    // Parse as string (remove quotes if present)
344    let s = s.trim_matches('\'').trim_matches('"');
345    Ok(Value::String(s.to_string()))
346}
347
348#[cfg(test)]
349mod tests {
350    use super::*;
351
352    #[test]
353    fn test_literal() {
354        let expr = Expression::literal(Value::Bool(true));
355        let result = expr.evaluate(&HashMap::new()).expect("Failed to evaluate");
356        assert_eq!(result, Value::Bool(true));
357    }
358
359    #[test]
360    fn test_variable() {
361        let mut ctx = HashMap::new();
362        ctx.insert("x".to_string(), Value::Number(42.into()));
363
364        let expr = Expression::variable("x");
365        let result = expr.evaluate(&ctx).expect("Failed to evaluate");
366        assert_eq!(result, Value::Number(42.into()));
367    }
368
369    #[test]
370    fn test_equality() {
371        let mut ctx = HashMap::new();
372        ctx.insert("status".to_string(), Value::String("success".to_string()));
373
374        let expr = Expression::eq(
375            Expression::variable("status"),
376            Expression::literal(Value::String("success".to_string())),
377        );
378
379        let result = expr.evaluate(&ctx).expect("Failed to evaluate");
380        assert_eq!(result, Value::Bool(true));
381    }
382
383    #[test]
384    fn test_comparison() {
385        let mut ctx = HashMap::new();
386        ctx.insert("count".to_string(), Value::Number(10.into()));
387
388        let expr = Expression::binary(
389            Expression::variable("count"),
390            BinaryOperator::Gt,
391            Expression::literal(Value::Number(5.into())),
392        );
393
394        let result = expr.evaluate(&ctx).expect("Failed to evaluate");
395        assert_eq!(result, Value::Bool(true));
396    }
397
398    #[test]
399    fn test_parse_simple_expression() {
400        let expr = parse_simple_expression("status == 'success'").expect("Failed to parse");
401
402        let mut ctx = HashMap::new();
403        ctx.insert("status".to_string(), Value::String("success".to_string()));
404
405        let result = expr.evaluate(&ctx).expect("Failed to evaluate");
406        assert_eq!(result, Value::Bool(true));
407    }
408}