1use crate::context::Context;
2use crate::magic::{Arguments, This};
3use crate::objects::Value;
4use crate::resolvers::Resolver;
5use crate::ExecutionError;
6use cel_parser::Expression;
7use std::cmp::Ordering;
8use std::convert::TryInto;
9use std::sync::Arc;
10
11type Result<T> = std::result::Result<T, ExecutionError>;
12
13#[derive(Clone)]
19pub struct FunctionContext<'context> {
20 pub name: Arc<String>,
21 pub this: Option<Value>,
22 pub ptx: &'context Context<'context>,
23 pub args: Vec<Expression>,
24 pub arg_idx: usize,
25}
26
27impl<'context> FunctionContext<'context> {
28 pub fn new(
29 name: Arc<String>,
30 this: Option<Value>,
31 ptx: &'context Context<'context>,
32 args: Vec<Expression>,
33 ) -> Self {
34 Self {
35 name,
36 this,
37 ptx,
38 args,
39 arg_idx: 0,
40 }
41 }
42
43 pub fn resolve<R>(&self, resolver: R) -> Result<Value>
45 where
46 R: Resolver,
47 {
48 resolver.resolve(self)
49 }
50
51 pub fn error<M: ToString>(&self, message: M) -> ExecutionError {
53 ExecutionError::function_error(self.name.as_str(), message)
54 }
55}
56
57pub fn size(ftx: &FunctionContext, This(this): This<Value>) -> Result<i64> {
77 let size = match this {
78 Value::List(l) => l.len(),
79 Value::Map(m) => m.map.len(),
80 Value::String(s) => s.len(),
81 Value::Bytes(b) => b.len(),
82 value => return Err(ftx.error(format!("cannot determine the size of {value:?}"))),
83 };
84 Ok(size as i64)
85}
86
87pub fn contains(This(this): This<Value>, arg: Value) -> Result<Value> {
118 Ok(match this {
119 Value::List(v) => v.contains(&arg),
120 Value::Map(v) => v
121 .map
122 .contains_key(&arg.try_into().map_err(ExecutionError::UnsupportedKeyType)?),
123 Value::String(s) => {
124 if let Value::String(arg) = arg {
125 s.contains(arg.as_str())
126 } else {
127 false
128 }
129 }
130 Value::Bytes(b) => {
131 if let Value::Bytes(arg) = arg {
132 let s = arg.as_slice();
133 b.windows(arg.len()).any(|w| w == s)
134 } else {
135 false
136 }
137 }
138 _ => false,
139 }
140 .into())
141}
142
143pub fn string(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
153 Ok(match this {
154 Value::String(v) => Value::String(v.clone()),
155 #[cfg(feature = "chrono")]
156 Value::Timestamp(t) => Value::String(t.to_rfc3339().into()),
157 #[cfg(feature = "chrono")]
158 Value::Duration(v) => Value::String(crate::duration::format_duration(&v).into()),
159 Value::Int(v) => Value::String(v.to_string().into()),
160 Value::UInt(v) => Value::String(v.to_string().into()),
161 Value::Float(v) => Value::String(v.to_string().into()),
162 Value::Bytes(v) => Value::String(Arc::new(String::from_utf8_lossy(v.as_slice()).into())),
163 v => return Err(ftx.error(format!("cannot convert {v:?} to string"))),
164 })
165}
166
167pub fn bytes(value: Arc<String>) -> Result<Value> {
168 Ok(Value::Bytes(value.as_bytes().to_vec().into()))
169}
170
171pub fn double(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
173 Ok(match this {
174 Value::String(v) => v
175 .parse::<f64>()
176 .map(Value::Float)
177 .map_err(|e| ftx.error(format!("string parse error: {e}")))?,
178 Value::Float(v) => Value::Float(v),
179 Value::Int(v) => Value::Float(v as f64),
180 Value::UInt(v) => Value::Float(v as f64),
181 v => return Err(ftx.error(format!("cannot convert {v:?} to double"))),
182 })
183}
184
185pub fn uint(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
187 Ok(match this {
188 Value::String(v) => v
189 .parse::<u64>()
190 .map(Value::UInt)
191 .map_err(|e| ftx.error(format!("string parse error: {e}")))?,
192 Value::Float(v) => {
193 if v > u64::MAX as f64 || v < u64::MIN as f64 {
194 return Err(ftx.error("unsigned integer overflow"));
195 }
196 Value::UInt(v as u64)
197 }
198 Value::Int(v) => Value::UInt(
199 v.try_into()
200 .map_err(|_| ftx.error("unsigned integer overflow"))?,
201 ),
202 Value::UInt(v) => Value::UInt(v),
203 v => return Err(ftx.error(format!("cannot convert {v:?} to uint"))),
204 })
205}
206
207pub fn int(ftx: &FunctionContext, This(this): This<Value>) -> Result<Value> {
209 Ok(match this {
210 Value::String(v) => v
211 .parse::<i64>()
212 .map(Value::Int)
213 .map_err(|e| ftx.error(format!("string parse error: {e}")))?,
214 Value::Float(v) => {
215 if v > i64::MAX as f64 || v < i64::MIN as f64 {
216 return Err(ftx.error("integer overflow"));
217 }
218 Value::Int(v as i64)
219 }
220 Value::Int(v) => Value::Int(v),
221 Value::UInt(v) => Value::Int(v.try_into().map_err(|_| ftx.error("integer overflow"))?),
222 v => return Err(ftx.error(format!("cannot convert {v:?} to int"))),
223 })
224}
225
226pub fn starts_with(This(this): This<Arc<String>>, prefix: Arc<String>) -> bool {
233 this.starts_with(prefix.as_str())
234}
235
236pub fn ends_with(This(this): This<Arc<String>>, suffix: Arc<String>) -> bool {
243 this.ends_with(suffix.as_str())
244}
245
246#[cfg(feature = "regex")]
253pub fn matches(
254 ftx: &FunctionContext,
255 This(this): This<Arc<String>>,
256 regex: Arc<String>,
257) -> Result<bool> {
258 match regex::Regex::new(®ex) {
259 Ok(re) => Ok(re.is_match(&this)),
260 Err(err) => Err(ftx.error(format!("'{regex}' not a valid regex:\n{err}"))),
261 }
262}
263
264#[cfg(feature = "chrono")]
265pub use time::duration;
266#[cfg(feature = "chrono")]
267pub use time::timestamp;
268
269#[cfg(feature = "chrono")]
270pub mod time {
271 use super::Result;
272 use crate::magic::This;
273 use crate::{ExecutionError, Value};
274 use chrono::{Datelike, Days, Months, Timelike};
275 use std::sync::Arc;
276
277 pub fn duration(value: Arc<String>) -> crate::functions::Result<Value> {
293 Ok(Value::Duration(_duration(value.as_str())?))
294 }
295
296 pub fn timestamp(value: Arc<String>) -> Result<Value> {
299 Ok(Value::Timestamp(
300 chrono::DateTime::parse_from_rfc3339(value.as_str())
301 .map_err(|e| ExecutionError::function_error("timestamp", e.to_string().as_str()))?,
302 ))
303 }
304
305 fn _duration(i: &str) -> Result<chrono::Duration> {
308 let (_, duration) = crate::duration::parse_duration(i)
309 .map_err(|e| ExecutionError::function_error("duration", e.to_string()))?;
310 Ok(duration)
311 }
312
313 fn _timestamp(i: &str) -> Result<chrono::DateTime<chrono::FixedOffset>> {
314 chrono::DateTime::parse_from_rfc3339(i)
315 .map_err(|e| ExecutionError::function_error("timestamp", e.to_string()))
316 }
317
318 pub fn timestamp_year(
319 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
320 ) -> Result<Value> {
321 Ok(this.year().into())
322 }
323
324 pub fn timestamp_month(
325 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
326 ) -> Result<Value> {
327 Ok((this.month0() as i32).into())
328 }
329
330 pub fn timestamp_year_day(
331 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
332 ) -> Result<Value> {
333 let year = this
334 .checked_sub_days(Days::new(this.day0() as u64))
335 .unwrap()
336 .checked_sub_months(Months::new(this.month0()))
337 .unwrap();
338 Ok(this.signed_duration_since(year).num_days().into())
339 }
340
341 pub fn timestamp_month_day(
342 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
343 ) -> Result<Value> {
344 Ok((this.day0() as i32).into())
345 }
346
347 pub fn timestamp_date(
348 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
349 ) -> Result<Value> {
350 Ok((this.day() as i32).into())
351 }
352
353 pub fn timestamp_weekday(
354 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
355 ) -> Result<Value> {
356 Ok((this.weekday().num_days_from_sunday() as i32).into())
357 }
358
359 pub fn timestamp_hours(
360 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
361 ) -> Result<Value> {
362 Ok((this.hour() as i32).into())
363 }
364
365 pub fn timestamp_minutes(
366 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
367 ) -> Result<Value> {
368 Ok((this.minute() as i32).into())
369 }
370
371 pub fn timestamp_seconds(
372 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
373 ) -> Result<Value> {
374 Ok((this.second() as i32).into())
375 }
376
377 pub fn timestamp_millis(
378 This(this): This<chrono::DateTime<chrono::FixedOffset>>,
379 ) -> Result<Value> {
380 Ok((this.timestamp_subsec_millis() as i32).into())
381 }
382}
383
384pub fn max(Arguments(args): Arguments) -> Result<Value> {
385 let items = if args.len() == 1 {
387 match &args[0] {
388 Value::List(values) => values,
389 _ => return Ok(args[0].clone()),
390 }
391 } else {
392 &args
393 };
394
395 items
396 .iter()
397 .skip(1)
398 .try_fold(items.first().unwrap_or(&Value::Null), |acc, x| {
399 match acc.partial_cmp(x) {
400 Some(Ordering::Greater) => Ok(acc),
401 Some(_) => Ok(x),
402 None => Err(ExecutionError::ValuesNotComparable(acc.clone(), x.clone())),
403 }
404 })
405 .cloned()
406}
407
408pub fn min(Arguments(args): Arguments) -> Result<Value> {
409 let items = if args.len() == 1 {
411 match &args[0] {
412 Value::List(values) => values,
413 _ => return Ok(args[0].clone()),
414 }
415 } else {
416 &args
417 };
418
419 items
420 .iter()
421 .skip(1)
422 .try_fold(items.first().unwrap_or(&Value::Null), |acc, x| {
423 match acc.partial_cmp(x) {
424 Some(Ordering::Less) => Ok(acc),
425 Some(_) => Ok(x),
426 None => Err(ExecutionError::ValuesNotComparable(acc.clone(), x.clone())),
427 }
428 })
429 .cloned()
430}
431
432#[cfg(test)]
433mod tests {
434 use crate::context::Context;
435 use crate::tests::test_script;
436
437 fn assert_script(input: &(&str, &str)) {
438 assert_eq!(test_script(input.1, None), Ok(true.into()), "{}", input.0);
439 }
440
441 #[test]
442 fn test_size() {
443 [
444 ("size of list", "size([1, 2, 3]) == 3"),
445 ("size of map", "size({'a': 1, 'b': 2, 'c': 3}) == 3"),
446 ("size of string", "size('foo') == 3"),
447 ("size of bytes", "size(b'foo') == 3"),
448 ("size as a list method", "[1, 2, 3].size() == 3"),
449 ("size as a string method", "'foobar'.size() == 6"),
450 ]
451 .iter()
452 .for_each(assert_script);
453 }
454
455 #[test]
456 fn test_has() {
457 let tests = vec![
458 ("map has", "has(foo.bar) == true"),
459 ("map not has", "has(foo.baz) == false"),
460 ];
461
462 for (name, script) in tests {
463 let mut ctx = Context::default();
464 ctx.add_variable_from_value("foo", std::collections::HashMap::from([("bar", 1)]));
465 assert_eq!(test_script(script, Some(ctx)), Ok(true.into()), "{name}");
466 }
467 }
468
469 #[test]
470 fn test_map() {
471 [
472 ("map list", "[1, 2, 3].map(x, x * 2) == [2, 4, 6]"),
473 ("map list 2", "[1, 2, 3].map(y, y + 1) == [2, 3, 4]"),
474 (
475 "map list filter",
476 "[1, 2, 3].map(y, y % 2 == 0, y + 1) == [3]",
477 ),
478 (
479 "nested map",
480 "[[1, 2], [2, 3]].map(x, x.map(x, x * 2)) == [[2, 4], [4, 6]]",
481 ),
482 (
483 "map to list",
484 r#"{'John': 'smart'}.map(key, key) == ['John']"#,
485 ),
486 ]
487 .iter()
488 .for_each(assert_script);
489 }
490
491 #[test]
492 fn test_filter() {
493 [("filter list", "[1, 2, 3].filter(x, x > 2) == [3]")]
494 .iter()
495 .for_each(assert_script);
496 }
497
498 #[test]
499 fn test_all() {
500 [
501 ("all list #1", "[0, 1, 2].all(x, x >= 0)"),
502 ("all list #2", "[0, 1, 2].all(x, x > 0) == false"),
503 ("all map", "{0: 0, 1:1, 2:2}.all(x, x >= 0) == true"),
504 ]
505 .iter()
506 .for_each(assert_script);
507 }
508
509 #[test]
510 fn test_exists() {
511 [
512 ("exist list #1", "[0, 1, 2].exists(x, x > 0)"),
513 ("exist list #2", "[0, 1, 2].exists(x, x == 3) == false"),
514 ("exist list #3", "[0, 1, 2, 2].exists(x, x == 2)"),
515 ("exist map", "{0: 0, 1:1, 2:2}.exists(x, x > 0)"),
516 ]
517 .iter()
518 .for_each(assert_script);
519 }
520
521 #[test]
522 fn test_exists_one() {
523 [
524 ("exist list #1", "[0, 1, 2].exists_one(x, x > 0) == false"),
525 ("exist list #2", "[0, 1, 2].exists_one(x, x == 0)"),
526 ("exist map", "{0: 0, 1:1, 2:2}.exists_one(x, x == 2)"),
527 ]
528 .iter()
529 .for_each(assert_script);
530 }
531
532 #[test]
533 fn test_max() {
534 [
535 ("max single", "max(1) == 1"),
536 ("max multiple", "max(1, 2, 3) == 3"),
537 ("max negative", "max(-1, 0) == 0"),
538 ("max float", "max(-1.0, 0.0) == 0.0"),
539 ("max list", "max([1, 2, 3]) == 3"),
540 ("max empty list", "max([]) == null"),
541 ("max no args", "max() == null"),
542 ]
543 .iter()
544 .for_each(assert_script);
545 }
546
547 #[test]
548 fn test_min() {
549 [
550 ("min single", "min(1) == 1"),
551 ("min multiple", "min(1, 2, 3) == 1"),
552 ("min negative", "min(-1, 0) == -1"),
553 ("min float", "min(-1.0, 0.0) == -1.0"),
554 (
555 "min float multiple",
556 "min(1.61803, 3.1415, 2.71828, 1.41421) == 1.41421",
557 ),
558 ("min list", "min([1, 2, 3]) == 1"),
559 ("min empty list", "min([]) == null"),
560 ("min no args", "min() == null"),
561 ]
562 .iter()
563 .for_each(assert_script);
564 }
565
566 #[test]
567 fn test_starts_with() {
568 [
569 ("starts with true", "'foobar'.startsWith('foo') == true"),
570 ("starts with false", "'foobar'.startsWith('bar') == false"),
571 ]
572 .iter()
573 .for_each(assert_script);
574 }
575
576 #[test]
577 fn test_ends_with() {
578 [
579 ("ends with true", "'foobar'.endsWith('bar') == true"),
580 ("ends with false", "'foobar'.endsWith('foo') == false"),
581 ]
582 .iter()
583 .for_each(assert_script);
584 }
585
586 #[cfg(feature = "chrono")]
587 #[test]
588 fn test_timestamp() {
589 [(
590 "comparison",
591 "timestamp('2023-05-29T00:00:00Z') > timestamp('2023-05-28T00:00:00Z')",
592 ),
593 (
594 "comparison",
595 "timestamp('2023-05-29T00:00:00Z') < timestamp('2023-05-30T00:00:00Z')",
596 ),
597 (
598 "subtracting duration",
599 "timestamp('2023-05-29T00:00:00Z') - duration('24h') == timestamp('2023-05-28T00:00:00Z')",
600 ),
601 (
602 "subtracting date",
603 "timestamp('2023-05-29T00:00:00Z') - timestamp('2023-05-28T00:00:00Z') == duration('24h')",
604 ),
605 (
606 "adding duration",
607 "timestamp('2023-05-28T00:00:00Z') + duration('24h') == timestamp('2023-05-29T00:00:00Z')",
608 ),
609 (
610 "timestamp string",
611 "timestamp('2023-05-28T00:00:00Z').string() == '2023-05-28T00:00:00+00:00'",
612 ),
613 (
614 "timestamp getFullYear",
615 "timestamp('2023-05-28T00:00:00Z').getFullYear() == 2023",
616 ),
617 (
618 "timestamp getMonth",
619 "timestamp('2023-05-28T00:00:00Z').getMonth() == 4",
620 ),
621 (
622 "timestamp getDayOfMonth",
623 "timestamp('2023-05-28T00:00:00Z').getDayOfMonth() == 27",
624 ),
625 (
626 "timestamp getDayOfYear",
627 "timestamp('2023-05-28T00:00:00Z').getDayOfYear() == 147",
628 ),
629 (
630 "timestamp getDate",
631 "timestamp('2023-05-28T00:00:00Z').getDate() == 28",
632 ),
633 (
634 "timestamp getDayOfWeek",
635 "timestamp('2023-05-28T00:00:00Z').getDayOfWeek() == 0",
636 ),
637 (
638 "timestamp getHours",
639 "timestamp('2023-05-28T02:00:00Z').getHours() == 2",
640 ),
641 (
642 "timestamp getMinutes",
643 " timestamp('2023-05-28T00:05:00Z').getMinutes() == 5",
644 ),
645 (
646 "timestamp getSeconds",
647 "timestamp('2023-05-28T00:00:06Z').getSeconds() == 6",
648 ),
649 (
650 "timestamp getMilliseconds",
651 "timestamp('2023-05-28T00:00:42.123Z').getMilliseconds() == 123",
652 ),
653
654 ]
655 .iter()
656 .for_each(assert_script);
657 }
658
659 #[cfg(feature = "chrono")]
660 #[test]
661 fn test_duration() {
662 [
663 ("duration equal 1", "duration('1s') == duration('1000ms')"),
664 ("duration equal 2", "duration('1m') == duration('60s')"),
665 ("duration equal 3", "duration('1h') == duration('60m')"),
666 ("duration comparison 1", "duration('1m') > duration('1s')"),
667 ("duration comparison 2", "duration('1m') < duration('1h')"),
668 (
669 "duration subtraction",
670 "duration('1h') - duration('1m') == duration('59m')",
671 ),
672 (
673 "duration addition",
674 "duration('1h') + duration('1m') == duration('1h1m')",
675 ),
676 ]
677 .iter()
678 .for_each(assert_script);
679 }
680
681 #[cfg(feature = "chrono")]
682 #[test]
683 fn test_timestamp_variable() {
684 let mut context = Context::default();
685 let ts: chrono::DateTime<chrono::FixedOffset> =
686 chrono::DateTime::parse_from_rfc3339("2023-05-29T00:00:00Z").unwrap();
687 context
688 .add_variable("ts", crate::Value::Timestamp(ts))
689 .unwrap();
690
691 let program = crate::Program::compile("ts == timestamp('2023-05-29T00:00:00Z')").unwrap();
692 let result = program.execute(&context).unwrap();
693 assert_eq!(result, true.into());
694 }
695
696 #[cfg(feature = "chrono")]
697 #[test]
698 fn test_chrono_string() {
699 [
700 ("duration", "duration('1h30m').string() == '1h30m0s'"),
701 (
702 "timestamp",
703 "timestamp('2023-05-29T00:00:00Z').string() == '2023-05-29T00:00:00+00:00'",
704 ),
705 ]
706 .iter()
707 .for_each(assert_script);
708 }
709
710 #[test]
711 fn test_contains() {
712 let tests = vec![
713 ("list", "[1, 2, 3].contains(3) == true"),
714 ("map", "{1: true, 2: true, 3: true}.contains(3) == true"),
715 ("string", "'foobar'.contains('bar') == true"),
716 ("bytes", "b'foobar'.contains(b'o') == true"),
717 ];
718
719 for (name, script) in tests {
720 assert_eq!(test_script(script, None), Ok(true.into()), "{name}");
721 }
722 }
723
724 #[cfg(feature = "regex")]
725 #[test]
726 fn test_matches() {
727 let tests = vec![
728 ("string", "'foobar'.matches('^[a-zA-Z]*$') == true"),
729 (
730 "map",
731 "{'1': 'abc', '2': 'def', '3': 'ghi'}.all(key, key.matches('^[a-zA-Z]*$')) == false",
732 ),
733 ];
734
735 for (name, script) in tests {
736 assert_eq!(
737 test_script(script, None),
738 Ok(true.into()),
739 ".matches failed for '{name}'"
740 );
741 }
742 }
743
744 #[cfg(feature = "regex")]
745 #[test]
746 fn test_matches_err() {
747 assert_eq!(
748 test_script(
749 "'foobar'.matches('(foo') == true", None),
750 Err(
751 crate::ExecutionError::FunctionError {
752 function: "matches".to_string(),
753 message: "'(foo' not a valid regex:\nregex parse error:\n (foo\n ^\nerror: unclosed group".to_string()
754 }
755 )
756 );
757 }
758
759 #[test]
760 fn test_string() {
761 [
762 ("string", "'foo'.string() == 'foo'"),
763 ("int", "10.string() == '10'"),
764 ("float", "10.5.string() == '10.5'"),
765 ("bytes", "b'foo'.string() == 'foo'"),
766 ]
767 .iter()
768 .for_each(assert_script);
769 }
770
771 #[test]
772 fn test_bytes() {
773 [
774 ("string", "bytes('abc') == b'abc'"),
775 ("bytes", "bytes('abc') == b'\\x61b\\x63'"),
776 ]
777 .iter()
778 .for_each(assert_script);
779 }
780
781 #[test]
782 fn test_double() {
783 [
784 ("string", "'10'.double() == 10.0"),
785 ("int", "10.double() == 10.0"),
786 ("double", "10.0.double() == 10.0"),
787 ]
788 .iter()
789 .for_each(assert_script);
790 }
791
792 #[test]
793 fn test_uint() {
794 [
795 ("string", "'10'.uint() == 10.uint()"),
796 ("double", "10.5.uint() == 10.uint()"),
797 ]
798 .iter()
799 .for_each(assert_script);
800 }
801
802 #[test]
803 fn test_int() {
804 [
805 ("string", "'10'.int() == 10"),
806 ("int", "10.int() == 10"),
807 ("uint", "10.uint().int() == 10"),
808 ("double", "10.5.int() == 10"),
809 ]
810 .iter()
811 .for_each(assert_script);
812 }
813}