rimu_value/serde/
de.rs

1use indexmap::IndexMap;
2use serde::de::{
3    self, Deserialize, DeserializeSeed, Deserializer, EnumAccess, Error as SError, Expected,
4    IntoDeserializer, MapAccess, SeqAccess, Unexpected, VariantAccess, Visitor,
5};
6use serde::forward_to_deserialize_any;
7use std::{fmt, vec};
8
9use crate::{number, SerdeValue, SerdeValueError, SerdeValueList, SerdeValueObject};
10
11impl<'de> Deserialize<'de> for SerdeValue {
12    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
13    where
14        D: Deserializer<'de>,
15    {
16        struct ValueVisitor;
17
18        impl<'de> Visitor<'de> for ValueVisitor {
19            type Value = SerdeValue;
20
21            fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
22                formatter.write_str("any value")
23            }
24
25            fn visit_bool<E>(self, b: bool) -> Result<Self::Value, E>
26            where
27                E: SError,
28            {
29                Ok(Self::Value::Boolean(b))
30            }
31
32            fn visit_i64<E>(self, i: i64) -> Result<Self::Value, E>
33            where
34                E: SError,
35            {
36                Ok(Self::Value::Number(i.into()))
37            }
38
39            fn visit_u64<E>(self, u: u64) -> Result<Self::Value, E>
40            where
41                E: SError,
42            {
43                Ok(Self::Value::Number(u.into()))
44            }
45
46            fn visit_f64<E>(self, f: f64) -> Result<Self::Value, E>
47            where
48                E: SError,
49            {
50                Ok(Self::Value::Number(f.into()))
51            }
52
53            fn visit_str<E>(self, s: &str) -> Result<Self::Value, E>
54            where
55                E: SError,
56            {
57                Ok(Self::Value::String(s.to_owned()))
58            }
59
60            fn visit_string<E>(self, s: String) -> Result<Self::Value, E>
61            where
62                E: SError,
63            {
64                Ok(Self::Value::String(s))
65            }
66
67            fn visit_unit<E>(self) -> Result<Self::Value, E>
68            where
69                E: SError,
70            {
71                Ok(Self::Value::Null)
72            }
73
74            fn visit_none<E>(self) -> Result<Self::Value, E>
75            where
76                E: SError,
77            {
78                Ok(Self::Value::Null)
79            }
80
81            fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
82            where
83                D: Deserializer<'de>,
84            {
85                Deserialize::deserialize(deserializer)
86            }
87
88            fn visit_seq<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
89            where
90                V: SeqAccess<'de>,
91            {
92                let mut vec = Vec::new();
93
94                while let Some(element) = visitor.next_element()? {
95                    vec.push(element);
96                }
97
98                Ok(Self::Value::List(vec))
99            }
100
101            fn visit_map<V>(self, mut visitor: V) -> Result<Self::Value, V::Error>
102            where
103                V: MapAccess<'de>,
104            {
105                let mut values = IndexMap::new();
106
107                while let Some((key, value)) = visitor.next_entry()? {
108                    values.insert(key, value);
109                }
110
111                Ok(Self::Value::Object(values))
112            }
113        }
114
115        deserializer.deserialize_any(ValueVisitor)
116    }
117}
118
119impl SerdeValue {
120    fn deserialize_number<'de, V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
121    where
122        V: Visitor<'de>,
123    {
124        match self {
125            SerdeValue::Number(n) => n.deserialize_any(visitor),
126            _ => Err(self.invalid_type(&visitor)),
127        }
128    }
129}
130
131fn visit_list<'de, V>(list: SerdeValueList, visitor: V) -> Result<V::Value, SerdeValueError>
132where
133    V: Visitor<'de>,
134{
135    let len = list.len();
136    let mut deserializer = SeqDeserializer::new(list);
137    let seq = visitor.visit_seq(&mut deserializer)?;
138    let remaining = deserializer.iter.len();
139    if remaining == 0 {
140        Ok(seq)
141    } else {
142        Err(SerdeValueError::invalid_length(
143            len,
144            &"fewer elements in list",
145        ))
146    }
147}
148
149fn visit_object<'de, V>(object: SerdeValueObject, visitor: V) -> Result<V::Value, SerdeValueError>
150where
151    V: Visitor<'de>,
152{
153    let len = object.len();
154    let mut deserializer = MapDeserializer::new(object);
155    let map = visitor.visit_map(&mut deserializer)?;
156    let remaining = deserializer.iter.len();
157    if remaining == 0 {
158        Ok(map)
159    } else {
160        Err(SerdeValueError::invalid_length(
161            len,
162            &"fewer elements in map",
163        ))
164    }
165}
166
167impl<'de> IntoDeserializer<'de, SerdeValueError> for SerdeValue {
168    type Deserializer = Self;
169
170    fn into_deserializer(self) -> Self::Deserializer {
171        self
172    }
173}
174
175impl<'de> Deserializer<'de> for SerdeValue {
176    type Error = SerdeValueError;
177
178    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
179    where
180        V: Visitor<'de>,
181    {
182        match self {
183            SerdeValue::Null => visitor.visit_unit(),
184            SerdeValue::Boolean(v) => visitor.visit_bool(v),
185            SerdeValue::Number(n) => n.deserialize_any(visitor),
186            SerdeValue::String(v) => visitor.visit_string(v),
187            SerdeValue::List(v) => visit_list(v, visitor),
188            SerdeValue::Object(v) => visit_object(v, visitor),
189            SerdeValue::Function(_f) => todo!(),
190        }
191    }
192
193    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
194    where
195        V: Visitor<'de>,
196    {
197        match self {
198            SerdeValue::Boolean(v) => visitor.visit_bool(v),
199            _ => Err(self.invalid_type(&visitor)),
200        }
201    }
202
203    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
204    where
205        V: Visitor<'de>,
206    {
207        self.deserialize_number(visitor)
208    }
209
210    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
211    where
212        V: Visitor<'de>,
213    {
214        self.deserialize_number(visitor)
215    }
216
217    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
218    where
219        V: Visitor<'de>,
220    {
221        self.deserialize_number(visitor)
222    }
223
224    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
225    where
226        V: Visitor<'de>,
227    {
228        self.deserialize_number(visitor)
229    }
230
231    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
232    where
233        V: Visitor<'de>,
234    {
235        self.deserialize_number(visitor)
236    }
237
238    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
239    where
240        V: Visitor<'de>,
241    {
242        self.deserialize_number(visitor)
243    }
244
245    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
246    where
247        V: Visitor<'de>,
248    {
249        self.deserialize_number(visitor)
250    }
251
252    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
253    where
254        V: Visitor<'de>,
255    {
256        self.deserialize_number(visitor)
257    }
258
259    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
260    where
261        V: Visitor<'de>,
262    {
263        self.deserialize_number(visitor)
264    }
265
266    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
267    where
268        V: Visitor<'de>,
269    {
270        self.deserialize_number(visitor)
271    }
272
273    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
274    where
275        V: Visitor<'de>,
276    {
277        self.deserialize_number(visitor)
278    }
279
280    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
281    where
282        V: Visitor<'de>,
283    {
284        self.deserialize_number(visitor)
285    }
286
287    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
288    where
289        V: Visitor<'de>,
290    {
291        self.deserialize_string(visitor)
292    }
293
294    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
295    where
296        V: Visitor<'de>,
297    {
298        self.deserialize_string(visitor)
299    }
300
301    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
302    where
303        V: Visitor<'de>,
304    {
305        match self {
306            SerdeValue::String(v) => visitor.visit_string(v),
307            _ => Err(self.invalid_type(&visitor)),
308        }
309    }
310
311    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
312    where
313        V: Visitor<'de>,
314    {
315        self.deserialize_byte_buf(visitor)
316    }
317
318    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
319    where
320        V: Visitor<'de>,
321    {
322        match self {
323            SerdeValue::String(v) => visitor.visit_string(v),
324            SerdeValue::List(v) => visit_list(v, visitor),
325            _ => Err(self.invalid_type(&visitor)),
326        }
327    }
328
329    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
330    where
331        V: Visitor<'de>,
332    {
333        match self {
334            SerdeValue::Null => visitor.visit_none(),
335            _ => visitor.visit_some(self),
336        }
337    }
338
339    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
340    where
341        V: Visitor<'de>,
342    {
343        match self {
344            SerdeValue::Null => visitor.visit_unit(),
345            _ => Err(self.invalid_type(&visitor)),
346        }
347    }
348
349    fn deserialize_unit_struct<V>(
350        self,
351        _name: &'static str,
352        visitor: V,
353    ) -> Result<V::Value, SerdeValueError>
354    where
355        V: Visitor<'de>,
356    {
357        self.deserialize_unit(visitor)
358    }
359
360    fn deserialize_newtype_struct<V>(
361        self,
362        _name: &'static str,
363        visitor: V,
364    ) -> Result<V::Value, SerdeValueError>
365    where
366        V: Visitor<'de>,
367    {
368        visitor.visit_newtype_struct(self)
369    }
370
371    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
372    where
373        V: Visitor<'de>,
374    {
375        match self {
376            SerdeValue::List(v) => visit_list(v, visitor),
377            _ => Err(self.invalid_type(&visitor)),
378        }
379    }
380
381    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, SerdeValueError>
382    where
383        V: Visitor<'de>,
384    {
385        self.deserialize_seq(visitor)
386    }
387
388    fn deserialize_tuple_struct<V>(
389        self,
390        _name: &'static str,
391        _len: usize,
392        visitor: V,
393    ) -> Result<V::Value, SerdeValueError>
394    where
395        V: Visitor<'de>,
396    {
397        self.deserialize_seq(visitor)
398    }
399
400    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
401    where
402        V: Visitor<'de>,
403    {
404        match self {
405            SerdeValue::Object(v) => visit_object(v, visitor),
406            _ => Err(self.invalid_type(&visitor)),
407        }
408    }
409
410    fn deserialize_struct<V>(
411        self,
412        _name: &'static str,
413        _fields: &'static [&'static str],
414        visitor: V,
415    ) -> Result<V::Value, SerdeValueError>
416    where
417        V: Visitor<'de>,
418    {
419        match self {
420            SerdeValue::List(v) => visit_list(v, visitor),
421            SerdeValue::Object(v) => visit_object(v, visitor),
422            _ => Err(self.invalid_type(&visitor)),
423        }
424    }
425
426    fn deserialize_enum<V>(
427        self,
428        _name: &str,
429        _variants: &'static [&'static str],
430        visitor: V,
431    ) -> Result<V::Value, SerdeValueError>
432    where
433        V: Visitor<'de>,
434    {
435        let (variant, value) = match self {
436            SerdeValue::Object(value) => {
437                let mut iter = value.into_iter();
438                let (variant, value) = match iter.next() {
439                    Some(v) => v,
440                    None => {
441                        return Err(SerdeValueError::invalid_value(
442                            Unexpected::Map,
443                            &"map with a single key",
444                        ));
445                    }
446                };
447                // enums are encoded in json as maps with a single key:value pair
448                if iter.next().is_some() {
449                    return Err(SerdeValueError::invalid_value(
450                        Unexpected::Map,
451                        &"map with a single key",
452                    ));
453                }
454                (SerdeValue::String(variant), Some(value))
455            }
456            SerdeValue::String(variant) => (SerdeValue::String(variant), None),
457            other => {
458                return Err(SerdeValueError::invalid_type(
459                    other.unexpected(),
460                    &"string or map",
461                ));
462            }
463        };
464
465        visitor.visit_enum(EnumDeserializer { variant, value })
466    }
467
468    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
469    where
470        V: Visitor<'de>,
471    {
472        self.deserialize_string(visitor)
473    }
474
475    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
476    where
477        V: Visitor<'de>,
478    {
479        drop(self);
480        visitor.visit_unit()
481    }
482}
483
484struct EnumDeserializer {
485    variant: SerdeValue,
486    value: Option<SerdeValue>,
487}
488
489impl<'de> EnumAccess<'de> for EnumDeserializer {
490    type Error = SerdeValueError;
491    type Variant = VariantDeserializer;
492
493    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, VariantDeserializer), SerdeValueError>
494    where
495        V: DeserializeSeed<'de>,
496    {
497        let visitor = VariantDeserializer { value: self.value };
498        seed.deserialize(self.variant).map(|v| (v, visitor))
499    }
500}
501
502struct VariantDeserializer {
503    value: Option<SerdeValue>,
504}
505
506impl<'de> VariantAccess<'de> for VariantDeserializer {
507    type Error = SerdeValueError;
508
509    fn unit_variant(self) -> Result<(), SerdeValueError> {
510        match self.value {
511            Some(value) => Deserialize::deserialize(value),
512            None => Ok(()),
513        }
514    }
515
516    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, SerdeValueError>
517    where
518        T: DeserializeSeed<'de>,
519    {
520        match self.value {
521            Some(value) => seed.deserialize(value),
522            None => Err(SerdeValueError::invalid_type(
523                Unexpected::UnitVariant,
524                &"newtype variant",
525            )),
526        }
527    }
528
529    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, SerdeValueError>
530    where
531        V: Visitor<'de>,
532    {
533        match self.value {
534            Some(SerdeValue::List(v)) => {
535                Deserializer::deserialize_any(SeqDeserializer::new(v), visitor)
536            }
537            Some(other) => Err(SerdeValueError::invalid_type(
538                other.unexpected(),
539                &"tuple variant",
540            )),
541            None => Err(SerdeValueError::invalid_type(
542                Unexpected::UnitVariant,
543                &"tuple variant",
544            )),
545        }
546    }
547
548    fn struct_variant<V>(
549        self,
550        _fields: &'static [&'static str],
551        visitor: V,
552    ) -> Result<V::Value, SerdeValueError>
553    where
554        V: Visitor<'de>,
555    {
556        match self.value {
557            Some(SerdeValue::Object(v)) => {
558                Deserializer::deserialize_any(MapDeserializer::new(v), visitor)
559            }
560            Some(other) => Err(SerdeValueError::invalid_type(
561                other.unexpected(),
562                &"struct variant",
563            )),
564            None => Err(SerdeValueError::invalid_type(
565                Unexpected::UnitVariant,
566                &"struct variant",
567            )),
568        }
569    }
570}
571
572struct SeqDeserializer {
573    iter: vec::IntoIter<SerdeValue>,
574}
575
576impl SeqDeserializer {
577    fn new(vec: Vec<SerdeValue>) -> Self {
578        SeqDeserializer {
579            iter: vec.into_iter(),
580        }
581    }
582}
583
584impl<'de> Deserializer<'de> for SeqDeserializer {
585    type Error = SerdeValueError;
586
587    #[inline]
588    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, SerdeValueError>
589    where
590        V: Visitor<'de>,
591    {
592        let len = self.iter.len();
593        if len == 0 {
594            visitor.visit_unit()
595        } else {
596            let ret = visitor.visit_seq(&mut self)?;
597            let remaining = self.iter.len();
598            if remaining == 0 {
599                Ok(ret)
600            } else {
601                Err(SerdeValueError::invalid_length(
602                    len,
603                    &"fewer elements in list",
604                ))
605            }
606        }
607    }
608
609    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
610    where
611        V: Visitor<'de>,
612    {
613        drop(self);
614        visitor.visit_unit()
615    }
616
617    forward_to_deserialize_any! {
618        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
619        byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
620        map struct enum identifier
621    }
622}
623
624impl<'de> SeqAccess<'de> for SeqDeserializer {
625    type Error = SerdeValueError;
626
627    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, SerdeValueError>
628    where
629        T: DeserializeSeed<'de>,
630    {
631        match self.iter.next() {
632            Some(value) => seed.deserialize(value).map(Some),
633            None => Ok(None),
634        }
635    }
636
637    fn size_hint(&self) -> Option<usize> {
638        match self.iter.size_hint() {
639            (lower, Some(upper)) if lower == upper => Some(upper),
640            _ => None,
641        }
642    }
643}
644
645struct MapDeserializer {
646    iter: <SerdeValueObject as IntoIterator>::IntoIter,
647    value: Option<SerdeValue>,
648}
649
650impl MapDeserializer {
651    fn new(map: SerdeValueObject) -> Self {
652        MapDeserializer {
653            iter: map.into_iter(),
654            value: None,
655        }
656    }
657}
658
659impl<'de> MapAccess<'de> for MapDeserializer {
660    type Error = SerdeValueError;
661
662    fn next_key_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, SerdeValueError>
663    where
664        T: DeserializeSeed<'de>,
665    {
666        match self.iter.next() {
667            Some((key, value)) => {
668                self.value = Some(value);
669                seed.deserialize(SerdeValue::String(key)).map(Some)
670            }
671            None => Ok(None),
672        }
673    }
674
675    fn next_value_seed<T>(&mut self, seed: T) -> Result<T::Value, SerdeValueError>
676    where
677        T: DeserializeSeed<'de>,
678    {
679        match self.value.take() {
680            Some(value) => seed.deserialize(value),
681            None => panic!("visit_value called before visit_key"),
682        }
683    }
684
685    fn size_hint(&self) -> Option<usize> {
686        match self.iter.size_hint() {
687            (lower, Some(upper)) if lower == upper => Some(upper),
688            _ => None,
689        }
690    }
691}
692
693impl<'de> Deserializer<'de> for MapDeserializer {
694    type Error = SerdeValueError;
695
696    #[inline]
697    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
698    where
699        V: Visitor<'de>,
700    {
701        visitor.visit_map(self)
702    }
703
704    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, SerdeValueError>
705    where
706        V: Visitor<'de>,
707    {
708        drop(self);
709        visitor.visit_unit()
710    }
711
712    forward_to_deserialize_any! {
713        bool i8 i16 i32 i64 u8 u16 u32 u64 f32 f64 char str string bytes
714        byte_buf option unit unit_struct newtype_struct seq tuple tuple_struct
715        map struct enum identifier
716    }
717}
718
719impl SerdeValue {
720    #[cold]
721    fn invalid_type<E>(&self, exp: &dyn Expected) -> E
722    where
723        E: de::Error,
724    {
725        de::Error::invalid_type(self.unexpected(), exp)
726    }
727
728    #[cold]
729    fn unexpected(&self) -> Unexpected {
730        match self {
731            SerdeValue::Null => Unexpected::Unit,
732            SerdeValue::Boolean(b) => Unexpected::Bool(*b),
733            SerdeValue::Number(n) => number::unexpected(n),
734            SerdeValue::String(s) => Unexpected::Str(s),
735            SerdeValue::List(_) => Unexpected::Seq,
736            SerdeValue::Object(_) => Unexpected::Map,
737            SerdeValue::Function(_f) => todo!(),
738        }
739    }
740}