1use crate::error::{Result, WorkflowError};
4use serde::{Deserialize, Serialize};
5use serde_json::Value;
6use std::collections::HashMap;
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub enum Expression {
11 Literal(Value),
13 Variable(String),
15 Binary {
17 left: Box<Expression>,
19 op: BinaryOperator,
21 right: Box<Expression>,
23 },
24 Unary {
26 op: UnaryOperator,
28 expr: Box<Expression>,
30 },
31 Function {
33 name: String,
35 args: Vec<Expression>,
37 },
38}
39
40#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
42pub enum BinaryOperator {
43 Eq,
45 Ne,
47 Lt,
49 Le,
51 Gt,
53 Ge,
55 And,
57 Or,
59 Add,
61 Sub,
63 Mul,
65 Div,
67 Mod,
69}
70
71#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
73pub enum UnaryOperator {
74 Not,
76 Neg,
78}
79
80pub type ExpressionContext = HashMap<String, Value>;
82
83impl Expression {
84 pub fn literal(value: Value) -> Self {
86 Self::Literal(value)
87 }
88
89 pub fn variable<S: Into<String>>(name: S) -> Self {
91 Self::Variable(name.into())
92 }
93
94 pub fn binary(left: Expression, op: BinaryOperator, right: Expression) -> Self {
96 Self::Binary {
97 left: Box::new(left),
98 op,
99 right: Box::new(right),
100 }
101 }
102
103 pub fn eq(left: Expression, right: Expression) -> Self {
105 Self::binary(left, BinaryOperator::Eq, right)
106 }
107
108 pub fn and(left: Expression, right: Expression) -> Self {
110 Self::binary(left, BinaryOperator::And, right)
111 }
112
113 pub fn or(left: Expression, right: Expression) -> Self {
115 Self::binary(left, BinaryOperator::Or, right)
116 }
117
118 pub fn logical_not(expr: Expression) -> Self {
120 Self::Unary {
121 op: UnaryOperator::Not,
122 expr: Box::new(expr),
123 }
124 }
125
126 pub fn evaluate(&self, context: &ExpressionContext) -> Result<Value> {
128 match self {
129 Expression::Literal(value) => Ok(value.clone()),
130
131 Expression::Variable(name) => context.get(name).cloned().ok_or_else(|| {
132 WorkflowError::conditional(format!("Variable '{}' not found", name))
133 }),
134
135 Expression::Binary { left, op, right } => {
136 let left_val = left.evaluate(context)?;
137 let right_val = right.evaluate(context)?;
138 self.evaluate_binary(*op, &left_val, &right_val)
139 }
140
141 Expression::Unary { op, expr } => {
142 let val = expr.evaluate(context)?;
143 self.evaluate_unary(*op, &val)
144 }
145
146 Expression::Function { name, args } => {
147 let arg_vals: Result<Vec<_>> =
148 args.iter().map(|arg| arg.evaluate(context)).collect();
149 let arg_vals = arg_vals?;
150 self.evaluate_function(name, &arg_vals)
151 }
152 }
153 }
154
155 fn evaluate_binary(&self, op: BinaryOperator, left: &Value, right: &Value) -> Result<Value> {
157 match op {
158 BinaryOperator::Eq => Ok(Value::Bool(left == right)),
159 BinaryOperator::Ne => Ok(Value::Bool(left != right)),
160 BinaryOperator::Lt => self.compare_values(left, right, |cmp| cmp.is_lt()),
161 BinaryOperator::Le => self.compare_values(left, right, |cmp| cmp.is_le()),
162 BinaryOperator::Gt => self.compare_values(left, right, |cmp| cmp.is_gt()),
163 BinaryOperator::Ge => self.compare_values(left, right, |cmp| cmp.is_ge()),
164 BinaryOperator::And => self.logical_and(left, right),
165 BinaryOperator::Or => self.logical_or(left, right),
166 BinaryOperator::Add => self.arithmetic_op(left, right, |a, b| a + b),
167 BinaryOperator::Sub => self.arithmetic_op(left, right, |a, b| a - b),
168 BinaryOperator::Mul => self.arithmetic_op(left, right, |a, b| a * b),
169 BinaryOperator::Div => {
170 self.arithmetic_op(left, right, |a, b| if b == 0.0 { f64::NAN } else { a / b })
171 }
172 BinaryOperator::Mod => self.arithmetic_op(left, right, |a, b| a % b),
173 }
174 }
175
176 fn compare_values<F>(&self, left: &Value, right: &Value, pred: F) -> Result<Value>
178 where
179 F: FnOnce(std::cmp::Ordering) -> bool,
180 {
181 let cmp = match (left, right) {
182 (Value::Number(l), Value::Number(r)) => {
183 let l = l
184 .as_f64()
185 .ok_or_else(|| WorkflowError::conditional("Invalid number"))?;
186 let r = r
187 .as_f64()
188 .ok_or_else(|| WorkflowError::conditional("Invalid number"))?;
189 l.partial_cmp(&r)
190 .ok_or_else(|| WorkflowError::conditional("NaN comparison"))?
191 }
192 (Value::String(l), Value::String(r)) => l.cmp(r),
193 _ => {
194 return Err(WorkflowError::conditional("Cannot compare these types"));
195 }
196 };
197
198 Ok(Value::Bool(pred(cmp)))
199 }
200
201 fn logical_and(&self, left: &Value, right: &Value) -> Result<Value> {
203 let left_bool = left
204 .as_bool()
205 .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
206 let right_bool = right
207 .as_bool()
208 .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
209 Ok(Value::Bool(left_bool && right_bool))
210 }
211
212 fn logical_or(&self, left: &Value, right: &Value) -> Result<Value> {
214 let left_bool = left
215 .as_bool()
216 .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
217 let right_bool = right
218 .as_bool()
219 .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
220 Ok(Value::Bool(left_bool || right_bool))
221 }
222
223 fn arithmetic_op<F>(&self, left: &Value, right: &Value, op: F) -> Result<Value>
225 where
226 F: FnOnce(f64, f64) -> f64,
227 {
228 let left_num = left
229 .as_f64()
230 .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
231 let right_num = right
232 .as_f64()
233 .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
234
235 let result = op(left_num, right_num);
236 Ok(serde_json::json!(result))
237 }
238
239 fn evaluate_unary(&self, op: UnaryOperator, val: &Value) -> Result<Value> {
241 match op {
242 UnaryOperator::Not => {
243 let bool_val = val
244 .as_bool()
245 .ok_or_else(|| WorkflowError::conditional("Expected boolean"))?;
246 Ok(Value::Bool(!bool_val))
247 }
248 UnaryOperator::Neg => {
249 let num_val = val
250 .as_f64()
251 .ok_or_else(|| WorkflowError::conditional("Expected number"))?;
252 Ok(serde_json::json!(-num_val))
253 }
254 }
255 }
256
257 fn evaluate_function(&self, name: &str, args: &[Value]) -> Result<Value> {
259 match name {
260 "len" => {
261 if args.len() != 1 {
262 return Err(WorkflowError::conditional("len() expects 1 argument"));
263 }
264 match &args[0] {
265 Value::String(s) => Ok(Value::Number(s.len().into())),
266 Value::Array(a) => Ok(Value::Number(a.len().into())),
267 _ => Err(WorkflowError::conditional("len() expects string or array")),
268 }
269 }
270 "upper" => {
271 if args.len() != 1 {
272 return Err(WorkflowError::conditional("upper() expects 1 argument"));
273 }
274 match &args[0] {
275 Value::String(s) => Ok(Value::String(s.to_uppercase())),
276 _ => Err(WorkflowError::conditional("upper() expects string")),
277 }
278 }
279 "lower" => {
280 if args.len() != 1 {
281 return Err(WorkflowError::conditional("lower() expects 1 argument"));
282 }
283 match &args[0] {
284 Value::String(s) => Ok(Value::String(s.to_lowercase())),
285 _ => Err(WorkflowError::conditional("lower() expects string")),
286 }
287 }
288 _ => Err(WorkflowError::conditional(format!(
289 "Unknown function '{}'",
290 name
291 ))),
292 }
293 }
294}
295
296pub fn parse_simple_expression(expr: &str) -> Result<Expression> {
299 let parts: Vec<&str> = expr.split_whitespace().collect();
300
301 if parts.len() != 3 {
302 return Err(WorkflowError::conditional(
303 "Invalid expression format. Expected: 'variable operator value'",
304 ));
305 }
306
307 let var = Expression::variable(parts[0]);
308 let value = parse_value(parts[2])?;
309
310 let op = match parts[1] {
311 "==" => BinaryOperator::Eq,
312 "!=" => BinaryOperator::Ne,
313 "<" => BinaryOperator::Lt,
314 "<=" => BinaryOperator::Le,
315 ">" => BinaryOperator::Gt,
316 ">=" => BinaryOperator::Ge,
317 _ => {
318 return Err(WorkflowError::conditional(format!(
319 "Unknown operator '{}'",
320 parts[1]
321 )));
322 }
323 };
324
325 Ok(Expression::binary(var, op, Expression::literal(value)))
326}
327
328fn parse_value(s: &str) -> Result<Value> {
330 if let Ok(num) = s.parse::<i64>() {
332 return Ok(Value::Number(num.into()));
333 }
334 if let Ok(num) = s.parse::<f64>() {
335 return Ok(serde_json::json!(num));
336 }
337
338 if let Ok(b) = s.parse::<bool>() {
340 return Ok(Value::Bool(b));
341 }
342
343 let s = s.trim_matches('\'').trim_matches('"');
345 Ok(Value::String(s.to_string()))
346}
347
348#[cfg(test)]
349mod tests {
350 use super::*;
351
352 #[test]
353 fn test_literal() {
354 let expr = Expression::literal(Value::Bool(true));
355 let result = expr.evaluate(&HashMap::new()).expect("Failed to evaluate");
356 assert_eq!(result, Value::Bool(true));
357 }
358
359 #[test]
360 fn test_variable() {
361 let mut ctx = HashMap::new();
362 ctx.insert("x".to_string(), Value::Number(42.into()));
363
364 let expr = Expression::variable("x");
365 let result = expr.evaluate(&ctx).expect("Failed to evaluate");
366 assert_eq!(result, Value::Number(42.into()));
367 }
368
369 #[test]
370 fn test_equality() {
371 let mut ctx = HashMap::new();
372 ctx.insert("status".to_string(), Value::String("success".to_string()));
373
374 let expr = Expression::eq(
375 Expression::variable("status"),
376 Expression::literal(Value::String("success".to_string())),
377 );
378
379 let result = expr.evaluate(&ctx).expect("Failed to evaluate");
380 assert_eq!(result, Value::Bool(true));
381 }
382
383 #[test]
384 fn test_comparison() {
385 let mut ctx = HashMap::new();
386 ctx.insert("count".to_string(), Value::Number(10.into()));
387
388 let expr = Expression::binary(
389 Expression::variable("count"),
390 BinaryOperator::Gt,
391 Expression::literal(Value::Number(5.into())),
392 );
393
394 let result = expr.evaluate(&ctx).expect("Failed to evaluate");
395 assert_eq!(result, Value::Bool(true));
396 }
397
398 #[test]
399 fn test_parse_simple_expression() {
400 let expr = parse_simple_expression("status == 'success'").expect("Failed to parse");
401
402 let mut ctx = HashMap::new();
403 ctx.insert("status".to_string(), Value::String("success".to_string()));
404
405 let result = expr.evaluate(&ctx).expect("Failed to evaluate");
406 assert_eq!(result, Value::Bool(true));
407 }
408}