bsonrs/serde_impl/
decode.rs

1use std::fmt;
2use std::vec;
3use std::result;
4use std::marker::PhantomData;
5use std::{i32, u32};
6
7use serde::de::{self, Deserialize, Deserializer, Visitor, MapAccess, SeqAccess, VariantAccess,
8                DeserializeSeed, EnumAccess};
9use serde::de::{Error, Expected, Unexpected};
10
11use indexmap::IndexMap;
12
13use crate::value::{Value, Array, UTCDateTime, TimeStamp};
14use crate::doc::{Document, IntoIter};
15use crate::decode::DecodeError;
16use crate::decode::DecodeResult;
17
18impl de::Error for DecodeError {
19    fn custom<T: fmt::Display>(msg: T) -> DecodeError {
20        DecodeError::Unknown(msg.to_string())
21    }
22
23    fn invalid_type(_unexp: Unexpected, exp: &dyn Expected) -> DecodeError {
24        DecodeError::InvalidType(exp.to_string())
25    }
26
27    fn invalid_value(_unexp: Unexpected, exp: &dyn Expected) -> DecodeError {
28        DecodeError::InvalidValue(exp.to_string())
29    }
30
31    fn invalid_length(len: usize, exp: &dyn Expected) -> DecodeError {
32        DecodeError::InvalidLength(len, exp.to_string())
33    }
34
35    fn unknown_variant(variant: &str, _expected: &'static [&'static str]) -> DecodeError {
36        DecodeError::UnknownVariant(variant.to_string())
37    }
38
39    fn unknown_field(field: &str, _expected: &'static [&'static str]) -> DecodeError {
40        DecodeError::UnknownField(field.to_string())
41    }
42
43    fn missing_field(field: &'static str) -> DecodeError {
44        DecodeError::ExpectedField(field)
45    }
46
47    fn duplicate_field(field: &'static str) -> DecodeError {
48        DecodeError::DuplicatedField(field)
49    }
50}
51
52impl<'de> Deserialize<'de> for Document {
53    /// Deserialize this value given this `Deserializer`.
54    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55        where D: Deserializer<'de>
56    {
57        deserializer
58            .deserialize_map(ValueVisitor)
59            .and_then(|bson|
60                if let Value::Document(document) = bson {
61                    Ok(document)
62                } else {
63                    let err = format!("expected document, found extended JSON data type: {}", bson);
64                    Err(de::Error::invalid_type(Unexpected::Map, &&*err))
65            })
66    }
67}
68
69impl<'de> Deserialize<'de> for Value {
70    #[inline]
71    fn deserialize<D>(deserializer: D) -> Result<Value, D::Error>
72        where D: Deserializer<'de>
73    {
74        deserializer.deserialize_any(ValueVisitor)
75    }
76}
77
78pub struct ValueVisitor;
79
80impl<'de> Visitor<'de> for ValueVisitor {
81    type Value = Value;
82
83    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
84        write!(f, "expecting a Value")
85    }
86
87    #[inline]
88    fn visit_bool<E>(self, value: bool) -> Result<Value, E>
89        where E: Error
90    {
91        Ok(Value::Boolean(value))
92    }
93
94    #[inline]
95    fn visit_i8<E>(self, value: i8) -> Result<Value, E>
96        where E: Error
97    {
98        Ok(Value::Int32(i32::from(value)))
99    }
100
101    #[inline]
102    fn visit_u8<E>(self, value: u8) -> Result<Value, E>
103        where E: Error
104    {
105        Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
106    }
107
108    #[inline]
109    fn visit_i16<E>(self, value: i16) -> Result<Value, E>
110        where E: Error
111    {
112        Ok(Value::Int32(i32::from(value)))
113    }
114
115    #[inline]
116    fn visit_u16<E>(self, value: u16) -> Result<Value, E>
117        where E: Error
118    {
119        Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
120    }
121
122    #[inline]
123    fn visit_i32<E>(self, value: i32) -> Result<Value, E>
124        where E: Error
125    {
126        Ok(Value::Int32(value))
127    }
128
129    #[inline]
130    fn visit_u32<E>(self, value: u32) -> Result<Value, E>
131        where E: Error
132    {
133        Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
134    }
135
136    #[inline]
137    fn visit_i64<E>(self, value: i64) -> Result<Value, E>
138        where E: Error
139    {
140        Ok(Value::Int64(value))
141    }
142
143    #[inline]
144    fn visit_u64<E>(self, value: u64) -> Result<Value, E>
145        where E: Error
146    {
147        Err(Error::invalid_type(Unexpected::Unsigned(value), &"a signed integer"))
148    }
149
150    #[inline]
151    fn visit_f64<E>(self, value: f64) -> Result<Value, E> {
152        Ok(Value::Double(value))
153    }
154
155    #[inline]
156    fn visit_str<E>(self, value: &str) -> Result<Value, E>
157        where E: de::Error
158    {
159        self.visit_string(value.to_string())
160    }
161
162    #[inline]
163    fn visit_string<E>(self, value: String) -> Result<Value, E> {
164        Ok(Value::String(value))
165    }
166
167    #[inline]
168    fn visit_none<E>(self) -> Result<Value, E> {
169        Ok(Value::Null)
170    }
171
172    #[inline]
173    fn visit_some<D>(self, deserializer: D) -> Result<Value, D::Error>
174        where D: Deserializer<'de>
175    {
176        deserializer.deserialize_any(self)
177    }
178
179    #[inline]
180    fn visit_unit<E>(self) -> Result<Value, E> {
181        Ok(Value::Null)
182    }
183
184    #[inline]
185    fn visit_seq<V>(self, mut visitor: V) -> Result<Value, V::Error>
186        where V: SeqAccess<'de>
187    {
188        let mut values = Array::new();
189
190        while let Some(elem) = visitor.next_element()? {
191            values.push(elem);
192        }
193
194        Ok(Value::Array(values))
195    }
196
197    #[inline]
198    fn visit_map<V>(self, visitor: V) -> Result<Value, V::Error>
199        where V: MapAccess<'de>
200    {
201        let values = DocumentVisitor::new().visit_map(visitor)?;
202        Ok(Value::from_extended_document(values))
203    }
204}
205
206#[derive(Default)]
207pub struct DocumentVisitor {
208    marker: PhantomData<Document>
209}
210
211impl DocumentVisitor {
212    pub fn new() -> DocumentVisitor {
213        DocumentVisitor { marker: PhantomData }
214    }
215}
216
217impl<'de> Visitor<'de> for DocumentVisitor {
218    type Value = Document;
219
220    fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
221        write!(f, "expecting ordered object")
222    }
223
224    #[inline]
225    fn visit_unit<E>(self) -> result::Result<Document, E>
226        where E: de::Error
227    {
228        Ok(Document::new())
229    }
230
231    #[inline]
232    fn visit_map<V>(self, mut visitor: V) -> result::Result<Document, V::Error>
233        where V: MapAccess<'de>
234    {
235        let mut inner = match visitor.size_hint() {
236            Some(size) => IndexMap::with_capacity(size),
237            None => IndexMap::new(),
238        };
239
240        while let Some((key, value)) = visitor.next_entry()? {
241            inner.insert(key, value);
242        }
243
244        Ok(inner.into())
245    }
246}
247
248/// Serde Decoder
249pub struct Decoder {
250    value: Option<Value>,
251}
252
253impl Decoder {
254    pub fn new(value: Value) -> Decoder {
255        Decoder { value: Some(value) }
256    }
257}
258
259macro_rules! forward_to_deserialize {
260    ($(
261        $name:ident ( $( $arg:ident : $ty:ty ),* );
262    )*) => {
263        $(
264            forward_to_deserialize!{
265                func: $name ( $( $arg: $ty ),* );
266            }
267        )*
268    };
269
270    (func: deserialize_enum ( $( $arg:ident : $ty:ty ),* );) => {
271        fn deserialize_enum<V>(
272            self,
273            $(_: $ty,)*
274            _visitor: V,
275        ) -> ::std::result::Result<V::Value, Self::Error>
276            where V: ::serde::de::Visitor<'de>
277        {
278            Err(::serde::de::Error::custom("unexpected Enum"))
279        }
280    };
281
282    (func: $name:ident ( $( $arg:ident : $ty:ty ),* );) => {
283        #[inline]
284        fn $name<V>(
285            self,
286            $(_: $ty,)*
287            visitor: V,
288        ) -> ::std::result::Result<V::Value, Self::Error>
289            where V: ::serde::de::Visitor<'de>
290        {
291            self.deserialize_any(visitor)
292        }
293    };
294}
295
296impl<'de> Deserializer<'de> for Decoder {
297    type Error = DecodeError;
298
299    #[inline]
300    fn deserialize_any<V>(mut self, visitor: V) -> DecodeResult<V::Value>
301        where V: Visitor<'de>
302    {
303        let value = match self.value.take() {
304            Some(value) => value,
305            None => return Err(DecodeError::EndOfStream),
306        };
307
308        match value {
309            Value::Double(v) => visitor.visit_f64(v),
310            Value::String(v) => visitor.visit_string(v),
311            Value::Array(v) => {
312                let len = v.len();
313                visitor.visit_seq(
314                    SeqDecoder {
315                        iter: v.into_iter(),
316                        len,
317                    }
318                )
319            }
320            Value::Document(v) => {
321                let len = v.len();
322                visitor.visit_map(
323                    MapDecoder {
324                        iter: v.into_iter(),
325                        value: None,
326                        len,
327                    }
328                )
329            }
330            Value::Boolean(v) => visitor.visit_bool(v),
331            Value::Null => visitor.visit_unit(),
332            Value::Int32(v) => visitor.visit_i32(v),
333            Value::Int64(v) => visitor.visit_i64(v),
334            Value::Binary(_, v) => visitor.visit_bytes(&v),
335            _ => {
336                let doc = value.to_extended_document();
337                let len = doc.len();
338                visitor.visit_map(
339                    MapDecoder {
340                        iter: doc.into_iter(),
341                        value: None,
342                        len,
343                    }
344                )
345            }
346        }
347    }
348
349    #[inline]
350    fn deserialize_option<V>(self, visitor: V) -> DecodeResult<V::Value>
351        where V: Visitor<'de>
352    {
353        match self.value {
354            Some(Value::Null) => visitor.visit_none(),
355            Some(_) => visitor.visit_some(self),
356            None => Err(DecodeError::EndOfStream),
357        }
358    }
359
360    #[inline]
361    fn deserialize_enum<V>(
362        mut self,
363        _name: &str,
364        _variants: &'static [&'static str],
365        visitor: V
366    ) -> DecodeResult<V::Value>
367        where V: Visitor<'de>
368    {
369        let value = match self.value.take() {
370            Some(Value::Document(value)) => value,
371            Some(Value::String(variant)) => {
372                return visitor.visit_enum(
373                    EnumDecoder {
374                        val: Value::String(variant),
375                        decoder: VariantDecoder { val: None },
376                    }
377                );
378            }
379            Some(_) => {
380                return Err(DecodeError::InvalidType("expected an enum".to_string()));
381            }
382            None => {
383                return Err(DecodeError::EndOfStream);
384            }
385        };
386
387        let mut iter = value.into_iter();
388
389        let (variant, value) = match iter.next() {
390            Some(v) => v,
391            None => return Err(DecodeError::SyntaxError("expected a variant name".to_string())),
392        };
393
394        // enums are encoded in json as maps with a single key:value pair
395        match iter.next() {
396            Some(_) => {
397                Err(DecodeError::InvalidType("expected a single key:value pair".to_string()))
398            }
399            None => {
400                visitor.visit_enum(
401                    EnumDecoder {
402                        val: Value::String(variant),
403                        decoder: VariantDecoder { val: Some(value) },
404                    }
405                )
406            }
407        }
408    }
409
410    #[inline]
411    fn deserialize_newtype_struct<V>(
412        self,
413        _name: &'static str,
414        visitor: V
415    ) -> DecodeResult<V::Value>
416        where V: Visitor<'de>
417    {
418        visitor.visit_newtype_struct(self)
419    }
420
421    forward_to_deserialize!{
422        deserialize_bool();
423        deserialize_u8();
424        deserialize_u16();
425        deserialize_u32();
426        deserialize_u64();
427        deserialize_i8();
428        deserialize_i16();
429        deserialize_i32();
430        deserialize_i64();
431        deserialize_f32();
432        deserialize_f64();
433        deserialize_char();
434        deserialize_str();
435        deserialize_string();
436        deserialize_unit();
437        deserialize_seq();
438        deserialize_bytes();
439        deserialize_map();
440        deserialize_unit_struct(name: &'static str);
441        deserialize_tuple_struct(name: &'static str, len: usize);
442        deserialize_struct(name: &'static str, fields: &'static [&'static str]);
443        deserialize_tuple(len: usize);
444        deserialize_identifier();
445        deserialize_ignored_any();
446        deserialize_byte_buf();
447    }
448}
449
450struct EnumDecoder {
451    val: Value,
452    decoder: VariantDecoder,
453}
454
455impl<'de> EnumAccess<'de> for EnumDecoder {
456    type Error = DecodeError;
457    type Variant = VariantDecoder;
458    fn variant_seed<V>(self, seed: V) -> DecodeResult<(V::Value, Self::Variant)>
459        where V: DeserializeSeed<'de>
460    {
461        let dec = Decoder::new(self.val);
462        let value = seed.deserialize(dec)?;
463        Ok((value, self.decoder))
464    }
465}
466
467struct VariantDecoder {
468    val: Option<Value>,
469}
470
471impl<'de> VariantAccess<'de> for VariantDecoder {
472    type Error = DecodeError;
473
474    fn unit_variant(mut self) -> DecodeResult<()> {
475        match self.val.take() {
476            None => Ok(()),
477            Some(val) => {
478                Value::deserialize(Decoder::new(val)).map(|_| ())
479            }
480        }
481    }
482
483    fn newtype_variant_seed<T>(mut self, seed: T) -> DecodeResult<T::Value>
484        where T: DeserializeSeed<'de>
485    {
486        let dec = Decoder::new(self.val.take().ok_or(DecodeError::EndOfStream)?);
487        seed.deserialize(dec)
488    }
489
490    fn tuple_variant<V>(mut self, _len: usize, visitor: V) -> DecodeResult<V::Value>
491        where V: Visitor<'de>
492    {
493        if let Value::Array(fields) = self.val.take().ok_or(DecodeError::EndOfStream)? {
494
495            let de = SeqDecoder {
496                len: fields.len(),
497                iter: fields.into_iter(),
498            };
499            de.deserialize_any(visitor)
500        } else {
501            return Err(DecodeError::InvalidType("expected a tuple".to_string()));
502        }
503    }
504
505    fn struct_variant<V>(
506        mut self,
507        _fields: &'static [&'static str],
508        visitor: V
509    ) -> DecodeResult<V::Value>
510        where V: Visitor<'de>
511    {
512        if let Value::Document(fields) = self.val.take().ok_or(DecodeError::EndOfStream)? {
513            let de = MapDecoder {
514                len: fields.len(),
515                iter: fields.into_iter(),
516                value: None,
517            };
518            de.deserialize_any(visitor)
519        } else {
520            return Err(DecodeError::InvalidType("expected a struct".to_string()));
521        }
522    }
523}
524
525struct SeqDecoder {
526    iter: vec::IntoIter<Value>,
527    len: usize,
528}
529
530impl<'de> Deserializer<'de> for SeqDecoder {
531    type Error = DecodeError;
532
533    #[inline]
534    fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
535        where V: Visitor<'de>
536    {
537        if self.len == 0 {
538            visitor.visit_unit()
539        } else {
540            visitor.visit_seq(self)
541        }
542    }
543
544    forward_to_deserialize!{
545        deserialize_bool();
546        deserialize_u8();
547        deserialize_u16();
548        deserialize_u32();
549        deserialize_u64();
550        deserialize_i8();
551        deserialize_i16();
552        deserialize_i32();
553        deserialize_i64();
554        deserialize_f32();
555        deserialize_f64();
556        deserialize_char();
557        deserialize_str();
558        deserialize_string();
559        deserialize_unit();
560        deserialize_option();
561        deserialize_seq();
562        deserialize_bytes();
563        deserialize_map();
564        deserialize_unit_struct(name: &'static str);
565        deserialize_newtype_struct(name: &'static str);
566        deserialize_tuple_struct(name: &'static str, len: usize);
567        deserialize_struct(name: &'static str, fields: &'static [&'static str]);
568        deserialize_tuple(len: usize);
569        deserialize_enum(name: &'static str, variants: &'static [&'static str]);
570        deserialize_identifier();
571        deserialize_ignored_any();
572        deserialize_byte_buf();
573    }
574}
575
576impl<'de> SeqAccess<'de> for SeqDecoder {
577    type Error = DecodeError;
578
579    fn next_element_seed<T>(&mut self, seed: T) -> DecodeResult<Option<T::Value>>
580        where T: DeserializeSeed<'de>
581    {
582        match self.iter.next() {
583            None => Ok(None),
584            Some(value) => {
585                self.len -= 1;
586                let de = Decoder::new(value);
587                match seed.deserialize(de) {
588                    Ok(value) => Ok(Some(value)),
589                    Err(err) => Err(err),
590                }
591            }
592        }
593    }
594
595    fn size_hint(&self) -> Option<usize> {
596        Some(self.len)
597    }
598}
599
600struct MapDecoder {
601    iter: IntoIter<String, Value>,
602    value: Option<Value>,
603    len: usize,
604}
605
606impl<'de> MapAccess<'de> for MapDecoder {
607    type Error = DecodeError;
608
609    fn next_key_seed<K>(&mut self, seed: K) -> DecodeResult<Option<K::Value>>
610        where K: DeserializeSeed<'de>
611    {
612        match self.iter.next() {
613            Some((key, value)) => {
614                self.len -= 1;
615                self.value = Some(value);
616
617                let de = Decoder::new(Value::String(key));
618                match seed.deserialize(de) {
619                    Ok(val) => Ok(Some(val)),
620                    Err(DecodeError::UnknownField(_)) => Ok(None),
621                    Err(e) => Err(e),
622                }
623            }
624            None => Ok(None),
625        }
626    }
627
628    fn next_value_seed<V>(&mut self, seed: V) -> DecodeResult<V::Value>
629        where V: DeserializeSeed<'de>
630    {
631        let value = self.value.take().ok_or(DecodeError::EndOfStream)?;
632        let de = Decoder::new(value);
633        seed.deserialize(de)
634    }
635
636    fn size_hint(&self) -> Option<usize> {
637        Some(self.len)
638    }
639}
640
641impl<'de> Deserializer<'de> for MapDecoder {
642    type Error = DecodeError;
643
644    #[inline]
645    fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
646        where V: Visitor<'de>
647    {
648        visitor.visit_map(self)
649    }
650
651    forward_to_deserialize!{
652        deserialize_bool();
653        deserialize_u8();
654        deserialize_u16();
655        deserialize_u32();
656        deserialize_u64();
657        deserialize_i8();
658        deserialize_i16();
659        deserialize_i32();
660        deserialize_i64();
661        deserialize_f32();
662        deserialize_f64();
663        deserialize_char();
664        deserialize_str();
665        deserialize_string();
666        deserialize_unit();
667        deserialize_option();
668        deserialize_seq();
669        deserialize_bytes();
670        deserialize_map();
671        deserialize_unit_struct(name: &'static str);
672        deserialize_newtype_struct(name: &'static str);
673        deserialize_tuple_struct(name: &'static str, len: usize);
674        deserialize_struct(name: &'static str, fields: &'static [&'static str]);
675        deserialize_tuple(len: usize);
676        deserialize_enum(name: &'static str, variants: &'static [&'static str]);
677        deserialize_identifier();
678        deserialize_ignored_any();
679        deserialize_byte_buf();
680    }
681}
682
683impl<'de> Deserialize<'de> for UTCDateTime {
684    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
685        where D: Deserializer<'de>
686    {
687        match Value::deserialize(deserializer)? {
688            Value::UTCDatetime(dt) => Ok(UTCDateTime(dt)),
689            _ => Err(D::Error::custom("expecting UtcDateTime")),
690        }
691    }
692}
693
694impl<'de> Deserialize<'de> for TimeStamp {
695    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
696        where D: Deserializer<'de>
697    {
698        match Value::deserialize(deserializer)? {
699            Value::TimeStamp(ts) => {
700                let ts = ts.to_le();
701
702                Ok(TimeStamp {
703                    timestamp: ((ts as u64) >> 32) as u32,
704                    increment: (ts & 0xFFFF_FFFF) as u32,
705                })
706            }
707            _ => Err(D::Error::custom("expecting UtcDateTime")),
708        }
709    }
710}