cel_interpreter/
ser.rs

1// The serde_json crate implements a Serializer for its own Value enum, that is
2// almost exactly the same to our Value enum, so this is more or less copied
3// from [serde_json](https://github.com/serde-rs/json/blob/master/src/value/ser.rs),
4// also mentioned in the [serde documentation](https://serde.rs/).
5
6use crate::{objects::Key, Value};
7use serde::{
8    ser::{self, Impossible, SerializeStruct},
9    Serialize,
10};
11use std::{collections::HashMap, fmt::Display, iter::FromIterator, sync::Arc};
12use thiserror::Error;
13
14#[cfg(feature = "chrono")]
15use chrono::FixedOffset;
16
17pub struct Serializer;
18pub struct KeySerializer;
19
20/// A wrapper Duration type which allows conversion to [Value::Duration] for
21/// types using automatic conversion with [serde::Serialize].
22///
23/// # Examples
24///
25/// ```
26/// use cel_interpreter::{Context, Duration, Program};
27/// use serde::Serialize;
28///
29/// #[derive(Serialize)]
30/// struct MyStruct {
31///     dur: Duration,
32/// }
33///
34/// let mut context = Context::default();
35///
36/// // MyStruct will be implicitly serialized into the CEL appropriate types
37/// context
38///     .add_variable(
39///         "foo",
40///         MyStruct {
41///             dur: chrono::Duration::hours(2).into(),
42///         },
43///     )
44///     .unwrap();
45///
46/// let program = Program::compile("foo.dur == duration('2h')").unwrap();
47/// let value = program.execute(&context).unwrap();
48/// assert_eq!(value, true.into());
49/// ```
50#[cfg(feature = "chrono")]
51#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
52pub struct Duration(pub chrono::Duration);
53
54#[cfg(feature = "chrono")]
55impl Duration {
56    // Since serde can't natively represent durations, we serialize a special
57    // newtype to indicate we want to rebuild the duration in the result, while
58    // remaining compatible with most other Serializer implementations.
59    const NAME: &str = "$__cel_private_Duration";
60    const STRUCT_NAME: &str = "Duration";
61    const SECS_FIELD: &str = "secs";
62    const NANOS_FIELD: &str = "nanos";
63}
64
65#[cfg(feature = "chrono")]
66impl From<Duration> for chrono::Duration {
67    fn from(value: Duration) -> Self {
68        value.0
69    }
70}
71
72#[cfg(feature = "chrono")]
73impl From<chrono::Duration> for Duration {
74    fn from(value: chrono::Duration) -> Self {
75        Self(value)
76    }
77}
78
79#[cfg(feature = "chrono")]
80impl ser::Serialize for Duration {
81    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
82    where
83        S: ser::Serializer,
84    {
85        // chrono::Duration's Serialize impl isn't stable yet and relies on
86        // private fields, so attempt to mimic serde's default impl for std
87        // Duration.
88        struct DurationProxy(chrono::Duration);
89        impl Serialize for DurationProxy {
90            fn serialize<S: ser::Serializer>(
91                &self,
92                serializer: S,
93            ) -> std::result::Result<S::Ok, S::Error> {
94                let mut s = serializer.serialize_struct(Duration::STRUCT_NAME, 2)?;
95                s.serialize_field(Duration::SECS_FIELD, &self.0.num_seconds())?;
96                s.serialize_field(Duration::NANOS_FIELD, &self.0.subsec_nanos())?;
97                s.end()
98            }
99        }
100        serializer.serialize_newtype_struct(Self::NAME, &DurationProxy(self.0))
101    }
102}
103
104/// A wrapper Timestamp type which allows conversion to [Value::Timestamp] for
105/// types using automatic conversion with [serde::Serialize].
106///
107/// # Examples
108///
109/// ```
110/// use cel_interpreter::{Context, Timestamp, Program};
111/// use serde::Serialize;
112///
113/// #[derive(Serialize)]
114/// struct MyStruct {
115///     ts: Timestamp,
116/// }
117///
118/// let mut context = Context::default();
119///
120/// // MyStruct will be implicitly serialized into the CEL appropriate types
121/// context
122///     .add_variable(
123///         "foo",
124///         MyStruct {
125///             ts: chrono::DateTime::parse_from_rfc3339("2025-01-01T00:00:00Z")
126///                 .unwrap()
127///                 .into(),
128///         },
129///     )
130///     .unwrap();
131///
132/// let program = Program::compile("foo.ts == timestamp('2025-01-01T00:00:00Z')").unwrap();
133/// let value = program.execute(&context).unwrap();
134/// assert_eq!(value, true.into());
135/// ```
136#[cfg(feature = "chrono")]
137#[derive(Clone, Copy, Default, PartialEq, Eq, PartialOrd, Ord, Debug, Hash)]
138pub struct Timestamp(pub chrono::DateTime<FixedOffset>);
139
140#[cfg(feature = "chrono")]
141impl Timestamp {
142    // Since serde can't natively represent timestamps, we serialize a special
143    // newtype to indicate we want to rebuild the timestamp in the result,
144    // while remaining compatible with most other Serializer implementations.
145    const NAME: &str = "$__cel_private_Timestamp";
146}
147
148#[cfg(feature = "chrono")]
149impl From<Timestamp> for chrono::DateTime<FixedOffset> {
150    fn from(value: Timestamp) -> Self {
151        value.0
152    }
153}
154
155#[cfg(feature = "chrono")]
156impl From<chrono::DateTime<FixedOffset>> for Timestamp {
157    fn from(value: chrono::DateTime<FixedOffset>) -> Self {
158        Self(value)
159    }
160}
161
162#[cfg(feature = "chrono")]
163impl ser::Serialize for Timestamp {
164    fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
165    where
166        S: ser::Serializer,
167    {
168        serializer.serialize_newtype_struct(Self::NAME, &self.0)
169    }
170}
171
172#[derive(Error, Debug, PartialEq, Clone)]
173pub enum SerializationError {
174    InvalidKey(String),
175    SerdeError(String),
176}
177
178impl ser::Error for SerializationError {
179    fn custom<T>(msg: T) -> Self
180    where
181        T: std::fmt::Display,
182    {
183        SerializationError::SerdeError(msg.to_string())
184    }
185}
186
187impl Display for SerializationError {
188    fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
189        match self {
190            SerializationError::SerdeError(msg) => formatter.write_str(msg),
191            SerializationError::InvalidKey(msg) => formatter.write_str(msg),
192        }
193    }
194}
195
196pub type Result<T> = std::result::Result<T, SerializationError>;
197
198pub fn to_value<T>(value: T) -> Result<Value>
199where
200    T: Serialize,
201{
202    value.serialize(Serializer)
203}
204
205impl ser::Serializer for Serializer {
206    type Ok = Value;
207    type Error = SerializationError;
208
209    type SerializeSeq = SerializeVec;
210    type SerializeTuple = SerializeVec;
211    type SerializeTupleStruct = SerializeVec;
212    type SerializeTupleVariant = SerializeTupleVariant;
213    type SerializeMap = SerializeMap;
214    type SerializeStruct = SerializeMap;
215    type SerializeStructVariant = SerializeStructVariant;
216
217    fn serialize_bool(self, v: bool) -> Result<Value> {
218        Ok(Value::Bool(v))
219    }
220
221    fn serialize_i8(self, v: i8) -> Result<Value> {
222        self.serialize_i64(i64::from(v))
223    }
224
225    fn serialize_i16(self, v: i16) -> Result<Value> {
226        self.serialize_i64(i64::from(v))
227    }
228
229    fn serialize_i32(self, v: i32) -> Result<Value> {
230        self.serialize_i64(i64::from(v))
231    }
232
233    fn serialize_i64(self, v: i64) -> Result<Value> {
234        Ok(Value::Int(v))
235    }
236
237    fn serialize_u8(self, v: u8) -> Result<Value> {
238        self.serialize_u64(u64::from(v))
239    }
240
241    fn serialize_u16(self, v: u16) -> Result<Value> {
242        self.serialize_u64(u64::from(v))
243    }
244
245    fn serialize_u32(self, v: u32) -> Result<Value> {
246        self.serialize_u64(u64::from(v))
247    }
248
249    fn serialize_u64(self, v: u64) -> Result<Value> {
250        Ok(Value::UInt(v))
251    }
252
253    fn serialize_f32(self, v: f32) -> Result<Value> {
254        self.serialize_f64(f64::from(v))
255    }
256
257    fn serialize_f64(self, v: f64) -> Result<Value> {
258        Ok(Value::Float(v))
259    }
260
261    fn serialize_char(self, v: char) -> Result<Value> {
262        self.serialize_str(&v.to_string())
263    }
264
265    fn serialize_str(self, v: &str) -> Result<Value> {
266        Ok(Value::String(Arc::new(v.to_string())))
267    }
268
269    fn serialize_bytes(self, v: &[u8]) -> Result<Value> {
270        Ok(Value::Bytes(Arc::new(v.to_vec())))
271    }
272
273    fn serialize_none(self) -> Result<Value> {
274        self.serialize_unit()
275    }
276
277    fn serialize_some<T>(self, value: &T) -> Result<Value>
278    where
279        T: ?Sized + Serialize,
280    {
281        value.serialize(self)
282    }
283
284    fn serialize_unit(self) -> Result<Value> {
285        Ok(Value::Null)
286    }
287
288    fn serialize_unit_struct(self, _name: &'static str) -> Result<Value> {
289        self.serialize_unit()
290    }
291
292    fn serialize_unit_variant(
293        self,
294        _name: &'static str,
295        _variant_index: u32,
296        variant: &'static str,
297    ) -> Result<Value> {
298        self.serialize_str(variant)
299    }
300
301    fn serialize_newtype_struct<T>(self, name: &'static str, value: &T) -> Result<Value>
302    where
303        T: ?Sized + Serialize,
304    {
305        match name {
306            #[cfg(feature = "chrono")]
307            Duration::NAME => value.serialize(TimeSerializer::Duration),
308            #[cfg(feature = "chrono")]
309            Timestamp::NAME => value.serialize(TimeSerializer::Timestamp),
310            _ => value.serialize(self),
311        }
312    }
313
314    fn serialize_newtype_variant<T>(
315        self,
316        _name: &'static str,
317        _variant_index: u32,
318        variant: &'static str,
319        value: &T,
320    ) -> Result<Value>
321    where
322        T: ?Sized + Serialize,
323    {
324        Ok(HashMap::from_iter([(variant.to_string(), value.serialize(Serializer)?)]).into())
325    }
326
327    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
328        Ok(SerializeVec {
329            vec: Vec::with_capacity(_len.unwrap_or(0)),
330        })
331    }
332
333    fn serialize_tuple(self, len: usize) -> Result<Self::SerializeTuple> {
334        self.serialize_seq(Some(len))
335    }
336
337    fn serialize_tuple_struct(
338        self,
339        _name: &'static str,
340        len: usize,
341    ) -> Result<Self::SerializeTupleStruct> {
342        self.serialize_seq(Some(len))
343    }
344
345    fn serialize_tuple_variant(
346        self,
347        _name: &'static str,
348        _variant_index: u32,
349        variant: &'static str,
350        _len: usize,
351    ) -> Result<Self::SerializeTupleVariant> {
352        Ok(SerializeTupleVariant {
353            name: String::from(variant),
354            vec: Vec::with_capacity(_len),
355        })
356    }
357
358    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
359        Ok(SerializeMap {
360            map: HashMap::new(),
361            next_key: None,
362        })
363    }
364
365    fn serialize_struct(self, _name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
366        self.serialize_map(Some(len))
367    }
368
369    fn serialize_struct_variant(
370        self,
371        _name: &'static str,
372        _variant_index: u32,
373        variant: &'static str,
374        _len: usize,
375    ) -> Result<Self::SerializeStructVariant> {
376        Ok(SerializeStructVariant {
377            name: String::from(variant),
378            map: HashMap::new(),
379        })
380    }
381}
382
383pub struct SerializeVec {
384    vec: Vec<Value>,
385}
386
387pub struct SerializeTupleVariant {
388    name: String,
389    vec: Vec<Value>,
390}
391
392pub struct SerializeMap {
393    map: HashMap<Key, Value>,
394    next_key: Option<Key>,
395}
396
397pub struct SerializeStructVariant {
398    name: String,
399    map: HashMap<Key, Value>,
400}
401
402#[cfg(feature = "chrono")]
403#[derive(Debug, Default)]
404struct SerializeTimestamp {
405    secs: i64,
406    nanos: i32,
407}
408
409impl ser::SerializeSeq for SerializeVec {
410    type Ok = Value;
411    type Error = SerializationError;
412
413    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
414    where
415        T: ?Sized + Serialize,
416    {
417        self.vec.push(to_value(value)?);
418        Ok(())
419    }
420
421    fn end(self) -> Result<Value> {
422        Ok(Value::List(Arc::new(self.vec)))
423    }
424}
425
426impl ser::SerializeTuple for SerializeVec {
427    type Ok = Value;
428    type Error = SerializationError;
429
430    fn serialize_element<T>(&mut self, value: &T) -> Result<()>
431    where
432        T: ?Sized + Serialize,
433    {
434        serde::ser::SerializeSeq::serialize_element(self, value)
435    }
436
437    fn end(self) -> Result<Value> {
438        serde::ser::SerializeSeq::end(self)
439    }
440}
441
442impl ser::SerializeTupleStruct for SerializeVec {
443    type Ok = Value;
444    type Error = SerializationError;
445
446    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
447    where
448        T: ?Sized + Serialize,
449    {
450        serde::ser::SerializeSeq::serialize_element(self, value)
451    }
452
453    fn end(self) -> Result<Value> {
454        serde::ser::SerializeSeq::end(self)
455    }
456}
457
458impl ser::SerializeTupleVariant for SerializeTupleVariant {
459    type Ok = Value;
460    type Error = SerializationError;
461
462    fn serialize_field<T>(&mut self, value: &T) -> Result<()>
463    where
464        T: ?Sized + Serialize,
465    {
466        self.vec.push(to_value(value)?);
467        Ok(())
468    }
469
470    fn end(self) -> Result<Value> {
471        let map = HashMap::from_iter([(self.name, Arc::new(self.vec))]);
472        Ok(map.into())
473    }
474}
475
476impl ser::SerializeMap for SerializeMap {
477    type Ok = Value;
478    type Error = SerializationError;
479
480    fn serialize_key<T>(&mut self, key: &T) -> Result<()>
481    where
482        T: ?Sized + Serialize,
483    {
484        self.next_key = Some(key.serialize(KeySerializer)?);
485        Ok(())
486    }
487
488    fn serialize_value<T>(&mut self, value: &T) -> Result<()>
489    where
490        T: ?Sized + Serialize,
491    {
492        self.map.insert(
493            self.next_key.clone().ok_or_else(|| {
494                SerializationError::InvalidKey(
495                    "serialize_value called before serialize_key".to_string(),
496                )
497            })?,
498            value.serialize(Serializer)?,
499        );
500        Ok(())
501    }
502
503    fn end(self) -> Result<Value> {
504        Ok(self.map.into())
505    }
506}
507
508impl ser::SerializeStruct for SerializeMap {
509    type Ok = Value;
510    type Error = SerializationError;
511
512    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
513    where
514        T: ?Sized + Serialize,
515    {
516        serde::ser::SerializeMap::serialize_entry(self, key, value)
517    }
518
519    fn end(self) -> Result<Value> {
520        serde::ser::SerializeMap::end(self)
521    }
522}
523
524impl ser::SerializeStructVariant for SerializeStructVariant {
525    type Ok = Value;
526    type Error = SerializationError;
527
528    fn serialize_field<T>(&mut self, key: &'static str, value: &T) -> Result<()>
529    where
530        T: ?Sized + Serialize,
531    {
532        self.map
533            .insert(key.serialize(KeySerializer)?, to_value(value)?);
534        Ok(())
535    }
536
537    fn end(self) -> Result<Value> {
538        let map: HashMap<String, Value> = HashMap::from_iter([(self.name, self.map.into())]);
539        Ok(map.into())
540    }
541}
542
543#[cfg(feature = "chrono")]
544impl ser::SerializeStruct for SerializeTimestamp {
545    type Ok = Value;
546    type Error = SerializationError;
547    fn serialize_field<T>(
548        &mut self,
549        key: &'static str,
550        value: &T,
551    ) -> std::result::Result<(), Self::Error>
552    where
553        T: ?Sized + Serialize,
554    {
555        match key {
556            Duration::SECS_FIELD => {
557                let Value::Int(val) = value.serialize(Serializer)? else {
558                    return Err(SerializationError::SerdeError(
559                        "invalid type of value in timestamp struct".to_owned(),
560                    ));
561                };
562                self.secs = val;
563                Ok(())
564            }
565            Duration::NANOS_FIELD => {
566                let Value::Int(val) = value.serialize(Serializer)? else {
567                    return Err(SerializationError::SerdeError(
568                        "invalid type of value in timestamp struct".to_owned(),
569                    ));
570                };
571                self.nanos = val.try_into().map_err(|_| {
572                    SerializationError::SerdeError(
573                        "timestamp struct nanos field is invalid".to_owned(),
574                    )
575                })?;
576                Ok(())
577            }
578            _ => Err(SerializationError::SerdeError(
579                "invalid field in duration struct".to_owned(),
580            )),
581        }
582    }
583
584    fn end(self) -> std::result::Result<Self::Ok, Self::Error> {
585        Ok(chrono::Duration::seconds(self.secs)
586            .checked_add(&chrono::Duration::nanoseconds(self.nanos.into()))
587            .unwrap()
588            .into())
589    }
590}
591
592impl ser::Serializer for KeySerializer {
593    type Ok = Key;
594    type Error = SerializationError;
595
596    type SerializeSeq = Impossible<Key, SerializationError>;
597    type SerializeTuple = Impossible<Key, SerializationError>;
598    type SerializeTupleStruct = Impossible<Key, SerializationError>;
599    type SerializeTupleVariant = Impossible<Key, SerializationError>;
600    type SerializeMap = Impossible<Key, SerializationError>;
601    type SerializeStruct = Impossible<Key, SerializationError>;
602    type SerializeStructVariant = Impossible<Key, SerializationError>;
603
604    fn serialize_bool(self, v: bool) -> Result<Key> {
605        Ok(Key::Bool(v))
606    }
607
608    fn serialize_i8(self, v: i8) -> Result<Key> {
609        self.serialize_i64(i64::from(v))
610    }
611
612    fn serialize_i16(self, v: i16) -> Result<Key> {
613        self.serialize_i64(i64::from(v))
614    }
615
616    fn serialize_i32(self, v: i32) -> Result<Key> {
617        self.serialize_i64(i64::from(v))
618    }
619
620    fn serialize_i64(self, v: i64) -> Result<Key> {
621        Ok(Key::Int(v))
622    }
623
624    fn serialize_u8(self, v: u8) -> Result<Key> {
625        self.serialize_u64(u64::from(v))
626    }
627
628    fn serialize_u16(self, v: u16) -> Result<Key> {
629        self.serialize_u64(u64::from(v))
630    }
631
632    fn serialize_u32(self, v: u32) -> Result<Key> {
633        self.serialize_u64(u64::from(v))
634    }
635
636    fn serialize_u64(self, v: u64) -> Result<Key> {
637        Ok(Key::Uint(v))
638    }
639
640    fn serialize_f32(self, _v: f32) -> Result<Key> {
641        Err(SerializationError::InvalidKey(
642            "Float is not supported".to_string(),
643        ))
644    }
645
646    fn serialize_f64(self, _v: f64) -> Result<Key> {
647        Err(SerializationError::InvalidKey(
648            "Float is not supported".to_string(),
649        ))
650    }
651
652    fn serialize_char(self, v: char) -> Result<Key> {
653        self.serialize_str(&v.to_string())
654    }
655
656    fn serialize_str(self, v: &str) -> Result<Key> {
657        Ok(Key::String(Arc::new(v.to_string())))
658    }
659
660    fn serialize_bytes(self, _v: &[u8]) -> Result<Key> {
661        Err(SerializationError::InvalidKey(
662            "Bytes are not supported".to_string(),
663        ))
664    }
665
666    fn serialize_none(self) -> Result<Key> {
667        Err(SerializationError::InvalidKey(
668            "None is not supported".to_string(),
669        ))
670    }
671
672    fn serialize_some<T>(self, value: &T) -> Result<Key>
673    where
674        T: ?Sized + Serialize,
675    {
676        value.serialize(self)
677    }
678
679    fn serialize_unit(self) -> Result<Key> {
680        Err(SerializationError::InvalidKey(
681            "Null is not supported".to_string(),
682        ))
683    }
684
685    fn serialize_unit_struct(self, _name: &'static str) -> Result<Key> {
686        Err(SerializationError::InvalidKey(
687            "Empty unit structs are not supported".to_string(),
688        ))
689    }
690
691    fn serialize_unit_variant(
692        self,
693        _name: &'static str,
694        _variant_index: u32,
695        variant: &'static str,
696    ) -> Result<Key> {
697        Ok(Key::String(Arc::new(variant.to_string())))
698    }
699
700    fn serialize_newtype_struct<T>(self, _name: &'static str, value: &T) -> Result<Key>
701    where
702        T: ?Sized + Serialize,
703    {
704        value.serialize(KeySerializer)
705    }
706
707    fn serialize_newtype_variant<T>(
708        self,
709        _name: &'static str,
710        _variant_index: u32,
711        _variant: &'static str,
712        _value: &T,
713    ) -> Result<Key>
714    where
715        T: ?Sized + Serialize,
716    {
717        Err(SerializationError::InvalidKey(
718            "Newtype variant is not supported".to_string(),
719        ))
720    }
721
722    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
723        Err(SerializationError::InvalidKey(
724            "Sequences are not supported".to_string(),
725        ))
726    }
727
728    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
729        Err(SerializationError::InvalidKey(
730            "Tuples are not supported".to_string(),
731        ))
732    }
733
734    fn serialize_tuple_struct(
735        self,
736        _name: &'static str,
737        _len: usize,
738    ) -> Result<Self::SerializeTupleStruct> {
739        Err(SerializationError::InvalidKey(
740            "Structs are not supported".to_string(),
741        ))
742    }
743
744    fn serialize_tuple_variant(
745        self,
746        _name: &'static str,
747        _variant_index: u32,
748        _variant: &'static str,
749        _len: usize,
750    ) -> Result<Self::SerializeTupleVariant> {
751        Err(SerializationError::InvalidKey(
752            "Tuple variants are not supported".to_string(),
753        ))
754    }
755
756    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
757        Err(SerializationError::InvalidKey(
758            "Map variants are not supported".to_string(),
759        ))
760    }
761
762    fn serialize_struct(self, _name: &'static str, _len: usize) -> Result<Self::SerializeStruct> {
763        Err(SerializationError::InvalidKey(
764            "Structs are not supported".to_string(),
765        ))
766    }
767
768    fn serialize_struct_variant(
769        self,
770        _name: &'static str,
771        _variant_index: u32,
772        _variant: &'static str,
773        _len: usize,
774    ) -> Result<Self::SerializeStructVariant> {
775        Err(SerializationError::InvalidKey(
776            "Struct variants are not supported".to_string(),
777        ))
778    }
779}
780
781#[cfg(feature = "chrono")]
782#[derive(Debug)]
783enum TimeSerializer {
784    Duration,
785    Timestamp,
786}
787
788#[cfg(feature = "chrono")]
789impl ser::Serializer for TimeSerializer {
790    type Ok = Value;
791    type Error = SerializationError;
792
793    type SerializeStruct = SerializeTimestamp;
794
795    // Should never be used, so just reuse existing.
796    type SerializeSeq = SerializeVec;
797    type SerializeTuple = SerializeVec;
798    type SerializeTupleStruct = SerializeVec;
799    type SerializeTupleVariant = SerializeTupleVariant;
800    type SerializeMap = SerializeMap;
801    type SerializeStructVariant = SerializeStructVariant;
802
803    fn serialize_struct(self, name: &'static str, len: usize) -> Result<Self::SerializeStruct> {
804        if !matches!(self, Self::Duration { .. }) || name != Duration::STRUCT_NAME {
805            return Err(SerializationError::SerdeError(
806                "expected Duration struct with Duration marker newtype struct".to_owned(),
807            ));
808        }
809        if len != 2 {
810            return Err(SerializationError::SerdeError(
811                "expected Duration struct to have 2 fields".to_owned(),
812            ));
813        }
814        Ok(SerializeTimestamp::default())
815    }
816
817    fn serialize_str(self, v: &str) -> Result<Value> {
818        if !matches!(self, Self::Timestamp) {
819            return Err(SerializationError::SerdeError(
820                "expected Timestamp string with Timestamp marker newtype struct".to_owned(),
821            ));
822        }
823        Ok(v.parse::<chrono::DateTime<FixedOffset>>()
824            .map_err(|e| SerializationError::SerdeError(e.to_string()))?
825            .into())
826    }
827
828    fn serialize_bool(self, _v: bool) -> Result<Value> {
829        unreachable!()
830    }
831
832    fn serialize_i8(self, _v: i8) -> Result<Value> {
833        unreachable!()
834    }
835
836    fn serialize_i16(self, _v: i16) -> Result<Value> {
837        unreachable!()
838    }
839
840    fn serialize_i32(self, _v: i32) -> Result<Value> {
841        unreachable!()
842    }
843
844    fn serialize_i64(self, _v: i64) -> Result<Value> {
845        unreachable!()
846    }
847
848    fn serialize_u8(self, _v: u8) -> Result<Value> {
849        unreachable!()
850    }
851
852    fn serialize_u16(self, _v: u16) -> Result<Value> {
853        unreachable!()
854    }
855
856    fn serialize_u32(self, _v: u32) -> Result<Value> {
857        unreachable!()
858    }
859
860    fn serialize_u64(self, _v: u64) -> Result<Value> {
861        unreachable!()
862    }
863
864    fn serialize_f32(self, _v: f32) -> Result<Value> {
865        unreachable!()
866    }
867
868    fn serialize_f64(self, _v: f64) -> Result<Value> {
869        unreachable!()
870    }
871
872    fn serialize_char(self, _v: char) -> Result<Value> {
873        unreachable!()
874    }
875
876    fn serialize_bytes(self, _v: &[u8]) -> Result<Value> {
877        unreachable!()
878    }
879
880    fn serialize_none(self) -> Result<Value> {
881        unreachable!()
882    }
883
884    fn serialize_some<T>(self, _value: &T) -> Result<Value>
885    where
886        T: ?Sized + Serialize,
887    {
888        unreachable!()
889    }
890
891    fn serialize_unit(self) -> Result<Value> {
892        unreachable!()
893    }
894
895    fn serialize_unit_struct(self, _name: &'static str) -> Result<Value> {
896        unreachable!()
897    }
898
899    fn serialize_unit_variant(
900        self,
901        _name: &'static str,
902        _variant_index: u32,
903        _variant: &'static str,
904    ) -> Result<Value> {
905        unreachable!()
906    }
907
908    fn serialize_newtype_struct<T>(self, _name: &'static str, _value: &T) -> Result<Value>
909    where
910        T: ?Sized + Serialize,
911    {
912        unreachable!()
913    }
914
915    fn serialize_newtype_variant<T>(
916        self,
917        _name: &'static str,
918        _variant_index: u32,
919        _variant: &'static str,
920        _value: &T,
921    ) -> Result<Value>
922    where
923        T: ?Sized + Serialize,
924    {
925        unreachable!()
926    }
927
928    fn serialize_seq(self, _len: Option<usize>) -> Result<Self::SerializeSeq> {
929        unreachable!()
930    }
931
932    fn serialize_tuple(self, _len: usize) -> Result<Self::SerializeTuple> {
933        unreachable!()
934    }
935
936    fn serialize_tuple_struct(
937        self,
938        _name: &'static str,
939        _len: usize,
940    ) -> Result<Self::SerializeTupleStruct> {
941        unreachable!()
942    }
943
944    fn serialize_tuple_variant(
945        self,
946        _name: &'static str,
947        _variant_index: u32,
948        _variant: &'static str,
949        _len: usize,
950    ) -> Result<Self::SerializeTupleVariant> {
951        unreachable!()
952    }
953
954    fn serialize_map(self, _len: Option<usize>) -> Result<Self::SerializeMap> {
955        unreachable!()
956    }
957
958    fn serialize_struct_variant(
959        self,
960        _name: &'static str,
961        _variant_index: u32,
962        _variant: &'static str,
963        _len: usize,
964    ) -> Result<Self::SerializeStructVariant> {
965        unreachable!()
966    }
967}
968
969#[cfg(test)]
970mod tests {
971    use crate::{objects::Key, to_value, Value};
972    use crate::{Context, Program};
973    use serde::Serialize;
974    use serde_bytes::Bytes;
975    use std::{collections::HashMap, iter::FromIterator, sync::Arc};
976
977    #[cfg(feature = "chrono")]
978    use super::{Duration, Timestamp};
979
980    macro_rules! primitive_test {
981        ($functionName:ident, $strValue: literal, $value: expr) => {
982            #[test]
983            fn $functionName() {
984                let program = Program::compile($strValue).unwrap();
985                let result = program.execute(&Context::default());
986                assert_eq!(Value::from($value), result.unwrap());
987            }
988        };
989    }
990
991    primitive_test!(test_u64_zero, "0u", 0_u64);
992    primitive_test!(test_i64_zero, "0", 0_i64);
993    primitive_test!(test_f64_zero, "0.0", 0_f64);
994    //primitive_test!(test_f64_zero, "0.", 0_f64); this test fails
995    primitive_test!(test_bool_false, "false", false);
996    primitive_test!(test_bool_true, "true", true);
997    primitive_test!(test_string_empty, "\"\"", "");
998    primitive_test!(test_string_non_empty, "\"test\"", "test");
999    primitive_test!(test_byte_ones, r#"b"\001\001""#, vec!(1_u8, 1_u8));
1000    // primitive_test!(test_triple_double_quoted_string, #"r"""""""#, "");
1001    // primitive_test!(test_triple_single_quoted_string, "r''''''", "");
1002    primitive_test!(test_utf8_character_as_bytes, "b'ΓΏ'", vec!(195_u8, 191_u8));
1003
1004    #[test]
1005    fn test_json_data_conversion() {
1006        #[derive(Serialize)]
1007        struct TestPrimitives {
1008            bool: bool,
1009            u8: u8,
1010            u16: u16,
1011            u32: u32,
1012            u64: u64,
1013            int8: i8,
1014            int16: i16,
1015            int32: i32,
1016            int64: i64,
1017            f32: f32,
1018            f64: f64,
1019            char: char,
1020            string: String,
1021            bytes: &'static Bytes,
1022        }
1023
1024        let test = TestPrimitives {
1025            bool: true,
1026            int8: 8_i8,
1027            int16: 16_i16,
1028            int32: 32_i32,
1029            int64: 64_i64,
1030            u8: 8_u8,
1031            u16: 16_u16,
1032            u32: 32_u32,
1033            u64: 64_u64,
1034            f32: 0.32_f32,
1035            f64: 0.64_f64,
1036            char: 'a',
1037            string: "string".to_string(),
1038            bytes: Bytes::new(&[1_u8, 1_u8, 1_u8, 1_u8]),
1039        };
1040
1041        let serialized = to_value(test).unwrap();
1042        let expected: Value = HashMap::from_iter([
1043            (Key::String(Arc::new("bool".to_string())), Value::Bool(true)),
1044            (Key::String(Arc::new("int8".to_string())), Value::Int(8)),
1045            (Key::String(Arc::new("int16".to_string())), Value::Int(16)),
1046            (Key::String(Arc::new("int32".to_string())), Value::Int(32)),
1047            (Key::String(Arc::new("int64".to_string())), Value::Int(64)),
1048            (Key::String(Arc::new("u8".to_string())), Value::UInt(8)),
1049            (Key::String(Arc::new("u16".to_string())), Value::UInt(16)),
1050            (Key::String(Arc::new("u32".to_string())), Value::UInt(32)),
1051            (Key::String(Arc::new("u64".to_string())), Value::UInt(64)),
1052            (
1053                Key::String(Arc::new("f32".to_string())),
1054                Value::Float(f64::from(0.32_f32)),
1055            ),
1056            (Key::String(Arc::new("f64".to_string())), Value::Float(0.64)),
1057            (
1058                Key::String(Arc::new("char".to_string())),
1059                Value::String(Arc::new("a".to_string())),
1060            ),
1061            (
1062                Key::String(Arc::new("string".to_string())),
1063                Value::String(Arc::new("string".to_string())),
1064            ),
1065            (
1066                Key::String(Arc::new("bytes".to_string())),
1067                Value::Bytes(Arc::new(vec![1, 1, 1, 1])),
1068            ),
1069        ])
1070        .into();
1071
1072        // Test with CEL because iterator is not implemented for Value::Map
1073        let program = Program::compile(
1074            "expected.all(key, (has(serialized[key]) && (serialized[key] == expected[key])))",
1075        )
1076        .unwrap();
1077        let mut context = Context::default();
1078        context.add_variable("expected", expected).unwrap();
1079        context.add_variable("serialized", serialized).unwrap();
1080        let value = program.execute(&context).unwrap();
1081        assert_eq!(value, true.into())
1082    }
1083
1084    #[derive(Serialize)]
1085    enum TestCompoundTypes {
1086        Unit,
1087        Newtype(u32),
1088        Wrapped(Option<u8>),
1089        Tuple(u32, u32),
1090        Struct {
1091            a: i32,
1092            nested: HashMap<bool, HashMap<String, Vec<String>>>,
1093        },
1094        Map(HashMap<String, &'static Bytes>),
1095    }
1096    #[test]
1097    fn test_unit() {
1098        let unit = to_value(TestCompoundTypes::Unit).unwrap();
1099        let expected: Value = "Unit".into();
1100        let program = Program::compile("test == expected").unwrap();
1101        let mut context = Context::default();
1102        context.add_variable("expected", expected).unwrap();
1103        context.add_variable("test", unit).unwrap();
1104        let value = program.execute(&context).unwrap();
1105        assert_eq!(value, true.into())
1106    }
1107    #[test]
1108    fn test_newtype() {
1109        let newtype = to_value(TestCompoundTypes::Newtype(32)).unwrap();
1110        let expected: Value = HashMap::from([("Newtype", Value::UInt(32))]).into();
1111        let program = Program::compile("test == expected").unwrap();
1112        let mut context = Context::default();
1113        context.add_variable("expected", expected).unwrap();
1114        context.add_variable("test", newtype).unwrap();
1115        let value = program.execute(&context).unwrap();
1116        assert_eq!(value, true.into())
1117    }
1118    #[test]
1119    fn test_options() {
1120        // Test Option serialization
1121        let wrapped = to_value(TestCompoundTypes::Wrapped(None)).unwrap();
1122        let expected: Value = HashMap::from([("Wrapped", Value::Null)]).into();
1123        let program = Program::compile("test == expected").unwrap();
1124        let mut context = Context::default();
1125        context.add_variable("expected", expected).unwrap();
1126        context.add_variable("test", wrapped).unwrap();
1127        let value = program.execute(&context).unwrap();
1128        assert_eq!(value, true.into());
1129
1130        let wrapped = to_value(TestCompoundTypes::Wrapped(Some(8))).unwrap();
1131        let expected: Value = HashMap::from([("Wrapped", Value::UInt(8))]).into();
1132        let program = Program::compile("test == expected").unwrap();
1133        let mut context = Context::default();
1134        context.add_variable("expected", expected).unwrap();
1135        context.add_variable("test", wrapped).unwrap();
1136        let value = program.execute(&context).unwrap();
1137        assert_eq!(value, true.into())
1138    }
1139
1140    #[test]
1141    fn test_tuples() {
1142        // Test Tuple serialization
1143        let tuple = to_value(TestCompoundTypes::Tuple(12, 16)).unwrap();
1144        let expected: Value = HashMap::from([(
1145            "Tuple",
1146            Value::List(Arc::new(vec![12_u64.into(), 16_u64.into()])),
1147        )])
1148        .into();
1149        let program = Program::compile("test == expected").unwrap();
1150        let mut context = Context::default();
1151        context.add_variable("expected", expected).unwrap();
1152        context.add_variable("test", tuple).unwrap();
1153        let value = program.execute(&context).unwrap();
1154        assert_eq!(value, true.into())
1155    }
1156
1157    #[test]
1158    fn test_structs() {
1159        // Test Struct serialization
1160        let test_struct = TestCompoundTypes::Struct {
1161            a: 32_i32,
1162            nested: HashMap::from_iter([(
1163                true,
1164                HashMap::from_iter([(
1165                    "Test".to_string(),
1166                    vec!["a".to_string(), "b".to_string(), "c".to_string()],
1167                )]),
1168            )]),
1169        };
1170        let expected: Value = HashMap::<Key, Value>::from([(
1171            "Struct".into(),
1172            HashMap::<Key, Value>::from_iter([
1173                ("a".into(), 32_i32.into()),
1174                (
1175                    "nested".into(),
1176                    HashMap::<Key, Value>::from_iter([(
1177                        true.into(),
1178                        HashMap::<Key, Value>::from_iter([(
1179                            "Test".into(),
1180                            vec!["a".to_string(), "b".to_string(), "c".to_string()].into(),
1181                        )])
1182                        .into(),
1183                    )])
1184                    .into(),
1185                ),
1186            ])
1187            .into(),
1188        )])
1189        .into();
1190        let program = Program::compile("expected.all(key, test[key] == expected[key])").unwrap();
1191        let mut context = Context::default();
1192        context.add_variable("expected", expected).unwrap();
1193        context.add_variable("test", test_struct).unwrap();
1194        let value = program.execute(&context).unwrap();
1195        assert_eq!(value, true.into());
1196    }
1197
1198    #[test]
1199    fn test_maps() {
1200        // Test Map serialization
1201        let map = to_value(TestCompoundTypes::Map(
1202            HashMap::<String, &'static Bytes>::from_iter([(
1203                "Test".to_string(),
1204                Bytes::new(&[0_u8, 0_u8, 0_u8, 0_u8]),
1205            )]),
1206        ))
1207        .unwrap();
1208        let expected: Value = HashMap::from([(
1209            "Map",
1210            HashMap::<Key, Value>::from_iter([(
1211                "Test".into(),
1212                Value::Bytes(Arc::new(vec![0_u8, 0_u8, 0_u8, 0_u8])),
1213            )]),
1214        )])
1215        .into();
1216        assert_eq!(map, expected)
1217    }
1218
1219    #[cfg(feature = "chrono")]
1220    #[derive(Serialize)]
1221    struct TestTimeTypes {
1222        dur: Duration,
1223        ts: Timestamp,
1224    }
1225
1226    #[cfg(feature = "chrono")]
1227    #[test]
1228    fn test_time_types() {
1229        use chrono::FixedOffset;
1230
1231        let tests = to_value([
1232            TestTimeTypes {
1233                dur: chrono::Duration::milliseconds(1527).into(),
1234                ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1235                    .unwrap()
1236                    .into(),
1237            },
1238            // Let's test chrono::Duration's particular handling around math
1239            // and negatives and timestamps from BCE.
1240            TestTimeTypes {
1241                dur: chrono::Duration::milliseconds(-1527).into(),
1242                ts: "-0001-12-01T00:00:00-08:00"
1243                    .parse::<chrono::DateTime<FixedOffset>>()
1244                    .unwrap()
1245                    .into(),
1246            },
1247            TestTimeTypes {
1248                dur: (chrono::Duration::seconds(1) - chrono::Duration::nanoseconds(1000000001))
1249                    .into(),
1250                ts: chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00")
1251                    .unwrap()
1252                    .into(),
1253            },
1254            TestTimeTypes {
1255                dur: (chrono::Duration::seconds(-1) + chrono::Duration::nanoseconds(1000000001))
1256                    .into(),
1257                ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1258                    .unwrap()
1259                    .into(),
1260            },
1261        ])
1262        .unwrap();
1263        let expected: Value = vec![
1264            Value::Map(
1265                HashMap::<_, Value>::from([
1266                    ("dur", chrono::Duration::milliseconds(1527).into()),
1267                    (
1268                        "ts",
1269                        chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1270                            .unwrap()
1271                            .into(),
1272                    ),
1273                ])
1274                .into(),
1275            ),
1276            Value::Map(
1277                HashMap::<_, Value>::from([
1278                    ("dur", chrono::Duration::nanoseconds(-1527000000).into()),
1279                    (
1280                        "ts",
1281                        "-0001-12-01T00:00:00-08:00"
1282                            .parse::<chrono::DateTime<FixedOffset>>()
1283                            .unwrap()
1284                            .into(),
1285                    ),
1286                ])
1287                .into(),
1288            ),
1289            Value::Map(
1290                HashMap::<_, Value>::from([
1291                    ("dur", chrono::Duration::nanoseconds(-1).into()),
1292                    (
1293                        "ts",
1294                        chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00")
1295                            .unwrap()
1296                            .into(),
1297                    ),
1298                ])
1299                .into(),
1300            ),
1301            Value::Map(
1302                HashMap::<_, Value>::from([
1303                    ("dur", chrono::Duration::nanoseconds(1).into()),
1304                    (
1305                        "ts",
1306                        chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1307                            .unwrap()
1308                            .into(),
1309                    ),
1310                ])
1311                .into(),
1312            ),
1313        ]
1314        .into();
1315        assert_eq!(tests, expected);
1316
1317        let program = Program::compile("test == expected").unwrap();
1318        let mut context = Context::default();
1319        context.add_variable("expected", expected).unwrap();
1320        context.add_variable("test", tests).unwrap();
1321        let value = program.execute(&context).unwrap();
1322        assert_eq!(value, true.into());
1323    }
1324
1325    #[cfg(feature = "chrono")]
1326    #[cfg(feature = "json")]
1327    #[test]
1328    fn test_time_json() {
1329        use chrono::FixedOffset;
1330
1331        // Test that Durations and Timestamps serialize correctly with
1332        // serde_json.
1333        let tests = [
1334            TestTimeTypes {
1335                dur: chrono::Duration::milliseconds(1527).into(),
1336                ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1337                    .unwrap()
1338                    .into(),
1339            },
1340            TestTimeTypes {
1341                dur: chrono::Duration::milliseconds(-1527).into(),
1342                ts: "-0001-12-01T00:00:00-08:00"
1343                    .parse::<chrono::DateTime<FixedOffset>>()
1344                    .unwrap()
1345                    .into(),
1346            },
1347            TestTimeTypes {
1348                dur: (chrono::Duration::seconds(1) - chrono::Duration::nanoseconds(1000000001))
1349                    .into(),
1350                ts: chrono::DateTime::parse_from_rfc3339("0001-12-01T00:00:00+08:00")
1351                    .unwrap()
1352                    .into(),
1353            },
1354            TestTimeTypes {
1355                dur: (chrono::Duration::seconds(-1) + chrono::Duration::nanoseconds(1000000001))
1356                    .into(),
1357                ts: chrono::DateTime::parse_from_rfc3339("1996-12-19T16:39:57-08:00")
1358                    .unwrap()
1359                    .into(),
1360            },
1361        ];
1362
1363        let expect = "[\
1364{\"dur\":{\"secs\":1,\"nanos\":527000000},\"ts\":\"1996-12-19T16:39:57-08:00\"},\
1365{\"dur\":{\"secs\":-1,\"nanos\":-527000000},\"ts\":\"-0001-12-01T00:00:00-08:00\"},\
1366{\"dur\":{\"secs\":0,\"nanos\":-1},\"ts\":\"0001-12-01T00:00:00+08:00\"},\
1367{\"dur\":{\"secs\":0,\"nanos\":1},\"ts\":\"1996-12-19T16:39:57-08:00\"}\
1368]";
1369        let actual = serde_json::to_string(&tests).unwrap();
1370        assert_eq!(actual, expect);
1371    }
1372}