1use super::types::*;
14use crate::hooks::HookContext;
15use crate::prelude::*;
16use serde_json::Value;
17
18const MAX_DEPTH: usize = 50;
20const MAX_NODES: usize = 100;
22
23pub struct ExpressionEvaluator {
25 depth: usize,
26 node_count: usize,
27}
28
29impl ExpressionEvaluator {
30 pub fn new() -> Self {
32 Self { depth: 0, node_count: 0 }
33 }
34
35 pub fn evaluate(&mut self, expr: &Expression, context: &HookContext) -> ClResult<Value> {
37 self.depth += 1;
38 self.node_count += 1;
39
40 if self.depth > MAX_DEPTH {
41 return Err(Error::ValidationError(format!(
42 "Maximum expression depth exceeded ({})",
43 MAX_DEPTH
44 )));
45 }
46 if self.node_count > MAX_NODES {
47 return Err(Error::ValidationError(format!(
48 "Maximum expression nodes exceeded ({})",
49 MAX_NODES
50 )));
51 }
52
53 let result = self.evaluate_inner(expr, context)?;
54
55 self.depth -= 1;
56 Ok(result)
57 }
58
59 fn evaluate_inner(&mut self, expr: &Expression, context: &HookContext) -> ClResult<Value> {
60 match expr {
61 Expression::Null => Ok(Value::Null),
63 Expression::Bool(b) => Ok(Value::Bool(*b)),
64 Expression::Number(n) => {
65 serde_json::Number::from_f64(*n).map(Value::Number).ok_or_else(|| {
66 Error::ValidationError("Invalid number (NaN or infinity)".to_string())
67 })
68 }
69 Expression::String(s) => self.evaluate_template(s, context),
70
71 Expression::Comparison(c) => self.evaluate_comparison(c, context),
73 Expression::Logical(l) => self.evaluate_logical(l, context),
74 Expression::Arithmetic(a) => self.evaluate_arithmetic(a, context),
75 Expression::StringOp(s) => self.evaluate_string_op(s, context),
76 Expression::Ternary(t) => self.evaluate_ternary(t, context),
77 Expression::Coalesce(c) => self.evaluate_coalesce(c, context),
78 }
79 }
80
81 fn evaluate_template(&mut self, template: &str, context: &HookContext) -> ClResult<Value> {
87 if template.starts_with('{')
89 && template.ends_with('}')
90 && template.matches('{').count() == 1
91 {
92 let var_name = &template[1..template.len() - 1];
93 return self.get_variable(var_name, context);
94 }
95
96 let mut result = String::new();
98 let mut chars = template.chars().peekable();
99
100 while let Some(ch) = chars.next() {
101 if ch == '{' {
102 let mut var_name = String::new();
104 while let Some(&next_ch) = chars.peek() {
105 if next_ch == '}' {
106 chars.next(); break;
108 }
109 if let Some(ch) = chars.next() {
111 var_name.push(ch);
112 }
113 }
114
115 let value = self.get_variable(&var_name, context)?;
117 let replacement = match value {
118 Value::Null => String::new(),
119 Value::String(s) => s,
120 v => v.to_string(),
121 };
122 result.push_str(&replacement);
123 } else {
124 result.push(ch);
125 }
126 }
127
128 Ok(Value::String(result))
129 }
130
131 fn get_variable(&self, path: &str, context: &HookContext) -> ClResult<Value> {
137 let parts: Vec<&str> = path.split('.').collect();
138
139 let mut current = match parts[0] {
141 "action_id" => Value::String(context.action_id.clone()),
143 "type" => Value::String(context.r#type.clone()),
144 "subtype" => context
145 .subtype
146 .as_ref()
147 .map(|s| Value::String(s.clone()))
148 .unwrap_or(Value::Null),
149 "issuer" => Value::String(context.issuer.clone()),
150 "audience" => context
151 .audience
152 .as_ref()
153 .map(|s| Value::String(s.clone()))
154 .unwrap_or(Value::Null),
155 "parent" => {
156 context.parent.as_ref().map(|s| Value::String(s.clone())).unwrap_or(Value::Null)
157 }
158 "subject" => context
159 .subject
160 .as_ref()
161 .map(|s| Value::String(s.clone()))
162 .unwrap_or(Value::Null),
163 "content" => context.content.clone().unwrap_or(Value::Null),
164 "attachments" => context
165 .attachments
166 .as_ref()
167 .map(|a| Value::Array(a.iter().map(|s| Value::String(s.clone())).collect()))
168 .unwrap_or(Value::Null),
169
170 "created_at" => Value::String(context.created_at.clone()),
172 "expires_at" => context
173 .expires_at
174 .as_ref()
175 .map(|s| Value::String(s.clone()))
176 .unwrap_or(Value::Null),
177
178 "context" => {
180 let mut obj = serde_json::Map::new();
181 obj.insert("tenant_id".to_string(), Value::Number(context.tenant_id.into()));
182 obj.insert("tenant_tag".to_string(), Value::String(context.tenant_tag.clone()));
183 obj.insert("tenant_type".to_string(), Value::String(context.tenant_type.clone()));
184 Value::Object(obj)
185 }
186
187 "is_inbound" => Value::Bool(context.is_inbound),
189 "is_outbound" => Value::Bool(context.is_outbound),
190
191 var_name => context.vars.get(var_name).cloned().ok_or_else(|| {
193 Error::ValidationError(format!("Variable not found: {}", var_name))
194 })?,
195 };
196
197 for part in &parts[1..] {
199 match ¤t {
200 Value::Object(map) => {
201 current = map.get(*part).cloned().unwrap_or(Value::Null);
202 }
203 Value::Null => return Ok(Value::Null),
204 _ => {
205 return Err(Error::ValidationError(format!(
206 "Cannot access property '{}' on non-object",
207 part
208 )))
209 }
210 }
211 }
212
213 Ok(current)
214 }
215
216 fn evaluate_comparison(
218 &mut self,
219 comp: &ComparisonExpr,
220 context: &HookContext,
221 ) -> ClResult<Value> {
222 match comp {
223 ComparisonExpr::Eq([left, right]) => {
224 let l = self.evaluate(left, context)?;
225 let r = self.evaluate(right, context)?;
226 Ok(Value::Bool(l == r))
227 }
228 ComparisonExpr::Ne([left, right]) => {
229 let l = self.evaluate(left, context)?;
230 let r = self.evaluate(right, context)?;
231 Ok(Value::Bool(l != r))
232 }
233 ComparisonExpr::Gt([left, right]) => {
234 let l_val = self.evaluate(left, context)?;
235 let r_val = self.evaluate(right, context)?;
236 let l = self.to_number(&l_val)?;
237 let r = self.to_number(&r_val)?;
238 Ok(Value::Bool(l > r))
239 }
240 ComparisonExpr::Gte([left, right]) => {
241 let l_val = self.evaluate(left, context)?;
242 let r_val = self.evaluate(right, context)?;
243 let l = self.to_number(&l_val)?;
244 let r = self.to_number(&r_val)?;
245 Ok(Value::Bool(l >= r))
246 }
247 ComparisonExpr::Lt([left, right]) => {
248 let l_val = self.evaluate(left, context)?;
249 let r_val = self.evaluate(right, context)?;
250 let l = self.to_number(&l_val)?;
251 let r = self.to_number(&r_val)?;
252 Ok(Value::Bool(l < r))
253 }
254 ComparisonExpr::Lte([left, right]) => {
255 let l_val = self.evaluate(left, context)?;
256 let r_val = self.evaluate(right, context)?;
257 let l = self.to_number(&l_val)?;
258 let r = self.to_number(&r_val)?;
259 Ok(Value::Bool(l <= r))
260 }
261 }
262 }
263
264 fn evaluate_logical(
266 &mut self,
267 logical: &LogicalExpr,
268 context: &HookContext,
269 ) -> ClResult<Value> {
270 match logical {
271 LogicalExpr::And(exprs) => {
272 for expr in exprs {
273 let value = self.evaluate(expr, context)?;
274 if !self.to_bool(&value) {
275 return Ok(Value::Bool(false));
276 }
277 }
278 Ok(Value::Bool(true))
279 }
280 LogicalExpr::Or(exprs) => {
281 for expr in exprs {
282 let value = self.evaluate(expr, context)?;
283 if self.to_bool(&value) {
284 return Ok(Value::Bool(true));
285 }
286 }
287 Ok(Value::Bool(false))
288 }
289 LogicalExpr::Not(expr) => {
290 let value = self.evaluate(expr, context)?;
291 Ok(Value::Bool(!self.to_bool(&value)))
292 }
293 }
294 }
295
296 fn evaluate_arithmetic(
298 &mut self,
299 arith: &ArithmeticExpr,
300 context: &HookContext,
301 ) -> ClResult<Value> {
302 match arith {
303 ArithmeticExpr::Add(exprs) => {
304 let mut sum = 0.0;
305 for expr in exprs {
306 let val = self.evaluate(expr, context)?;
307 sum += self.to_number(&val)?;
308 }
309 serde_json::Number::from_f64(sum).map(Value::Number).ok_or_else(|| {
310 Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
311 })
312 }
313 ArithmeticExpr::Subtract([left, right]) => {
314 let l_val = self.evaluate(left, context)?;
315 let r_val = self.evaluate(right, context)?;
316 let l = self.to_number(&l_val)?;
317 let r = self.to_number(&r_val)?;
318 serde_json::Number::from_f64(l - r).map(Value::Number).ok_or_else(|| {
319 Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
320 })
321 }
322 ArithmeticExpr::Multiply(exprs) => {
323 let mut product = 1.0;
324 for expr in exprs {
325 let val = self.evaluate(expr, context)?;
326 product *= self.to_number(&val)?;
327 }
328 serde_json::Number::from_f64(product).map(Value::Number).ok_or_else(|| {
329 Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
330 })
331 }
332 ArithmeticExpr::Divide([left, right]) => {
333 let l_val = self.evaluate(left, context)?;
334 let r_val = self.evaluate(right, context)?;
335 let l = self.to_number(&l_val)?;
336 let r = self.to_number(&r_val)?;
337 serde_json::Number::from_f64(l / r).map(Value::Number).ok_or_else(|| {
338 Error::ValidationError("Invalid number result (NaN or infinity)".to_string())
339 })
340 }
341 }
342 }
343
344 fn evaluate_string_op(
346 &mut self,
347 string_op: &StringOpExpr,
348 context: &HookContext,
349 ) -> ClResult<Value> {
350 match string_op {
351 StringOpExpr::Concat(exprs) => {
352 let mut result = String::new();
353 for expr in exprs {
354 let value = self.evaluate(expr, context)?;
355 result.push_str(&self.to_string(&value));
356 }
357 Ok(Value::String(result))
358 }
359 StringOpExpr::Contains([haystack, needle]) => {
360 let h_val = self.evaluate(haystack, context)?;
361 let n_val = self.evaluate(needle, context)?;
362 let h = self.to_string(&h_val);
363 let n = self.to_string(&n_val);
364 Ok(Value::Bool(h.contains(&n)))
365 }
366 StringOpExpr::StartsWith([string, prefix]) => {
367 let s_val = self.evaluate(string, context)?;
368 let p_val = self.evaluate(prefix, context)?;
369 let s = self.to_string(&s_val);
370 let p = self.to_string(&p_val);
371 Ok(Value::Bool(s.starts_with(&p)))
372 }
373 StringOpExpr::EndsWith([string, suffix]) => {
374 let s_val = self.evaluate(string, context)?;
375 let suf_val = self.evaluate(suffix, context)?;
376 let s = self.to_string(&s_val);
377 let suf = self.to_string(&suf_val);
378 Ok(Value::Bool(s.ends_with(&suf)))
379 }
380 }
381 }
382
383 fn evaluate_ternary(
385 &mut self,
386 ternary: &TernaryExpr,
387 context: &HookContext,
388 ) -> ClResult<Value> {
389 let condition = self.evaluate(&ternary.r#if, context)?;
390 if self.to_bool(&condition) {
391 self.evaluate(&ternary.then, context)
392 } else {
393 self.evaluate(&ternary.r#else, context)
394 }
395 }
396
397 fn evaluate_coalesce(
399 &mut self,
400 coalesce: &CoalesceExpr,
401 context: &HookContext,
402 ) -> ClResult<Value> {
403 for expr in &coalesce.coalesce {
404 let value = self.evaluate(expr, context)?;
405 if !value.is_null() {
406 return Ok(value);
407 }
408 }
409 Ok(Value::Null)
410 }
411
412 fn to_bool(&self, value: &Value) -> bool {
414 match value {
415 Value::Null => false,
416 Value::Bool(b) => *b,
417 Value::Number(n) => n.as_f64().unwrap_or(0.0) != 0.0,
418 Value::String(s) => !s.is_empty(),
419 Value::Array(a) => !a.is_empty(),
420 Value::Object(o) => !o.is_empty(),
421 }
422 }
423
424 fn to_string(&self, value: &Value) -> String {
426 match value {
427 Value::Null => String::new(),
428 Value::Bool(b) => b.to_string(),
429 Value::Number(n) => n.to_string(),
430 Value::String(s) => s.clone(),
431 v => v.to_string(),
432 }
433 }
434
435 fn to_number(&self, value: &Value) -> ClResult<f64> {
437 match value {
438 Value::Number(n) => n.as_f64().ok_or_else(|| {
439 Error::ValidationError(
440 "Invalid number value (not representable as f64)".to_string(),
441 )
442 }),
443 Value::String(s) => s.parse::<f64>().map_err(|_| {
444 Error::ValidationError(format!(
445 "Type mismatch: expected number, got string '{}'",
446 s
447 ))
448 }),
449 _ => Err(Error::ValidationError(format!(
450 "Type mismatch: expected number, got {:?}",
451 value
452 ))),
453 }
454 }
455}
456
457impl Default for ExpressionEvaluator {
458 fn default() -> Self {
459 Self::new()
460 }
461}
462
463#[cfg(test)]
464mod tests {
465 use super::*;
466 use std::collections::HashMap;
467
468 fn create_test_context() -> HookContext {
469 HookContext {
470 action_id: "test-action-id".to_string(),
471 r#type: "CONN".to_string(),
472 subtype: None,
473 issuer: "alice".to_string(),
474 audience: Some("bob".to_string()),
475 parent: None,
476 subject: None,
477 content: Some(Value::String("Hello".to_string())),
478 attachments: None,
479 created_at: "2024-01-01T00:00:00Z".to_string(),
480 expires_at: None,
481 tenant_id: 1,
482 tenant_tag: "example".to_string(),
483 tenant_type: "person".to_string(),
484 is_inbound: false,
485 is_outbound: true,
486 client_address: None,
487 vars: HashMap::new(),
488 }
489 }
490
491 #[test]
492 fn test_simple_variable() {
493 let mut eval = ExpressionEvaluator::new();
494 let context = create_test_context();
495 let expr = Expression::String("{issuer}".to_string());
496
497 let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
498 assert_eq!(result, Value::String("alice".to_string()));
499 }
500
501 #[test]
502 fn test_nested_path() {
503 let mut eval = ExpressionEvaluator::new();
504 let context = create_test_context();
505 let expr = Expression::String("{context.tenant_type}".to_string());
506
507 let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
508 assert_eq!(result, Value::String("person".to_string()));
509 }
510
511 #[test]
512 fn test_template_string() {
513 let mut eval = ExpressionEvaluator::new();
514 let context = create_test_context();
515 let expr = Expression::String("{type}:{issuer}:{audience}".to_string());
516
517 let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
518 assert_eq!(result, Value::String("CONN:alice:bob".to_string()));
519 }
520
521 #[test]
522 fn test_comparison_eq() {
523 let mut eval = ExpressionEvaluator::new();
524 let context = create_test_context();
525 let expr = Expression::Comparison(Box::new(ComparisonExpr::Eq([
526 Expression::String("{subtype}".to_string()),
527 Expression::Null,
528 ])));
529
530 let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
531 assert_eq!(result, Value::Bool(true));
532 }
533
534 #[test]
535 fn test_logical_and() {
536 let mut eval = ExpressionEvaluator::new();
537 let context = create_test_context();
538 let expr = Expression::Logical(Box::new(LogicalExpr::And(vec![
539 Expression::Bool(true),
540 Expression::String("{issuer}".to_string()),
541 ])));
542
543 let result = eval.evaluate(&expr, &context).expect("evaluation should succeed");
544 assert_eq!(result, Value::Bool(true)); }
546}
547
548