1use crate::context::Context;
2use crate::functions::FunctionContext;
3use crate::{ExecutionError, Expression};
4use cel_parser::ast::{operators, EntryExpr, Expr};
5use cel_parser::reference::Val;
6use std::cmp::Ordering;
7use std::collections::HashMap;
8use std::convert::{Infallible, TryFrom, TryInto};
9use std::fmt::{Display, Formatter};
10use std::ops;
11use std::ops::Deref;
12use std::sync::Arc;
13
14#[derive(Debug, PartialEq, Clone)]
15#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
16pub struct Map {
17 pub map: Arc<HashMap<Key, Value>>,
18}
19
20impl PartialOrd for Map {
21 fn partial_cmp(&self, _: &Self) -> Option<Ordering> {
22 None
23 }
24}
25
26impl Map {
27 pub fn get(&self, key: &Key) -> Option<&Value> {
30 self.map.get(key).or_else(|| {
31 let converted = match key {
33 Key::Int(k) => Key::Uint(u64::try_from(*k).ok()?),
34 Key::Uint(k) => Key::Int(i64::try_from(*k).ok()?),
35 _ => return None,
36 };
37 self.map.get(&converted)
38 })
39 }
40}
41
42#[derive(Debug, Eq, PartialEq, Hash, Ord, Clone, PartialOrd)]
43#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
44pub enum Key {
45 Int(i64),
46 Uint(u64),
47 Bool(bool),
48 String(Arc<String>),
49}
50
51impl From<String> for Key {
53 fn from(v: String) -> Self {
54 Key::String(v.into())
55 }
56}
57
58impl From<Arc<String>> for Key {
59 fn from(v: Arc<String>) -> Self {
60 Key::String(v)
61 }
62}
63
64impl<'a> From<&'a str> for Key {
65 fn from(v: &'a str) -> Self {
66 Key::String(Arc::new(v.into()))
67 }
68}
69
70impl From<bool> for Key {
71 fn from(v: bool) -> Self {
72 Key::Bool(v)
73 }
74}
75
76impl From<i64> for Key {
77 fn from(v: i64) -> Self {
78 Key::Int(v)
79 }
80}
81
82impl From<u64> for Key {
83 fn from(v: u64) -> Self {
84 Key::Uint(v)
85 }
86}
87
88impl serde::Serialize for Key {
89 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
90 where
91 S: serde::Serializer,
92 {
93 match self {
94 Key::Int(v) => v.serialize(serializer),
95 Key::Uint(v) => v.serialize(serializer),
96 Key::Bool(v) => v.serialize(serializer),
97 Key::String(v) => v.serialize(serializer),
98 }
99 }
100}
101
102impl Display for Key {
103 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
104 match self {
105 Key::Int(v) => write!(f, "{v}"),
106 Key::Uint(v) => write!(f, "{v}"),
107 Key::Bool(v) => write!(f, "{v}"),
108 Key::String(v) => write!(f, "{v}"),
109 }
110 }
111}
112
113impl TryInto<Key> for Value {
115 type Error = Value;
116
117 #[inline(always)]
118 fn try_into(self) -> Result<Key, Self::Error> {
119 match self {
120 Value::Int(v) => Ok(Key::Int(v)),
121 Value::UInt(v) => Ok(Key::Uint(v)),
122 Value::String(v) => Ok(Key::String(v)),
123 Value::Bool(v) => Ok(Key::Bool(v)),
124 _ => Err(self),
125 }
126 }
127}
128
129impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Map {
131 fn from(map: HashMap<K, V>) -> Self {
132 let mut new_map = HashMap::with_capacity(map.len());
133 for (k, v) in map {
134 new_map.insert(k.into(), v.into());
135 }
136 Map {
137 map: Arc::new(new_map),
138 }
139 }
140}
141
142pub trait TryIntoValue {
143 type Error: std::error::Error + 'static + Send + Sync;
144 fn try_into_value(self) -> Result<Value, Self::Error>;
145}
146
147impl<T: serde::Serialize> TryIntoValue for T {
148 type Error = crate::ser::SerializationError;
149 fn try_into_value(self) -> Result<Value, Self::Error> {
150 crate::ser::to_value(self)
151 }
152}
153impl TryIntoValue for Value {
154 type Error = Infallible;
155 fn try_into_value(self) -> Result<Value, Self::Error> {
156 Ok(self)
157 }
158}
159
160#[derive(Debug, Clone)]
161#[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
162pub enum Value {
163 List(Arc<Vec<Value>>),
164 Map(Map),
165
166 Function(Arc<String>, Option<Box<Value>>),
167
168 Int(i64),
170 UInt(u64),
171 Float(f64),
172 String(Arc<String>),
173 Bytes(Arc<Vec<u8>>),
174 Bool(bool),
175 #[cfg(feature = "chrono")]
176 Duration(chrono::Duration),
177 #[cfg(feature = "chrono")]
178 Timestamp(chrono::DateTime<chrono::FixedOffset>),
179 Null,
180}
181
182impl From<Val> for Value {
183 fn from(val: Val) -> Self {
184 match val {
185 Val::String(s) => Value::String(Arc::new(s)),
186 Val::Boolean(b) => Value::Bool(b),
187 Val::Int(i) => Value::Int(i),
188 Val::UInt(u) => Value::UInt(u),
189 Val::Double(d) => Value::Float(d),
190 Val::Bytes(bytes) => Value::Bytes(Arc::new(bytes)),
191 Val::Null => Value::Null,
192 }
193 }
194}
195
196#[derive(Clone, Copy, Debug)]
197pub enum ValueType {
198 List,
199 Map,
200 Function,
201 Int,
202 UInt,
203 Float,
204 String,
205 Bytes,
206 Bool,
207 Duration,
208 Timestamp,
209 Null,
210}
211
212impl Display for ValueType {
213 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
214 match self {
215 ValueType::List => write!(f, "list"),
216 ValueType::Map => write!(f, "map"),
217 ValueType::Function => write!(f, "function"),
218 ValueType::Int => write!(f, "int"),
219 ValueType::UInt => write!(f, "uint"),
220 ValueType::Float => write!(f, "float"),
221 ValueType::String => write!(f, "string"),
222 ValueType::Bytes => write!(f, "bytes"),
223 ValueType::Bool => write!(f, "bool"),
224 ValueType::Duration => write!(f, "duration"),
225 ValueType::Timestamp => write!(f, "timestamp"),
226 ValueType::Null => write!(f, "null"),
227 }
228 }
229}
230
231impl Value {
232 pub fn type_of(&self) -> ValueType {
233 match self {
234 Value::List(_) => ValueType::List,
235 Value::Map(_) => ValueType::Map,
236 Value::Function(_, _) => ValueType::Function,
237 Value::Int(_) => ValueType::Int,
238 Value::UInt(_) => ValueType::UInt,
239 Value::Float(_) => ValueType::Float,
240 Value::String(_) => ValueType::String,
241 Value::Bytes(_) => ValueType::Bytes,
242 Value::Bool(_) => ValueType::Bool,
243 #[cfg(feature = "chrono")]
244 Value::Duration(_) => ValueType::Duration,
245 #[cfg(feature = "chrono")]
246 Value::Timestamp(_) => ValueType::Timestamp,
247 Value::Null => ValueType::Null,
248 }
249 }
250
251 pub fn error_expected_type(&self, expected: ValueType) -> ExecutionError {
252 ExecutionError::UnexpectedType {
253 got: self.type_of().to_string(),
254 want: expected.to_string(),
255 }
256 }
257}
258
259impl From<&Value> for Value {
260 fn from(value: &Value) -> Self {
261 value.clone()
262 }
263}
264
265impl PartialEq for Value {
266 fn eq(&self, other: &Self) -> bool {
267 match (self, other) {
268 (Value::Map(a), Value::Map(b)) => a == b,
269 (Value::List(a), Value::List(b)) => a == b,
270 (Value::Function(a1, a2), Value::Function(b1, b2)) => a1 == b1 && a2 == b2,
271 (Value::Int(a), Value::Int(b)) => a == b,
272 (Value::UInt(a), Value::UInt(b)) => a == b,
273 (Value::Float(a), Value::Float(b)) => a == b,
274 (Value::String(a), Value::String(b)) => a == b,
275 (Value::Bytes(a), Value::Bytes(b)) => a == b,
276 (Value::Bool(a), Value::Bool(b)) => a == b,
277 (Value::Null, Value::Null) => true,
278 #[cfg(feature = "chrono")]
279 (Value::Duration(a), Value::Duration(b)) => a == b,
280 #[cfg(feature = "chrono")]
281 (Value::Timestamp(a), Value::Timestamp(b)) => a == b,
282 (Value::Int(a), Value::UInt(b)) => a
284 .to_owned()
285 .try_into()
286 .map(|a: u64| a == *b)
287 .unwrap_or(false),
288 (Value::Int(a), Value::Float(b)) => (*a as f64) == *b,
289 (Value::UInt(a), Value::Int(b)) => a
290 .to_owned()
291 .try_into()
292 .map(|a: i64| a == *b)
293 .unwrap_or(false),
294 (Value::UInt(a), Value::Float(b)) => (*a as f64) == *b,
295 (Value::Float(a), Value::Int(b)) => *a == (*b as f64),
296 (Value::Float(a), Value::UInt(b)) => *a == (*b as f64),
297 (_, _) => false,
298 }
299 }
300}
301
302impl Eq for Value {}
303
304impl PartialOrd for Value {
305 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
306 match (self, other) {
307 (Value::Int(a), Value::Int(b)) => Some(a.cmp(b)),
308 (Value::UInt(a), Value::UInt(b)) => Some(a.cmp(b)),
309 (Value::Float(a), Value::Float(b)) => a.partial_cmp(b),
310 (Value::String(a), Value::String(b)) => Some(a.cmp(b)),
311 (Value::Bool(a), Value::Bool(b)) => Some(a.cmp(b)),
312 (Value::Null, Value::Null) => Some(Ordering::Equal),
313 #[cfg(feature = "chrono")]
314 (Value::Duration(a), Value::Duration(b)) => Some(a.cmp(b)),
315 #[cfg(feature = "chrono")]
316 (Value::Timestamp(a), Value::Timestamp(b)) => Some(a.cmp(b)),
317 (Value::Int(a), Value::UInt(b)) => Some(
319 a.to_owned()
320 .try_into()
321 .map(|a: u64| a.cmp(b))
322 .unwrap_or(Ordering::Less),
324 ),
325 (Value::Int(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
326 (Value::UInt(a), Value::Int(b)) => Some(
327 a.to_owned()
328 .try_into()
329 .map(|a: i64| a.cmp(b))
330 .unwrap_or(Ordering::Greater),
332 ),
333 (Value::UInt(a), Value::Float(b)) => (*a as f64).partial_cmp(b),
334 (Value::Float(a), Value::Int(b)) => a.partial_cmp(&(*b as f64)),
335 (Value::Float(a), Value::UInt(b)) => a.partial_cmp(&(*b as f64)),
336 _ => None,
337 }
338 }
339}
340
341impl From<&Key> for Value {
342 fn from(value: &Key) -> Self {
343 match value {
344 Key::Int(v) => Value::Int(*v),
345 Key::Uint(v) => Value::UInt(*v),
346 Key::Bool(v) => Value::Bool(*v),
347 Key::String(v) => Value::String(v.clone()),
348 }
349 }
350}
351
352impl From<Key> for Value {
353 fn from(value: Key) -> Self {
354 match value {
355 Key::Int(v) => Value::Int(v),
356 Key::Uint(v) => Value::UInt(v),
357 Key::Bool(v) => Value::Bool(v),
358 Key::String(v) => Value::String(v),
359 }
360 }
361}
362
363impl From<&Key> for Key {
364 fn from(key: &Key) -> Self {
365 key.clone()
366 }
367}
368
369impl<T: Into<Value>> From<Vec<T>> for Value {
371 fn from(v: Vec<T>) -> Self {
372 Value::List(v.into_iter().map(|v| v.into()).collect::<Vec<_>>().into())
373 }
374}
375
376impl From<Vec<u8>> for Value {
378 fn from(v: Vec<u8>) -> Self {
379 Value::Bytes(v.into())
380 }
381}
382
383impl From<String> for Value {
385 fn from(v: String) -> Self {
386 Value::String(v.into())
387 }
388}
389
390impl From<&str> for Value {
391 fn from(v: &str) -> Self {
392 Value::String(v.to_string().into())
393 }
394}
395
396impl<T: Into<Value>> From<Option<T>> for Value {
398 fn from(v: Option<T>) -> Self {
399 match v {
400 Some(v) => v.into(),
401 None => Value::Null,
402 }
403 }
404}
405
406impl<K: Into<Key>, V: Into<Value>> From<HashMap<K, V>> for Value {
408 fn from(v: HashMap<K, V>) -> Self {
409 Value::Map(v.into())
410 }
411}
412
413impl From<ExecutionError> for ResolveResult {
414 fn from(value: ExecutionError) -> Self {
415 Err(value)
416 }
417}
418
419pub type ResolveResult = Result<Value, ExecutionError>;
420
421impl From<Value> for ResolveResult {
422 fn from(value: Value) -> Self {
423 Ok(value)
424 }
425}
426
427impl Value {
428 pub fn resolve_all(expr: &[Expression], ctx: &Context) -> ResolveResult {
429 let mut res = Vec::with_capacity(expr.len());
430 for expr in expr {
431 res.push(Value::resolve(expr, ctx)?);
432 }
433 Ok(Value::List(res.into()))
434 }
435
436 #[inline(always)]
437 pub fn resolve(expr: &Expression, ctx: &Context) -> ResolveResult {
438 match &expr.expr {
439 Expr::Literal(val) => Ok(val.clone().into()),
440 Expr::Call(call) => {
441 if call.args.len() == 3 && call.func_name == operators::CONDITIONAL {
442 let cond = Value::resolve(&call.args[0], ctx)?;
443 return if cond.to_bool() {
444 Value::resolve(&call.args[1], ctx)
445 } else {
446 Value::resolve(&call.args[2], ctx)
447 };
448 }
449 if call.args.len() == 2 {
450 let left = Value::resolve(&call.args[0], ctx)?;
451 match call.func_name.as_str() {
452 operators::ADD => return left + Value::resolve(&call.args[1], ctx)?,
453 operators::SUBSTRACT => return left - Value::resolve(&call.args[1], ctx)?,
454 operators::DIVIDE => return left / Value::resolve(&call.args[1], ctx)?,
455 operators::MULTIPLY => return left * Value::resolve(&call.args[1], ctx)?,
456 operators::MODULO => return left % Value::resolve(&call.args[1], ctx)?,
457 operators::EQUALS => {
458 return Value::Bool(left.eq(&Value::resolve(&call.args[1], ctx)?))
459 .into()
460 }
461 operators::NOT_EQUALS => {
462 return Value::Bool(left.ne(&Value::resolve(&call.args[1], ctx)?))
463 .into()
464 }
465 operators::LESS => {
466 let right = Value::resolve(&call.args[1], ctx)?;
467 return Value::Bool(
468 left.partial_cmp(&right)
469 .ok_or(ExecutionError::ValuesNotComparable(left, right))?
470 == Ordering::Less,
471 )
472 .into();
473 }
474 operators::LESS_EQUALS => {
475 let right = Value::resolve(&call.args[1], ctx)?;
476 return Value::Bool(
477 left.partial_cmp(&right)
478 .ok_or(ExecutionError::ValuesNotComparable(left, right))?
479 != Ordering::Greater,
480 )
481 .into();
482 }
483 operators::GREATER => {
484 let right = Value::resolve(&call.args[1], ctx)?;
485 return Value::Bool(
486 left.partial_cmp(&right)
487 .ok_or(ExecutionError::ValuesNotComparable(left, right))?
488 == Ordering::Greater,
489 )
490 .into();
491 }
492 operators::GREATER_EQUALS => {
493 let right = Value::resolve(&call.args[1], ctx)?;
494 return Value::Bool(
495 left.partial_cmp(&right)
496 .ok_or(ExecutionError::ValuesNotComparable(left, right))?
497 != Ordering::Less,
498 )
499 .into();
500 }
501 operators::IN => {
502 let right = Value::resolve(&call.args[1], ctx)?;
503 match (left, right) {
504 (Value::String(l), Value::String(r)) => {
505 return Value::Bool(r.contains(&*l)).into()
506 }
507 (any, Value::List(v)) => {
508 return Value::Bool(v.contains(&any)).into()
509 }
510 (any, Value::Map(m)) => match any.try_into() {
511 Ok(key) => return Value::Bool(m.map.contains_key(&key)).into(),
512 Err(_) => return Value::Bool(false).into(),
513 },
514 (left, right) => {
515 Err(ExecutionError::ValuesNotComparable(left, right))?
516 }
517 }
518 }
519 operators::LOGICAL_OR => {
520 return if left.to_bool() {
521 left.into()
522 } else {
523 Value::resolve(&call.args[1], ctx)
524 };
525 }
526 operators::LOGICAL_AND => {
527 return if !left.to_bool() {
528 Value::Bool(false)
529 } else {
530 let right = Value::resolve(&call.args[1], ctx)?;
531 Value::Bool(right.to_bool())
532 }
533 .into();
534 }
535 operators::INDEX => {
536 let value = left;
537 let idx = Value::resolve(&call.args[1], ctx)?;
538 return match (value, idx) {
539 (Value::List(items), Value::Int(idx)) => items
540 .get(idx as usize)
541 .cloned()
542 .unwrap_or(Value::Null)
543 .into(),
544 (Value::String(str), Value::Int(idx)) => {
545 match str.get(idx as usize..(idx + 1) as usize) {
546 None => Ok(Value::Null),
547 Some(str) => Ok(Value::String(str.to_string().into())),
548 }
549 }
550 (Value::Map(map), Value::String(property)) => map
551 .get(&property.into())
552 .cloned()
553 .unwrap_or(Value::Null)
554 .into(),
555 (Value::Map(map), Value::Bool(property)) => map
556 .get(&property.into())
557 .cloned()
558 .unwrap_or(Value::Null)
559 .into(),
560 (Value::Map(map), Value::Int(property)) => map
561 .get(&property.into())
562 .cloned()
563 .unwrap_or(Value::Null)
564 .into(),
565 (Value::Map(map), Value::UInt(property)) => map
566 .get(&property.into())
567 .cloned()
568 .unwrap_or(Value::Null)
569 .into(),
570 (Value::Map(_), index) => {
571 Err(ExecutionError::UnsupportedMapIndex(index))
572 }
573 (Value::List(_), index) => {
574 Err(ExecutionError::UnsupportedListIndex(index))
575 }
576 (value, index) => {
577 Err(ExecutionError::UnsupportedIndex(value, index))
578 }
579 };
580 }
581 _ => (),
582 }
583 }
584 if call.args.len() == 1 {
585 let expr = Value::resolve(&call.args[0], ctx)?;
586 match call.func_name.as_str() {
587 operators::LOGICAL_NOT => return Ok(Value::Bool(!expr.to_bool())),
588 operators::NEGATE => {
589 return match expr {
590 Value::Int(i) => Ok(Value::Int(-i)),
591 Value::Float(f) => Ok(Value::Float(-f)),
592 value => {
593 Err(ExecutionError::UnsupportedUnaryOperator("minus", value))
594 }
595 }
596 }
597 operators::NOT_STRICTLY_FALSE => {
598 return match expr {
599 Value::Bool(b) => Ok(Value::Bool(b)),
600 _ => Ok(Value::Bool(true)),
601 }
602 }
603 _ => (),
604 }
605 }
606 let func = ctx.get_function(call.func_name.as_str()).ok_or_else(|| {
607 ExecutionError::UndeclaredReference(call.func_name.clone().into())
608 })?;
609 match &call.target {
610 None => {
611 let mut ctx = FunctionContext::new(
612 call.func_name.clone().into(),
613 None,
614 ctx,
615 call.args.clone(),
616 );
617 (func)(&mut ctx)
618 }
619 Some(target) => {
620 let mut ctx = FunctionContext::new(
621 call.func_name.clone().into(),
622 Some(Value::resolve(target, ctx)?),
623 ctx,
624 call.args.clone(),
625 );
626 (func)(&mut ctx)
627 }
628 }
629 }
630 Expr::Ident(name) => ctx.get_variable(name),
631 Expr::Select(select) => {
632 let left = Value::resolve(select.operand.deref(), ctx)?;
633 if select.test {
634 match &left {
635 Value::Map(map) => {
636 for key in map.map.deref().keys() {
637 if key.to_string().eq(&select.field) {
638 return Ok(Value::Bool(true));
639 }
640 }
641 Ok(Value::Bool(false))
642 }
643 _ => Ok(Value::Bool(false)),
644 }
645 } else {
646 left.member(&select.field, ctx)
647 }
648 }
649 Expr::List(list_expr) => {
650 let list = list_expr
651 .elements
652 .iter()
653 .map(|i| Value::resolve(i, ctx))
654 .collect::<Result<Vec<_>, _>>()?;
655 Value::List(list.into()).into()
656 }
657 Expr::Map(map_expr) => {
658 let mut map = HashMap::with_capacity(map_expr.entries.len());
659 for entry in map_expr.entries.iter() {
660 let (k, v) = match &entry.expr {
661 EntryExpr::StructField(_) => panic!("WAT?"),
662 EntryExpr::MapEntry(e) => (&e.key, &e.value),
663 };
664 let key = Value::resolve(k, ctx)?
665 .try_into()
666 .map_err(ExecutionError::UnsupportedKeyType)?;
667 let value = Value::resolve(v, ctx)?;
668 map.insert(key, value);
669 }
670 Ok(Value::Map(Map {
671 map: Arc::from(map),
672 }))
673 }
674 Expr::Comprehension(comprehension) => {
675 let accu_init = Value::resolve(comprehension.accu_init.deref(), ctx)?;
676 let iter = Value::resolve(comprehension.iter_range.deref(), ctx)?;
677 let mut ctx = ctx.new_inner_scope();
678 ctx.add_variable(&comprehension.accu_var, accu_init)
679 .expect("Failed to add accu variable");
680
681 match iter {
682 Value::List(items) => {
683 for item in items.deref() {
684 if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
685 break;
686 }
687 ctx.add_variable_from_value(&comprehension.iter_var, item.clone());
688 let accu = Value::resolve(comprehension.loop_step.deref(), &ctx)?;
689 ctx.add_variable_from_value(&comprehension.accu_var, accu);
690 }
691 }
692 Value::Map(map) => {
693 for key in map.map.deref().keys() {
694 if !Value::resolve(&comprehension.loop_cond, &ctx)?.to_bool() {
695 break;
696 }
697 ctx.add_variable_from_value(&comprehension.iter_var, key.clone());
698 let accu = Value::resolve(comprehension.loop_step.deref(), &ctx)?;
699 ctx.add_variable_from_value(&comprehension.accu_var, accu);
700 }
701 }
702 t => todo!("Support {t:?}"),
703 }
704 Value::resolve(comprehension.result.deref(), &ctx)
705 }
706 Expr::Struct(_) => todo!("Support structs!"),
707 Expr::Unspecified => panic!("Can't evaluate Unspecified Expr"),
708 }
709 }
710
711 fn member(self, name: &str, ctx: &Context) -> ResolveResult {
720 let name: Arc<String> = name.to_owned().into();
723
724 let child = match self {
727 Value::Map(ref m) => m.map.get(&name.clone().into()).cloned(),
728 _ => None,
729 };
730
731 if let Some(child) = child {
735 child.into()
736 } else if ctx.has_function(&name) {
737 Value::Function(name.clone(), Some(self.into())).into()
738 } else {
739 ExecutionError::NoSuchKey(name.clone()).into()
740 }
741 }
742
743 #[inline(always)]
744 fn to_bool(&self) -> bool {
745 match self {
746 Value::List(v) => !v.is_empty(),
747 Value::Map(v) => !v.map.is_empty(),
748 Value::Int(v) => *v != 0,
749 Value::UInt(v) => *v != 0,
750 Value::Float(v) => *v != 0.0,
751 Value::String(v) => !v.is_empty(),
752 Value::Bytes(v) => !v.is_empty(),
753 Value::Bool(v) => *v,
754 Value::Null => false,
755 #[cfg(feature = "chrono")]
756 Value::Duration(v) => v.num_nanoseconds().map(|n| n != 0).unwrap_or(false),
757 #[cfg(feature = "chrono")]
758 Value::Timestamp(v) => v.timestamp_nanos_opt().unwrap_or_default() > 0,
759 Value::Function(_, _) => false,
760 }
761 }
762}
763
764impl ops::Add<Value> for Value {
765 type Output = ResolveResult;
766
767 #[inline(always)]
768 fn add(self, rhs: Value) -> Self::Output {
769 match (self, rhs) {
770 (Value::Int(l), Value::Int(r)) => l
771 .checked_add(r)
772 .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
773 .map(Value::Int),
774
775 (Value::UInt(l), Value::UInt(r)) => l
776 .checked_add(r)
777 .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
778 .map(Value::UInt),
779
780 (Value::Float(l), Value::Float(r)) => Value::Float(l + r).into(),
781
782 (Value::List(mut l), Value::List(mut r)) => {
783 {
784 let l = Arc::make_mut(&mut l);
787
788 match Arc::get_mut(&mut r) {
791 Some(r) => l.append(r),
792 None => l.extend(r.iter().cloned()),
793 }
794 }
795
796 Ok(Value::List(l))
797 }
798 (Value::String(mut l), Value::String(r)) => {
799 Arc::make_mut(&mut l).push_str(&r);
802 Ok(Value::String(l))
803 }
804 #[cfg(feature = "chrono")]
805 (Value::Duration(l), Value::Duration(r)) => l
806 .checked_add(&r)
807 .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
808 .map(Value::Duration),
809 #[cfg(feature = "chrono")]
810 (Value::Timestamp(l), Value::Duration(r)) => l
811 .checked_add_signed(r)
812 .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
813 .map(Value::Timestamp),
814 #[cfg(feature = "chrono")]
815 (Value::Duration(l), Value::Timestamp(r)) => r
816 .checked_add_signed(l)
817 .ok_or(ExecutionError::Overflow("add", l.into(), r.into()))
818 .map(Value::Timestamp),
819 (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
820 "add", left, right,
821 )),
822 }
823 }
824}
825
826impl ops::Sub<Value> for Value {
827 type Output = ResolveResult;
828
829 #[inline(always)]
830 fn sub(self, rhs: Value) -> Self::Output {
831 match (self, rhs) {
832 (Value::Int(l), Value::Int(r)) => l
833 .checked_sub(r)
834 .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
835 .map(Value::Int),
836
837 (Value::UInt(l), Value::UInt(r)) => l
838 .checked_sub(r)
839 .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
840 .map(Value::UInt),
841
842 (Value::Float(l), Value::Float(r)) => Value::Float(l - r).into(),
843
844 #[cfg(feature = "chrono")]
845 (Value::Duration(l), Value::Duration(r)) => l
846 .checked_sub(&r)
847 .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
848 .map(Value::Duration),
849 #[cfg(feature = "chrono")]
850 (Value::Timestamp(l), Value::Duration(r)) => l
851 .checked_sub_signed(r)
852 .ok_or(ExecutionError::Overflow("sub", l.into(), r.into()))
853 .map(Value::Timestamp),
854 #[cfg(feature = "chrono")]
855 (Value::Timestamp(l), Value::Timestamp(r)) => {
856 Value::Duration(l.signed_duration_since(r)).into()
857 }
858 (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
859 "sub", left, right,
860 )),
861 }
862 }
863}
864
865impl ops::Div<Value> for Value {
866 type Output = ResolveResult;
867
868 #[inline(always)]
869 fn div(self, rhs: Value) -> Self::Output {
870 match (self, rhs) {
871 (Value::Int(l), Value::Int(r)) => {
872 if r == 0 {
873 Err(ExecutionError::DivisionByZero(l.into()))
874 } else {
875 l.checked_div(r)
876 .ok_or(ExecutionError::Overflow("div", l.into(), r.into()))
877 .map(Value::Int)
878 }
879 }
880
881 (Value::UInt(l), Value::UInt(r)) => l
882 .checked_div(r)
883 .ok_or(ExecutionError::DivisionByZero(l.into()))
884 .map(Value::UInt),
885
886 (Value::Float(l), Value::Float(r)) => Value::Float(l / r).into(),
887
888 (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
889 "div", left, right,
890 )),
891 }
892 }
893}
894
895impl ops::Mul<Value> for Value {
896 type Output = ResolveResult;
897
898 #[inline(always)]
899 fn mul(self, rhs: Value) -> Self::Output {
900 match (self, rhs) {
901 (Value::Int(l), Value::Int(r)) => l
902 .checked_mul(r)
903 .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
904 .map(Value::Int),
905
906 (Value::UInt(l), Value::UInt(r)) => l
907 .checked_mul(r)
908 .ok_or(ExecutionError::Overflow("mul", l.into(), r.into()))
909 .map(Value::UInt),
910
911 (Value::Float(l), Value::Float(r)) => Value::Float(l * r).into(),
912
913 (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
914 "mul", left, right,
915 )),
916 }
917 }
918}
919
920impl ops::Rem<Value> for Value {
921 type Output = ResolveResult;
922
923 #[inline(always)]
924 fn rem(self, rhs: Value) -> Self::Output {
925 match (self, rhs) {
926 (Value::Int(l), Value::Int(r)) => {
927 if r == 0 {
928 Err(ExecutionError::RemainderByZero(l.into()))
929 } else {
930 l.checked_rem(r)
931 .ok_or(ExecutionError::Overflow("rem", l.into(), r.into()))
932 .map(Value::Int)
933 }
934 }
935
936 (Value::UInt(l), Value::UInt(r)) => l
937 .checked_rem(r)
938 .ok_or(ExecutionError::RemainderByZero(l.into()))
939 .map(Value::UInt),
940
941 (left, right) => Err(ExecutionError::UnsupportedBinaryOperator(
942 "rem", left, right,
943 )),
944 }
945 }
946}
947
948#[cfg(test)]
949mod tests {
950 use crate::{objects::Key, Context, ExecutionError, Program, Value};
951 use std::collections::HashMap;
952 use std::sync::Arc;
953
954 #[test]
955 fn test_indexed_map_access() {
956 let mut context = Context::default();
957 let mut headers = HashMap::new();
958 headers.insert("Content-Type", "application/json".to_string());
959 context.add_variable_from_value("headers", headers);
960
961 let program = Program::compile("headers[\"Content-Type\"]").unwrap();
962 let value = program.execute(&context).unwrap();
963 assert_eq!(value, "application/json".into());
964 }
965
966 #[test]
967 fn test_numeric_map_access() {
968 let mut context = Context::default();
969 let mut numbers = HashMap::new();
970 numbers.insert(Key::Uint(1), "one".to_string());
971 context.add_variable_from_value("numbers", numbers);
972
973 let program = Program::compile("numbers[1]").unwrap();
974 let value = program.execute(&context).unwrap();
975 assert_eq!(value, "one".into());
976 }
977
978 #[test]
979 fn test_heterogeneous_compare() {
980 let context = Context::default();
981
982 let program = Program::compile("1 < uint(2)").unwrap();
983 let value = program.execute(&context).unwrap();
984 assert_eq!(value, true.into());
985
986 let program = Program::compile("1 < 1.1").unwrap();
987 let value = program.execute(&context).unwrap();
988 assert_eq!(value, true.into());
989
990 let program = Program::compile("uint(0) > -10").unwrap();
991 let value = program.execute(&context).unwrap();
992 assert_eq!(
993 value,
994 true.into(),
995 "negative signed ints should be less than uints"
996 );
997 }
998
999 #[test]
1000 fn test_float_compare() {
1001 let context = Context::default();
1002
1003 let program = Program::compile("1.0 > 0.0").unwrap();
1004 let value = program.execute(&context).unwrap();
1005 assert_eq!(value, true.into());
1006
1007 let program = Program::compile("double('NaN') == double('NaN')").unwrap();
1008 let value = program.execute(&context).unwrap();
1009 assert_eq!(value, false.into(), "NaN should not equal itself");
1010
1011 let program = Program::compile("1.0 > double('NaN')").unwrap();
1012 let result = program.execute(&context);
1013 assert!(
1014 result.is_err(),
1015 "NaN should not be comparable with inequality operators"
1016 );
1017 }
1018
1019 #[test]
1020 fn test_invalid_compare() {
1021 let context = Context::default();
1022
1023 let program = Program::compile("{} == []").unwrap();
1024 let value = program.execute(&context).unwrap();
1025 assert_eq!(value, false.into());
1026 }
1027
1028 #[test]
1029 fn test_size_fn_var() {
1030 let program = Program::compile("size(requests) + size == 5").unwrap();
1031 let mut context = Context::default();
1032 let requests = vec![Value::Int(42), Value::Int(42)];
1033 context
1034 .add_variable("requests", Value::List(Arc::new(requests)))
1035 .unwrap();
1036 context.add_variable("size", Value::Int(3)).unwrap();
1037 assert_eq!(program.execute(&context).unwrap(), Value::Bool(true));
1038 }
1039
1040 fn test_execution_error(program: &str, expected: ExecutionError) {
1041 let program = Program::compile(program).unwrap();
1042 let result = program.execute(&Context::default());
1043 assert_eq!(result.unwrap_err(), expected);
1044 }
1045
1046 #[test]
1047 fn test_invalid_sub() {
1048 test_execution_error(
1049 "'foo' - 10",
1050 ExecutionError::UnsupportedBinaryOperator("sub", "foo".into(), Value::Int(10)),
1051 );
1052 }
1053
1054 #[test]
1055 fn test_invalid_add() {
1056 test_execution_error(
1057 "'foo' + 10",
1058 ExecutionError::UnsupportedBinaryOperator("add", "foo".into(), Value::Int(10)),
1059 );
1060 }
1061
1062 #[test]
1063 fn test_invalid_div() {
1064 test_execution_error(
1065 "'foo' / 10",
1066 ExecutionError::UnsupportedBinaryOperator("div", "foo".into(), Value::Int(10)),
1067 );
1068 }
1069
1070 #[test]
1071 fn test_invalid_rem() {
1072 test_execution_error(
1073 "'foo' % 10",
1074 ExecutionError::UnsupportedBinaryOperator("rem", "foo".into(), Value::Int(10)),
1075 );
1076 }
1077
1078 #[test]
1079 fn out_of_bound_list_access() {
1080 let program = Program::compile("list[10]").unwrap();
1081 let mut context = Context::default();
1082 context
1083 .add_variable("list", Value::List(Arc::new(vec![])))
1084 .unwrap();
1085 let result = program.execute(&context);
1086 assert_eq!(result.unwrap(), Value::Null);
1087 }
1088
1089 #[test]
1090 fn reference_to_value() {
1091 let test = "example".to_string();
1092 let direct: Value = test.as_str().into();
1093 assert_eq!(direct, Value::String(Arc::new(String::from("example"))));
1094
1095 let vec = vec![test.as_str()];
1096 let indirect: Value = vec.into();
1097 assert_eq!(
1098 indirect,
1099 Value::List(Arc::new(vec![Value::String(Arc::new(String::from(
1100 "example"
1101 )))]))
1102 );
1103 }
1104
1105 #[test]
1106 fn test_short_circuit_and() {
1107 let mut context = Context::default();
1108 let data: HashMap<String, String> = HashMap::new();
1109 context.add_variable_from_value("data", data);
1110
1111 let program = Program::compile("has(data.x) && data.x.startsWith(\"foo\")").unwrap();
1112 let value = program.execute(&context);
1113 println!("{value:?}");
1114 assert!(
1115 value.is_ok(),
1116 "The AND expression should support short-circuit evaluation."
1117 );
1118 }
1119
1120 #[test]
1121 fn invalid_int_math() {
1122 use ExecutionError::*;
1123
1124 let cases = [
1125 ("1 / 0", DivisionByZero(1.into())),
1126 ("1 % 0", RemainderByZero(1.into())),
1127 (
1128 &format!("{} + 1", i64::MAX),
1129 Overflow("add", i64::MAX.into(), 1.into()),
1130 ),
1131 (
1132 &format!("{} - 1", i64::MIN),
1133 Overflow("sub", i64::MIN.into(), 1.into()),
1134 ),
1135 (
1136 &format!("{} * 2", i64::MAX),
1137 Overflow("mul", i64::MAX.into(), 2.into()),
1138 ),
1139 (
1140 &format!("{} / -1", i64::MIN),
1141 Overflow("div", i64::MIN.into(), (-1).into()),
1142 ),
1143 (
1144 &format!("{} % -1", i64::MIN),
1145 Overflow("rem", i64::MIN.into(), (-1).into()),
1146 ),
1147 ];
1148
1149 for (expr, err) in cases {
1150 test_execution_error(expr, err);
1151 }
1152 }
1153
1154 #[test]
1155 fn invalid_uint_math() {
1156 use ExecutionError::*;
1157
1158 let cases = [
1159 ("1u / 0u", DivisionByZero(1u64.into())),
1160 ("1u % 0u", RemainderByZero(1u64.into())),
1161 (
1162 &format!("{}u + 1u", u64::MAX),
1163 Overflow("add", u64::MAX.into(), 1u64.into()),
1164 ),
1165 ("0u - 1u", Overflow("sub", 0u64.into(), 1u64.into())),
1166 (
1167 &format!("{}u * 2u", u64::MAX),
1168 Overflow("mul", u64::MAX.into(), 2u64.into()),
1169 ),
1170 ];
1171
1172 for (expr, err) in cases {
1173 test_execution_error(expr, err);
1174 }
1175 }
1176}