1use std::convert::TryInto;
2
3use serde::de::{DeserializeSeed, Deserializer, MapAccess, Visitor};
4use serde::forward_to_deserialize_any;
5
6use super::Error;
7
8pub static NAME: &str = "$__bson_DateTime";
9pub static FIELD: &str = "$date";
10pub static FIELDS: &[&str] = &[FIELD];
11
12struct DateTimeKeyDeserializer {
13    key: &'static str,
14}
15
16impl DateTimeKeyDeserializer {
17    fn new(key: &'static str) -> DateTimeKeyDeserializer {
18        DateTimeKeyDeserializer { key }
19    }
20}
21
22impl<'de> Deserializer<'de> for DateTimeKeyDeserializer {
23    type Error = Error;
24
25    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
26    where
27        V: Visitor<'de>,
28    {
29        visitor.visit_str(self.key)
30    }
31
32    forward_to_deserialize_any!(
33        bool u8 u16 u32 u64 i8 i16 i32 i64 f32 f64 char str string seq
34        bytes byte_buf map struct option unit newtype_struct
35        ignored_any unit_struct tuple_struct tuple enum identifier
36    );
37}
38
39pub struct DateTimeDeserializer {
40    data: i64,
41    visited: bool,
42}
43
44impl DateTimeDeserializer {
45    pub fn new(data: i64) -> DateTimeDeserializer {
46        DateTimeDeserializer {
47            data,
48            visited: false,
49        }
50    }
51}
52
53impl<'de> Deserializer<'de> for DateTimeDeserializer {
54    type Error = Error;
55
56    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
57    where
58        V: Visitor<'de>,
59    {
60        self.deserialize_struct(NAME, FIELDS, visitor)
61    }
62
63    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Error>
64    where
65        V: Visitor<'de>,
66    {
67        visitor.visit_i64(self.data)
68    }
69
70    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Error>
71    where
72        V: Visitor<'de>,
73    {
74        visitor.visit_u64(self.data.try_into()?)
75    }
76
77    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value, Self::Error> {
78        visitor.visit_map(self)
79    }
80
81    fn deserialize_struct<V: Visitor<'de>>(
82        self,
83        name: &str,
84        _fields: &[&str],
85        visitor: V,
86    ) -> Result<V::Value, Self::Error> {
87        if name == NAME {
88            visitor.visit_map(self)
89        } else {
90            Err(Error::MalformedDocument)
91        }
92    }
93
94    forward_to_deserialize_any!(
95        bool u8 u16 u32 i8 i16 i32 f32 f64 char bytes byte_buf
96        option unit newtype_struct str string tuple
97        ignored_any seq unit_struct tuple_struct enum identifier
98    );
99}
100
101impl<'de> MapAccess<'de> for DateTimeDeserializer {
102    type Error = Error;
103
104    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Error>
105    where
106        K: DeserializeSeed<'de>,
107    {
108        match self.visited {
109            false => seed
110                .deserialize(DateTimeKeyDeserializer::new(FIELD))
111                .map(Some),
112            true => Ok(None),
113        }
114    }
115
116    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Error>
117    where
118        V: DeserializeSeed<'de>,
119    {
120        match self.visited {
121            false => {
122                self.visited = true;
123                seed.deserialize(DateTimeFieldDeserializer::new(self.data))
124            }
125            true => Err(Error::MalformedDocument),
126        }
127    }
128}
129
130struct DateTimeFieldDeserializer {
131    data: i64,
132}
133
134impl<'de> DateTimeFieldDeserializer {
135    fn new(data: i64) -> DateTimeFieldDeserializer {
136        DateTimeFieldDeserializer { data }
137    }
138}
139
140impl<'de> Deserializer<'de> for DateTimeFieldDeserializer {
141    type Error = Error;
142
143    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Error>
144    where
145        V: Visitor<'de>,
146    {
147        self.deserialize_i64(visitor)
148    }
149
150    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Error>
151    where
152        V: Visitor<'de>,
153    {
154        visitor.visit_i64(self.data)
155    }
156
157    forward_to_deserialize_any!(
158        bool u8 u16 u32 u64 i8 i16 i32 f32 f64 char seq
159        bytes byte_buf str string map struct option unit newtype_struct
160        ignored_any unit_struct tuple_struct tuple enum identifier
161    );
162}