Skip to main content

messagepack_serde/de/
mod.rs

1//! Deserialize support for messagepack
2
3mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use messagepack_core::{
10    Decode, Format,
11    decode::NbyteReader,
12    io::{IoRead, RError},
13};
14use serde::{
15    Deserialize,
16    de::{self, IntoDeserializer},
17    forward_to_deserialize_any,
18};
19
20/// Deserialize from [messagepack_core::io::IoRead]
21pub fn from_core_reader<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
22where
23    R: IoRead<'de>,
24    T: Deserialize<'de>,
25{
26    let mut deserializer = Deserializer::new(reader);
27    T::deserialize(&mut deserializer)
28}
29
30/// Deserialize from slice
31pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
32    use messagepack_core::io::SliceReader;
33    let reader = SliceReader::new(input);
34    from_core_reader(reader)
35}
36
37#[cfg(feature = "std")]
38/// Deserialize from [std::io::Read]
39pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
40where
41    R: std::io::Read,
42    T: for<'a> Deserialize<'a>,
43{
44    use messagepack_core::io::StdReader;
45    use std::io;
46    let reader = StdReader::new(reader);
47    let result = from_core_reader::<'_, StdReader<R>, T>(reader);
48    match result {
49        Ok(v) => Ok(v),
50        Err(err) => match err {
51            Error::Decode(err) => match err {
52                messagepack_core::decode::Error::InvalidData
53                | messagepack_core::decode::Error::UnexpectedFormat => {
54                    Err(io::Error::new(io::ErrorKind::InvalidData, err))
55                }
56                messagepack_core::decode::Error::UnexpectedEof => {
57                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
58                }
59                messagepack_core::decode::Error::Io(e) => Err(e),
60            },
61            _ => Err(io::Error::other(err)),
62        },
63    }
64}
65
66const MAX_RECURSION_DEPTH: usize = 256;
67
68struct Deserializer<R> {
69    reader: R,
70    depth: usize,
71    format: Option<Format>,
72}
73
74impl<'de, R> Deserializer<R>
75where
76    R: IoRead<'de>,
77{
78    fn new(reader: R) -> Self {
79        Deserializer {
80            reader,
81            depth: 0,
82            format: None,
83        }
84    }
85
86    fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
87    where
88        F: FnOnce(&mut Self) -> V,
89    {
90        if self.depth == MAX_RECURSION_DEPTH {
91            return Err(Error::RecursionLimitExceeded);
92        }
93        self.depth += 1;
94        let result = f(self);
95        self.depth -= 1;
96        Ok(result)
97    }
98
99    fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
100        match self.format.take() {
101            Some(v) => Ok(v),
102            None => {
103                let v = Format::decode(&mut self.reader)?;
104                Ok(v)
105            }
106        }
107    }
108
109    fn decode_seq_with_format<V>(
110        &mut self,
111        format: Format,
112        visitor: V,
113    ) -> Result<V::Value, Error<R::Error>>
114    where
115        V: de::Visitor<'de>,
116    {
117        let n = match format {
118            Format::FixArray(n) => n.into(),
119            Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
120            Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
121            _ => return Err(CoreError::UnexpectedFormat.into()),
122        };
123        self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
124    }
125
126    fn decode_map_with_format<V>(
127        &mut self,
128        format: Format,
129        visitor: V,
130    ) -> Result<V::Value, Error<R::Error>>
131    where
132        V: de::Visitor<'de>,
133    {
134        let n = match format {
135            Format::FixMap(n) => n.into(),
136            Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
137            Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
138            _ => return Err(CoreError::UnexpectedFormat.into()),
139        };
140        self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
141    }
142}
143
144impl<R> AsMut<Self> for Deserializer<R> {
145    fn as_mut(&mut self) -> &mut Self {
146        self
147    }
148}
149
150impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
151where
152    R: IoRead<'de>,
153{
154    type Error = Error<R::Error>;
155
156    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
157    where
158        V: de::Visitor<'de>,
159    {
160        let format = self.decode_format()?;
161        match format {
162            Format::Nil => visitor.visit_unit(),
163            Format::False => visitor.visit_bool(false),
164            Format::True => visitor.visit_bool(true),
165            Format::PositiveFixInt(v) => visitor.visit_u8(v),
166            Format::Uint8 => {
167                let v = u8::decode_with_format(format, &mut self.reader)?;
168                visitor.visit_u8(v)
169            }
170            Format::Uint16 => {
171                let v = u16::decode_with_format(format, &mut self.reader)?;
172                visitor.visit_u16(v)
173            }
174            Format::Uint32 => {
175                let v = u32::decode_with_format(format, &mut self.reader)?;
176                visitor.visit_u32(v)
177            }
178            Format::Uint64 => {
179                let v = u64::decode_with_format(format, &mut self.reader)?;
180                visitor.visit_u64(v)
181            }
182            Format::NegativeFixInt(v) => visitor.visit_i8(v),
183            Format::Int8 => {
184                let v = i8::decode_with_format(format, &mut self.reader)?;
185                visitor.visit_i8(v)
186            }
187            Format::Int16 => {
188                let v = i16::decode_with_format(format, &mut self.reader)?;
189                visitor.visit_i16(v)
190            }
191            Format::Int32 => {
192                let v = i32::decode_with_format(format, &mut self.reader)?;
193                visitor.visit_i32(v)
194            }
195            Format::Int64 => {
196                let v = i64::decode_with_format(format, &mut self.reader)?;
197                visitor.visit_i64(v)
198            }
199            Format::Float32 => {
200                let v = f32::decode_with_format(format, &mut self.reader)?;
201                visitor.visit_f32(v)
202            }
203            Format::Float64 => {
204                let v = f64::decode_with_format(format, &mut self.reader)?;
205                visitor.visit_f64(v)
206            }
207            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
208                use messagepack_core::decode::ReferenceStrDecoder;
209                let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
210                match data {
211                    messagepack_core::decode::ReferenceStr::Borrowed(s) => {
212                        visitor.visit_borrowed_str(s)
213                    }
214                    messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
215                }
216            }
217            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
218                self.decode_seq_with_format(format, visitor)
219            }
220            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
221                use messagepack_core::decode::ReferenceDecoder;
222                let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
223                match data {
224                    messagepack_core::io::Reference::Borrowed(items) => {
225                        visitor.visit_borrowed_bytes(items)
226                    }
227                    messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
228                }
229            }
230            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
231                self.decode_map_with_format(format, visitor)
232            }
233            Format::Ext8
234            | Format::Ext16
235            | Format::Ext32
236            | Format::FixExt1
237            | Format::FixExt2
238            | Format::FixExt4
239            | Format::FixExt8
240            | Format::FixExt16 => {
241                let mut de_ext =
242                    crate::extension::de::DeserializeExt::new(format, &mut self.reader)?;
243                let val = de::Deserializer::deserialize_newtype_struct(
244                    &mut de_ext,
245                    crate::extension::EXTENSION_STRUCT_NAME,
246                    visitor,
247                )?;
248
249                Ok(val)
250            }
251            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
252        }
253    }
254
255    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
256    where
257        V: de::Visitor<'de>,
258    {
259        let format = self.decode_format()?;
260        match format {
261            Format::Nil => visitor.visit_none(),
262            _ => {
263                self.format = Some(format);
264                visitor.visit_some(self.as_mut())
265            }
266        }
267    }
268
269    fn deserialize_enum<V>(
270        self,
271        _name: &'static str,
272        _variants: &'static [&'static str],
273        visitor: V,
274    ) -> Result<V::Value, Self::Error>
275    where
276        V: de::Visitor<'de>,
277    {
278        let format = self.decode_format()?;
279        match format {
280            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
281                let s = <&str>::decode_with_format(format, &mut self.reader)?;
282                visitor.visit_enum(s.into_deserializer())
283            }
284            Format::FixMap(_)
285            | Format::Map16
286            | Format::Map32
287            | Format::FixArray(_)
288            | Format::Array16
289            | Format::Array32 => {
290                let enum_access = enum_::Enum::new(self);
291                visitor.visit_enum(enum_access)
292            }
293            _ => Err(CoreError::UnexpectedFormat.into()),
294        }
295    }
296
297    forward_to_deserialize_any! {
298        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
299        bytes byte_buf unit unit_struct newtype_struct seq tuple
300        tuple_struct map struct identifier ignored_any
301    }
302
303    fn is_human_readable(&self) -> bool {
304        false
305    }
306}
307
308#[cfg(test)]
309mod tests {
310    use rstest::rstest;
311
312    use super::*;
313    use serde::de::IgnoredAny;
314
315    #[rstest]
316    #[case([0xc3],true)]
317    #[case([0xc2],false)]
318    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
319        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
320        assert_eq!(decoded, expected);
321    }
322
323    #[rstest]
324    #[case([0x05],5)]
325    #[case([0xcc, 0x80],128)]
326    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
327        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
328        assert_eq!(decoded, expected);
329    }
330
331    #[test]
332    fn decode_float_vec() {
333        // [1.1,1.2,1.3,1.4,1.5]
334        let buf = [
335            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
336            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
337            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
338            0x00, 0x00, 0x00, 0x00,
339        ];
340
341        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
342        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
343
344        assert_eq!(decoded, expected)
345    }
346
347    #[test]
348    fn decode_struct() {
349        #[derive(Deserialize)]
350        struct S {
351            compact: bool,
352            schema: u8,
353        }
354
355        // {"super":1,"schema":0}
356        let buf: &[u8] = &[
357            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
358            0x65, 0x6d, 0x61, 0x00,
359        ];
360
361        let decoded = from_slice::<S>(buf).unwrap();
362        assert!(decoded.compact);
363        assert_eq!(decoded.schema, 0);
364    }
365
366    #[test]
367    fn decode_struct_from_array() {
368        #[derive(Deserialize, Debug, PartialEq)]
369        struct S {
370            compact: bool,
371            schema: u8,
372        }
373
374        // [true, 0] where fields are in declaration order
375        let buf: &[u8] = &[0x92, 0xc3, 0x00];
376
377        let decoded = from_slice::<S>(buf).unwrap();
378        assert_eq!(
379            decoded,
380            S {
381                compact: true,
382                schema: 0
383            }
384        );
385    }
386
387    #[test]
388    fn option_consumes_nil_in_sequence() {
389        // [None, 5] as an array of two elements
390        let buf: &[u8] = &[0x92, 0xc0, 0x05];
391
392        let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
393        assert_eq!(decoded, (None, 5));
394    }
395
396    #[test]
397    fn option_some_simple() {
398        let buf: &[u8] = &[0x05];
399        let decoded = from_slice::<Option<u8>>(buf).unwrap();
400        assert_eq!(decoded, Some(5));
401    }
402
403    #[test]
404    fn unit_from_nil() {
405        let buf: &[u8] = &[0xc0];
406        from_slice::<()>(buf).unwrap();
407    }
408
409    #[test]
410    fn unit_struct() {
411        #[derive(Debug, Deserialize, PartialEq)]
412        struct U;
413
414        let buf: &[u8] = &[0xc0];
415        let decoded = from_slice::<U>(buf).unwrap();
416        assert_eq!(decoded, U);
417    }
418
419    #[derive(Deserialize, PartialEq, Debug)]
420    enum E {
421        Unit,
422        Newtype(u8),
423        Tuple(u8, bool),
424        Struct { a: bool },
425    }
426    #[rstest]
427    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
428    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
429    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
430    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
431    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
432        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
433        assert_eq!(decoded, expected);
434    }
435
436    #[derive(Deserialize, PartialEq, Debug)]
437    #[serde(untagged)]
438    enum Untagged {
439        Bool(bool),
440        U8(u8),
441        Pair(u8, bool),
442        Struct { a: bool },
443        Nested(E),
444    }
445
446    #[rstest]
447    #[case([0xc3],Untagged::Bool(true))]
448    #[case([0x05],Untagged::U8(5))]
449    #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
450    #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
451    #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] // "Unit"
452    fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
453        let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
454        assert_eq!(decoded, expected);
455    }
456
457    #[test]
458    fn recursion_limit_ok_at_256() {
459        // [[[[...]]]] 256 nested array
460        let mut buf = vec![0x91u8; 256];
461        buf.push(0xc0);
462
463        let _ = from_slice::<IgnoredAny>(&buf).unwrap();
464    }
465
466    #[test]
467    fn recursion_limit_err_over_256() {
468        // [[[[...]]]] 257 nested array
469        let mut buf = vec![0x91u8; 257];
470        buf.push(0xc0);
471
472        let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
473        assert!(matches!(err, Error::RecursionLimitExceeded));
474    }
475
476    #[cfg(feature = "std")]
477    #[rstest]
478    // nil -> unit
479    #[case([0xc0],())]
480    // bool
481    #[case([0xc3],true)]
482    #[case([0xc2],false)]
483    // positive integers (fixint/uint*)
484    #[case([0x2a],42u8)]
485    #[case([0xcc, 0x80],128u8)]
486    #[case([0xcd, 0x01, 0x00],256u16)]
487    #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
488    #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
489    // negative integers (fixint/int*)
490    #[case([0xff],-1i8)]
491    #[case([0xd0, 0x80],-128i8)]
492    #[case([0xd1, 0x80, 0x00],-32768i16)]
493    #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
494    #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
495    // floats
496    #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
497    #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
498    // strings (fixstr/str8)
499    #[case([0xa1, 0x61],"a".to_string())]
500    #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
501    // binary (bin8) `bin` family need like `serde_bytes`
502    #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
503    // array (fixarray)
504    #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
505    // map (fixmap) with 2 entries: {"a":1, "b":2}
506    #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
507        let mut m = std::collections::BTreeMap::<String, u8>::new();
508        m.insert("a".to_string(), 1u8);
509        m.insert("b".to_string(), 2u8);
510        m
511    })]
512    fn decode_success_from_reader_when_owned<
513        Buf: AsRef<[u8]>,
514        T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
515    >(
516        #[case] buf: Buf,
517        #[case] expected: T,
518    ) {
519        use super::from_reader;
520        let mut reader = std::io::Cursor::new(buf.as_ref());
521        let val = from_reader::<_, T>(&mut reader).unwrap();
522        assert_eq!(val, expected)
523    }
524}