1use std::fmt;
2use std::vec;
3use std::result;
4use std::marker::PhantomData;
5use std::{i32, u32};
6
7use serde::de::{self, Deserialize, Deserializer, Visitor, MapAccess, SeqAccess, VariantAccess,
8 DeserializeSeed, EnumAccess};
9use serde::de::{Error, Expected, Unexpected};
10
11use indexmap::IndexMap;
12
13use crate::value::{Value, Array, UTCDateTime, TimeStamp};
14use crate::doc::{Document, IntoIter};
15use crate::decode::DecodeError;
16use crate::decode::DecodeResult;
17
18impl de::Error for DecodeError {
19 fn custom<T: fmt::Display>(msg: T) -> DecodeError {
20 DecodeError::Unknown(msg.to_string())
21 }
22
23 fn invalid_type(_unexp: Unexpected, exp: &dyn Expected) -> DecodeError {
24 DecodeError::InvalidType(exp.to_string())
25 }
26
27 fn invalid_value(_unexp: Unexpected, exp: &dyn Expected) -> DecodeError {
28 DecodeError::InvalidValue(exp.to_string())
29 }
30
31 fn invalid_length(len: usize, exp: &dyn Expected) -> DecodeError {
32 DecodeError::InvalidLength(len, exp.to_string())
33 }
34
35 fn unknown_variant(variant: &str, _expected: &'static [&'static str]) -> DecodeError {
36 DecodeError::UnknownVariant(variant.to_string())
37 }
38
39 fn unknown_field(field: &str, _expected: &'static [&'static str]) -> DecodeError {
40 DecodeError::UnknownField(field.to_string())
41 }
42
43 fn missing_field(field: &'static str) -> DecodeError {
44 DecodeError::ExpectedField(field)
45 }
46
47 fn duplicate_field(field: &'static str) -> DecodeError {
48 DecodeError::DuplicatedField(field)
49 }
50}
51
52impl<'de> Deserialize<'de> for Document {
53 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
55 where D: Deserializer<'de>
56 {
57 deserializer
58 .deserialize_map(ValueVisitor)
59 .and_then(|bson|
60 if let Value::Document(document) = bson {
61 Ok(document)
62 } else {
63 let err = format!("expected document, found extended JSON data type: {}", bson);
64 Err(de::Error::invalid_type(Unexpected::Map, &&*err))
65 })
66 }
67}
68
69impl<'de> Deserialize<'de> for Value {
70 #[inline]
71 fn deserialize<D>(deserializer: D) -> Result<Value, D::Error>
72 where D: Deserializer<'de>
73 {
74 deserializer.deserialize_any(ValueVisitor)
75 }
76}
77
78pub struct ValueVisitor;
79
80impl<'de> Visitor<'de> for ValueVisitor {
81 type Value = Value;
82
83 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
84 write!(f, "expecting a Value")
85 }
86
87 #[inline]
88 fn visit_bool<E>(self, value: bool) -> Result<Value, E>
89 where E: Error
90 {
91 Ok(Value::Boolean(value))
92 }
93
94 #[inline]
95 fn visit_i8<E>(self, value: i8) -> Result<Value, E>
96 where E: Error
97 {
98 Ok(Value::Int32(i32::from(value)))
99 }
100
101 #[inline]
102 fn visit_u8<E>(self, value: u8) -> Result<Value, E>
103 where E: Error
104 {
105 Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
106 }
107
108 #[inline]
109 fn visit_i16<E>(self, value: i16) -> Result<Value, E>
110 where E: Error
111 {
112 Ok(Value::Int32(i32::from(value)))
113 }
114
115 #[inline]
116 fn visit_u16<E>(self, value: u16) -> Result<Value, E>
117 where E: Error
118 {
119 Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
120 }
121
122 #[inline]
123 fn visit_i32<E>(self, value: i32) -> Result<Value, E>
124 where E: Error
125 {
126 Ok(Value::Int32(value))
127 }
128
129 #[inline]
130 fn visit_u32<E>(self, value: u32) -> Result<Value, E>
131 where E: Error
132 {
133 Err(Error::invalid_type(Unexpected::Unsigned(u64::from(value)), &"a signed integer"))
134 }
135
136 #[inline]
137 fn visit_i64<E>(self, value: i64) -> Result<Value, E>
138 where E: Error
139 {
140 Ok(Value::Int64(value))
141 }
142
143 #[inline]
144 fn visit_u64<E>(self, value: u64) -> Result<Value, E>
145 where E: Error
146 {
147 Err(Error::invalid_type(Unexpected::Unsigned(value), &"a signed integer"))
148 }
149
150 #[inline]
151 fn visit_f64<E>(self, value: f64) -> Result<Value, E> {
152 Ok(Value::Double(value))
153 }
154
155 #[inline]
156 fn visit_str<E>(self, value: &str) -> Result<Value, E>
157 where E: de::Error
158 {
159 self.visit_string(value.to_string())
160 }
161
162 #[inline]
163 fn visit_string<E>(self, value: String) -> Result<Value, E> {
164 Ok(Value::String(value))
165 }
166
167 #[inline]
168 fn visit_none<E>(self) -> Result<Value, E> {
169 Ok(Value::Null)
170 }
171
172 #[inline]
173 fn visit_some<D>(self, deserializer: D) -> Result<Value, D::Error>
174 where D: Deserializer<'de>
175 {
176 deserializer.deserialize_any(self)
177 }
178
179 #[inline]
180 fn visit_unit<E>(self) -> Result<Value, E> {
181 Ok(Value::Null)
182 }
183
184 #[inline]
185 fn visit_seq<V>(self, mut visitor: V) -> Result<Value, V::Error>
186 where V: SeqAccess<'de>
187 {
188 let mut values = Array::new();
189
190 while let Some(elem) = visitor.next_element()? {
191 values.push(elem);
192 }
193
194 Ok(Value::Array(values))
195 }
196
197 #[inline]
198 fn visit_map<V>(self, visitor: V) -> Result<Value, V::Error>
199 where V: MapAccess<'de>
200 {
201 let values = DocumentVisitor::new().visit_map(visitor)?;
202 Ok(Value::from_extended_document(values))
203 }
204}
205
206#[derive(Default)]
207pub struct DocumentVisitor {
208 marker: PhantomData<Document>
209}
210
211impl DocumentVisitor {
212 pub fn new() -> DocumentVisitor {
213 DocumentVisitor { marker: PhantomData }
214 }
215}
216
217impl<'de> Visitor<'de> for DocumentVisitor {
218 type Value = Document;
219
220 fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
221 write!(f, "expecting ordered object")
222 }
223
224 #[inline]
225 fn visit_unit<E>(self) -> result::Result<Document, E>
226 where E: de::Error
227 {
228 Ok(Document::new())
229 }
230
231 #[inline]
232 fn visit_map<V>(self, mut visitor: V) -> result::Result<Document, V::Error>
233 where V: MapAccess<'de>
234 {
235 let mut inner = match visitor.size_hint() {
236 Some(size) => IndexMap::with_capacity(size),
237 None => IndexMap::new(),
238 };
239
240 while let Some((key, value)) = visitor.next_entry()? {
241 inner.insert(key, value);
242 }
243
244 Ok(inner.into())
245 }
246}
247
248pub struct Decoder {
250 value: Option<Value>,
251}
252
253impl Decoder {
254 pub fn new(value: Value) -> Decoder {
255 Decoder { value: Some(value) }
256 }
257}
258
259macro_rules! forward_to_deserialize {
260 ($(
261 $name:ident ( $( $arg:ident : $ty:ty ),* );
262 )*) => {
263 $(
264 forward_to_deserialize!{
265 func: $name ( $( $arg: $ty ),* );
266 }
267 )*
268 };
269
270 (func: deserialize_enum ( $( $arg:ident : $ty:ty ),* );) => {
271 fn deserialize_enum<V>(
272 self,
273 $(_: $ty,)*
274 _visitor: V,
275 ) -> ::std::result::Result<V::Value, Self::Error>
276 where V: ::serde::de::Visitor<'de>
277 {
278 Err(::serde::de::Error::custom("unexpected Enum"))
279 }
280 };
281
282 (func: $name:ident ( $( $arg:ident : $ty:ty ),* );) => {
283 #[inline]
284 fn $name<V>(
285 self,
286 $(_: $ty,)*
287 visitor: V,
288 ) -> ::std::result::Result<V::Value, Self::Error>
289 where V: ::serde::de::Visitor<'de>
290 {
291 self.deserialize_any(visitor)
292 }
293 };
294}
295
296impl<'de> Deserializer<'de> for Decoder {
297 type Error = DecodeError;
298
299 #[inline]
300 fn deserialize_any<V>(mut self, visitor: V) -> DecodeResult<V::Value>
301 where V: Visitor<'de>
302 {
303 let value = match self.value.take() {
304 Some(value) => value,
305 None => return Err(DecodeError::EndOfStream),
306 };
307
308 match value {
309 Value::Double(v) => visitor.visit_f64(v),
310 Value::String(v) => visitor.visit_string(v),
311 Value::Array(v) => {
312 let len = v.len();
313 visitor.visit_seq(
314 SeqDecoder {
315 iter: v.into_iter(),
316 len,
317 }
318 )
319 }
320 Value::Document(v) => {
321 let len = v.len();
322 visitor.visit_map(
323 MapDecoder {
324 iter: v.into_iter(),
325 value: None,
326 len,
327 }
328 )
329 }
330 Value::Boolean(v) => visitor.visit_bool(v),
331 Value::Null => visitor.visit_unit(),
332 Value::Int32(v) => visitor.visit_i32(v),
333 Value::Int64(v) => visitor.visit_i64(v),
334 Value::Binary(_, v) => visitor.visit_bytes(&v),
335 _ => {
336 let doc = value.to_extended_document();
337 let len = doc.len();
338 visitor.visit_map(
339 MapDecoder {
340 iter: doc.into_iter(),
341 value: None,
342 len,
343 }
344 )
345 }
346 }
347 }
348
349 #[inline]
350 fn deserialize_option<V>(self, visitor: V) -> DecodeResult<V::Value>
351 where V: Visitor<'de>
352 {
353 match self.value {
354 Some(Value::Null) => visitor.visit_none(),
355 Some(_) => visitor.visit_some(self),
356 None => Err(DecodeError::EndOfStream),
357 }
358 }
359
360 #[inline]
361 fn deserialize_enum<V>(
362 mut self,
363 _name: &str,
364 _variants: &'static [&'static str],
365 visitor: V
366 ) -> DecodeResult<V::Value>
367 where V: Visitor<'de>
368 {
369 let value = match self.value.take() {
370 Some(Value::Document(value)) => value,
371 Some(Value::String(variant)) => {
372 return visitor.visit_enum(
373 EnumDecoder {
374 val: Value::String(variant),
375 decoder: VariantDecoder { val: None },
376 }
377 );
378 }
379 Some(_) => {
380 return Err(DecodeError::InvalidType("expected an enum".to_string()));
381 }
382 None => {
383 return Err(DecodeError::EndOfStream);
384 }
385 };
386
387 let mut iter = value.into_iter();
388
389 let (variant, value) = match iter.next() {
390 Some(v) => v,
391 None => return Err(DecodeError::SyntaxError("expected a variant name".to_string())),
392 };
393
394 match iter.next() {
396 Some(_) => {
397 Err(DecodeError::InvalidType("expected a single key:value pair".to_string()))
398 }
399 None => {
400 visitor.visit_enum(
401 EnumDecoder {
402 val: Value::String(variant),
403 decoder: VariantDecoder { val: Some(value) },
404 }
405 )
406 }
407 }
408 }
409
410 #[inline]
411 fn deserialize_newtype_struct<V>(
412 self,
413 _name: &'static str,
414 visitor: V
415 ) -> DecodeResult<V::Value>
416 where V: Visitor<'de>
417 {
418 visitor.visit_newtype_struct(self)
419 }
420
421 forward_to_deserialize!{
422 deserialize_bool();
423 deserialize_u8();
424 deserialize_u16();
425 deserialize_u32();
426 deserialize_u64();
427 deserialize_i8();
428 deserialize_i16();
429 deserialize_i32();
430 deserialize_i64();
431 deserialize_f32();
432 deserialize_f64();
433 deserialize_char();
434 deserialize_str();
435 deserialize_string();
436 deserialize_unit();
437 deserialize_seq();
438 deserialize_bytes();
439 deserialize_map();
440 deserialize_unit_struct(name: &'static str);
441 deserialize_tuple_struct(name: &'static str, len: usize);
442 deserialize_struct(name: &'static str, fields: &'static [&'static str]);
443 deserialize_tuple(len: usize);
444 deserialize_identifier();
445 deserialize_ignored_any();
446 deserialize_byte_buf();
447 }
448}
449
450struct EnumDecoder {
451 val: Value,
452 decoder: VariantDecoder,
453}
454
455impl<'de> EnumAccess<'de> for EnumDecoder {
456 type Error = DecodeError;
457 type Variant = VariantDecoder;
458 fn variant_seed<V>(self, seed: V) -> DecodeResult<(V::Value, Self::Variant)>
459 where V: DeserializeSeed<'de>
460 {
461 let dec = Decoder::new(self.val);
462 let value = seed.deserialize(dec)?;
463 Ok((value, self.decoder))
464 }
465}
466
467struct VariantDecoder {
468 val: Option<Value>,
469}
470
471impl<'de> VariantAccess<'de> for VariantDecoder {
472 type Error = DecodeError;
473
474 fn unit_variant(mut self) -> DecodeResult<()> {
475 match self.val.take() {
476 None => Ok(()),
477 Some(val) => {
478 Value::deserialize(Decoder::new(val)).map(|_| ())
479 }
480 }
481 }
482
483 fn newtype_variant_seed<T>(mut self, seed: T) -> DecodeResult<T::Value>
484 where T: DeserializeSeed<'de>
485 {
486 let dec = Decoder::new(self.val.take().ok_or(DecodeError::EndOfStream)?);
487 seed.deserialize(dec)
488 }
489
490 fn tuple_variant<V>(mut self, _len: usize, visitor: V) -> DecodeResult<V::Value>
491 where V: Visitor<'de>
492 {
493 if let Value::Array(fields) = self.val.take().ok_or(DecodeError::EndOfStream)? {
494
495 let de = SeqDecoder {
496 len: fields.len(),
497 iter: fields.into_iter(),
498 };
499 de.deserialize_any(visitor)
500 } else {
501 return Err(DecodeError::InvalidType("expected a tuple".to_string()));
502 }
503 }
504
505 fn struct_variant<V>(
506 mut self,
507 _fields: &'static [&'static str],
508 visitor: V
509 ) -> DecodeResult<V::Value>
510 where V: Visitor<'de>
511 {
512 if let Value::Document(fields) = self.val.take().ok_or(DecodeError::EndOfStream)? {
513 let de = MapDecoder {
514 len: fields.len(),
515 iter: fields.into_iter(),
516 value: None,
517 };
518 de.deserialize_any(visitor)
519 } else {
520 return Err(DecodeError::InvalidType("expected a struct".to_string()));
521 }
522 }
523}
524
525struct SeqDecoder {
526 iter: vec::IntoIter<Value>,
527 len: usize,
528}
529
530impl<'de> Deserializer<'de> for SeqDecoder {
531 type Error = DecodeError;
532
533 #[inline]
534 fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
535 where V: Visitor<'de>
536 {
537 if self.len == 0 {
538 visitor.visit_unit()
539 } else {
540 visitor.visit_seq(self)
541 }
542 }
543
544 forward_to_deserialize!{
545 deserialize_bool();
546 deserialize_u8();
547 deserialize_u16();
548 deserialize_u32();
549 deserialize_u64();
550 deserialize_i8();
551 deserialize_i16();
552 deserialize_i32();
553 deserialize_i64();
554 deserialize_f32();
555 deserialize_f64();
556 deserialize_char();
557 deserialize_str();
558 deserialize_string();
559 deserialize_unit();
560 deserialize_option();
561 deserialize_seq();
562 deserialize_bytes();
563 deserialize_map();
564 deserialize_unit_struct(name: &'static str);
565 deserialize_newtype_struct(name: &'static str);
566 deserialize_tuple_struct(name: &'static str, len: usize);
567 deserialize_struct(name: &'static str, fields: &'static [&'static str]);
568 deserialize_tuple(len: usize);
569 deserialize_enum(name: &'static str, variants: &'static [&'static str]);
570 deserialize_identifier();
571 deserialize_ignored_any();
572 deserialize_byte_buf();
573 }
574}
575
576impl<'de> SeqAccess<'de> for SeqDecoder {
577 type Error = DecodeError;
578
579 fn next_element_seed<T>(&mut self, seed: T) -> DecodeResult<Option<T::Value>>
580 where T: DeserializeSeed<'de>
581 {
582 match self.iter.next() {
583 None => Ok(None),
584 Some(value) => {
585 self.len -= 1;
586 let de = Decoder::new(value);
587 match seed.deserialize(de) {
588 Ok(value) => Ok(Some(value)),
589 Err(err) => Err(err),
590 }
591 }
592 }
593 }
594
595 fn size_hint(&self) -> Option<usize> {
596 Some(self.len)
597 }
598}
599
600struct MapDecoder {
601 iter: IntoIter<String, Value>,
602 value: Option<Value>,
603 len: usize,
604}
605
606impl<'de> MapAccess<'de> for MapDecoder {
607 type Error = DecodeError;
608
609 fn next_key_seed<K>(&mut self, seed: K) -> DecodeResult<Option<K::Value>>
610 where K: DeserializeSeed<'de>
611 {
612 match self.iter.next() {
613 Some((key, value)) => {
614 self.len -= 1;
615 self.value = Some(value);
616
617 let de = Decoder::new(Value::String(key));
618 match seed.deserialize(de) {
619 Ok(val) => Ok(Some(val)),
620 Err(DecodeError::UnknownField(_)) => Ok(None),
621 Err(e) => Err(e),
622 }
623 }
624 None => Ok(None),
625 }
626 }
627
628 fn next_value_seed<V>(&mut self, seed: V) -> DecodeResult<V::Value>
629 where V: DeserializeSeed<'de>
630 {
631 let value = self.value.take().ok_or(DecodeError::EndOfStream)?;
632 let de = Decoder::new(value);
633 seed.deserialize(de)
634 }
635
636 fn size_hint(&self) -> Option<usize> {
637 Some(self.len)
638 }
639}
640
641impl<'de> Deserializer<'de> for MapDecoder {
642 type Error = DecodeError;
643
644 #[inline]
645 fn deserialize_any<V>(self, visitor: V) -> DecodeResult<V::Value>
646 where V: Visitor<'de>
647 {
648 visitor.visit_map(self)
649 }
650
651 forward_to_deserialize!{
652 deserialize_bool();
653 deserialize_u8();
654 deserialize_u16();
655 deserialize_u32();
656 deserialize_u64();
657 deserialize_i8();
658 deserialize_i16();
659 deserialize_i32();
660 deserialize_i64();
661 deserialize_f32();
662 deserialize_f64();
663 deserialize_char();
664 deserialize_str();
665 deserialize_string();
666 deserialize_unit();
667 deserialize_option();
668 deserialize_seq();
669 deserialize_bytes();
670 deserialize_map();
671 deserialize_unit_struct(name: &'static str);
672 deserialize_newtype_struct(name: &'static str);
673 deserialize_tuple_struct(name: &'static str, len: usize);
674 deserialize_struct(name: &'static str, fields: &'static [&'static str]);
675 deserialize_tuple(len: usize);
676 deserialize_enum(name: &'static str, variants: &'static [&'static str]);
677 deserialize_identifier();
678 deserialize_ignored_any();
679 deserialize_byte_buf();
680 }
681}
682
683impl<'de> Deserialize<'de> for UTCDateTime {
684 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
685 where D: Deserializer<'de>
686 {
687 match Value::deserialize(deserializer)? {
688 Value::UTCDatetime(dt) => Ok(UTCDateTime(dt)),
689 _ => Err(D::Error::custom("expecting UtcDateTime")),
690 }
691 }
692}
693
694impl<'de> Deserialize<'de> for TimeStamp {
695 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
696 where D: Deserializer<'de>
697 {
698 match Value::deserialize(deserializer)? {
699 Value::TimeStamp(ts) => {
700 let ts = ts.to_le();
701
702 Ok(TimeStamp {
703 timestamp: ((ts as u64) >> 32) as u32,
704 increment: (ts & 0xFFFF_FFFF) as u32,
705 })
706 }
707 _ => Err(D::Error::custom("expecting UtcDateTime")),
708 }
709 }
710}