cel_interpreter/
objects.rs

1use crate::context::Context;
2use crate::functions::FunctionContext;
3use crate::ExecutionError;
4use cel_parser::ast::*;
5use std::cmp::Ordering;
6use std::collections::HashMap;
7use std::convert::{Infallible, TryFrom, TryInto};
8use std::fmt::{Display, Formatter};
9use std::ops;
10use std::sync::Arc;
11
12#[derive(Debug, PartialEq, Clone)]
13pub struct Map {
14    pub map: Arc<HashMap<Key, Value>>,
15}
16
17impl PartialOrd for Map {
18    fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
19        None
20    }
21}
22
23impl Map {
24    /// Returns a reference to the value corresponding to the key. Implicitly converts between int
25    /// and uint keys.
26    pub fn get(&self, key: &Key) -> Option<&Value> {
27        self.map.get(key).or_else(|| {
28            // Also check keys that are cross type comparable.
29            let converted = match key {
30                Key::Int(k) => Key::Uint(u64::try_from(*k).ok()?),
31                Key::Uint(k) => Key::Int(i64::try_from(*k).ok()?),
32                _ => return None,
33            };
34            self.map.get(&converted)
35        })
36    }
37}
38
39#[derive(Debug, Eq, PartialEq, Hash, Ord, Clone, PartialOrd)]
40pub enum Key {
41    Int(i64),
42    Uint(u64),
43    Bool(bool),
44    String(Arc<String>),
45}
46
47/// Implement conversions from primitive types to [`Key`]
48impl From<String> for Key {
49    fn from(v: String) -> Self {
50        Key::String(v.into())
51    }
52}
53
54impl From<Arc<String>> for Key {
55    fn from(v: Arc<String>) -> Self {
56        Key::String(v.clone())
57    }
58}
59
60impl<'a> From<&'a str> for Key {
61    fn from(v: &'a str) -> Self {
62        Key::String(Arc::new(v.into()))
63    }
64}
65
66impl From<bool> for Key {
67    fn from(v: bool) -> Self {
68        Key::Bool(v)
69    }
70}
71
72impl From<i64> for Key {
73    fn from(v: i64) -> Self {
74        Key::Int(v)
75    }
76}
77
78impl From<u64> for Key {
79    fn from(v: u64) -> Self {
80        Key::Uint(v)
81    }
82}
83
84impl serde::Serialize for Key {
85    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
86    where
87        S: serde::Serializer,
88    {
89        match self {
90            Key::Int(v) => v.serialize(serializer),
91            Key::Uint(v) => v.serialize(serializer),
92            Key::Bool(v) => v.serialize(serializer),
93            Key::String(v) => v.serialize(serializer),
94        }
95    }
96}
97
98impl Display for Key {
99    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
100        match self {
101            Key::Int(v) => write!(f, "{}", v),
102            Key::Uint(v) => write!(f, "{}", v),
103            Key::Bool(v) => write!(f, "{}", v),
104            Key::String(v) => write!(f, "{}", v),
105        }
106    }
107}
108
109/// Implement conversions from [`Key`] into [`Value`]
110impl TryInto<Key> for Value {
111    type Error = Value;
112
113    #[inline(always)]
114    fn try_into(self) -> Result<Key, Self::Error> {
115        match self {
116            Value::Int(v) => Ok(Key::Int(v)),
117            Value::UInt(v) => Ok(Key::Uint(v)),
118            Value::String(v) => Ok(Key::String(v)),
119            Value::Bool(v) => Ok(Key::Bool(v)),
120            _ => Err(self),
121        }
122    }
123}
124
125// Implement conversion from HashMap<K, V> into CelMap
126impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Map {
127    fn from(map: HashMap<K, V>) -> Self {
128        let mut new_map = HashMap::new();
129        for (k, v) in map {
130            new_map.insert(k.into(), v.into());
131        }
132        Map {
133            map: Arc::new(new_map),
134        }
135    }
136}
137
138pub trait TryIntoValue {
139    type Error: std::error::Error + 'static + Send + Sync;
140    fn try_into_value(self) -> Result<Value, Self::Error>;
141}
142
143impl<T: serde::Serialize> TryIntoValue for T {
144    type Error = crate::ser::SerializationError;
145    fn try_into_value(self) -> Result<Value, Self::Error> {
146        crate::ser::to_value(self)
147    }
148}
149impl TryIntoValue for Value {
150    type Error = Infallible;
151    fn try_into_value(self) -> Result<Value, Self::Error> {
152        Ok(self)
153    }
154}
155
156#[derive(Debug, Clone)]
157pub enum Value {
158    List(Arc<Vec<Value>>),
159    Map(Map),
160
161    Function(Arc<String>, Option<Box<Value>>),
162
163    // Atoms
164    Int(i64),
165    UInt(u64),
166    Float(f64),
167    String(Arc<String>),
168    Bytes(Arc<Vec<u8>>),
169    Bool(bool),
170    #[cfg(feature = "chrono")]
171    Duration(chrono::Duration),
172    #[cfg(feature = "chrono")]
173    Timestamp(chrono::DateTime<chrono::FixedOffset>),
174    Null,
175}
176
177#[derive(Clone, Copy, Debug)]
178pub enum ValueType {
179    List,
180    Map,
181    Function,
182    Int,
183    UInt,
184    Float,
185    String,
186    Bytes,
187    Bool,
188    Duration,
189    Timestamp,
190    Null,
191}
192
193impl Display for ValueType {
194    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
195        match self {
196            ValueType::List => write!(f, "list"),
197            ValueType::Map => write!(f, "map"),
198            ValueType::Function => write!(f, "function"),
199            ValueType::Int => write!(f, "int"),
200            ValueType::UInt => write!(f, "uint"),
201            ValueType::Float => write!(f, "float"),
202            ValueType::String => write!(f, "string"),
203            ValueType::Bytes => write!(f, "bytes"),
204            ValueType::Bool => write!(f, "bool"),
205            ValueType::Duration => write!(f, "duration"),
206            ValueType::Timestamp => write!(f, "timestamp"),
207            ValueType::Null => write!(f, "null"),
208        }
209    }
210}
211
212impl Value {
213    pub fn type_of(&self) -> ValueType {
214        match self {
215            Value::List(_) => ValueType::List,
216            Value::Map(_) => ValueType::Map,
217            Value::Function(_, _) => ValueType::Function,
218            Value::Int(_) => ValueType::Int,
219            Value::UInt(_) => ValueType::UInt,
220            Value::Float(_) => ValueType::Float,
221            Value::String(_) => ValueType::String,
222            Value::Bytes(_) => ValueType::Bytes,
223            Value::Bool(_) => ValueType::Bool,
224            #[cfg(feature = "chrono")]
225            Value::Duration(_) => ValueType::Duration,
226            #[cfg(feature = "chrono")]
227            Value::Timestamp(_) => ValueType::Timestamp,
228            Value::Null => ValueType::Null,
229        }
230    }
231
232    pub fn error_expected_type(&self, expected: ValueType) -> ExecutionError {
233        ExecutionError::UnexpectedType {
234            got: self.type_of().to_string(),
235            want: expected.to_string(),
236        }
237    }
238}
239
240impl From<&Value> for Value {
241    fn from(value: &Value) -> Self {
242        value.clone()
243    }
244}
245
246impl PartialEq for Value {
247    fn eq(&self, other: &Self) -> bool {
248        match (self, other) {
249            (Value::Map(a), Value::Map(b)) => a == b,
250            (Value::List(a), Value::List(b)) => a == b,
251            (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2,
252            (Value::Int(a), Value::Int(b)) => a == b,
253            (Value::UInt(a), Value::UInt(b)) => a == b,
254            (Value::Float(a), Value::Float(b)) => a == b,
255            (Value::String(a), Value::String(b)) => a == b,
256            (Value::Bytes(a), Value::Bytes(b)) => a == b,
257            (Value::Bool(a), Value::Bool(b)) => a == b,
258            (Value::Null, Value::Null) => true,
259            #[cfg(feature = "chrono")]
260            (Value::Duration(a), Value::Duration(b)) => a == b,
261            #[cfg(feature = "chrono")]
262            (Value::Timestamp(a), Value::Timestamp(b)) => a == b,
263            // Allow different numeric types to be compared without explicit casting.
264            (Value::Int(a), Value::UInt(b)) => a
265                .to_owned()
266                .try_into()
267                .map(|a: u64| a == *b)
268                .unwrap_or(false),
269            (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
270            (Value::UInt(a), Value::Int(b)) => a
271                .to_owned()
272                .try_into()
273                .map(|a: i64| a == *b)
274                .unwrap_or(false),
275            (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b,
276            (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
277            (Value::Float(a), Value::UInt(b)) => *a == (*b as f64),
278            (_, _) => false,
279        }
280    }
281}
282
283impl Eq for Value {}
284
285impl PartialOrd for Value {
286    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
287        match (self, other) {
288            (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
289            (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)),
290            (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
291            (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
292            (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
293            (Value::Null, Value::Null) => Some(Ordering::Equal),
294            #[cfg(feature = "chrono")]
295            (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)),
296            #[cfg(feature = "chrono")]
297            (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
298            // Allow different numeric types to be compared without explicit casting.
299            (Value::Int(a), Value::UInt(b)) => Some(
300                a.to_owned()
301                    .try_into()
302                    .map(|a: u64| a.cmp(b))
303                    // If the i64 doesn't fit into a u64 it must be less than 0.
304                    .unwrap_or(Ordering::Less),
305            ),
306            (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
307            (Value::UInt(a), Value::Int(b)) => Some(
308                a.to_owned()
309                    .try_into()
310                    .map(|a: i64| a.cmp(b))
311                    // If the u64 doesn't fit into a i64 it must be greater than i64::MAX.
312                    .unwrap_or(Ordering::Greater),
313            ),
314            (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
315            (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
316            (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)),
317            _ => None,
318        }
319    }
320}
321
322impl From<&Key> for Value {
323    fn from(value: &Key) -> Self {
324        match value {
325            Key::Int(v) => Value::Int(*v),
326            Key::Uint(v) => Value::UInt(*v),
327            Key::Bool(v) => Value::Bool(*v),
328            Key::String(v) => Value::String(v.clone()),
329        }
330    }
331}
332
333impl From<Key> for Value {
334    fn from(value: Key) -> Self {
335        match value {
336            Key::Int(v) => Value::Int(v),
337            Key::Uint(v) => Value::UInt(v),
338            Key::Bool(v) => Value::Bool(v),
339            Key::String(v) => Value::String(v),
340        }
341    }
342}
343
344impl From<&Key> for Key {
345    fn from(key: &Key) -> Self {
346        key.clone()
347    }
348}
349
350// Convert Vec<T> to Value
351impl<T: Into<Value>> From<Vec<T>> for Value {
352    fn from(v: Vec<T>) -> Self {
353        Value::List(v.into_iter().map(|v| v.into()).collect::<Vec<_>>().into())
354    }
355}
356
357// Convert Vec<u8> to Value
358impl From<Vec<u8>> for Value {
359    fn from(v: Vec<u8>) -> Self {
360        Value::Bytes(v.into())
361    }
362}
363
364// Convert String to Value
365impl From<String> for Value {
366    fn from(v: String) -> Self {
367        Value::String(v.into())
368    }
369}
370
371impl From<&str> for Value {
372    fn from(v: &str) -> Self {
373        Value::String(v.to_string().into())
374    }
375}
376
377// Convert Option<T> to Value
378impl<T: Into<Value>> From<Option<T>> for Value {
379    fn from(v: Option<T>) -> Self {
380        match v {
381            Some(v) => v.into(),
382            None => Value::Null,
383        }
384    }
385}
386
387// Convert HashMap<K, V> to Value
388impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Value {
389    fn from(v: HashMap<K, V>) -> Self {
390        Value::Map(v.into())
391    }
392}
393
394impl From<ExecutionError> for ResolveResult {
395    fn from(value: ExecutionError) -> Self {
396        Err(value)
397    }
398}
399
400pub type ResolveResult = Result<Value, ExecutionError>;
401
402impl From<Value> for ResolveResult {
403    fn from(value: Value) -> Self {
404        Ok(value)
405    }
406}
407
408impl Value {
409    pub fn resolve_all(expr: &[Expression], ctx: &Context) -> ResolveResult {
410        let mut res = Vec::with_capacity(expr.len());
411        for expr in expr {
412            res.push(Value::resolve(expr, ctx)?);
413        }
414        Ok(Value::List(res.into()))
415    }
416
417    #[inline(always)]
418    pub fn resolve(expr: &Expression, ctx: &Context) -> ResolveResult {
419        match expr {
420            Expression::Atom(atom) => Ok(atom.into()),
421            Expression::Arithmetic(left, op, right) => {
422                let left = Value::resolve(left, ctx)?;
423                let right = Value::resolve(right, ctx)?;
424
425                match op {
426                    ArithmeticOp::Add => left + right,
427                    ArithmeticOp::Subtract => left - right,
428                    ArithmeticOp::Divide => left / right,
429                    ArithmeticOp::Multiply => left * right,
430                    ArithmeticOp::Modulus => left % right,
431                }
432            }
433            Expression::Relation(left, op, right) => {
434                let left = Value::resolve(left, ctx)?;
435                let right = Value::resolve(right, ctx)?;
436                let res = match op {
437                    RelationOp::LessThan => {
438                        left.partial_cmp(&right)
439                            .ok_or(ExecutionError::ValuesNotComparable(left, right))?
440                            == Ordering::Less
441                    }
442                    RelationOp::LessThanEq => {
443                        left.partial_cmp(&right)
444                            .ok_or(ExecutionError::ValuesNotComparable(left, right))?
445                            != Ordering::Greater
446                    }
447                    RelationOp::GreaterThan => {
448                        left.partial_cmp(&right)
449                            .ok_or(ExecutionError::ValuesNotComparable(left, right))?
450                            == Ordering::Greater
451                    }
452                    RelationOp::GreaterThanEq => {
453                        left.partial_cmp(&right)
454                            .ok_or(ExecutionError::ValuesNotComparable(left, right))?
455                            != Ordering::Less
456                    }
457                    RelationOp::Equals => right.eq(&left),
458                    RelationOp::NotEquals => right.ne(&left),
459                    RelationOp::In => match (left, right) {
460                        (Value::String(l), Value::String(r)) => r.contains(&*l),
461                        (any, Value::List(v)) => v.contains(&any),
462                        (any, Value::Map(m)) => match any.try_into() {
463                            Ok(key) => m.map.contains_key(&key),
464                            Err(_) => false,
465                        },
466                        (left, right) => Err(ExecutionError::ValuesNotComparable(left, right))?,
467                    },
468                };
469                Value::Bool(res).into()
470            }
471            Expression::Ternary(cond, left, right) => {
472                let cond = Value::resolve(cond, ctx)?;
473                if cond.to_bool() {
474                    Value::resolve(left, ctx)
475                } else {
476                    Value::resolve(right, ctx)
477                }
478            }
479            Expression::Or(left, right) => {
480                let left = Value::resolve(left, ctx)?;
481                if left.to_bool() {
482                    left.into()
483                } else {
484                    Value::resolve(right, ctx)
485                }
486            }
487            Expression::And(left, right) => {
488                let left = Value::resolve(left, ctx)?;
489                if !left.to_bool() {
490                    Value::Bool(false).into()
491                } else {
492                    let right = Value::resolve(right, ctx)?;
493                    Value::Bool(right.to_bool()).into()
494                }
495            }
496            Expression::Unary(op, expr) => {
497                let expr = Value::resolve(expr, ctx)?;
498                match op {
499                    UnaryOp::Not => Ok(Value::Bool(!expr.to_bool())),
500                    UnaryOp::DoubleNot => Ok(Value::Bool(expr.to_bool())),
501                    UnaryOp::Minus => match expr {
502                        Value::Int(i) => Ok(Value::Int(-i)),
503                        Value::Float(i) => Ok(Value::Float(-i)),
504                        value => Err(ExecutionError::UnsupportedUnaryOperator("minus", value)),
505                    },
506                    UnaryOp::DoubleMinus => match expr {
507                        Value::Int(_) => Ok(expr),
508                        Value::UInt(_) => Ok(expr),
509                        Value::Float(_) => Ok(expr),
510                        value => Err(ExecutionError::UnsupportedUnaryOperator("negate", value)),
511                    },
512                }
513            }
514            Expression::Member(left, right) => {
515                let left = Value::resolve(left, ctx)?;
516                left.member(right, ctx)
517            }
518            Expression::List(items) => {
519                let list = items
520                    .iter()
521                    .map(|i| Value::resolve(i, ctx))
522                    .collect::<Result<Vec<_>, _>>()?;
523                Value::List(list.into()).into()
524            }
525            Expression::Map(items) => {
526                let mut map = HashMap::default();
527                for (k, v) in items.iter() {
528                    let key = Value::resolve(k, ctx)?
529                        .try_into()
530                        .map_err(ExecutionError::UnsupportedKeyType)?;
531                    let value = Value::resolve(v, ctx)?;
532                    map.insert(key, value);
533                }
534                Ok(Value::Map(Map {
535                    map: Arc::from(map),
536                }))
537            }
538            Expression::Ident(name) => ctx.get_variable(&***name),
539            Expression::FunctionCall(name, target, args) => {
540                if let Expression::Ident(name) = &**name {
541                    let func = ctx
542                        .get_function(&**name)
543                        .ok_or_else(|| ExecutionError::UndeclaredReference(Arc::clone(name)))?;
544                    match target {
545                        None => {
546                            let mut ctx =
547                                FunctionContext::new(name.clone(), None, ctx, args.clone());
548                            func.call_with_context(&mut ctx)
549                        }
550                        Some(target) => {
551                            let mut ctx = FunctionContext::new(
552                                name.clone(),
553                                Some(Value::resolve(target, ctx)?),
554                                ctx,
555                                args.clone(),
556                            );
557                            func.call_with_context(&mut ctx)
558                        }
559                    }
560                } else {
561                    Err(ExecutionError::UnsupportedFunctionCallIdentifierType(
562                        (**name).clone(),
563                    ))
564                }
565            }
566        }
567    }
568
569    // >> a(b)
570    // Member(Ident("a"),
571    //        FunctionCall([Ident("b")]))
572    // >> a.b(c)
573    // Member(Member(Ident("a"),
574    //               Attribute("b")),
575    //        FunctionCall([Ident("c")]))
576
577    fn member(self, member: &Member, ctx: &Context) -> ResolveResult {
578        match member {
579            Member::Index(idx) => {
580                let idx = Value::resolve(idx, ctx)?;
581                match (self, idx) {
582                    (Value::List(items), Value::Int(idx)) => items
583                        .get(idx as usize)
584                        .cloned()
585                        .unwrap_or(Value::Null)
586                        .into(),
587                    (Value::String(str), Value::Int(idx)) => {
588                        match str.get(idx as usize..(idx + 1) as usize) {
589                            None => Ok(Value::Null),
590                            Some(str) => Ok(Value::String(str.to_string().into())),
591                        }
592                    }
593                    (Value::Map(map), Value::String(property)) => map
594                        .get(&property.into())
595                        .cloned()
596                        .unwrap_or(Value::Null)
597                        .into(),
598                    (Value::Map(map), Value::Bool(property)) => map
599                        .get(&property.into())
600                        .cloned()
601                        .unwrap_or(Value::Null)
602                        .into(),
603                    (Value::Map(map), Value::Int(property)) => map
604                        .get(&property.into())
605                        .cloned()
606                        .unwrap_or(Value::Null)
607                        .into(),
608                    (Value::Map(map), Value::UInt(property)) => map
609                        .get(&property.into())
610                        .cloned()
611                        .unwrap_or(Value::Null)
612                        .into(),
613                    (Value::Map(_), index) => Err(ExecutionError::UnsupportedMapIndex(index)),
614                    (Value::List(_), index) => Err(ExecutionError::UnsupportedListIndex(index)),
615                    (value, index) => Err(ExecutionError::UnsupportedIndex(value, index)),
616                }
617            }
618            Member::Fields(_) => Err(ExecutionError::UnsupportedFieldsConstruction(
619                member.clone(),
620            )),
621            Member::Attribute(name) => {
622                // This will always either be because we're trying to access
623                // a property on self, or a method on self.
624                let child = match self {
625                    Value::Map(ref m) => m.map.get(&name.clone().into()).cloned(),
626                    _ => None,
627                };
628
629                // If the property is both an attribute and a method, then we
630                // give priority to the property. Maybe we can implement lookahead
631                // to see if the next token is a function call?
632                match (child, ctx.has_function(&***name)) {
633                    (None, false) => ExecutionError::NoSuchKey(name.clone()).into(),
634                    (Some(child), _) => child.into(),
635                    (None, true) => Value::Function(name.clone(), Some(self.into())).into(),
636                }
637            }
638        }
639    }
640
641    #[inline(always)]
642    fn to_bool(&self) -> bool {
643        match self {
644            Value::List(v) => !v.is_empty(),
645            Value::Map(v) => !v.map.is_empty(),
646            Value::Int(v) => *v != 0,
647            Value::UInt(v) => *v != 0,
648            Value::Float(v) => *v != 0.0,
649            Value::String(v) => !v.is_empty(),
650            Value::Bytes(v) => !v.is_empty(),
651            Value::Bool(v) => *v,
652            Value::Null => false,
653            #[cfg(feature = "chrono")]
654            Value::Duration(v) => v.num_nanoseconds().map(|n| n != 0).unwrap_or(false),
655            #[cfg(feature = "chrono")]
656            Value::Timestamp(v) => v.timestamp_nanos_opt().unwrap_or_default() > 0,
657            Value::Function(_, _) => false,
658        }
659    }
660}
661
662impl From<&Atom> for Value {
663    #[inline(always)]
664    fn from(atom: &Atom) -> Self {
665        match atom {
666            Atom::Int(v) => Value::Int(*v),
667            Atom::UInt(v) => Value::UInt(*v),
668            Atom::Float(v) => Value::Float(*v),
669            Atom::String(v) => Value::String(v.clone()),
670            Atom::Bytes(v) => Value::Bytes(v.clone()),
671            Atom::Bool(v) => Value::Bool(*v),
672            Atom::Null => Value::Null,
673        }
674    }
675}
676
677impl ops::Add<Value> for Value {
678    type Output = ResolveResult;
679
680    #[inline(always)]
681    fn add(self, rhs: Value) -> Self::Output {
682        match (self, rhs) {
683            (Value::Int(l), Value::Int(r)) => Value::Int(l + r).into(),
684            (Value::UInt(l), Value::UInt(r)) => Value::UInt(l + r).into(),
685
686            // Float matrix
687            (Value::Float(l), Value::Float(r)) => Value::Float(l + r).into(),
688            (Value::Int(l), Value::Float(r)) => Value::Float(l as f64 + r).into(),
689            (Value::Float(l), Value::Int(r)) => Value::Float(l + r as f64).into(),
690            (Value::UInt(l), Value::Float(r)) => Value::Float(l as f64 + r).into(),
691            (Value::Float(l), Value::UInt(r)) => Value::Float(l + r as f64).into(),
692
693            (Value::List(l), Value::List(r)) => {
694                Value::List(l.iter().chain(r.iter()).cloned().collect::<Vec<_>>().into()).into()
695            }
696            (Value::String(l), Value::String(r)) => {
697                let mut new = String::with_capacity(l.len() + r.len());
698                new.push_str(&l);
699                new.push_str(&r);
700                Value::String(new.into()).into()
701            }
702            // Merge two maps should overwrite keys in the left map with the right map
703            (Value::Map(l), Value::Map(r)) => {
704                let mut new = HashMap::default();
705                for (k, v) in l.map.iter() {
706                    new.insert(k.clone(), v.clone());
707                }
708                for (k, v) in r.map.iter() {
709                    new.insert(k.clone(), v.clone());
710                }
711                Value::Map(Map { map: Arc::new(new) }).into()
712            }
713            #[cfg(feature = "chrono")]
714            (Value::Duration(l), Value::Duration(r)) => Value::Duration(l + r).into(),
715            #[cfg(feature = "chrono")]
716            (Value::Timestamp(l), Value::Duration(r)) => Value::Timestamp(l + r).into(),
717            #[cfg(feature = "chrono")]
718            (Value::Duration(l), Value::Timestamp(r)) => Value::Timestamp(r + l).into(),
719            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
720                "add", left, right,
721            )),
722        }
723    }
724}
725
726impl ops::Sub<Value> for Value {
727    type Output = ResolveResult;
728
729    #[inline(always)]
730    fn sub(self, rhs: Value) -> Self::Output {
731        match (self, rhs) {
732            (Value::Int(l), Value::Int(r)) => Value::Int(l - r).into(),
733            (Value::UInt(l), Value::UInt(r)) => Value::UInt(l - r).into(),
734
735            // Float matrix
736            (Value::Float(l), Value::Float(r)) => Value::Float(l - r).into(),
737            (Value::Int(l), Value::Float(r)) => Value::Float(l as f64 - r).into(),
738            (Value::Float(l), Value::Int(r)) => Value::Float(l - r as f64).into(),
739            (Value::UInt(l), Value::Float(r)) => Value::Float(l as f64 - r).into(),
740            (Value::Float(l), Value::UInt(r)) => Value::Float(l - r as f64).into(),
741            // todo: implement checked sub for these over-flowable operations
742            #[cfg(feature = "chrono")]
743            (Value::Duration(l), Value::Duration(r)) => Value::Duration(l - r).into(),
744            #[cfg(feature = "chrono")]
745            (Value::Timestamp(l), Value::Duration(r)) => Value::Timestamp(l - r).into(),
746            #[cfg(feature = "chrono")]
747            (Value::Timestamp(l), Value::Timestamp(r)) => Value::Duration(l - r).into(),
748            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
749                "sub", left, right,
750            )),
751        }
752    }
753}
754
755impl ops::Div<Value> for Value {
756    type Output = ResolveResult;
757
758    #[inline(always)]
759    fn div(self, rhs: Value) -> Self::Output {
760        match (self, rhs) {
761            (Value::Int(l), Value::Int(r)) => Value::Int(l / r).into(),
762            (Value::UInt(l), Value::UInt(r)) => Value::UInt(l / r).into(),
763
764            // Float matrix
765            (Value::Float(l), Value::Float(r)) => Value::Float(l / r).into(),
766            (Value::Int(l), Value::Float(r)) => Value::Float(l as f64 / r).into(),
767            (Value::Float(l), Value::Int(r)) => Value::Float(l / r as f64).into(),
768            (Value::UInt(l), Value::Float(r)) => Value::Float(l as f64 / r).into(),
769            (Value::Float(l), Value::UInt(r)) => Value::Float(l / r as f64).into(),
770
771            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
772                "div", left, right,
773            )),
774        }
775    }
776}
777
778impl ops::Mul<Value> for Value {
779    type Output = ResolveResult;
780
781    #[inline(always)]
782    fn mul(self, rhs: Value) -> Self::Output {
783        match (self, rhs) {
784            (Value::Int(l), Value::Int(r)) => Value::Int(l * r).into(),
785            (Value::UInt(l), Value::UInt(r)) => Value::UInt(l * r).into(),
786
787            // Float matrix
788            (Value::Float(l), Value::Float(r)) => Value::Float(l * r).into(),
789            (Value::Int(l), Value::Float(r)) => Value::Float(l as f64 * r).into(),
790            (Value::Float(l), Value::Int(r)) => Value::Float(l * r as f64).into(),
791            (Value::UInt(l), Value::Float(r)) => Value::Float(l as f64 * r).into(),
792            (Value::Float(l), Value::UInt(r)) => Value::Float(l * r as f64).into(),
793
794            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
795                "mul", left, right,
796            )),
797        }
798    }
799}
800
801impl ops::Rem<Value> for Value {
802    type Output = ResolveResult;
803
804    #[inline(always)]
805    fn rem(self, rhs: Value) -> Self::Output {
806        match (self, rhs) {
807            (Value::Int(l), Value::Int(r)) => Value::Int(l % r).into(),
808            (Value::UInt(l), Value::UInt(r)) => Value::UInt(l % r).into(),
809
810            // Float matrix
811            (Value::Float(l), Value::Float(r)) => Value::Float(l % r).into(),
812            (Value::Int(l), Value::Float(r)) => Value::Float(l as f64 % r).into(),
813            (Value::Float(l), Value::Int(r)) => Value::Float(l % r as f64).into(),
814            (Value::UInt(l), Value::Float(r)) => Value::Float(l as f64 % r).into(),
815            (Value::Float(l), Value::UInt(r)) => Value::Float(l % r as f64).into(),
816
817            (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
818                "rem", left, right,
819            )),
820        }
821    }
822}
823
824#[cfg(test)]
825mod tests {
826    use crate::{objects::Key, Context, ExecutionError, Program, Value};
827    use std::collections::HashMap;
828    use std::sync::Arc;
829
830    #[test]
831    fn test_indexed_map_access() {
832        let mut context = Context::default();
833        let mut headers = HashMap::new();
834        headers.insert("Content-Type", "application/json".to_string());
835        context.add_variable_from_value("headers", headers);
836
837        let program = Program::compile("headers[\"Content-Type\"]").unwrap();
838        let value = program.execute(&context).unwrap();
839        assert_eq!(value, "application/json".into());
840    }
841
842    #[test]
843    fn test_numeric_map_access() {
844        let mut context = Context::default();
845        let mut numbers = HashMap::new();
846        numbers.insert(Key::Uint(1), "one".to_string());
847        context.add_variable_from_value("numbers", numbers);
848
849        let program = Program::compile("numbers[1]").unwrap();
850        let value = program.execute(&context).unwrap();
851        assert_eq!(value, "one".into());
852    }
853
854    #[test]
855    fn test_heterogeneous_compare() {
856        let context = Context::default();
857
858        let program = Program::compile("1 < uint(2)").unwrap();
859        let value = program.execute(&context).unwrap();
860        assert_eq!(value, true.into());
861
862        let program = Program::compile("1 < 1.1").unwrap();
863        let value = program.execute(&context).unwrap();
864        assert_eq!(value, true.into());
865
866        let program = Program::compile("uint(0) > -10").unwrap();
867        let value = program.execute(&context).unwrap();
868        assert_eq!(
869            value,
870            true.into(),
871            "negative signed ints should be less than uints"
872        );
873    }
874
875    #[test]
876    fn test_float_compare() {
877        let context = Context::default();
878
879        let program = Program::compile("1.0 > 0.0").unwrap();
880        let value = program.execute(&context).unwrap();
881        assert_eq!(value, true.into());
882
883        let program = Program::compile("double('NaN') == double('NaN')").unwrap();
884        let value = program.execute(&context).unwrap();
885        assert_eq!(value, false.into(), "NaN should not equal itself");
886
887        let program = Program::compile("1.0 > double('NaN')").unwrap();
888        let result = program.execute(&context);
889        assert!(
890            result.is_err(),
891            "NaN should not be comparable with inequality operators"
892        );
893    }
894
895    #[test]
896    fn test_invalid_compare() {
897        let context = Context::default();
898
899        let program = Program::compile("{} == []").unwrap();
900        let value = program.execute(&context).unwrap();
901        assert_eq!(value, false.into());
902    }
903
904    #[test]
905    fn test_size_fn_var() {
906        let program = Program::compile("size(requests) + size == 5").unwrap();
907        let mut context = Context::default();
908        let requests = vec![Value::Int(42), Value::Int(42)];
909        context
910            .add_variable("requests", Value::List(Arc::new(requests)))
911            .unwrap();
912        context.add_variable("size", Value::Int(3)).unwrap();
913        assert_eq!(program.execute(&context).unwrap(), Value::Bool(true));
914    }
915
916    fn test_execution_error(program: &str, expected: ExecutionError) {
917        let program = Program::compile(program).unwrap();
918        let result = program.execute(&Context::default());
919        assert_eq!(result.unwrap_err(), expected);
920    }
921
922    #[test]
923    fn test_invalid_sub() {
924        test_execution_error(
925            "'foo' - 10",
926            ExecutionError::UnsupportedBinaryOperator("sub", "foo".into(), Value::Int(10)),
927        );
928    }
929
930    #[test]
931    fn test_invalid_add() {
932        test_execution_error(
933            "'foo' + 10",
934            ExecutionError::UnsupportedBinaryOperator("add", "foo".into(), Value::Int(10)),
935        );
936    }
937
938    #[test]
939    fn test_invalid_div() {
940        test_execution_error(
941            "'foo' / 10",
942            ExecutionError::UnsupportedBinaryOperator("div", "foo".into(), Value::Int(10)),
943        );
944    }
945
946    #[test]
947    fn test_invalid_rem() {
948        test_execution_error(
949            "'foo' % 10",
950            ExecutionError::UnsupportedBinaryOperator("rem", "foo".into(), Value::Int(10)),
951        );
952    }
953
954    #[test]
955    fn out_of_bound_list_access() {
956        let program = Program::compile("list[10]").unwrap();
957        let mut context = Context::default();
958        context
959            .add_variable("list", Value::List(Arc::new(vec![])))
960            .unwrap();
961        let result = program.execute(&context);
962        assert_eq!(result.unwrap(), Value::Null);
963    }
964
965    #[test]
966    fn reference_to_value() {
967        let test = "example".to_string();
968        let direct: Value = test.as_str().into();
969        assert_eq!(direct, Value::String(Arc::new(String::from("example"))));
970
971        let vec = vec![test.as_str()];
972        let indirect: Value = vec.into();
973        assert_eq!(
974            indirect,
975            Value::List(Arc::new(vec![Value::String(Arc::new(String::from(
976                "example"
977            )))]))
978        );
979    }
980
981    #[test]
982    fn test_short_circuit_and() {
983        let mut context = Context::default();
984        let data: HashMap<String, String> = HashMap::new();
985        context.add_variable_from_value("data", data);
986
987        let program = Program::compile("has(data.x) && data.x.startsWith(\"foo\")").unwrap();
988        let value = program.execute(&context);
989        assert!(
990            value.is_ok(),
991            "The AND expression should support short-circuit evaluation."
992        );
993    }
994}