messagepack_serde/de/
mod.rs

1//! Deserialize support for messagepack
2
3mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use messagepack_core::{
10    Decode, Format,
11    decode::NbyteReader,
12    io::{IoRead, RError},
13};
14use serde::{
15    Deserialize,
16    de::{self, IntoDeserializer},
17    forward_to_deserialize_any,
18};
19
20/// Deserialize from slice
21pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
22    use messagepack_core::io::SliceReader;
23    let reader = SliceReader::new(input);
24    from_trait(reader)
25}
26
27#[cfg(feature = "std")]
28/// Deserialize from [std::io::Read]
29pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
30where
31    R: std::io::Read,
32    T: for<'a> Deserialize<'a>,
33{
34    use messagepack_core::io::StdReader;
35    use std::io;
36    let reader = StdReader::new(reader);
37    let result = from_trait::<'_, StdReader<R>, T>(reader);
38    match result {
39        Ok(v) => Ok(v),
40        Err(err) => match err {
41            Error::Decode(err) => match err {
42                messagepack_core::decode::Error::InvalidData
43                | messagepack_core::decode::Error::UnexpectedFormat => {
44                    Err(io::Error::new(io::ErrorKind::InvalidData, err))
45                }
46                messagepack_core::decode::Error::UnexpectedEof => {
47                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
48                }
49                messagepack_core::decode::Error::Io(e) => Err(e),
50            },
51            _ => Err(io::Error::other(err)),
52        },
53    }
54}
55
56fn from_trait<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
57where
58    R: IoRead<'de>,
59    T: Deserialize<'de>,
60{
61    let mut deserializer = Deserializer::from_trait(reader);
62    T::deserialize(&mut deserializer)
63}
64
65const MAX_RECURSION_DEPTH: usize = 256;
66
67struct Deserializer<R> {
68    reader: R,
69    depth: usize,
70    format: Option<Format>,
71}
72
73impl<'de, R> Deserializer<R>
74where
75    R: IoRead<'de>,
76{
77    pub fn from_trait(reader: R) -> Self {
78        Deserializer {
79            reader,
80            depth: 0,
81            format: None,
82        }
83    }
84
85    fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
86    where
87        F: FnOnce(&mut Self) -> V,
88    {
89        if self.depth == MAX_RECURSION_DEPTH {
90            return Err(Error::RecursionLimitExceeded);
91        }
92        self.depth += 1;
93        let result = f(self);
94        self.depth -= 1;
95        Ok(result)
96    }
97
98    fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
99        match self.format.take() {
100            Some(v) => Ok(v),
101            None => {
102                let v = Format::decode(&mut self.reader)?;
103                Ok(v)
104            }
105        }
106    }
107
108    fn decode_seq_with_format<V>(
109        &mut self,
110        format: Format,
111        visitor: V,
112    ) -> Result<V::Value, Error<R::Error>>
113    where
114        V: de::Visitor<'de>,
115    {
116        let n = match format {
117            Format::FixArray(n) => n.into(),
118            Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
119            Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
120            _ => return Err(CoreError::UnexpectedFormat.into()),
121        };
122        self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
123    }
124
125    fn decode_map_with_format<V>(
126        &mut self,
127        format: Format,
128        visitor: V,
129    ) -> Result<V::Value, Error<R::Error>>
130    where
131        V: de::Visitor<'de>,
132    {
133        let n = match format {
134            Format::FixMap(n) => n.into(),
135            Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
136            Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
137            _ => return Err(CoreError::UnexpectedFormat.into()),
138        };
139        self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
140    }
141}
142
143impl<R> AsMut<Self> for Deserializer<R> {
144    fn as_mut(&mut self) -> &mut Self {
145        self
146    }
147}
148
149impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
150where
151    R: IoRead<'de>,
152{
153    type Error = Error<R::Error>;
154
155    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
156    where
157        V: de::Visitor<'de>,
158    {
159        let format = self.decode_format()?;
160        match format {
161            Format::Nil => visitor.visit_unit(),
162            Format::False => visitor.visit_bool(false),
163            Format::True => visitor.visit_bool(true),
164            Format::PositiveFixInt(v) => visitor.visit_u8(v),
165            Format::Uint8 => {
166                let v = u8::decode_with_format(format, &mut self.reader)?;
167                visitor.visit_u8(v)
168            }
169            Format::Uint16 => {
170                let v = u16::decode_with_format(format, &mut self.reader)?;
171                visitor.visit_u16(v)
172            }
173            Format::Uint32 => {
174                let v = u32::decode_with_format(format, &mut self.reader)?;
175                visitor.visit_u32(v)
176            }
177            Format::Uint64 => {
178                let v = u64::decode_with_format(format, &mut self.reader)?;
179                visitor.visit_u64(v)
180            }
181            Format::NegativeFixInt(v) => visitor.visit_i8(v),
182            Format::Int8 => {
183                let v = i8::decode_with_format(format, &mut self.reader)?;
184                visitor.visit_i8(v)
185            }
186            Format::Int16 => {
187                let v = i16::decode_with_format(format, &mut self.reader)?;
188                visitor.visit_i16(v)
189            }
190            Format::Int32 => {
191                let v = i32::decode_with_format(format, &mut self.reader)?;
192                visitor.visit_i32(v)
193            }
194            Format::Int64 => {
195                let v = i64::decode_with_format(format, &mut self.reader)?;
196                visitor.visit_i64(v)
197            }
198            Format::Float32 => {
199                let v = f32::decode_with_format(format, &mut self.reader)?;
200                visitor.visit_f32(v)
201            }
202            Format::Float64 => {
203                let v = f64::decode_with_format(format, &mut self.reader)?;
204                visitor.visit_f64(v)
205            }
206            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
207                use messagepack_core::decode::ReferenceStrDecoder;
208                let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
209                match data {
210                    messagepack_core::decode::ReferenceStr::Borrowed(s) => {
211                        visitor.visit_borrowed_str(s)
212                    }
213                    messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
214                }
215            }
216            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
217                self.decode_seq_with_format(format, visitor)
218            }
219            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
220                use messagepack_core::decode::ReferenceDecoder;
221                let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
222                match data {
223                    messagepack_core::io::Reference::Borrowed(items) => {
224                        visitor.visit_borrowed_bytes(items)
225                    }
226                    messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
227                }
228            }
229            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
230                self.decode_map_with_format(format, visitor)
231            }
232            Format::Ext8
233            | Format::Ext16
234            | Format::Ext32
235            | Format::FixExt1
236            | Format::FixExt2
237            | Format::FixExt4
238            | Format::FixExt8
239            | Format::FixExt16 => {
240                let mut de_ext =
241                    crate::extension::de::DeserializeExt::new(format, &mut self.reader)?;
242                let val = de::Deserializer::deserialize_newtype_struct(
243                    &mut de_ext,
244                    crate::extension::EXTENSION_STRUCT_NAME,
245                    visitor,
246                )?;
247
248                Ok(val)
249            }
250            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
251        }
252    }
253
254    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
255    where
256        V: de::Visitor<'de>,
257    {
258        let format = self.decode_format()?;
259        match format {
260            Format::Nil => visitor.visit_none(),
261            _ => {
262                self.format = Some(format);
263                visitor.visit_some(self.as_mut())
264            }
265        }
266    }
267
268    fn deserialize_enum<V>(
269        self,
270        _name: &'static str,
271        _variants: &'static [&'static str],
272        visitor: V,
273    ) -> Result<V::Value, Self::Error>
274    where
275        V: de::Visitor<'de>,
276    {
277        let format = self.decode_format()?;
278        match format {
279            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
280                let s = <&str>::decode_with_format(format, &mut self.reader)?;
281                visitor.visit_enum(s.into_deserializer())
282            }
283            Format::FixMap(_)
284            | Format::Map16
285            | Format::Map32
286            | Format::FixArray(_)
287            | Format::Array16
288            | Format::Array32 => {
289                let enum_access = enum_::Enum::new(self);
290                visitor.visit_enum(enum_access)
291            }
292            _ => Err(CoreError::UnexpectedFormat.into()),
293        }
294    }
295
296    forward_to_deserialize_any! {
297        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
298        bytes byte_buf unit unit_struct newtype_struct seq tuple
299        tuple_struct map struct identifier ignored_any
300    }
301
302    fn is_human_readable(&self) -> bool {
303        false
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use rstest::rstest;
310
311    use super::*;
312    use serde::de::IgnoredAny;
313
314    #[rstest]
315    #[case([0xc3],true)]
316    #[case([0xc2],false)]
317    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
318        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
319        assert_eq!(decoded, expected);
320    }
321
322    #[rstest]
323    #[case([0x05],5)]
324    #[case([0xcc, 0x80],128)]
325    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
326        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
327        assert_eq!(decoded, expected);
328    }
329
330    #[test]
331    fn decode_float_vec() {
332        // [1.1,1.2,1.3,1.4,1.5]
333        let buf = [
334            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
335            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
336            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
337            0x00, 0x00, 0x00, 0x00,
338        ];
339
340        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
341        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
342
343        assert_eq!(decoded, expected)
344    }
345
346    #[test]
347    fn decode_struct() {
348        #[derive(Deserialize)]
349        struct S {
350            compact: bool,
351            schema: u8,
352        }
353
354        // {"super":1,"schema":0}
355        let buf: &[u8] = &[
356            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
357            0x65, 0x6d, 0x61, 0x00,
358        ];
359
360        let decoded = from_slice::<S>(buf).unwrap();
361        assert!(decoded.compact);
362        assert_eq!(decoded.schema, 0);
363    }
364
365    #[test]
366    fn decode_struct_from_array() {
367        #[derive(Deserialize, Debug, PartialEq)]
368        struct S {
369            compact: bool,
370            schema: u8,
371        }
372
373        // [true, 0] where fields are in declaration order
374        let buf: &[u8] = &[0x92, 0xc3, 0x00];
375
376        let decoded = from_slice::<S>(buf).unwrap();
377        assert_eq!(
378            decoded,
379            S {
380                compact: true,
381                schema: 0
382            }
383        );
384    }
385
386    #[test]
387    fn option_consumes_nil_in_sequence() {
388        // [None, 5] as an array of two elements
389        let buf: &[u8] = &[0x92, 0xc0, 0x05];
390
391        let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
392        assert_eq!(decoded, (None, 5));
393    }
394
395    #[test]
396    fn option_some_simple() {
397        let buf: &[u8] = &[0x05];
398        let decoded = from_slice::<Option<u8>>(buf).unwrap();
399        assert_eq!(decoded, Some(5));
400    }
401
402    #[test]
403    fn unit_from_nil() {
404        let buf: &[u8] = &[0xc0];
405        from_slice::<()>(buf).unwrap();
406    }
407
408    #[test]
409    fn unit_struct() {
410        #[derive(Debug, Deserialize, PartialEq)]
411        struct U;
412
413        let buf: &[u8] = &[0xc0];
414        let decoded = from_slice::<U>(buf).unwrap();
415        assert_eq!(decoded, U);
416    }
417
418    #[derive(Deserialize, PartialEq, Debug)]
419    enum E {
420        Unit,
421        Newtype(u8),
422        Tuple(u8, bool),
423        Struct { a: bool },
424    }
425    #[rstest]
426    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
427    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
428    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
429    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
430    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
431        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
432        assert_eq!(decoded, expected);
433    }
434
435    #[derive(Deserialize, PartialEq, Debug)]
436    #[serde(untagged)]
437    enum Untagged {
438        Bool(bool),
439        U8(u8),
440        Pair(u8, bool),
441        Struct { a: bool },
442        Nested(E),
443    }
444
445    #[rstest]
446    #[case([0xc3],Untagged::Bool(true))]
447    #[case([0x05],Untagged::U8(5))]
448    #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
449    #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
450    #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] // "Unit"
451    fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
452        let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
453        assert_eq!(decoded, expected);
454    }
455
456    #[test]
457    fn recursion_limit_ok_at_256() {
458        // [[[[...]]]] 256 nested array
459        let mut buf = vec![0x91u8; 256];
460        buf.push(0xc0);
461
462        let _ = from_slice::<IgnoredAny>(&buf).unwrap();
463    }
464
465    #[test]
466    fn recursion_limit_err_over_256() {
467        // [[[[...]]]] 257 nested array
468        let mut buf = vec![0x91u8; 257];
469        buf.push(0xc0);
470
471        let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
472        assert!(matches!(err, Error::RecursionLimitExceeded));
473    }
474
475    #[cfg(feature = "std")]
476    #[rstest]
477    // nil -> unit
478    #[case([0xc0],())]
479    // bool
480    #[case([0xc3],true)]
481    #[case([0xc2],false)]
482    // positive integers (fixint/uint*)
483    #[case([0x2a],42u8)]
484    #[case([0xcc, 0x80],128u8)]
485    #[case([0xcd, 0x01, 0x00],256u16)]
486    #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
487    #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
488    // negative integers (fixint/int*)
489    #[case([0xff],-1i8)]
490    #[case([0xd0, 0x80],-128i8)]
491    #[case([0xd1, 0x80, 0x00],-32768i16)]
492    #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
493    #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
494    // floats
495    #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
496    #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
497    // strings (fixstr/str8)
498    #[case([0xa1, 0x61],"a".to_string())]
499    #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
500    // binary (bin8) `bin` family need like `serde_bytes`
501    #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
502    // array (fixarray)
503    #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
504    // map (fixmap) with 2 entries: {"a":1, "b":2}
505    #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
506        let mut m = std::collections::BTreeMap::<String, u8>::new();
507        m.insert("a".to_string(), 1u8);
508        m.insert("b".to_string(), 2u8);
509        m
510    })]
511    fn decode_success_from_reader_when_owned<
512        Buf: AsRef<[u8]>,
513        T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
514    >(
515        #[case] buf: Buf,
516        #[case] expected: T,
517    ) {
518        use super::from_reader;
519        let mut reader = std::io::Cursor::new(buf.as_ref());
520        let val = from_reader::<_, T>(&mut reader).unwrap();
521        assert_eq!(val, expected)
522    }
523}