messagepack_serde/de/
mod.rs

1//! Deserialize support for messagepack
2
3mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use crate::value::extension::DeserializeExt;
10use messagepack_core::{Decode, Format, decode::NbyteReader};
11use serde::{
12    Deserialize,
13    de::{self, IntoDeserializer},
14    forward_to_deserialize_any,
15};
16
17/// Deserialize from slice
18pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
19    let mut deserializer = Deserializer::from_slice(input);
20    T::deserialize(&mut deserializer)
21}
22
23#[cfg(feature = "std")]
24/// Deserialize from [std::io::Read]
25pub fn from_reader<R, T>(reader: &mut R) -> std::io::Result<T>
26where
27    R: std::io::Read,
28    T: for<'a> Deserialize<'a>,
29{
30    let mut buf = Vec::new();
31    reader.read_to_end(&mut buf)?;
32
33    let mut deserializer = Deserializer::from_slice(&buf);
34    T::deserialize(&mut deserializer).map_err(std::io::Error::other)
35}
36
37const MAX_RECURSION_DEPTH: usize = 256;
38
39#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
40struct Deserializer<'de> {
41    input: &'de [u8],
42    depth: usize,
43}
44
45impl<'de> Deserializer<'de> {
46    pub fn from_slice(input: &'de [u8]) -> Self {
47        Deserializer { input, depth: 0 }
48    }
49
50    fn recurse<F, V>(&mut self, f: F) -> Result<V, Error>
51    where
52        F: FnOnce(&mut Self) -> V,
53    {
54        if self.depth == MAX_RECURSION_DEPTH {
55            return Err(Error::RecursionLimitExceeded);
56        }
57        self.depth += 1;
58        let result = f(self);
59        self.depth -= 1;
60        Ok(result)
61    }
62
63    fn decode<V: Decode<'de>>(&mut self) -> Result<V::Value, Error> {
64        let (decoded, rest) = V::decode(self.input)?;
65        self.input = rest;
66        Ok(decoded)
67    }
68
69    fn decode_with_format<V: Decode<'de>>(&mut self, format: Format) -> Result<V::Value, Error> {
70        let (decoded, rest) = V::decode_with_format(format, self.input)?;
71        self.input = rest;
72        Ok(decoded)
73    }
74
75    fn decode_seq_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
76    where
77        V: de::Visitor<'de>,
78    {
79        let n = match format {
80            Format::FixArray(n) => n.into(),
81            Format::Array16 => {
82                let (n, buf) = NbyteReader::<2>::read(self.input)?;
83                self.input = buf;
84                n
85            }
86            Format::Array32 => {
87                let (n, buf) = NbyteReader::<4>::read(self.input)?;
88                self.input = buf;
89                n
90            }
91            _ => return Err(CoreError::UnexpectedFormat.into()),
92        };
93        self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
94    }
95
96    fn decode_map_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
97    where
98        V: de::Visitor<'de>,
99    {
100        let n = match format {
101            Format::FixMap(n) => n.into(),
102            Format::Map16 => {
103                let (n, buf) = NbyteReader::<2>::read(self.input)?;
104                self.input = buf;
105                n
106            }
107            Format::Map32 => {
108                let (n, buf) = NbyteReader::<4>::read(self.input)?;
109                self.input = buf;
110                n
111            }
112            _ => return Err(CoreError::UnexpectedFormat.into()),
113        };
114        self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
115    }
116}
117
118impl AsMut<Self> for Deserializer<'_> {
119    fn as_mut(&mut self) -> &mut Self {
120        self
121    }
122}
123
124impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
125    type Error = Error;
126
127    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
128    where
129        V: de::Visitor<'de>,
130    {
131        let format = self.decode::<Format>()?;
132        match format {
133            Format::Nil => visitor.visit_unit(),
134            Format::False => visitor.visit_bool(false),
135            Format::True => visitor.visit_bool(true),
136            Format::PositiveFixInt(v) => visitor.visit_u8(v),
137            Format::Uint8 => {
138                let v = self.decode_with_format::<u8>(format)?;
139                visitor.visit_u8(v)
140            }
141            Format::Uint16 => {
142                let v = self.decode_with_format::<u16>(format)?;
143                visitor.visit_u16(v)
144            }
145            Format::Uint32 => {
146                let v = self.decode_with_format::<u32>(format)?;
147                visitor.visit_u32(v)
148            }
149            Format::Uint64 => {
150                let v = self.decode_with_format::<u64>(format)?;
151                visitor.visit_u64(v)
152            }
153            Format::NegativeFixInt(v) => visitor.visit_i8(v),
154            Format::Int8 => {
155                let v = self.decode_with_format::<i8>(format)?;
156                visitor.visit_i8(v)
157            }
158            Format::Int16 => {
159                let v = self.decode_with_format::<i16>(format)?;
160                visitor.visit_i16(v)
161            }
162            Format::Int32 => {
163                let v = self.decode_with_format::<i32>(format)?;
164                visitor.visit_i32(v)
165            }
166            Format::Int64 => {
167                let v = self.decode_with_format::<i64>(format)?;
168                visitor.visit_i64(v)
169            }
170            Format::Float32 => {
171                let v = self.decode_with_format::<f32>(format)?;
172                visitor.visit_f32(v)
173            }
174            Format::Float64 => {
175                let v = self.decode_with_format::<f64>(format)?;
176                visitor.visit_f64(v)
177            }
178            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
179                let v = self.decode_with_format::<&str>(format)?;
180                visitor.visit_borrowed_str(v)
181            }
182            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
183                self.decode_seq_with_format(format, visitor)
184            }
185            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
186                let v = self.decode_with_format::<&[u8]>(format)?;
187                visitor.visit_borrowed_bytes(v)
188            }
189            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
190                self.decode_map_with_format(format, visitor)
191            }
192            Format::Ext8
193            | Format::Ext16
194            | Format::Ext32
195            | Format::FixExt1
196            | Format::FixExt2
197            | Format::FixExt4
198            | Format::FixExt8
199            | Format::FixExt16 => {
200                let mut de_ext = DeserializeExt::new(format, self.input)?;
201                let val = (&mut de_ext).deserialize_newtype_struct(
202                    crate::value::extension::EXTENSION_STRUCT_NAME,
203                    visitor,
204                )?;
205                self.input = de_ext.input;
206
207                Ok(val)
208            }
209            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
210        }
211    }
212
213    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
214    where
215        V: de::Visitor<'de>,
216    {
217        let (first, rest) = self.input.split_first().ok_or(CoreError::EofFormat)?;
218
219        let format = Format::from_byte(*first);
220        match format {
221            Format::Nil => {
222                self.input = rest;
223                visitor.visit_none()
224            }
225            _ => visitor.visit_some(self),
226        }
227    }
228
229    fn deserialize_enum<V>(
230        self,
231        _name: &'static str,
232        _variants: &'static [&'static str],
233        visitor: V,
234    ) -> Result<V::Value, Self::Error>
235    where
236        V: de::Visitor<'de>,
237    {
238        let ident = self.decode::<&str>();
239        match ident {
240            Ok(ident) => visitor.visit_enum(ident.into_deserializer()),
241            _ => {
242                let (format, rest) = Format::decode(self.input)?;
243                let mut des = Deserializer::from_slice(rest);
244                // inherit depth
245                des.depth = self.depth;
246                let val = match format {
247                    Format::FixMap(_)
248                    | Format::Map16
249                    | Format::Map32
250                    | Format::FixArray(_)
251                    | Format::Array16
252                    | Format::Array32 => {
253                        des.recurse(|d| visitor.visit_enum(enum_::Enum::new(d)))?
254                    }
255                    _ => Err(CoreError::UnexpectedFormat.into()),
256                }?;
257                self.input = des.input;
258
259                Ok(val)
260            }
261        }
262    }
263
264    forward_to_deserialize_any! {
265        bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
266        bytes byte_buf unit unit_struct newtype_struct seq tuple
267        tuple_struct map struct identifier ignored_any
268    }
269
270    fn is_human_readable(&self) -> bool {
271        false
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use rstest::rstest;
278
279    use super::*;
280    use serde::de::IgnoredAny;
281
282    #[rstest]
283    #[case([0xc3],true)]
284    #[case([0xc2],false)]
285    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
286        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
287        assert_eq!(decoded, expected);
288    }
289
290    #[rstest]
291    #[case([0x05],5)]
292    #[case([0xcc, 0x80],128)]
293    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
294        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
295        assert_eq!(decoded, expected);
296    }
297
298    #[test]
299    fn decode_float_vec() {
300        // [1.1,1.2,1.3,1.4,1.5]
301        let buf = [
302            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
303            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
304            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
305            0x00, 0x00, 0x00, 0x00,
306        ];
307
308        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
309        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
310
311        assert_eq!(decoded, expected)
312    }
313
314    #[test]
315    fn decode_struct() {
316        #[derive(Deserialize)]
317        struct S {
318            compact: bool,
319            schema: u8,
320        }
321
322        // {"super":1,"schema":0}
323        let buf: &[u8] = &[
324            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
325            0x65, 0x6d, 0x61, 0x00,
326        ];
327
328        let decoded = from_slice::<S>(buf).unwrap();
329        assert!(decoded.compact);
330        assert_eq!(decoded.schema, 0);
331    }
332
333    #[test]
334    fn option_consumes_nil_in_sequence() {
335        // [None, 5] as an array of two elements
336        let buf: &[u8] = &[0x92, 0xc0, 0x05];
337
338        let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
339        assert_eq!(decoded, (None, 5));
340    }
341
342    #[test]
343    fn option_some_simple() {
344        let buf: &[u8] = &[0x05];
345        let decoded = from_slice::<Option<u8>>(buf).unwrap();
346        assert_eq!(decoded, Some(5));
347    }
348
349    #[test]
350    fn unit_from_nil() {
351        let buf: &[u8] = &[0xc0];
352        from_slice::<()>(buf).unwrap();
353    }
354
355    #[test]
356    fn unit_struct() {
357        #[derive(Debug, Deserialize, PartialEq)]
358        struct U;
359
360        let buf: &[u8] = &[0xc0];
361        let decoded = from_slice::<U>(buf).unwrap();
362        assert_eq!(decoded, U);
363    }
364
365    #[derive(Deserialize, PartialEq, Debug)]
366    enum E {
367        Unit,
368        Newtype(u8),
369        Tuple(u8, bool),
370        Struct { a: bool },
371    }
372    #[rstest]
373    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
374    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
375    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
376    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
377    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
378        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
379        assert_eq!(decoded, expected);
380    }
381
382    #[test]
383    fn recursion_limit_ok_at_256() {
384        // [[[[...]]]] 256 nested array
385        let mut buf = vec![0x91u8; 256];
386        buf.push(0xc0);
387
388        let _ = from_slice::<IgnoredAny>(&buf).unwrap();
389    }
390
391    #[test]
392    fn recursion_limit_err_over_256() {
393        // [[[[...]]]] 257 nested array
394        let mut buf = vec![0x91u8; 257];
395        buf.push(0xc0);
396
397        let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
398        assert!(matches!(err, Error::RecursionLimitExceeded));
399    }
400}