messagepack_serde/de/
mod.rs

1mod enum_;
2mod error;
3mod seq;
4
5pub use error::{CoreError, Error};
6
7use crate::value::extension::DeserializeExt;
8use messagepack_core::{
9    Decode, Format,
10    decode::{NbyteReader, NilDecoder},
11};
12use serde::{
13    Deserialize,
14    de::{self, IntoDeserializer},
15};
16
17#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
18pub struct Deserializer<'de> {
19    input: &'de [u8],
20}
21
22impl<'de> Deserializer<'de> {
23    pub fn from_slice(input: &'de [u8]) -> Self {
24        Deserializer { input }
25    }
26
27    fn decode<V: Decode<'de>>(&mut self) -> Result<V::Value, Error> {
28        let (decoded, rest) = V::decode(self.input)?;
29        self.input = rest;
30        Ok(decoded)
31    }
32
33    fn decode_with_format<V: Decode<'de>>(&mut self, format: Format) -> Result<V::Value, Error> {
34        let (decoded, rest) = V::decode_with_format(format, self.input)?;
35        self.input = rest;
36        Ok(decoded)
37    }
38
39    fn decode_seq_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
40    where
41        V: de::Visitor<'de>,
42    {
43        let n = match format {
44            Format::FixArray(n) => n.into(),
45            Format::Array16 => {
46                let (n, buf) = NbyteReader::<2>::read(self.input)?;
47                self.input = buf;
48                n
49            }
50            Format::Array32 => {
51                let (n, buf) = NbyteReader::<4>::read(self.input)?;
52                self.input = buf;
53                n
54            }
55            _ => return Err(CoreError::UnexpectedFormat.into()),
56        };
57        visitor.visit_seq(seq::FixLenAccess::new(self, n))
58    }
59
60    fn decode_map_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
61    where
62        V: de::Visitor<'de>,
63    {
64        let n = match format {
65            Format::FixMap(n) => n.into(),
66            Format::Map16 => {
67                let (n, buf) = NbyteReader::<2>::read(self.input)?;
68                self.input = buf;
69                n
70            }
71            Format::Map32 => {
72                let (n, buf) = NbyteReader::<4>::read(self.input)?;
73                self.input = buf;
74                n
75            }
76            _ => return Err(CoreError::UnexpectedFormat.into()),
77        };
78        visitor.visit_map(seq::FixLenAccess::new(self, n))
79    }
80}
81
82impl AsMut<Self> for Deserializer<'_> {
83    fn as_mut(&mut self) -> &mut Self {
84        self
85    }
86}
87
88pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
89    from_slice_with_config(input)
90}
91
92pub fn from_slice_with_config<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
93    let mut deserializer = Deserializer::from_slice(input);
94    T::deserialize(&mut deserializer)
95}
96
97#[cfg(feature = "std")]
98pub fn from_reader<R, T>(reader: &mut R) -> std::io::Result<T>
99where
100    R: std::io::Read,
101    T: for<'a> Deserialize<'a>,
102{
103    from_reader_with_config(reader)
104}
105
106#[cfg(feature = "std")]
107pub fn from_reader_with_config<R, T>(reader: &mut R) -> std::io::Result<T>
108where
109    R: std::io::Read,
110    T: for<'a> Deserialize<'a>,
111{
112    let mut buf = Vec::new();
113    reader.read_to_end(&mut buf)?;
114
115    let mut deserializer = Deserializer::from_slice(&buf);
116    T::deserialize(&mut deserializer).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
117}
118
119impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
120    type Error = Error;
121
122    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
123    where
124        V: de::Visitor<'de>,
125    {
126        let format = self.decode::<Format>()?;
127        match format {
128            Format::Nil => visitor.visit_none(),
129            Format::False => visitor.visit_bool(false),
130            Format::True => visitor.visit_bool(true),
131            Format::PositiveFixInt(v) => visitor.visit_u8(v),
132            Format::Uint8 => {
133                let v = self.decode_with_format::<u8>(format)?;
134                visitor.visit_u8(v)
135            }
136            Format::Uint16 => {
137                let v = self.decode_with_format::<u16>(format)?;
138                visitor.visit_u16(v)
139            }
140            Format::Uint32 => {
141                let v = self.decode_with_format::<u32>(format)?;
142                visitor.visit_u32(v)
143            }
144            Format::Uint64 => {
145                let v = self.decode_with_format::<u64>(format)?;
146                visitor.visit_u64(v)
147            }
148            Format::NegativeFixInt(v) => visitor.visit_i8(v),
149            Format::Int8 => {
150                let v = self.decode_with_format::<i8>(format)?;
151                visitor.visit_i8(v)
152            }
153            Format::Int16 => {
154                let v = self.decode_with_format::<i16>(format)?;
155                visitor.visit_i16(v)
156            }
157            Format::Int32 => {
158                let v = self.decode_with_format::<i32>(format)?;
159                visitor.visit_i32(v)
160            }
161            Format::Int64 => {
162                let v = self.decode_with_format::<i64>(format)?;
163                visitor.visit_i64(v)
164            }
165            Format::Float32 => {
166                let v = self.decode_with_format::<f32>(format)?;
167                visitor.visit_f32(v)
168            }
169            Format::Float64 => {
170                let v = self.decode_with_format::<f64>(format)?;
171                visitor.visit_f64(v)
172            }
173            Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
174                let v = self.decode_with_format::<&str>(format)?;
175                visitor.visit_borrowed_str(v)
176            }
177            Format::FixArray(_) | Format::Array16 | Format::Array32 => {
178                self.decode_seq_with_format(format, visitor)
179            }
180            Format::Bin8 | Format::Bin16 | Format::Bin32 => {
181                let v = self.decode_with_format::<&[u8]>(format)?;
182                visitor.visit_borrowed_bytes(v)
183            }
184            Format::FixMap(_) | Format::Map16 | Format::Map32 => {
185                self.decode_map_with_format(format, visitor)
186            }
187            Format::Ext8
188            | Format::Ext16
189            | Format::Ext32
190            | Format::FixExt1
191            | Format::FixExt2
192            | Format::FixExt4
193            | Format::FixExt8
194            | Format::FixExt16 => {
195                let mut de_ext = DeserializeExt::new(format, self.input)?;
196                let val = (&mut de_ext).deserialize_newtype_struct(
197                    crate::value::extension::EXTENSION_STRUCT_NAME,
198                    visitor,
199                )?;
200                self.input = de_ext.input;
201
202                Ok(val)
203            }
204            Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
205        }
206    }
207
208    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209    where
210        V: de::Visitor<'de>,
211    {
212        let decoded = self.decode::<bool>()?;
213        visitor.visit_bool(decoded)
214    }
215
216    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
217    where
218        V: de::Visitor<'de>,
219    {
220        let decoded = self.decode::<i8>()?;
221        visitor.visit_i8(decoded)
222    }
223
224    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
225    where
226        V: de::Visitor<'de>,
227    {
228        let decoded = self.decode::<i16>()?;
229        visitor.visit_i16(decoded)
230    }
231
232    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
233    where
234        V: de::Visitor<'de>,
235    {
236        let decoded = self.decode::<i32>()?;
237        visitor.visit_i32(decoded)
238    }
239
240    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
241    where
242        V: de::Visitor<'de>,
243    {
244        let decoded = self.decode::<i64>()?;
245        visitor.visit_i64(decoded)
246    }
247
248    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
249    where
250        V: de::Visitor<'de>,
251    {
252        let decoded = self.decode::<u8>()?;
253        visitor.visit_u8(decoded)
254    }
255
256    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
257    where
258        V: de::Visitor<'de>,
259    {
260        let decoded = self.decode::<u16>()?;
261        visitor.visit_u16(decoded)
262    }
263
264    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: de::Visitor<'de>,
267    {
268        let decoded = self.decode::<u32>()?;
269        visitor.visit_u32(decoded)
270    }
271
272    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
273    where
274        V: de::Visitor<'de>,
275    {
276        let decoded = self.decode::<u64>()?;
277        visitor.visit_u64(decoded)
278    }
279
280    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
281    where
282        V: de::Visitor<'de>,
283    {
284        let decoded = self.decode::<f32>()?;
285        visitor.visit_f32(decoded)
286    }
287
288    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
289    where
290        V: de::Visitor<'de>,
291    {
292        let decoded = self.decode::<f64>()?;
293        visitor.visit_f64(decoded)
294    }
295
296    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
297    where
298        V: de::Visitor<'de>,
299    {
300        self.deserialize_str(visitor)
301    }
302
303    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
304    where
305        V: de::Visitor<'de>,
306    {
307        let decoded = self.decode::<&str>()?;
308        visitor.visit_borrowed_str(decoded)
309    }
310
311    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
312    where
313        V: de::Visitor<'de>,
314    {
315        self.deserialize_str(visitor)
316    }
317
318    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
319    where
320        V: de::Visitor<'de>,
321    {
322        let decoded = self.decode::<&[u8]>()?;
323        visitor.visit_borrowed_bytes(decoded)
324    }
325
326    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
327    where
328        V: de::Visitor<'de>,
329    {
330        self.deserialize_bytes(visitor)
331    }
332
333    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
334    where
335        V: de::Visitor<'de>,
336    {
337        let is_null = NilDecoder::decode(self.input).is_ok();
338        if is_null {
339            visitor.visit_none()
340        } else {
341            visitor.visit_some(self)
342        }
343    }
344
345    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
346    where
347        V: de::Visitor<'de>,
348    {
349        self.decode::<()>()?;
350        visitor.visit_unit()
351    }
352
353    fn deserialize_unit_struct<V>(
354        self,
355        _name: &'static str,
356        visitor: V,
357    ) -> Result<V::Value, Self::Error>
358    where
359        V: de::Visitor<'de>,
360    {
361        self.deserialize_unit(visitor)
362    }
363
364    fn deserialize_newtype_struct<V>(
365        self,
366        _name: &'static str,
367        visitor: V,
368    ) -> Result<V::Value, Self::Error>
369    where
370        V: de::Visitor<'de>,
371    {
372        visitor.visit_newtype_struct(self)
373    }
374
375    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
376    where
377        V: de::Visitor<'de>,
378    {
379        let (format, rest) = Format::decode(self.input)?;
380
381        let mut des = Deserializer::from_slice(rest);
382        let val = des.decode_seq_with_format(format, visitor)?;
383        self.input = des.input;
384
385        Ok(val)
386    }
387
388    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
389    where
390        V: de::Visitor<'de>,
391    {
392        self.deserialize_seq(visitor)
393    }
394
395    fn deserialize_tuple_struct<V>(
396        self,
397        _name: &'static str,
398        _len: usize,
399        visitor: V,
400    ) -> Result<V::Value, Self::Error>
401    where
402        V: de::Visitor<'de>,
403    {
404        self.deserialize_seq(visitor)
405    }
406
407    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
408    where
409        V: de::Visitor<'de>,
410    {
411        let (format, rest) = Format::decode(self.input)?;
412
413        let mut des = Deserializer::from_slice(rest);
414        let val = des.decode_map_with_format(format, visitor)?;
415        self.input = des.input;
416
417        Ok(val)
418    }
419
420    fn deserialize_struct<V>(
421        self,
422        _name: &'static str,
423        _fields: &'static [&'static str],
424        visitor: V,
425    ) -> Result<V::Value, Self::Error>
426    where
427        V: de::Visitor<'de>,
428    {
429        self.deserialize_map(visitor)
430    }
431
432    fn deserialize_enum<V>(
433        self,
434        _name: &'static str,
435        _variants: &'static [&'static str],
436        visitor: V,
437    ) -> Result<V::Value, Self::Error>
438    where
439        V: de::Visitor<'de>,
440    {
441        let ident = self.decode::<&str>();
442        match ident {
443            Ok(ident) => visitor.visit_enum(ident.into_deserializer()),
444            _ => {
445                let (format, rest) = Format::decode(self.input)?;
446
447                let mut des = Deserializer::from_slice(rest);
448                let val = match format {
449                    Format::FixMap(_)
450                    | Format::Map16
451                    | Format::Map32
452                    | Format::FixArray(_)
453                    | Format::Array16
454                    | Format::Array32 => visitor.visit_enum(enum_::Enum::new(&mut des)),
455                    _ => Err(CoreError::UnexpectedFormat.into()),
456                }?;
457
458                self.input = des.input;
459
460                Ok(val)
461            }
462        }
463    }
464
465    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
466    where
467        V: de::Visitor<'de>,
468    {
469        self.deserialize_str(visitor)
470    }
471
472    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
473    where
474        V: de::Visitor<'de>,
475    {
476        self.deserialize_any(visitor)
477    }
478
479    fn is_human_readable(&self) -> bool {
480        false
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use rstest::rstest;
487
488    use super::*;
489
490    #[rstest]
491    #[case([0xc3],true)]
492    #[case([0xc2],false)]
493    fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
494        let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
495        assert_eq!(decoded, expected);
496    }
497
498    #[rstest]
499    #[case([0x05],5)]
500    #[case([0xcc, 0x80],128)]
501    fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
502        let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
503        assert_eq!(decoded, expected);
504    }
505
506    #[test]
507    fn decode_float_vec() {
508        // [1.1,1.2,1.3,1.4,1.5]
509        let buf = [
510            0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
511            0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
512            0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
513            0x00, 0x00, 0x00, 0x00,
514        ];
515
516        let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
517        let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
518
519        assert_eq!(decoded, expected)
520    }
521
522    #[test]
523    fn decode_struct() {
524        #[derive(Deserialize)]
525        struct S {
526            compact: bool,
527            schema: u8,
528        }
529
530        // {"super":1,"schema":0}
531        let buf: &[u8] = &[
532            0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
533            0x65, 0x6d, 0x61, 0x00,
534        ];
535
536        let decoded = from_slice::<S>(buf).unwrap();
537        assert!(decoded.compact);
538        assert_eq!(decoded.schema, 0);
539    }
540
541    #[derive(Deserialize, PartialEq, Debug)]
542    enum E {
543        Unit,
544        Newtype(u8),
545        Tuple(u8, bool),
546        Struct { a: bool },
547    }
548    #[rstest]
549    #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] // "Unit"
550    #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] // {"Newtype":27}
551    #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] // {"Tuple":[3,true]}
552    #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] // {"Struct":{"a":false}}
553    fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
554        let decoded = from_slice::<E>(buf.as_ref()).unwrap();
555        assert_eq!(decoded, expected);
556    }
557
558    #[rstest]
559    fn decode_extension() {
560        use crate::value::extension::ExtensionRef;
561
562        let buf: &[u8] = &[0xd4, 0x7b, 0x12];
563
564        let ext = from_slice::<ExtensionRef<'_>>(buf).unwrap();
565        assert_eq!(ext.kind, 123);
566        assert_eq!(ext.data, [0x12_u8])
567    }
568}