Skip to main content

cloudillo_action/dsl/
expression.rs

1//! Expression evaluator for the Action DSL
2//!
3//! Evaluates expressions in hook contexts, supporting:
4//! - Variable references with path traversal
5//! - Template string interpolation
6//! - Comparison operations
7//! - Logical operations
8//! - Arithmetic operations
9//! - String operations
10//! - Ternary expressions
11//! - Null coalescing
12
13use super::types::*;
14use crate::hooks::HookContext;
15use crate::prelude::*;
16use serde_json::Value;
17
18/// Maximum expression nesting depth to prevent stack overflow
19const MAX_DEPTH: usize = 50;
20/// Maximum expression node count to prevent resource exhaustion
21const MAX_NODES: usize = 100;
22
23/// Expression evaluator with depth and node count tracking
24pub struct ExpressionEvaluator {
25	depth: usize,
26	node_count: usize,
27}
28
29impl ExpressionEvaluator {
30	/// Create a new expression evaluator
31	pub fn new() -> Self {
32		Self { depth: 0, node_count: 0 }
33	}
34
35	/// Evaluate an expression in the given context
36	pub fn evaluate(&mut self, expr: &Expression, context: &HookContext) -> ClResult<Value> {
37		self.depth += 1;
38		self.node_count += 1;
39
40		if self.depth > MAX_DEPTH {
41			return Err(Error::ValidationError(format!(
42				"Maximum expression depth exceeded ({})",
43				MAX_DEPTH
44			)));
45		}
46		if self.node_count > MAX_NODES {
47			return Err(Error::ValidationError(format!(
48				"Maximum expression nodes exceeded ({})",
49				MAX_NODES
50			)));
51		}
52
53		let result = self.evaluate_inner(expr, context)?;
54
55		self.depth -= 1;
56		Ok(result)
57	}
58
59	fn evaluate_inner(&mut self, expr: &Expression, context: &HookContext) -> ClResult<Value> {
60		match expr {
61			// Literals
62			Expression::Null => Ok(Value::Null),
63			Expression::Bool(b) => Ok(Value::Bool(*b)),
64			Expression::Number(n) => {
65				serde_json::Number::from_f64(*n).map(Value::Number).ok_or_else(|| {
66					Error::ValidationError("Invalid number (NaN or infinity)".to_string())
67				})
68			}
69			Expression::String(s) => self.evaluate_template(s, context),
70
71			// Complex expressions
72			Expression::Comparison(c) => self.evaluate_comparison(c, context),
73			Expression::Logical(l) => self.evaluate_logical(l, context),
74			Expression::Arithmetic(a) => self.evaluate_arithmetic(a, context),
75			Expression::StringOp(s) => self.evaluate_string_op(s, context),
76			Expression::Ternary(t) => self.evaluate_ternary(t, context),
77			Expression::Coalesce(c) => self.evaluate_coalesce(c, context),
78		}
79	}
80
81	/// Evaluate template string with variable interpolation
82	/// Supports:
83	/// - Simple variables: "{variable}"
84	/// - Nested paths: "{context.tenant.type}"
85	/// - Template strings: "Key: {type}:{issuer}:{audience}"
86	fn evaluate_template(&mut self, template: &str, context: &HookContext) -> ClResult<Value> {
87		// Check if it's a simple variable reference: "{variable}"
88		if template.starts_with('{')
89			&& template.ends_with('}')
90			&& template.matches('{').count() == 1
91		{
92			let var_name = &template[1..template.len() - 1];
93			return self.get_variable(var_name, context);
94		}
95
96		// Template with embedded variables: "Key: {type}:{issuer}"
97		let mut result = String::new();
98		let mut chars = template.chars().peekable();
99
100		while let Some(ch) = chars.next() {
101			if ch == '{' {
102				// Extract variable name
103				let mut var_name = String::new();
104				while let Some(&next_ch) = chars.peek() {
105					if next_ch == '}' {
106						chars.next(); // consume '}'
107						break;
108					}
109					// Safe: we just peeked and confirmed there's a character
110					if let Some(ch) = chars.next() {
111						var_name.push(ch);
112					}
113				}
114
115				// Get variable value
116				let value = self.get_variable(&var_name, context)?;
117				let replacement = match value {
118					Value::Null => String::new(),
119					Value::String(s) => s,
120					v => v.to_string(),
121				};
122				result.push_str(&replacement);
123			} else {
124				result.push(ch);
125			}
126		}
127
128		Ok(Value::String(result))
129	}
130
131	/// Get variable from context by path
132	/// Supports:
133	/// - Direct fields: "issuer", "type", "subtype"
134	/// - Nested paths: "context.tenant.type"
135	/// - User variables: any name set by Set operation
136	fn get_variable(&self, path: &str, context: &HookContext) -> ClResult<Value> {
137		let parts: Vec<&str> = path.split('.').collect();
138
139		// Start with the root value
140		let mut current = match parts[0] {
141			// Action fields
142			"action_id" => Value::String(context.action_id.clone()),
143			"type" => Value::String(context.r#type.clone()),
144			"subtype" => context
145				.subtype
146				.as_ref()
147				.map(|s| Value::String(s.clone()))
148				.unwrap_or(Value::Null),
149			"issuer" => Value::String(context.issuer.clone()),
150			"audience" => context
151				.audience
152				.as_ref()
153				.map(|s| Value::String(s.clone()))
154				.unwrap_or(Value::Null),
155			"parent" => {
156				context.parent.as_ref().map(|s| Value::String(s.clone())).unwrap_or(Value::Null)
157			}
158			"subject" => context
159				.subject
160				.as_ref()
161				.map(|s| Value::String(s.clone()))
162				.unwrap_or(Value::Null),
163			"content" => context.content.clone().unwrap_or(Value::Null),
164			"attachments" => context
165				.attachments
166				.as_ref()
167				.map(|a| Value::Array(a.iter().map(|s| Value::String(s.clone())).collect()))
168				.unwrap_or(Value::Null),
169
170			// Timestamps
171			"created_at" => Value::String(context.created_at.clone()),
172			"expires_at" => context
173				.expires_at
174				.as_ref()
175				.map(|s| Value::String(s.clone()))
176				.unwrap_or(Value::Null),
177
178			// Context object
179			"context" => {
180				let mut obj = serde_json::Map::new();
181				obj.insert("tenant_id".to_string(), Value::Number(context.tenant_id.into()));
182				obj.insert("tenant_tag".to_string(), Value::String(context.tenant_tag.clone()));
183				obj.insert("tenant_type".to_string(), Value::String(context.tenant_type.clone()));
184				Value::Object(obj)
185			}
186
187			// Flags
188			"is_inbound" => Value::Bool(context.is_inbound),
189			"is_outbound" => Value::Bool(context.is_outbound),
190
191			// User variables
192			var_name => context.vars.get(var_name).cloned().ok_or_else(|| {
193				Error::ValidationError(format!("Variable not found: {}", var_name))
194			})?,
195		};
196
197		// Traverse nested paths
198		for part in &parts[1..] {
199			match &current {
200				Value::Object(map) => {
201					current = map.get(*part).cloned().unwrap_or(Value::Null);
202				}
203				Value::Null => return Ok(Value::Null),
204				_ => {
205					return Err(Error::ValidationError(format!(
206						"Cannot access property '{}' on non-object",
207						part
208					)))
209				}
210			}
211		}
212
213		Ok(current)
214	}
215
216	/// Evaluate comparison expression
217	fn evaluate_comparison(
218		&mut self,
219		comp: &ComparisonExpr,
220		context: &HookContext,
221	) -> ClResult<Value> {
222		match comp {
223			ComparisonExpr::Eq([left, right]) => {
224				let l = self.evaluate(left, context)?;
225				let r = self.evaluate(right, context)?;
226				Ok(Value::Bool(l == r))
227			}
228			ComparisonExpr::Ne([left, right]) => {
229				let l = self.evaluate(left, context)?;
230				let r = self.evaluate(right, context)?;
231				Ok(Value::Bool(l != r))
232			}
233			ComparisonExpr::Gt([left, right]) => {
234				let l_val = self.evaluate(left, context)?;
235				let r_val = self.evaluate(right, context)?;
236				let l = self.to_number(&l_val)?;
237				let r = self.to_number(&r_val)?;
238				Ok(Value::Bool(l > r))
239			}
240			ComparisonExpr::Gte([left, right]) => {
241				let l_val = self.evaluate(left, context)?;
242				let r_val = self.evaluate(right, context)?;
243				let l = self.to_number(&l_val)?;
244				let r = self.to_number(&r_val)?;
245				Ok(Value::Bool(l >= r))
246			}
247			ComparisonExpr::Lt([left, right]) => {
248				let l_val = self.evaluate(left, context)?;
249				let r_val = self.evaluate(right, context)?;
250				let l = self.to_number(&l_val)?;
251				let r = self.to_number(&r_val)?;
252				Ok(Value::Bool(l < r))
253			}
254			ComparisonExpr::Lte([left, right]) => {
255				let l_val = self.evaluate(left, context)?;
256				let r_val = self.evaluate(right, context)?;
257				let l = self.to_number(&l_val)?;
258				let r = self.to_number(&r_val)?;
259				Ok(Value::Bool(l <= r))
260			}
261		}
262	}
263
264	/// Evaluate logical expression
265	fn evaluate_logical(
266		&mut self,
267		logical: &LogicalExpr,
268		context: &HookContext,
269	) -> ClResult<Value> {
270		match logical {
271			LogicalExpr::And(exprs) => {
272				for expr in exprs {
273					let value = self.evaluate(expr, context)?;
274					if !self.to_bool(&value) {
275						return Ok(Value::Bool(false));
276					}
277				}
278				Ok(Value::Bool(true))
279			}
280			LogicalExpr::Or(exprs) => {
281				for expr in exprs {
282					let value = self.evaluate(expr, context)?;
283					if self.to_bool(&value) {
284						return Ok(Value::Bool(true));
285					}
286				}
287				Ok(Value::Bool(false))
288			}
289			LogicalExpr::Not(expr) => {
290				let value = self.evaluate(expr, context)?;
291				Ok(Value::Bool(!self.to_bool(&value)))
292			}
293		}
294	}
295
296	/// Evaluate arithmetic expression
297	fn evaluate_arithmetic(
298		&mut self,
299		arith: &ArithmeticExpr,
300		context: &HookContext,
301	) -> ClResult<Value> {
302		match arith {
303			ArithmeticExpr::Add(exprs) => {
304				let mut sum = 0.0;
305				for expr in exprs {
306					let val = self.evaluate(expr, context)?;
307					sum += self.to_number(&val)?;
308				}
309				serde_json::Number::from_f64(sum).map(Value::Number).ok_or_else(|| {
310					Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
311				})
312			}
313			ArithmeticExpr::Subtract([left, right]) => {
314				let l_val = self.evaluate(left, context)?;
315				let r_val = self.evaluate(right, context)?;
316				let l = self.to_number(&l_val)?;
317				let r = self.to_number(&r_val)?;
318				serde_json::Number::from_f64(l - r).map(Value::Number).ok_or_else(|| {
319					Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
320				})
321			}
322			ArithmeticExpr::Multiply(exprs) => {
323				let mut product = 1.0;
324				for expr in exprs {
325					let val = self.evaluate(expr, context)?;
326					product *= self.to_number(&val)?;
327				}
328				serde_json::Number::from_f64(product).map(Value::Number).ok_or_else(|| {
329					Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
330				})
331			}
332			ArithmeticExpr::Divide([left, right]) => {
333				let l_val = self.evaluate(left, context)?;
334				let r_val = self.evaluate(right, context)?;
335				let l = self.to_number(&l_val)?;
336				let r = self.to_number(&r_val)?;
337				serde_json::Number::from_f64(l / r).map(Value::Number).ok_or_else(|| {
338					Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
339				})
340			}
341		}
342	}
343
344	/// Evaluate string operation
345	fn evaluate_string_op(
346		&mut self,
347		string_op: &StringOpExpr,
348		context: &HookContext,
349	) -> ClResult<Value> {
350		match string_op {
351			StringOpExpr::Concat(exprs) => {
352				let mut result = String::new();
353				for expr in exprs {
354					let value = self.evaluate(expr, context)?;
355					result.push_str(&self.to_string(&value));
356				}
357				Ok(Value::String(result))
358			}
359			StringOpExpr::Contains([haystack, needle]) => {
360				let h_val = self.evaluate(haystack, context)?;
361				let n_val = self.evaluate(needle, context)?;
362				let h = self.to_string(&h_val);
363				let n = self.to_string(&n_val);
364				Ok(Value::Bool(h.contains(&n)))
365			}
366			StringOpExpr::StartsWith([string, prefix]) => {
367				let s_val = self.evaluate(string, context)?;
368				let p_val = self.evaluate(prefix, context)?;
369				let s = self.to_string(&s_val);
370				let p = self.to_string(&p_val);
371				Ok(Value::Bool(s.starts_with(&p)))
372			}
373			StringOpExpr::EndsWith([string, suffix]) => {
374				let s_val = self.evaluate(string, context)?;
375				let suf_val = self.evaluate(suffix, context)?;
376				let s = self.to_string(&s_val);
377				let suf = self.to_string(&suf_val);
378				Ok(Value::Bool(s.ends_with(&suf)))
379			}
380		}
381	}
382
383	/// Evaluate ternary expression (if-then-else)
384	fn evaluate_ternary(
385		&mut self,
386		ternary: &TernaryExpr,
387		context: &HookContext,
388	) -> ClResult<Value> {
389		let condition = self.evaluate(&ternary.r#if, context)?;
390		if self.to_bool(&condition) {
391			self.evaluate(&ternary.then, context)
392		} else {
393			self.evaluate(&ternary.r#else, context)
394		}
395	}
396
397	/// Evaluate coalesce expression (return first non-null value)
398	fn evaluate_coalesce(
399		&mut self,
400		coalesce: &CoalesceExpr,
401		context: &HookContext,
402	) -> ClResult<Value> {
403		for expr in &coalesce.coalesce {
404			let value = self.evaluate(expr, context)?;
405			if !value.is_null() {
406				return Ok(value);
407			}
408		}
409		Ok(Value::Null)
410	}
411
412	/// Convert value to boolean (truthy/falsy)
413	fn to_bool(&self, value: &Value) -> bool {
414		match value {
415			Value::Null => false,
416			Value::Bool(b) => *b,
417			Value::Number(n) => n.as_f64().unwrap_or(0.0) != 0.0,
418			Value::String(s) => !s.is_empty(),
419			Value::Array(a) => !a.is_empty(),
420			Value::Object(o) => !o.is_empty(),
421		}
422	}
423
424	/// Convert value to string
425	fn to_string(&self, value: &Value) -> String {
426		match value {
427			Value::Null => String::new(),
428			Value::Bool(b) => b.to_string(),
429			Value::Number(n) => n.to_string(),
430			Value::String(s) => s.clone(),
431			v => v.to_string(),
432		}
433	}
434
435	/// Convert value to number
436	fn to_number(&self, value: &Value) -> ClResult<f64> {
437		match value {
438			Value::Number(n) => n.as_f64().ok_or_else(|| {
439				Error::ValidationError(
440					"Invalid number value (not representable as f64)".to_string(),
441				)
442			}),
443			Value::String(s) => s.parse::<f64>().map_err(|_| {
444				Error::ValidationError(format!(
445					"Type mismatch: expected number, got string '{}'",
446					s
447				))
448			}),
449			_ => Err(Error::ValidationError(format!(
450				"Type mismatch: expected number, got {:?}",
451				value
452			))),
453		}
454	}
455}
456
457impl Default for ExpressionEvaluator {
458	fn default() -> Self {
459		Self::new()
460	}
461}
462
463#[cfg(test)]
464mod tests {
465	use super::*;
466	use std::collections::HashMap;
467
468	fn create_test_context() -> HookContext {
469		HookContext {
470			action_id: "test-action-id".to_string(),
471			r#type: "CONN".to_string(),
472			subtype: None,
473			issuer: "alice".to_string(),
474			audience: Some("bob".to_string()),
475			parent: None,
476			subject: None,
477			content: Some(Value::String("Hello".to_string())),
478			attachments: None,
479			created_at: "2024-01-01T00:00:00Z".to_string(),
480			expires_at: None,
481			tenant_id: 1,
482			tenant_tag: "example".to_string(),
483			tenant_type: "person".to_string(),
484			is_inbound: false,
485			is_outbound: true,
486			client_address: None,
487			vars: HashMap::new(),
488		}
489	}
490
491	#[test]
492	fn test_simple_variable() {
493		let mut eval = ExpressionEvaluator::new();
494		let context = create_test_context();
495		let expr = Expression::String("{issuer}".to_string());
496
497		let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
498		assert_eq!(result, Value::String("alice".to_string()));
499	}
500
501	#[test]
502	fn test_nested_path() {
503		let mut eval = ExpressionEvaluator::new();
504		let context = create_test_context();
505		let expr = Expression::String("{context.tenant_type}".to_string());
506
507		let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
508		assert_eq!(result, Value::String("person".to_string()));
509	}
510
511	#[test]
512	fn test_template_string() {
513		let mut eval = ExpressionEvaluator::new();
514		let context = create_test_context();
515		let expr = Expression::String("{type}:{issuer}:{audience}".to_string());
516
517		let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
518		assert_eq!(result, Value::String("CONN:alice:bob".to_string()));
519	}
520
521	#[test]
522	fn test_comparison_eq() {
523		let mut eval = ExpressionEvaluator::new();
524		let context = create_test_context();
525		let expr = Expression::Comparison(Box::new(ComparisonExpr::Eq([
526			Expression::String("{subtype}".to_string()),
527			Expression::Null,
528		])));
529
530		let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
531		assert_eq!(result, Value::Bool(true));
532	}
533
534	#[test]
535	fn test_logical_and() {
536		let mut eval = ExpressionEvaluator::new();
537		let context = create_test_context();
538		let expr = Expression::Logical(Box::new(LogicalExpr::And(vec![
539			Expression::Bool(true),
540			Expression::String("{issuer}".to_string()),
541		])));
542
543		let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
544		assert_eq!(result, Value::Bool(true)); // Both truthy
545	}
546}
547
548// vim: ts=4