messagepack_serde/de/
mod.rs

1//! Deserialize support for messagepack
2
3mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use crate::value::extension::DeserializeExt;
10use messagepack_core::{
11    Decode, Format,
12    decode::NbyteReader,
13    io::{IoRead, RError},
14};
15use serde::{
16    Deserialize,
17    de::{self, IntoDeserializer},
18    forward_to_deserialize_any,
19};
20
21/// Deserialize from slice
22pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
23    use messagepack_core::io::SliceReader;
24    let reader = SliceReader::new(input);
25    from_trait(reader)
26}
27
28#[cfg(feature = "std")]
29/// Deserialize from [std::io::Read]
30pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
31where
32    R: std::io::Read,
33    T: for<'a> Deserialize<'a>,
34{
35    use messagepack_core::io::StdReader;
36    use std::io;
37    let reader = StdReader::new(reader);
38    let result = from_trait::<'_, StdReader<R>, T>(reader);
39    match result {
40        Ok(v) => Ok(v),
41        Err(err) => match err {
42            Error::Decode(err) => match err {
43                messagepack_core::decode::Error::InvalidData
44                | messagepack_core::decode::Error::UnexpectedFormat => {
45                    Err(io::Error::new(io::ErrorKind::InvalidData, err))
46                }
47                messagepack_core::decode::Error::UnexpectedEof => {
48                    Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
49                }
50                messagepack_core::decode::Error::Io(e) => Err(e),
51            },
52            _ => Err(io::Error::other(err)),
53        },
54    }
55}
56
57fn from_trait<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
58where
59    R: IoRead<'de>,
60    T: Deserialize<'de>,
61{
62    let mut deserializer = Deserializer::from_trait(reader);
63    T::deserialize(&mut deserializer)
64}
65
66const MAX_RECURSION_DEPTH: usize = 256;
67
68struct Deserializer<R> {
69    reader: R,
70    depth: usize,
71    format: Option<Format>,
72}
73
74impl<'de, R> Deserializer<R>
75where
76    R: IoRead<'de>,
77{
78    pub fn from_trait(reader: R) -> Self {
79        Deserializer {
80            reader,
81            depth: 0,
82            format: None,
83        }
84    }
85
86    fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
87    where
88        F: FnOnce(&mut Self) -> V,
89    {
90        if self.depth == MAX_RECURSION_DEPTH {
91            return Err(Error::RecursionLimitExceeded);
92        }
93        self.depth += 1;
94        let result = f(self);
95        self.depth -= 1;
96        Ok(result)
97    }
98
99    fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
100        match self.format.take() {
101            Some(v) => Ok(v),
102            None => {
103                let v = Format::decode(&mut self.reader)?;
104                Ok(v)
105            }
106        }
107    }
108
109    fn decode_seq_with_format<V>(
110        &mut self,
111        format: Format,
112        visitor: V,
113    ) -> Result<V::Value, Error<R::Error>>
114    where
115        V: de::Visitor<'de>,
116    {
117        let n = match format {
118            Format::FixArray(n) => n.into(),
119            Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
120            Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
121            _ => return Err(CoreError::UnexpectedFormat.into()),
122        };
123        self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
124    }
125
126    fn decode_map_with_format<V>(
127        &mut self,
128        format: Format,
129        visitor: V,
130    ) -> Result<V::Value, Error<R::Error>>
131    where
132        V: de::Visitor<'de>,
133    {
134        let n = match format {
135            Format::FixMap(n) => n.into(),
136            Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
137            Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
138            _ => return Err(CoreError::UnexpectedFormat.into()),
139        };
140        self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
141    }
142}
143
144impl<R> AsMut<Self> for Deserializer<R> {
145    fn as_mut(&mut self) -> &mut Self {
146        self
147    }
148}
149
150impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
151where
152    R: IoRead<'de>,
153{
154    type Error = Error<R::Error>;
155
156    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
157    where
158        V: de::Visitor<'de>,
159    {
160        let format = self.decode_format()?;
161        match format {
162            Format::Nil => visitor.visit_unit(),
163            Format::False => visitor.visit_bool(false),
164            Format::True => visitor.visit_bool(true),
165            Format::PositiveFixInt(v) => visitor.visit_u8(v),
166            Format::Uint8 => {
167                let v = u8::decode_with_format(format, &mut self.reader)?;
168                visitor.visit_u8(v)
169            }
170            Format::Uint16 => {
171                let v = u16::decode_with_format(format, &mut self.reader)?;
172                visitor.visit_u16(v)
173            }
174            Format::Uint32 => {
175                let v = u32::decode_with_format(format, &mut self.reader)?;
176                visitor.visit_u32(v)
177            }
178            Format::Uint64 => {
179                let v = u64::decode_with_format(format, &mut self.reader)?;
180                visitor.visit_u64(v)
181            }
182            Format::NegativeFixInt(v) => visitor.visit_i8(v),
183            Format::Int8 => {
184                let v = i8::decode_with_format(format, &mut self.reader)?;
185                visitor.visit_i8(v)
186            }
187            Format::Int16 => {
188                let v = i16::decode_with_format(format, &mut self.reader)?;
189                visitor.visit_i16(v)
190            }
191            Format::Int32 => {
192                let v = i32::decode_with_format(format, &mut self.reader)?;
193                visitor.visit_i32(v)
194            }
195            Format::Int64 => {
196                let v = i64::decode_with_format(format, &mut self.reader)?;
197                visitor.visit_i64(v)
198            }
199            Format::Float32 => {
200                let v = f32::decode_with_format(format, &mut self.reader)?;
201                visitor.visit_f32(v)
202            }
203            Format::Float64 => {
204                let v = f64::decode_with_format(format, &mut self.reader)?;
205                visitor.visit_f64(v)
206            }
207            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
208                use messagepack_core::decode::ReferenceStrDecoder;
209                let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
210                match data {
211                    messagepack_core::decode::ReferenceStr::Borrowed(s) => {
212                        visitor.visit_borrowed_str(s)
213                    }
214                    messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
215                }
216            }
217            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
218                self.decode_seq_with_format(format, visitor)
219            }
220            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
221                use messagepack_core::decode::ReferenceDecoder;
222                let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
223                match data {
224                    messagepack_core::io::Reference::Borrowed(items) => {
225                        visitor.visit_borrowed_bytes(items)
226                    }
227                    messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
228                }
229            }
230            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
231                self.decode_map_with_format(format, visitor)
232            }
233            Format::Ext8
234            | Format::Ext16
235            | Format::Ext32
236            | Format::FixExt1
237            | Format::FixExt2
238            | Format::FixExt4
239            | Format::FixExt8
240            | Format::FixExt16 => {
241                let mut de_ext = DeserializeExt::new(format, &mut self.reader)?;
242                let val = de::Deserializer::deserialize_newtype_struct(
243                    &mut de_ext,
244                    crate::value::extension::EXTENSION_STRUCT_NAME,
245                    visitor,
246                )?;
247
248                Ok(val)
249            }
250            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
251        }
252    }
253
254    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
255    where
256        V: de::Visitor<'de>,
257    {
258        let format = self.decode_format()?;
259        match format {
260            Format::Nil => visitor.visit_none(),
261            _ => {
262                self.format = Some(format);
263                visitor.visit_some(self.as_mut())
264            }
265        }
266    }
267
268    fn deserialize_enum<V>(
269        self,
270        _name: &'static str,
271        _variants: &'static [&'static str],
272        visitor: V,
273    ) -> Result<V::Value, Self::Error>
274    where
275        V: de::Visitor<'de>,
276    {
277        let format = self.decode_format()?;
278        match format {
279            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
280                let s = <&str>::decode_with_format(format, &mut self.reader)?;
281                visitor.visit_enum(s.into_deserializer())
282            }
283            Format::FixMap(_)
284            | Format::Map16
285            | Format::Map32
286            | Format::FixArray(_)
287            | Format::Array16
288            | Format::Array32 => {
289                let enum_access = enum_::Enum::new(self);
290                visitor.visit_enum(enum_access)
291            }
292            _ => Err(CoreError::UnexpectedFormat.into()),
293        }
294    }
295
296    forward_to_deserialize_any! {
297        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
298        bytes byte_buf unit unit_struct newtype_struct seq tuple
299        tuple_struct map struct identifier ignored_any
300    }
301
302    fn is_human_readable(&self) -> bool {
303        false
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use rstest::rstest;
310
311    use super::*;
312    use serde::de::IgnoredAny;
313
314    #[rstest]
315    #[case([0xc3],true)]
316    #[case([0xc2],false)]
317    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
318        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
319        assert_eq!(decoded, expected);
320    }
321
322    #[rstest]
323    #[case([0x05],5)]
324    #[case([0xcc, 0x80],128)]
325    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
326        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
327        assert_eq!(decoded, expected);
328    }
329
330    #[test]
331    fn decode_float_vec() {
332        // [1.1,1.2,1.3,1.4,1.5]
333        let buf = [
334            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
335            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
336            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
337            0x00, 0x00, 0x00, 0x00,
338        ];
339
340        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
341        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
342
343        assert_eq!(decoded, expected)
344    }
345
346    #[test]
347    fn decode_struct() {
348        #[derive(Deserialize)]
349        struct S {
350            compact: bool,
351            schema: u8,
352        }
353
354        // {"super":1,"schema":0}
355        let buf: &[u8] = &[
356            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
357            0x65, 0x6d, 0x61, 0x00,
358        ];
359
360        let decoded = from_slice::<S>(buf).unwrap();
361        assert!(decoded.compact);
362        assert_eq!(decoded.schema, 0);
363    }
364
365    #[test]
366    fn decode_struct_from_array() {
367        #[derive(Deserialize, Debug, PartialEq)]
368        struct S {
369            compact: bool,
370            schema: u8,
371        }
372
373        // [true, 0] where fields are in declaration order
374        let buf: &[u8] = &[0x92, 0xc3, 0x00];
375
376        let decoded = from_slice::<S>(buf).unwrap();
377        assert_eq!(
378            decoded,
379            S {
380                compact: true,
381                schema: 0
382            }
383        );
384    }
385
386    #[test]
387    fn option_consumes_nil_in_sequence() {
388        // [None, 5] as an array of two elements
389        let buf: &[u8] = &[0x92, 0xc0, 0x05];
390
391        let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
392        assert_eq!(decoded, (None, 5));
393    }
394
395    #[test]
396    fn option_some_simple() {
397        let buf: &[u8] = &[0x05];
398        let decoded = from_slice::<Option<u8>>(buf).unwrap();
399        assert_eq!(decoded, Some(5));
400    }
401
402    #[test]
403    fn unit_from_nil() {
404        let buf: &[u8] = &[0xc0];
405        from_slice::<()>(buf).unwrap();
406    }
407
408    #[test]
409    fn unit_struct() {
410        #[derive(Debug, Deserialize, PartialEq)]
411        struct U;
412
413        let buf: &[u8] = &[0xc0];
414        let decoded = from_slice::<U>(buf).unwrap();
415        assert_eq!(decoded, U);
416    }
417
418    #[derive(Deserialize, PartialEq, Debug)]
419    enum E {
420        Unit,
421        Newtype(u8),
422        Tuple(u8, bool),
423        Struct { a: bool },
424    }
425    #[rstest]
426    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
427    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
428    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
429    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
430    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
431        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
432        assert_eq!(decoded, expected);
433    }
434
435    #[derive(Deserialize, PartialEq, Debug)]
436    #[serde(untagged)]
437    enum Untagged {
438        Bool(bool),
439        U8(u8),
440        Pair(u8, bool),
441        Struct { a: bool },
442        Nested(E),
443    }
444
445    #[rstest]
446    #[case([0xc3],Untagged::Bool(true))]
447    #[case([0x05],Untagged::U8(5))]
448    #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
449    #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
450    #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] // "Unit"
451    fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
452        let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
453        assert_eq!(decoded, expected);
454    }
455
456    #[test]
457    fn recursion_limit_ok_at_256() {
458        // [[[[...]]]] 256 nested array
459        let mut buf = vec![0x91u8; 256];
460        buf.push(0xc0);
461
462        let _ = from_slice::<IgnoredAny>(&buf).unwrap();
463    }
464
465    #[test]
466    fn recursion_limit_err_over_256() {
467        // [[[[...]]]] 257 nested array
468        let mut buf = vec![0x91u8; 257];
469        buf.push(0xc0);
470
471        let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
472        assert!(matches!(err, Error::RecursionLimitExceeded));
473    }
474
475    #[cfg(feature = "std")]
476    #[rstest]
477    // nil -> unit
478    #[case([0xc0],())]
479    // bool
480    #[case([0xc3],true)]
481    #[case([0xc2],false)]
482    // positive integers (fixint/uint*)
483    #[case([0x2a],42u8)]
484    #[case([0xcc, 0x80],128u8)]
485    #[case([0xcd, 0x01, 0x00],256u16)]
486    #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
487    #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
488    // negative integers (fixint/int*)
489    #[case([0xff],-1i8)]
490    #[case([0xd0, 0x80],-128i8)]
491    #[case([0xd1, 0x80, 0x00],-32768i16)]
492    #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
493    #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
494    // floats
495    #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
496    #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
497    // strings (fixstr/str8)
498    #[case([0xa1, 0x61],"a".to_string())]
499    #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
500    // binary (bin8) `bin` family need like `serde_bytes`
501    #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
502    // array (fixarray)
503    #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
504    // map (fixmap) with 2 entries: {"a":1, "b":2}
505    #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
506        let mut m = std::collections::BTreeMap::<String, u8>::new();
507        m.insert("a".to_string(), 1u8);
508        m.insert("b".to_string(), 2u8);
509        m
510    })]
511    fn decode_success_from_reader_when_owned<
512        Buf: AsRef<[u8]>,
513        T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
514    >(
515        #[case] buf: Buf,
516        #[case] expected: T,
517    ) {
518        use super::from_reader;
519        let mut reader = std::io::Cursor::new(buf.as_ref());
520        let val = from_reader::<_, T>(&mut reader).unwrap();
521        assert_eq!(val, expected)
522    }
523}