Skip to main content

messagepack_serde/de/
mod.rs

1//! Deserialize support for messagepack
2
3mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use messagepack_core::{
10    Decode, Format,
11    decode::NbyteReader,
12    io::{IoRead, RError},
13};
14use serde::{
15    Deserialize,
16    de::{self, IntoDeserializer},
17    forward_to_deserialize_any,
18};
19
20/// Deserialize from [messagepack_core::io::IoRead]
21pub fn from_core_reader<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
22where
23    R: IoRead<'de>,
24    T: Deserialize<'de>,
25{
26    let mut deserializer = Deserializer::new(reader);
27    T::deserialize(&mut deserializer)
28}
29
30/// Deserialize from slice
31#[inline]
32pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
33    use messagepack_core::io::SliceReader;
34    let reader = SliceReader::new(input);
35    from_core_reader(reader)
36}
37
38#[cfg(feature = "std")]
39/// Deserialize from [std::io::Read]
40#[inline]
41pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
42where
43    R: std::io::Read,
44    T: for<'a> Deserialize<'a>,
45{
46    use messagepack_core::io::StdReader;
47    use std::io;
48    let reader = StdReader::new(reader);
49    let result = from_core_reader::<'_, StdReader<R>, T>(reader);
50    match result {
51        Ok(v) => Ok(v),
52        Err(err) => match err {
53            Error::Decode(err) => match err {
54                messagepack_core::decode::Error::InvalidData
55                | messagepack_core::decode::Error::UnexpectedFormat => {
56                    Err(io::Error::new(io::ErrorKind::InvalidData, err))
57                }
58                messagepack_core::decode::Error::UnexpectedEof => {
59                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
60                }
61                messagepack_core::decode::Error::Io(e) => Err(e),
62            },
63            _ => Err(io::Error::other(err)),
64        },
65    }
66}
67
68const MAX_RECURSION_DEPTH: usize = 256;
69
70struct Deserializer<R> {
71    reader: R,
72    depth: usize,
73    format: Option<Format>,
74}
75
76impl<'de, R> Deserializer<R>
77where
78    R: IoRead<'de>,
79{
80    fn new(reader: R) -> Self {
81        Deserializer {
82            reader,
83            depth: 0,
84            format: None,
85        }
86    }
87
88    fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
89    where
90        F: FnOnce(&mut Self) -> V,
91    {
92        if self.depth == MAX_RECURSION_DEPTH {
93            return Err(Error::RecursionLimitExceeded);
94        }
95        self.depth += 1;
96        let result = f(self);
97        self.depth -= 1;
98        Ok(result)
99    }
100
101    fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
102        match self.format.take() {
103            Some(v) => Ok(v),
104            None => {
105                let v = Format::decode(&mut self.reader)?;
106                Ok(v)
107            }
108        }
109    }
110
111    fn decode_seq_with_format<V>(
112        &mut self,
113        format: Format,
114        visitor: V,
115    ) -> Result<V::Value, Error<R::Error>>
116    where
117        V: de::Visitor<'de>,
118    {
119        let n = match format {
120            Format::FixArray(n) => n.into(),
121            Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
122            Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
123            _ => return Err(CoreError::UnexpectedFormat.into()),
124        };
125        self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
126    }
127
128    fn decode_map_with_format<V>(
129        &mut self,
130        format: Format,
131        visitor: V,
132    ) -> Result<V::Value, Error<R::Error>>
133    where
134        V: de::Visitor<'de>,
135    {
136        let n = match format {
137            Format::FixMap(n) => n.into(),
138            Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
139            Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
140            _ => return Err(CoreError::UnexpectedFormat.into()),
141        };
142        self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
143    }
144}
145
146impl<R> AsMut<Self> for Deserializer<R> {
147    fn as_mut(&mut self) -> &mut Self {
148        self
149    }
150}
151
152impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
153where
154    R: IoRead<'de>,
155{
156    type Error = Error<R::Error>;
157
158    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
159    where
160        V: de::Visitor<'de>,
161    {
162        let format = self.decode_format()?;
163        match format {
164            Format::Nil => visitor.visit_unit(),
165            Format::False => visitor.visit_bool(false),
166            Format::True => visitor.visit_bool(true),
167            Format::PositiveFixInt(v) => visitor.visit_u8(v),
168            Format::Uint8 => {
169                let v = u8::decode_with_format(format, &mut self.reader)?;
170                visitor.visit_u8(v)
171            }
172            Format::Uint16 => {
173                let v = u16::decode_with_format(format, &mut self.reader)?;
174                visitor.visit_u16(v)
175            }
176            Format::Uint32 => {
177                let v = u32::decode_with_format(format, &mut self.reader)?;
178                visitor.visit_u32(v)
179            }
180            Format::Uint64 => {
181                let v = u64::decode_with_format(format, &mut self.reader)?;
182                visitor.visit_u64(v)
183            }
184            Format::NegativeFixInt(v) => visitor.visit_i8(v),
185            Format::Int8 => {
186                let v = i8::decode_with_format(format, &mut self.reader)?;
187                visitor.visit_i8(v)
188            }
189            Format::Int16 => {
190                let v = i16::decode_with_format(format, &mut self.reader)?;
191                visitor.visit_i16(v)
192            }
193            Format::Int32 => {
194                let v = i32::decode_with_format(format, &mut self.reader)?;
195                visitor.visit_i32(v)
196            }
197            Format::Int64 => {
198                let v = i64::decode_with_format(format, &mut self.reader)?;
199                visitor.visit_i64(v)
200            }
201            Format::Float32 => {
202                let v = f32::decode_with_format(format, &mut self.reader)?;
203                visitor.visit_f32(v)
204            }
205            Format::Float64 => {
206                let v = f64::decode_with_format(format, &mut self.reader)?;
207                visitor.visit_f64(v)
208            }
209            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
210                use messagepack_core::decode::ReferenceStrDecoder;
211                let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
212                match data {
213                    messagepack_core::decode::ReferenceStr::Borrowed(s) => {
214                        visitor.visit_borrowed_str(s)
215                    }
216                    messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
217                }
218            }
219            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
220                self.decode_seq_with_format(format, visitor)
221            }
222            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
223                use messagepack_core::decode::ReferenceDecoder;
224                let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
225                match data {
226                    messagepack_core::io::Reference::Borrowed(items) => {
227                        visitor.visit_borrowed_bytes(items)
228                    }
229                    messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
230                }
231            }
232            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
233                self.decode_map_with_format(format, visitor)
234            }
235            Format::Ext8
236            | Format::Ext16
237            | Format::Ext32
238            | Format::FixExt1
239            | Format::FixExt2
240            | Format::FixExt4
241            | Format::FixExt8
242            | Format::FixExt16 => {
243                let mut de_ext =
244                    crate::extension::de::DeserializeExt::new(format, &mut self.reader)?;
245                let val = de::Deserializer::deserialize_newtype_struct(
246                    &mut de_ext,
247                    crate::extension::EXTENSION_STRUCT_NAME,
248                    visitor,
249                )?;
250
251                Ok(val)
252            }
253            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
254        }
255    }
256
257    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
258    where
259        V: de::Visitor<'de>,
260    {
261        let format = self.decode_format()?;
262        match format {
263            Format::Nil => visitor.visit_none(),
264            _ => {
265                self.format = Some(format);
266                visitor.visit_some(self.as_mut())
267            }
268        }
269    }
270
271    fn deserialize_enum<V>(
272        self,
273        _name: &'static str,
274        _variants: &'static [&'static str],
275        visitor: V,
276    ) -> Result<V::Value, Self::Error>
277    where
278        V: de::Visitor<'de>,
279    {
280        let format = self.decode_format()?;
281        match format {
282            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
283                let s = <&str>::decode_with_format(format, &mut self.reader)?;
284                visitor.visit_enum(s.into_deserializer())
285            }
286            Format::FixMap(_)
287            | Format::Map16
288            | Format::Map32
289            | Format::FixArray(_)
290            | Format::Array16
291            | Format::Array32 => {
292                let enum_access = enum_::Enum::new(self);
293                visitor.visit_enum(enum_access)
294            }
295            _ => Err(CoreError::UnexpectedFormat.into()),
296        }
297    }
298
299    forward_to_deserialize_any! {
300        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
301        bytes byte_buf unit unit_struct newtype_struct seq tuple
302        tuple_struct map struct identifier ignored_any
303    }
304
305    fn is_human_readable(&self) -> bool {
306        false
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use rstest::rstest;
313
314    use super::*;
315    use serde::de::IgnoredAny;
316
317    #[rstest]
318    #[case([0xc3],true)]
319    #[case([0xc2],false)]
320    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
321        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
322        assert_eq!(decoded, expected);
323    }
324
325    #[rstest]
326    #[case([0x05],5)]
327    #[case([0xcc, 0x80],128)]
328    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
329        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
330        assert_eq!(decoded, expected);
331    }
332
333    #[test]
334    fn decode_float_vec() {
335        // [1.1,1.2,1.3,1.4,1.5]
336        let buf = [
337            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
338            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
339            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
340            0x00, 0x00, 0x00, 0x00,
341        ];
342
343        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
344        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
345
346        assert_eq!(decoded, expected)
347    }
348
349    #[test]
350    fn decode_struct() {
351        #[derive(Deserialize)]
352        struct S {
353            compact: bool,
354            schema: u8,
355        }
356
357        // {"super":1,"schema":0}
358        let buf: &[u8] = &[
359            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
360            0x65, 0x6d, 0x61, 0x00,
361        ];
362
363        let decoded = from_slice::<S>(buf).unwrap();
364        assert!(decoded.compact);
365        assert_eq!(decoded.schema, 0);
366    }
367
368    #[test]
369    fn decode_struct_from_array() {
370        #[derive(Deserialize, Debug, PartialEq)]
371        struct S {
372            compact: bool,
373            schema: u8,
374        }
375
376        // [true, 0] where fields are in declaration order
377        let buf: &[u8] = &[0x92, 0xc3, 0x00];
378
379        let decoded = from_slice::<S>(buf).unwrap();
380        assert_eq!(
381            decoded,
382            S {
383                compact: true,
384                schema: 0
385            }
386        );
387    }
388
389    #[test]
390    fn option_consumes_nil_in_sequence() {
391        // [None, 5] as an array of two elements
392        let buf: &[u8] = &[0x92, 0xc0, 0x05];
393
394        let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
395        assert_eq!(decoded, (None, 5));
396    }
397
398    #[test]
399    fn option_some_simple() {
400        let buf: &[u8] = &[0x05];
401        let decoded = from_slice::<Option<u8>>(buf).unwrap();
402        assert_eq!(decoded, Some(5));
403    }
404
405    #[test]
406    fn unit_from_nil() {
407        let buf: &[u8] = &[0xc0];
408        from_slice::<()>(buf).unwrap();
409    }
410
411    #[test]
412    fn unit_struct() {
413        #[derive(Debug, Deserialize, PartialEq)]
414        struct U;
415
416        let buf: &[u8] = &[0xc0];
417        let decoded = from_slice::<U>(buf).unwrap();
418        assert_eq!(decoded, U);
419    }
420
421    #[derive(Deserialize, PartialEq, Debug)]
422    enum E {
423        Unit,
424        Newtype(u8),
425        Tuple(u8, bool),
426        Struct { a: bool },
427    }
428    #[rstest]
429    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
430    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
431    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
432    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
433    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
434        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
435        assert_eq!(decoded, expected);
436    }
437
438    #[derive(Deserialize, PartialEq, Debug)]
439    #[serde(untagged)]
440    enum Untagged {
441        Bool(bool),
442        U8(u8),
443        Pair(u8, bool),
444        Struct { a: bool },
445        Nested(E),
446    }
447
448    #[rstest]
449    #[case([0xc3],Untagged::Bool(true))]
450    #[case([0x05],Untagged::U8(5))]
451    #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
452    #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
453    #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] // "Unit"
454    fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
455        let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
456        assert_eq!(decoded, expected);
457    }
458
459    #[test]
460    fn recursion_limit_ok_at_256() {
461        // [[[[...]]]] 256 nested array
462        let mut buf = vec![0x91u8; 256];
463        buf.push(0xc0);
464
465        let _ = from_slice::<IgnoredAny>(&buf).unwrap();
466    }
467
468    #[test]
469    fn recursion_limit_err_over_256() {
470        // [[[[...]]]] 257 nested array
471        let mut buf = vec![0x91u8; 257];
472        buf.push(0xc0);
473
474        let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
475        assert!(matches!(err, Error::RecursionLimitExceeded));
476    }
477
478    #[cfg(feature = "std")]
479    #[rstest]
480    // nil -> unit
481    #[case([0xc0],())]
482    // bool
483    #[case([0xc3],true)]
484    #[case([0xc2],false)]
485    // positive integers (fixint/uint*)
486    #[case([0x2a],42u8)]
487    #[case([0xcc, 0x80],128u8)]
488    #[case([0xcd, 0x01, 0x00],256u16)]
489    #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
490    #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
491    // negative integers (fixint/int*)
492    #[case([0xff],-1i8)]
493    #[case([0xd0, 0x80],-128i8)]
494    #[case([0xd1, 0x80, 0x00],-32768i16)]
495    #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
496    #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
497    // floats
498    #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
499    #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
500    // strings (fixstr/str8)
501    #[case([0xa1, 0x61],"a".to_string())]
502    #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
503    // binary (bin8) `bin` family need like `serde_bytes`
504    #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
505    // array (fixarray)
506    #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
507    // map (fixmap) with 2 entries: {"a":1, "b":2}
508    #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
509        let mut m = std::collections::BTreeMap::<String, u8>::new();
510        m.insert("a".to_string(), 1u8);
511        m.insert("b".to_string(), 2u8);
512        m
513    })]
514    fn decode_success_from_reader_when_owned<
515        Buf: AsRef<[u8]>,
516        T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
517    >(
518        #[case] buf: Buf,
519        #[case] expected: T,
520    ) {
521        use super::from_reader;
522        let mut reader = std::io::Cursor::new(buf.as_ref());
523        let val = from_reader::<_, T>(&mut reader).unwrap();
524        assert_eq!(val, expected)
525    }
526}