cel_interpreter/
objects.rs

1use crate::context::Context;
2use crate::functions::FunctionContext;
3use crate::{ExecutionError, Expression};
4use cel_parser::ast::{operators, EntryExpr, Expr};
5use cel_parser::reference::Val;
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::convert::{Infallible, TryFrom, TryInto};
9use std::fmt::{Display, Formatter};
10use std::ops;
11use std::ops::Deref;
12use std::sync::Arc;
13
14#[derive(Debug, PartialEq, Clone)]
15#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16pub struct Map {
17    pub map: Arc<HashMap<Key, Value>>,
18}
19
20impl PartialOrd for Map {
21    fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
22        None
23    }
24}
25
26impl Map {
27    /// Returns a reference to the value corresponding to the key. Implicitly converts between int
28    /// and uint keys.
29    pub fn get(&self, key: &Key) -> Option<&Value> {
30        self.map.get(key).or_else(|| {
31            // Also check keys that are cross type comparable.
32            let converted = match key {
33                Key::Int(k) => Key::Uint(u64::try_from(*k).ok()?),
34                Key::Uint(k) => Key::Int(i64::try_from(*k).ok()?),
35                _ => return None,
36            };
37            self.map.get(&converted)
38        })
39    }
40}
41
42#[derive(Debug, Eq, PartialEq, Hash, Ord, Clone, PartialOrd)]
43#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
44pub enum Key {
45    Int(i64),
46    Uint(u64),
47    Bool(bool),
48    String(Arc<String>),
49}
50
51/// Implement conversions from primitive types to [`Key`]
52impl From<String> for Key {
53    fn from(v: String) -> Self {
54        Key::String(v.into())
55    }
56}
57
58impl From<Arc<String>> for Key {
59    fn from(v: Arc<String>) -> Self {
60        Key::String(v)
61    }
62}
63
64impl<'a> From<&'a str> for Key {
65    fn from(v: &'a str) -> Self {
66        Key::String(Arc::new(v.into()))
67    }
68}
69
70impl From<bool> for Key {
71    fn from(v: bool) -> Self {
72        Key::Bool(v)
73    }
74}
75
76impl From<i64> for Key {
77    fn from(v: i64) -> Self {
78        Key::Int(v)
79    }
80}
81
82impl From<u64> for Key {
83    fn from(v: u64) -> Self {
84        Key::Uint(v)
85    }
86}
87
88impl serde::Serialize for Key {
89    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90    where
91        S: serde::Serializer,
92    {
93        match self {
94            Key::Int(v) => v.serialize(serializer),
95            Key::Uint(v) => v.serialize(serializer),
96            Key::Bool(v) => v.serialize(serializer),
97            Key::String(v) => v.serialize(serializer),
98        }
99    }
100}
101
102impl Display for Key {
103    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
104        match self {
105            Key::Int(v) => write!(f, "{v}"),
106            Key::Uint(v) => write!(f, "{v}"),
107            Key::Bool(v) => write!(f, "{v}"),
108            Key::String(v) => write!(f, "{v}"),
109        }
110    }
111}
112
113/// Implement conversions from [`Key`] into [`Value`]
114impl TryInto<Key> for Value {
115    type Error = Value;
116
117    #[inline(always)]
118    fn try_into(self) -> Result<Key, Self::Error> {
119        match self {
120            Value::Int(v) => Ok(Key::Int(v)),
121            Value::UInt(v) => Ok(Key::Uint(v)),
122            Value::String(v) => Ok(Key::String(v)),
123            Value::Bool(v) => Ok(Key::Bool(v)),
124            _ => Err(self),
125        }
126    }
127}
128
129// Implement conversion from HashMap<K, V> into CelMap
130impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Map {
131    fn from(map: HashMap<K, V>) -> Self {
132        let mut new_map = HashMap::with_capacity(map.len());
133        for (k, v) in map {
134            new_map.insert(k.into(), v.into());
135        }
136        Map {
137            map: Arc::new(new_map),
138        }
139    }
140}
141
142pub trait TryIntoValue {
143    type Error: std::error::Error + 'static + Send + Sync;
144    fn try_into_value(self) -> Result<Value, Self::Error>;
145}
146
147impl<T: serde::Serialize> TryIntoValue for T {
148    type Error = crate::ser::SerializationError;
149    fn try_into_value(self) -> Result<Value, Self::Error> {
150        crate::ser::to_value(self)
151    }
152}
153impl TryIntoValue for Value {
154    type Error = Infallible;
155    fn try_into_value(self) -> Result<Value, Self::Error> {
156        Ok(self)
157    }
158}
159
160#[derive(Debug, Clone)]
161#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
162pub enum Value {
163    List(Arc<Vec<Value>>),
164    Map(Map),
165
166    Function(Arc<String>, Option<Box<Value>>),
167
168    // Atoms
169    Int(i64),
170    UInt(u64),
171    Float(f64),
172    String(Arc<String>),
173    Bytes(Arc<Vec<u8>>),
174    Bool(bool),
175    #[cfg(feature = "chrono")]
176    Duration(chrono::Duration),
177    #[cfg(feature = "chrono")]
178    Timestamp(chrono::DateTime<chrono::FixedOffset>),
179    Null,
180}
181
182impl From<Val> for Value {
183    fn from(val: Val) -> Self {
184        match val {
185            Val::String(s) => Value::String(Arc::new(s)),
186            Val::Boolean(b) => Value::Bool(b),
187            Val::Int(i) => Value::Int(i),
188            Val::UInt(u) => Value::UInt(u),
189            Val::Double(d) => Value::Float(d),
190            Val::Bytes(bytes) => Value::Bytes(Arc::new(bytes)),
191            Val::Null => Value::Null,
192        }
193    }
194}
195
196#[derive(Clone, Copy, Debug)]
197pub enum ValueType {
198    List,
199    Map,
200    Function,
201    Int,
202    UInt,
203    Float,
204    String,
205    Bytes,
206    Bool,
207    Duration,
208    Timestamp,
209    Null,
210}
211
212impl Display for ValueType {
213    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214        match self {
215            ValueType::List => write!(f, "list"),
216            ValueType::Map => write!(f, "map"),
217            ValueType::Function => write!(f, "function"),
218            ValueType::Int => write!(f, "int"),
219            ValueType::UInt => write!(f, "uint"),
220            ValueType::Float => write!(f, "float"),
221            ValueType::String => write!(f, "string"),
222            ValueType::Bytes => write!(f, "bytes"),
223            ValueType::Bool => write!(f, "bool"),
224            ValueType::Duration => write!(f, "duration"),
225            ValueType::Timestamp => write!(f, "timestamp"),
226            ValueType::Null => write!(f, "null"),
227        }
228    }
229}
230
231impl Value {
232    pub fn type_of(&self) -> ValueType {
233        match self {
234            Value::List(_) => ValueType::List,
235            Value::Map(_) => ValueType::Map,
236            Value::Function(_, _) => ValueType::Function,
237            Value::Int(_) => ValueType::Int,
238            Value::UInt(_) => ValueType::UInt,
239            Value::Float(_) => ValueType::Float,
240            Value::String(_) => ValueType::String,
241            Value::Bytes(_) => ValueType::Bytes,
242            Value::Bool(_) => ValueType::Bool,
243            #[cfg(feature = "chrono")]
244            Value::Duration(_) => ValueType::Duration,
245            #[cfg(feature = "chrono")]
246            Value::Timestamp(_) => ValueType::Timestamp,
247            Value::Null => ValueType::Null,
248        }
249    }
250
251    pub fn error_expected_type(&self, expected: ValueType) -> ExecutionError {
252        ExecutionError::UnexpectedType {
253            got: self.type_of().to_string(),
254            want: expected.to_string(),
255        }
256    }
257}
258
259impl From<&Value> for Value {
260    fn from(value: &Value) -> Self {
261        value.clone()
262    }
263}
264
265impl PartialEq for Value {
266    fn eq(&self, other: &Self) -> bool {
267        match (self, other) {
268            (Value::Map(a), Value::Map(b)) => a == b,
269            (Value::List(a), Value::List(b)) => a == b,
270            (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2,
271            (Value::Int(a), Value::Int(b)) => a == b,
272            (Value::UInt(a), Value::UInt(b)) => a == b,
273            (Value::Float(a), Value::Float(b)) => a == b,
274            (Value::String(a), Value::String(b)) => a == b,
275            (Value::Bytes(a), Value::Bytes(b)) => a == b,
276            (Value::Bool(a), Value::Bool(b)) => a == b,
277            (Value::Null, Value::Null) => true,
278            #[cfg(feature = "chrono")]
279            (Value::Duration(a), Value::Duration(b)) => a == b,
280            #[cfg(feature = "chrono")]
281            (Value::Timestamp(a), Value::Timestamp(b)) => a == b,
282            // Allow different numeric types to be compared without explicit casting.
283            (Value::Int(a), Value::UInt(b)) => a
284                .to_owned()
285                .try_into()
286                .map(|a: u64| a == *b)
287                .unwrap_or(false),
288            (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
289            (Value::UInt(a), Value::Int(b)) => a
290                .to_owned()
291                .try_into()
292                .map(|a: i64| a == *b)
293                .unwrap_or(false),
294            (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b,
295            (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
296            (Value::Float(a), Value::UInt(b)) => *a == (*b as f64),
297            (_, _) => false,
298        }
299    }
300}
301
302impl Eq for Value {}
303
304impl PartialOrd for Value {
305    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
306        match (self, other) {
307            (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
308            (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)),
309            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
310            (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
311            (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
312            (Value::Null, Value::Null) => Some(Ordering::Equal),
313            #[cfg(feature = "chrono")]
314            (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)),
315            #[cfg(feature = "chrono")]
316            (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
317            // Allow different numeric types to be compared without explicit casting.
318            (Value::Int(a), Value::UInt(b)) => Some(
319                a.to_owned()
320                    .try_into()
321                    .map(|a: u64| a.cmp(b))
322                    // If the i64 doesn't fit into a u64 it must be less than 0.
323                    .unwrap_or(Ordering::Less),
324            ),
325            (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
326            (Value::UInt(a), Value::Int(b)) => Some(
327                a.to_owned()
328                    .try_into()
329                    .map(|a: i64| a.cmp(b))
330                    // If the u64 doesn't fit into a i64 it must be greater than i64::MAX.
331                    .unwrap_or(Ordering::Greater),
332            ),
333            (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
334            (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
335            (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)),
336            _ => None,
337        }
338    }
339}
340
341impl From<&Key> for Value {
342    fn from(value: &Key) -> Self {
343        match value {
344            Key::Int(v) => Value::Int(*v),
345            Key::Uint(v) => Value::UInt(*v),
346            Key::Bool(v) => Value::Bool(*v),
347            Key::String(v) => Value::String(v.clone()),
348        }
349    }
350}
351
352impl From<Key> for Value {
353    fn from(value: Key) -> Self {
354        match value {
355            Key::Int(v) => Value::Int(v),
356            Key::Uint(v) => Value::UInt(v),
357            Key::Bool(v) => Value::Bool(v),
358            Key::String(v) => Value::String(v),
359        }
360    }
361}
362
363impl From<&Key> for Key {
364    fn from(key: &Key) -> Self {
365        key.clone()
366    }
367}
368
369// Convert Vec<T> to Value
370impl<T: Into<Value>> From<Vec<T>> for Value {
371    fn from(v: Vec<T>) -> Self {
372        Value::List(v.into_iter().map(|v| v.into()).collect::<Vec<_>>().into())
373    }
374}
375
376// Convert Vec<u8> to Value
377impl From<Vec<u8>> for Value {
378    fn from(v: Vec<u8>) -> Self {
379        Value::Bytes(v.into())
380    }
381}
382
383// Convert String to Value
384impl From<String> for Value {
385    fn from(v: String) -> Self {
386        Value::String(v.into())
387    }
388}
389
390impl From<&str> for Value {
391    fn from(v: &str) -> Self {
392        Value::String(v.to_string().into())
393    }
394}
395
396// Convert Option<T> to Value
397impl<T: Into<Value>> From<Option<T>> for Value {
398    fn from(v: Option<T>) -> Self {
399        match v {
400            Some(v) => v.into(),
401            None => Value::Null,
402        }
403    }
404}
405
406// Convert HashMap<K, V> to Value
407impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Value {
408    fn from(v: HashMap<K, V>) -> Self {
409        Value::Map(v.into())
410    }
411}
412
413impl From<ExecutionError> for ResolveResult {
414    fn from(value: ExecutionError) -> Self {
415        Err(value)
416    }
417}
418
419pub type ResolveResult = Result<Value, ExecutionError>;
420
421impl From<Value> for ResolveResult {
422    fn from(value: Value) -> Self {
423        Ok(value)
424    }
425}
426
427impl Value {
428    pub fn resolve_all(expr: &[Expression], ctx: &Context) -> ResolveResult {
429        let mut res = Vec::with_capacity(expr.len());
430        for expr in expr {
431            res.push(Value::resolve(expr, ctx)?);
432        }
433        Ok(Value::List(res.into()))
434    }
435
436    #[inline(always)]
437    pub fn resolve(expr: &Expression, ctx: &Context) -> ResolveResult {
438        match &expr.expr {
439            Expr::Literal(val) => Ok(val.clone().into()),
440            Expr::Call(call) => {
441                if call.args.len() == 3 && call.func_name == operators::CONDITIONAL {
442                    let cond = Value::resolve(&call.args[0], ctx)?;
443                    return if cond.to_bool() {
444                        Value::resolve(&call.args[1], ctx)
445                    } else {
446                        Value::resolve(&call.args[2], ctx)
447                    };
448                }
449                if call.args.len() == 2 {
450                    let left = Value::resolve(&call.args[0], ctx)?;
451                    match call.func_name.as_str() {
452                        operators::ADD => return left + Value::resolve(&call.args[1], ctx)?,
453                        operators::SUBSTRACT => return left - Value::resolve(&call.args[1], ctx)?,
454                        operators::DIVIDE => return left / Value::resolve(&call.args[1], ctx)?,
455                        operators::MULTIPLY => return left * Value::resolve(&call.args[1], ctx)?,
456                        operators::MODULO => return left % Value::resolve(&call.args[1], ctx)?,
457                        operators::EQUALS => {
458                            return Value::Bool(left.eq(&Value::resolve(&call.args[1], ctx)?))
459                                .into()
460                        }
461                        operators::NOT_EQUALS => {
462                            return Value::Bool(left.ne(&Value::resolve(&call.args[1], ctx)?))
463                                .into()
464                        }
465                        operators::LESS => {
466                            let right = Value::resolve(&call.args[1], ctx)?;
467                            return Value::Bool(
468                                left.partial_cmp(&right)
469                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
470                                    == Ordering::Less,
471                            )
472                            .into();
473                        }
474                        operators::LESS_EQUALS => {
475                            let right = Value::resolve(&call.args[1], ctx)?;
476                            return Value::Bool(
477                                left.partial_cmp(&right)
478                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
479                                    != Ordering::Greater,
480                            )
481                            .into();
482                        }
483                        operators::GREATER => {
484                            let right = Value::resolve(&call.args[1], ctx)?;
485                            return Value::Bool(
486                                left.partial_cmp(&right)
487                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
488                                    == Ordering::Greater,
489                            )
490                            .into();
491                        }
492                        operators::GREATER_EQUALS => {
493                            let right = Value::resolve(&call.args[1], ctx)?;
494                            return Value::Bool(
495                                left.partial_cmp(&right)
496                                    .ok_or(ExecutionError::ValuesNotComparable(left, right))?
497                                    != Ordering::Less,
498                            )
499                            .into();
500                        }
501                        operators::IN => {
502                            let right = Value::resolve(&call.args[1], ctx)?;
503                            match (left, right) {
504                                (Value::String(l), Value::String(r)) => {
505                                    return Value::Bool(r.contains(&*l)).into()
506                                }
507                                (any, Value::List(v)) => {
508                                    return Value::Bool(v.contains(&any)).into()
509                                }
510                                (any, Value::Map(m)) => match any.try_into() {
511                                    Ok(key) => return Value::Bool(m.map.contains_key(&key)).into(),
512                                    Err(_) => return Value::Bool(false).into(),
513                                },
514                                (left, right) => {
515                                    Err(ExecutionError::ValuesNotComparable(left, right))?
516                                }
517                            }
518                        }
519                        operators::LOGICAL_OR => {
520                            return if left.to_bool() {
521                                left.into()
522                            } else {
523                                Value::resolve(&call.args[1], ctx)
524                            };
525                        }
526                        operators::LOGICAL_AND => {
527                            return if !left.to_bool() {
528                                Value::Bool(false)
529                            } else {
530                                let right = Value::resolve(&call.args[1], ctx)?;
531                                Value::Bool(right.to_bool())
532                            }
533                            .into();
534                        }
535                        operators::INDEX => {
536                            let value = left;
537                            let idx = Value::resolve(&call.args[1], ctx)?;
538                            return match (value, idx) {
539                                (Value::List(items), Value::Int(idx)) => items
540                                    .get(idx as usize)
541                                    .cloned()
542                                    .unwrap_or(Value::Null)
543                                    .into(),
544                                (Value::String(str), Value::Int(idx)) => {
545                                    match str.get(idx as usize..(idx + 1) as usize) {
546                                        None => Ok(Value::Null),
547                                        Some(str) => Ok(Value::String(str.to_string().into())),
548                                    }
549                                }
550                                (Value::Map(map), Value::String(property)) => map
551                                    .get(&property.into())
552                                    .cloned()
553                                    .unwrap_or(Value::Null)
554                                    .into(),
555                                (Value::Map(map), Value::Bool(property)) => map
556                                    .get(&property.into())
557                                    .cloned()
558                                    .unwrap_or(Value::Null)
559                                    .into(),
560                                (Value::Map(map), Value::Int(property)) => map
561                                    .get(&property.into())
562                                    .cloned()
563                                    .unwrap_or(Value::Null)
564                                    .into(),
565                                (Value::Map(map), Value::UInt(property)) => map
566                                    .get(&property.into())
567                                    .cloned()
568                                    .unwrap_or(Value::Null)
569                                    .into(),
570                                (Value::Map(_), index) => {
571                                    Err(ExecutionError::UnsupportedMapIndex(index))
572                                }
573                                (Value::List(_), index) => {
574                                    Err(ExecutionError::UnsupportedListIndex(index))
575                                }
576                                (value, index) => {
577                                    Err(ExecutionError::UnsupportedIndex(value, index))
578                                }
579                            };
580                        }
581                        _ => (),
582                    }
583                }
584                if call.args.len() == 1 {
585                    let expr = Value::resolve(&call.args[0], ctx)?;
586                    match call.func_name.as_str() {
587                        operators::LOGICAL_NOT => return Ok(Value::Bool(!expr.to_bool())),
588                        operators::NEGATE => {
589                            return match expr {
590                                Value::Int(i) => Ok(Value::Int(-i)),
591                                Value::Float(f) => Ok(Value::Float(-f)),
592                                value => {
593                                    Err(ExecutionError::UnsupportedUnaryOperator("minus", value))
594                                }
595                            }
596                        }
597                        operators::NOT_STRICTLY_FALSE => {
598                            return match expr {
599                                Value::Bool(b) => Ok(Value::Bool(b)),
600                                _ => Ok(Value::Bool(true)),
601                            }
602                        }
603                        _ => (),
604                    }
605                }
606                let func = ctx.get_function(call.func_name.as_str()).ok_or_else(|| {
607                    ExecutionError::UndeclaredReference(call.func_name.clone().into())
608                })?;
609                match &call.target {
610                    None => {
611                        let mut ctx = FunctionContext::new(
612                            call.func_name.clone().into(),
613                            None,
614                            ctx,
615                            call.args.clone(),
616                        );
617                        (func)(&mut ctx)
618                    }
619                    Some(target) => {
620                        let mut ctx = FunctionContext::new(
621                            call.func_name.clone().into(),
622                            Some(Value::resolve(target, ctx)?),
623                            ctx,
624                            call.args.clone(),
625                        );
626                        (func)(&mut ctx)
627                    }
628                }
629            }
630            Expr::Ident(name) => ctx.get_variable(name),
631            Expr::Select(select) => {
632                let left = Value::resolve(select.operand.deref(), ctx)?;
633                if select.test {
634                    match &left {
635                        Value::Map(map) => {
636                            for key in map.map.deref().keys() {
637                                if key.to_string().eq(&select.field) {
638                                    return Ok(Value::Bool(true));
639                                }
640                            }
641                            Ok(Value::Bool(false))
642                        }
643                        _ => Ok(Value::Bool(false)),
644                    }
645                } else {
646                    left.member(&select.field, ctx)
647                }
648            }
649            Expr::List(list_expr) => {
650                let list = list_expr
651                    .elements
652                    .iter()
653                    .map(|i| Value::resolve(i, ctx))
654                    .collect::<Result<Vec<_>, _>>()?;
655                Value::List(list.into()).into()
656            }
657            Expr::Map(map_expr) => {
658                let mut map = HashMap::with_capacity(map_expr.entries.len());
659                for entry in map_expr.entries.iter() {
660                    let (k, v) = match &entry.expr {
661                        EntryExpr::StructField(_) => panic!("WAT?"),
662                        EntryExpr::MapEntry(e) => (&e.key, &e.value),
663                    };
664                    let key = Value::resolve(k, ctx)?
665                        .try_into()
666                        .map_err(ExecutionError::UnsupportedKeyType)?;
667                    let value = Value::resolve(v, ctx)?;
668                    map.insert(key, value);
669                }
670                Ok(Value::Map(Map {
671                    map: Arc::from(map),
672                }))
673            }
674            Expr::Comprehension(comprehension) => {
675                let accu_init = Value::resolve(comprehension.accu_init.deref(), ctx)?;
676                let iter = Value::resolve(comprehension.iter_range.deref(), ctx)?;
677                let mut ctx = ctx.new_inner_scope();
678                ctx.add_variable(&comprehension.accu_var, accu_init)
679                    .expect("Failed to add accu variable");
680
681                match iter {
682                    Value::List(items) => {
683                        for item in items.deref() {
684                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
685                                break;
686                            }
687                            ctx.add_variable_from_value(&comprehension.iter_var, item.clone());
688                            let accu = Value::resolve(comprehension.loop_step.deref(), &ctx)?;
689                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
690                        }
691                    }
692                    Value::Map(map) => {
693                        for key in map.map.deref().keys() {
694                            if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
695                                break;
696                            }
697                            ctx.add_variable_from_value(&comprehension.iter_var, key.clone());
698                            let accu = Value::resolve(comprehension.loop_step.deref(), &ctx)?;
699                            ctx.add_variable_from_value(&comprehension.accu_var, accu);
700                        }
701                    }
702                    t => todo!("Support {t:?}"),
703                }
704                Value::resolve(comprehension.result.deref(), &ctx)
705            }
706            Expr::Struct(_) => todo!("Support structs!"),
707            Expr::Unspecified => panic!("Can't evaluate Unspecified Expr"),
708        }
709    }
710
711    // >> a(b)
712    // Member(Ident("a"),
713    //        FunctionCall([Ident("b")]))
714    // >> a.b(c)
715    // Member(Member(Ident("a"),
716    //               Attribute("b")),
717    //        FunctionCall([Ident("c")]))
718
719    fn member(self, name: &str, ctx: &Context) -> ResolveResult {
720        // todo! Ideally we would avoid creating a String just to create a Key for lookup in the
721        // map, but this would require something like the `hashbrown` crate's `Equivalent` trait.
722        let name: Arc<String> = name.to_owned().into();
723
724        // This will always either be because we're trying to access
725        // a property on self, or a method on self.
726        let child = match self {
727            Value::Map(ref m) => m.map.get(&name.clone().into()).cloned(),
728            _ => None,
729        };
730
731        // If the property is both an attribute and a method, then we
732        // give priority to the property. Maybe we can implement lookahead
733        // to see if the next token is a function call?
734        if let Some(child) = child {
735            child.into()
736        } else if ctx.has_function(&name) {
737            Value::Function(name.clone(), Some(self.into())).into()
738        } else {
739            ExecutionError::NoSuchKey(name.clone()).into()
740        }
741    }
742
743    #[inline(always)]
744    fn to_bool(&self) -> bool {
745        match self {
746            Value::List(v) => !v.is_empty(),
747            Value::Map(v) => !v.map.is_empty(),
748            Value::Int(v) => *v != 0,
749            Value::UInt(v) => *v != 0,
750            Value::Float(v) => *v != 0.0,
751            Value::String(v) => !v.is_empty(),
752            Value::Bytes(v) => !v.is_empty(),
753            Value::Bool(v) => *v,
754            Value::Null => false,
755            #[cfg(feature = "chrono")]
756            Value::Duration(v) => v.num_nanoseconds().map(|n| n != 0).unwrap_or(false),
757            #[cfg(feature = "chrono")]
758            Value::Timestamp(v) => v.timestamp_nanos_opt().unwrap_or_default() > 0,
759            Value::Function(_, _) => false,
760        }
761    }
762}
763
764impl ops::Add<Value> for Value {
765    type Output = ResolveResult;
766
767    #[inline(always)]
768    fn add(self, rhs: Value) -> Self::Output {
769        match (self, rhs) {
770            (Value::Int(l), Value::Int(r)) => l
771                .checked_add(r)
772                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
773                .map(Value::Int),
774
775            (Value::UInt(l), Value::UInt(r)) => l
776                .checked_add(r)
777                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
778                .map(Value::UInt),
779
780            (Value::Float(l), Value::Float(r)) => Value::Float(l + r).into(),
781
782            (Value::List(mut l), Value::List(mut r)) => {
783                {
784                    // If this is the only reference to `l`, we can append to it in place.
785                    // `l` is replaced with a clone otherwise.
786                    let l = Arc::make_mut(&mut l);
787
788                    // Likewise, if this is the only reference to `r`, we can move its values
789                    // instead of cloning them.
790                    match Arc::get_mut(&mut r) {
791                        Some(r) => l.append(r),
792                        None => l.extend(r.iter().cloned()),
793                    }
794                }
795
796                Ok(Value::List(l))
797            }
798            (Value::String(mut l), Value::String(r)) => {
799                // If this is the only reference to `l`, we can append to it in place.
800                // `l` is replaced with a clone otherwise.
801                Arc::make_mut(&mut l).push_str(&r);
802                Ok(Value::String(l))
803            }
804            #[cfg(feature = "chrono")]
805            (Value::Duration(l), Value::Duration(r)) => l
806                .checked_add(&r)
807                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
808                .map(Value::Duration),
809            #[cfg(feature = "chrono")]
810            (Value::Timestamp(l), Value::Duration(r)) => l
811                .checked_add_signed(r)
812                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
813                .map(Value::Timestamp),
814            #[cfg(feature = "chrono")]
815            (Value::Duration(l), Value::Timestamp(r)) => r
816                .checked_add_signed(l)
817                .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
818                .map(Value::Timestamp),
819            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
820                "add", left, right,
821            )),
822        }
823    }
824}
825
826impl ops::Sub<Value> for Value {
827    type Output = ResolveResult;
828
829    #[inline(always)]
830    fn sub(self, rhs: Value) -> Self::Output {
831        match (self, rhs) {
832            (Value::Int(l), Value::Int(r)) => l
833                .checked_sub(r)
834                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
835                .map(Value::Int),
836
837            (Value::UInt(l), Value::UInt(r)) => l
838                .checked_sub(r)
839                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
840                .map(Value::UInt),
841
842            (Value::Float(l), Value::Float(r)) => Value::Float(l - r).into(),
843
844            #[cfg(feature = "chrono")]
845            (Value::Duration(l), Value::Duration(r)) => l
846                .checked_sub(&r)
847                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
848                .map(Value::Duration),
849            #[cfg(feature = "chrono")]
850            (Value::Timestamp(l), Value::Duration(r)) => l
851                .checked_sub_signed(r)
852                .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
853                .map(Value::Timestamp),
854            #[cfg(feature = "chrono")]
855            (Value::Timestamp(l), Value::Timestamp(r)) => {
856                Value::Duration(l.signed_duration_since(r)).into()
857            }
858            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
859                "sub", left, right,
860            )),
861        }
862    }
863}
864
865impl ops::Div<Value> for Value {
866    type Output = ResolveResult;
867
868    #[inline(always)]
869    fn div(self, rhs: Value) -> Self::Output {
870        match (self, rhs) {
871            (Value::Int(l), Value::Int(r)) => {
872                if r == 0 {
873                    Err(ExecutionError::DivisionByZero(l.into()))
874                } else {
875                    l.checked_div(r)
876                        .ok_or(ExecutionError::Overflow("div", l.into(), r.into()))
877                        .map(Value::Int)
878                }
879            }
880
881            (Value::UInt(l), Value::UInt(r)) => l
882                .checked_div(r)
883                .ok_or(ExecutionError::DivisionByZero(l.into()))
884                .map(Value::UInt),
885
886            (Value::Float(l), Value::Float(r)) => Value::Float(l / r).into(),
887
888            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
889                "div", left, right,
890            )),
891        }
892    }
893}
894
895impl ops::Mul<Value> for Value {
896    type Output = ResolveResult;
897
898    #[inline(always)]
899    fn mul(self, rhs: Value) -> Self::Output {
900        match (self, rhs) {
901            (Value::Int(l), Value::Int(r)) => l
902                .checked_mul(r)
903                .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
904                .map(Value::Int),
905
906            (Value::UInt(l), Value::UInt(r)) => l
907                .checked_mul(r)
908                .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
909                .map(Value::UInt),
910
911            (Value::Float(l), Value::Float(r)) => Value::Float(l * r).into(),
912
913            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
914                "mul", left, right,
915            )),
916        }
917    }
918}
919
920impl ops::Rem<Value> for Value {
921    type Output = ResolveResult;
922
923    #[inline(always)]
924    fn rem(self, rhs: Value) -> Self::Output {
925        match (self, rhs) {
926            (Value::Int(l), Value::Int(r)) => {
927                if r == 0 {
928                    Err(ExecutionError::RemainderByZero(l.into()))
929                } else {
930                    l.checked_rem(r)
931                        .ok_or(ExecutionError::Overflow("rem", l.into(), r.into()))
932                        .map(Value::Int)
933                }
934            }
935
936            (Value::UInt(l), Value::UInt(r)) => l
937                .checked_rem(r)
938                .ok_or(ExecutionError::RemainderByZero(l.into()))
939                .map(Value::UInt),
940
941            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
942                "rem", left, right,
943            )),
944        }
945    }
946}
947
948#[cfg(test)]
949mod tests {
950    use crate::{objects::Key, Context, ExecutionError, Program, Value};
951    use std::collections::HashMap;
952    use std::sync::Arc;
953
954    #[test]
955    fn test_indexed_map_access() {
956        let mut context = Context::default();
957        let mut headers = HashMap::new();
958        headers.insert("Content-Type", "application/json".to_string());
959        context.add_variable_from_value("headers", headers);
960
961        let program = Program::compile("headers[\"Content-Type\"]").unwrap();
962        let value = program.execute(&context).unwrap();
963        assert_eq!(value, "application/json".into());
964    }
965
966    #[test]
967    fn test_numeric_map_access() {
968        let mut context = Context::default();
969        let mut numbers = HashMap::new();
970        numbers.insert(Key::Uint(1), "one".to_string());
971        context.add_variable_from_value("numbers", numbers);
972
973        let program = Program::compile("numbers[1]").unwrap();
974        let value = program.execute(&context).unwrap();
975        assert_eq!(value, "one".into());
976    }
977
978    #[test]
979    fn test_heterogeneous_compare() {
980        let context = Context::default();
981
982        let program = Program::compile("1 < uint(2)").unwrap();
983        let value = program.execute(&context).unwrap();
984        assert_eq!(value, true.into());
985
986        let program = Program::compile("1 < 1.1").unwrap();
987        let value = program.execute(&context).unwrap();
988        assert_eq!(value, true.into());
989
990        let program = Program::compile("uint(0) > -10").unwrap();
991        let value = program.execute(&context).unwrap();
992        assert_eq!(
993            value,
994            true.into(),
995            "negative signed ints should be less than uints"
996        );
997    }
998
999    #[test]
1000    fn test_float_compare() {
1001        let context = Context::default();
1002
1003        let program = Program::compile("1.0 > 0.0").unwrap();
1004        let value = program.execute(&context).unwrap();
1005        assert_eq!(value, true.into());
1006
1007        let program = Program::compile("double('NaN') == double('NaN')").unwrap();
1008        let value = program.execute(&context).unwrap();
1009        assert_eq!(value, false.into(), "NaN should not equal itself");
1010
1011        let program = Program::compile("1.0 > double('NaN')").unwrap();
1012        let result = program.execute(&context);
1013        assert!(
1014            result.is_err(),
1015            "NaN should not be comparable with inequality operators"
1016        );
1017    }
1018
1019    #[test]
1020    fn test_invalid_compare() {
1021        let context = Context::default();
1022
1023        let program = Program::compile("{} == []").unwrap();
1024        let value = program.execute(&context).unwrap();
1025        assert_eq!(value, false.into());
1026    }
1027
1028    #[test]
1029    fn test_size_fn_var() {
1030        let program = Program::compile("size(requests) + size == 5").unwrap();
1031        let mut context = Context::default();
1032        let requests = vec![Value::Int(42), Value::Int(42)];
1033        context
1034            .add_variable("requests", Value::List(Arc::new(requests)))
1035            .unwrap();
1036        context.add_variable("size", Value::Int(3)).unwrap();
1037        assert_eq!(program.execute(&context).unwrap(), Value::Bool(true));
1038    }
1039
1040    fn test_execution_error(program: &str, expected: ExecutionError) {
1041        let program = Program::compile(program).unwrap();
1042        let result = program.execute(&Context::default());
1043        assert_eq!(result.unwrap_err(), expected);
1044    }
1045
1046    #[test]
1047    fn test_invalid_sub() {
1048        test_execution_error(
1049            "'foo' - 10",
1050            ExecutionError::UnsupportedBinaryOperator("sub", "foo".into(), Value::Int(10)),
1051        );
1052    }
1053
1054    #[test]
1055    fn test_invalid_add() {
1056        test_execution_error(
1057            "'foo' + 10",
1058            ExecutionError::UnsupportedBinaryOperator("add", "foo".into(), Value::Int(10)),
1059        );
1060    }
1061
1062    #[test]
1063    fn test_invalid_div() {
1064        test_execution_error(
1065            "'foo' / 10",
1066            ExecutionError::UnsupportedBinaryOperator("div", "foo".into(), Value::Int(10)),
1067        );
1068    }
1069
1070    #[test]
1071    fn test_invalid_rem() {
1072        test_execution_error(
1073            "'foo' % 10",
1074            ExecutionError::UnsupportedBinaryOperator("rem", "foo".into(), Value::Int(10)),
1075        );
1076    }
1077
1078    #[test]
1079    fn out_of_bound_list_access() {
1080        let program = Program::compile("list[10]").unwrap();
1081        let mut context = Context::default();
1082        context
1083            .add_variable("list", Value::List(Arc::new(vec![])))
1084            .unwrap();
1085        let result = program.execute(&context);
1086        assert_eq!(result.unwrap(), Value::Null);
1087    }
1088
1089    #[test]
1090    fn reference_to_value() {
1091        let test = "example".to_string();
1092        let direct: Value = test.as_str().into();
1093        assert_eq!(direct, Value::String(Arc::new(String::from("example"))));
1094
1095        let vec = vec![test.as_str()];
1096        let indirect: Value = vec.into();
1097        assert_eq!(
1098            indirect,
1099            Value::List(Arc::new(vec![Value::String(Arc::new(String::from(
1100                "example"
1101            )))]))
1102        );
1103    }
1104
1105    #[test]
1106    fn test_short_circuit_and() {
1107        let mut context = Context::default();
1108        let data: HashMap<String, String> = HashMap::new();
1109        context.add_variable_from_value("data", data);
1110
1111        let program = Program::compile("has(data.x) && data.x.startsWith(\"foo\")").unwrap();
1112        let value = program.execute(&context);
1113        println!("{value:?}");
1114        assert!(
1115            value.is_ok(),
1116            "The AND expression should support short-circuit evaluation."
1117        );
1118    }
1119
1120    #[test]
1121    fn invalid_int_math() {
1122        use ExecutionError::*;
1123
1124        let cases = [
1125            ("1 / 0", DivisionByZero(1.into())),
1126            ("1 % 0", RemainderByZero(1.into())),
1127            (
1128                &format!("{} + 1", i64::MAX),
1129                Overflow("add", i64::MAX.into(), 1.into()),
1130            ),
1131            (
1132                &format!("{} - 1", i64::MIN),
1133                Overflow("sub", i64::MIN.into(), 1.into()),
1134            ),
1135            (
1136                &format!("{} * 2", i64::MAX),
1137                Overflow("mul", i64::MAX.into(), 2.into()),
1138            ),
1139            (
1140                &format!("{} / -1", i64::MIN),
1141                Overflow("div", i64::MIN.into(), (-1).into()),
1142            ),
1143            (
1144                &format!("{} % -1", i64::MIN),
1145                Overflow("rem", i64::MIN.into(), (-1).into()),
1146            ),
1147        ];
1148
1149        for (expr, err) in cases {
1150            test_execution_error(expr, err);
1151        }
1152    }
1153
1154    #[test]
1155    fn invalid_uint_math() {
1156        use ExecutionError::*;
1157
1158        let cases = [
1159            ("1u / 0u", DivisionByZero(1u64.into())),
1160            ("1u % 0u", RemainderByZero(1u64.into())),
1161            (
1162                &format!("{}u + 1u", u64::MAX),
1163                Overflow("add", u64::MAX.into(), 1u64.into()),
1164            ),
1165            ("0u - 1u", Overflow("sub", 0u64.into(), 1u64.into())),
1166            (
1167                &format!("{}u * 2u", u64::MAX),
1168                Overflow("mul", u64::MAX.into(), 2u64.into()),
1169            ),
1170        ];
1171
1172        for (expr, err) in cases {
1173            test_execution_error(expr, err);
1174        }
1175    }
1176}