msgpack/
de.rs

1use crate::code;
2use crate::unpack;
3use crate::unpack_error;
4use crate::BufferedRead;
5
6use serde;
7use serde::de::{self, Deserialize, DeserializeOwned, DeserializeSeed, Visitor};
8use serde::forward_to_deserialize_any;
9use std::io;
10
11use std::error;
12use std::fmt::{self, Display};
13
14#[derive(Debug)]
15pub enum DeError {
16    InvalidSize,
17    UnpackError(unpack_error::UnpackError),
18    Custom(String),
19}
20
21impl From<unpack_error::UnpackError> for DeError {
22    fn from(err: unpack_error::UnpackError) -> DeError {
23        DeError::UnpackError(err)
24    }
25}
26
27impl Display for DeError {
28    fn fmt(&self, f: &mut fmt::Formatter) -> Result<(), fmt::Error> {
29        error::Error::description(self).fmt(f)
30    }
31}
32
33impl error::Error for DeError {
34    fn description(&self) -> &str {
35        use DeError::*;
36
37        match *self {
38            InvalidSize => "invalid size",
39            UnpackError(ref e) => e.description(),
40            Custom(ref s) => s,
41        }
42    }
43
44    fn cause(&self) -> Option<&dyn error::Error> {
45        use DeError::*;
46
47        match *self {
48            UnpackError(ref e) => Some(e),
49            Custom(_) => None,
50            InvalidSize => None,
51        }
52    }
53}
54
55impl serde::de::Error for DeError {
56    fn custom<T: Display>(msg: T) -> DeError {
57        DeError::Custom(msg.to_string())
58    }
59}
60
61struct PeekReader<R> {
62    code: Option<code::Code>,
63    reader: R,
64}
65
66impl<R: io::Read> PeekReader<R> {
67    pub fn peek_code(&mut self) -> Result<&code::Code, unpack_error::UnpackError> {
68        if let Some(ref v) = self.code {
69            Ok(v)
70        } else {
71            let code = unpack::read_code(&mut self.reader)?;
72            self.code = Some(code);
73            Ok(self.code.as_ref().unwrap())
74        }
75    }
76
77    pub fn consume_code(&mut self) -> Option<code::Code> {
78        self.code.take()
79    }
80}
81
82impl<R: io::Read> io::Read for PeekReader<R> {
83    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
84        if let Some(ref v) = self.code {
85            buf[0] = v.to_u8();
86            if buf.len() > 1 {
87                self.reader.read(&mut buf[1..])
88            } else {
89                Ok(1)
90            }
91        } else {
92            self.reader.read(buf)
93        }
94    }
95}
96
97impl<'a, R: BufferedRead<'a>> BufferedRead<'a> for PeekReader<R> {
98    fn fill_buf(&self) -> io::Result<&'a [u8]> {
99        self.reader.fill_buf()
100    }
101
102    fn consume(&mut self, len: usize) {
103        self.reader.consume(len)
104    }
105}
106
107struct SeqAccess<'a, R: io::Read + 'a> {
108    de: &'a mut Deserializer<R>,
109    len: usize,
110}
111
112impl<'de, 'a, R> serde::de::SeqAccess<'de> for SeqAccess<'a, R>
113where
114    R: BufferedRead<'de> + 'a,
115{
116    type Error = DeError;
117
118    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
119    where
120        T: serde::de::DeserializeSeed<'de>,
121    {
122        if self.len > 0 {
123            self.len -= 1;
124            Ok(Some(seed.deserialize(&mut *self.de)?))
125        } else {
126            Ok(None)
127        }
128    }
129
130    fn size_hint(&self) -> Option<usize> {
131        Some(self.len)
132    }
133}
134
135struct MapAccess<'a, R: 'a> {
136    de: &'a mut Deserializer<R>,
137    len: usize,
138}
139
140impl<'de, 'a, R> de::MapAccess<'de> for MapAccess<'a, R>
141where
142    R: BufferedRead<'de> + 'a,
143{
144    type Error = DeError;
145
146    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
147    where
148        K: DeserializeSeed<'de>,
149    {
150        if self.len > 0 {
151            self.len -= 1;
152            Ok(Some(seed.deserialize(&mut *self.de)?))
153        } else {
154            Ok(None)
155        }
156    }
157
158    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
159    where
160        V: DeserializeSeed<'de>,
161    {
162        Ok(seed.deserialize(&mut *self.de)?)
163    }
164
165    fn size_hint(&self) -> Option<usize> {
166        Some(self.len)
167    }
168}
169
170// struct BytesAccess<'a, R: io::Read + 'a> {
171//     de: &'a mut Deserializer<R>,
172//     len: usize,
173// }
174
175// impl<'de, 'a, R> serde::de::SeqAccess<'de> for BytesAccess<'a, R>
176// where
177//     R: BufferedRead<'de> + 'a,
178// {
179//     type Error = DeError;
180
181//     fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
182//     where
183//         T: serde::de::DeserializeSeed<'de>,
184//     {
185//         if self.len > 0 {
186//             self.len -= 1;
187//       // seed.deserialize(MapKey { de: &mut *self.de }).map(Some),
188//             Ok(Some(seed.deserialize(&mut *self.de)?))
189//         } else {
190//             Ok(None)
191//         }
192//     }
193
194//     fn size_hint(&self) -> Option<usize> {
195//         Some(self.len)
196//     }
197// }
198
199pub struct Deserializer<R> {
200    reader: PeekReader<R>,
201}
202
203impl<R> Deserializer<R> {
204    pub fn new(r: R) -> Self {
205        Deserializer {
206            reader: PeekReader {
207                code: None,
208                reader: r,
209            },
210        }
211    }
212}
213
214macro_rules! impl_nums {
215    ($dser_method:ident, $visitor_method:ident, $unpack_method:ident) => {
216        #[inline]
217        fn $dser_method<V>(self, visitor: V) -> Result<V::Value, Self::Error>
218        where
219            V: serde::de::Visitor<'de>,
220        {
221            let v = unpack::$unpack_method(&mut self.reader)?;
222            visitor.$visitor_method(v)
223        }
224    }
225}
226
227impl<'de, 'a, R> serde::Deserializer<'de> for &'a mut Deserializer<R>
228where
229    R: BufferedRead<'de>,
230{
231    type Error = DeError;
232
233    impl_nums!(deserialize_u8, visit_u8, unpack_u8);
234    impl_nums!(deserialize_u16, visit_u16, unpack_u16);
235    impl_nums!(deserialize_u32, visit_u32, unpack_u32);
236    impl_nums!(deserialize_u64, visit_u64, unpack_u64);
237    impl_nums!(deserialize_i8, visit_i8, unpack_i8);
238    impl_nums!(deserialize_i16, visit_i16, unpack_i16);
239    impl_nums!(deserialize_i32, visit_i32, unpack_i32);
240    impl_nums!(deserialize_i64, visit_i64, unpack_i64);
241    impl_nums!(deserialize_f32, visit_f32, unpack_f32);
242    impl_nums!(deserialize_f64, visit_f64, unpack_f64);
243
244    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
245    where
246        V: Visitor<'de>,
247    {
248        use code::Code;
249
250        match self.reader.peek_code()? {
251            Code::Nil => self.deserialize_unit(visitor),
252            Code::True | Code::False => self.deserialize_bool(visitor),
253            Code::Uint8 | Code::PosInt(_) => self.deserialize_u8(visitor),
254            Code::Uint16 => self.deserialize_u16(visitor),
255            Code::Uint32 => self.deserialize_u32(visitor),
256            Code::Uint64 => self.deserialize_u64(visitor),
257            Code::Int8 | Code::NegInt(_) => self.deserialize_i8(visitor),
258            Code::Int16 => self.deserialize_i16(visitor),
259            Code::Int32 => self.deserialize_i32(visitor),
260            Code::Int64 => self.deserialize_i64(visitor),
261            Code::Float32 => self.deserialize_f32(visitor),
262            Code::Float64 => self.deserialize_f64(visitor),
263            Code::FixStr(_) | Code::Str8 | Code::Str16 | Code::Str32 => {
264                self.deserialize_string(visitor)
265            }
266            Code::Bin8 | Code::Bin16 | Code::Bin32 => self.deserialize_bytes(visitor),
267            Code::FixArray(_) | Code::Array16 | Code::Array32 => self.deserialize_seq(visitor),
268            Code::FixMap(_) | Code::Map16 | Code::Map32 => self.deserialize_map(visitor),
269            // Code::FixExt1 => FIXEXT1,
270            // Code::FixExt2 => FIXEXT2,
271            // Code::FixExt4 => FIXEXT4,
272            // Code::FixExt8 => FIXEXT8,
273            // Code::FixExt16 => FIXEXT16,
274            // Code::Ext8 => EXT8,
275            // Code::Ext16 => EXT16,
276            // Code::Ext32 => EXT32,
277            Code::Reserved => unreachable!(), // tmp
278            _ => unreachable!(),
279        }
280    }
281
282    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
283    where
284        V: serde::de::Visitor<'de>,
285    {
286        match self.reader.peek_code()? {
287            code::Code::Nil => {
288                let _ = self.reader.consume_code();
289                visitor.visit_none()
290            }
291            _ => visitor.visit_some(self),
292        }
293    }
294
295    fn deserialize_enum<V>(
296        self,
297        _name: &str,
298        _variants: &[&str],
299        visitor: V,
300    ) -> Result<V::Value, Self::Error>
301    where
302        V: Visitor<'de>,
303    {
304        visitor.visit_none()
305    }
306
307    fn deserialize_newtype_struct<V>(
308        self,
309        _name: &'static str,
310        visitor: V,
311    ) -> Result<V::Value, Self::Error>
312    where
313        V: Visitor<'de>,
314    {
315        visitor.visit_newtype_struct(self)
316    }
317
318    fn deserialize_unit_struct<V>(
319        self,
320        _name: &'static str,
321        visitor: V,
322    ) -> Result<V::Value, Self::Error>
323    where
324        V: Visitor<'de>,
325    {
326        visitor.visit_unit()
327    }
328
329    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
330    where
331        V: serde::de::Visitor<'de>,
332    {
333        match unpack::unpack_bool(&mut self.reader)? {
334            true => visitor.visit_bool(true),
335            false => visitor.visit_bool(false),
336        }
337    }
338
339    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
340    where
341        V: serde::de::Visitor<'de>,
342    {
343        let size = unpack::unpack_ary_header(&mut self.reader)?;
344
345        visitor.visit_seq(SeqAccess {
346            de: self,
347            len: size,
348        })
349    }
350
351    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
352    where
353        V: serde::de::Visitor<'de>,
354    {
355        let size = unpack::unpack_ary_header(&mut self.reader)?;
356        if size != len {
357            return Err(Self::Error::InvalidSize);
358        }
359
360        visitor.visit_seq(SeqAccess {
361            de: self,
362            len: size,
363        })
364    }
365
366    fn deserialize_tuple_struct<V>(
367        self,
368        _name: &'static str,
369        len: usize,
370        visitor: V,
371    ) -> Result<V::Value, Self::Error>
372    where
373        V: serde::de::Visitor<'de>,
374    {
375        self.deserialize_tuple(len, visitor)
376    }
377
378    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
379    where
380        V: serde::de::Visitor<'de>,
381    {
382        visitor.visit_unit()
383    }
384
385    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
386    where
387        V: serde::de::Visitor<'de>,
388    {
389        let body = unpack::unpack_str(&mut self.reader)?;
390        // TODO: bytes
391        visitor.visit_string(body)
392    }
393
394    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
395    where
396        V: serde::de::Visitor<'de>,
397    {
398        let body = unpack::unpack_str_ref(&mut self.reader)?;
399        // TODO: bytes_ref
400        visitor.visit_str(body)
401    }
402
403    fn deserialize_struct<V>(
404        self,
405        _name: &str,
406        fields: &'static [&'static str],
407        visitor: V,
408    ) -> Result<V::Value, Self::Error>
409    where
410        V: serde::de::Visitor<'de>,
411    {
412        // TODO
413        let _size = unpack::unpack_map_header(&mut self.reader)?;
414        visitor.visit_map(MapAccess {
415            de: self,
416            len: fields.len(),
417        })
418    }
419
420    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
421    where
422        V: serde::de::Visitor<'de>,
423    {
424        let size = unpack::unpack_map_header(&mut self.reader)?;
425        visitor.visit_map(MapAccess {
426            de: self,
427            len: size,
428        })
429    }
430
431    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
432    where
433        V: de::Visitor<'de>,
434    {
435        self.deserialize_str(visitor)
436    }
437
438    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
439    where
440        V: serde::de::Visitor<'de>,
441    {
442        let body = unpack::unpack_bin_ref(&mut self.reader)?;
443        visitor.visit_bytes(body)
444    }
445
446    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
447    where
448        V: serde::de::Visitor<'de>,
449    {
450        let body = unpack::unpack_bin(&mut self.reader)?;
451        visitor.visit_byte_buf(body)
452    }
453
454    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
455    where
456        V: de::Visitor<'de>,
457    {
458        self.deserialize_str(visitor)
459    }
460
461    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
462    where
463        V: de::Visitor<'de>,
464    {
465        visitor.visit_unit()
466    }
467}