1mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use messagepack_core::{
10 Decode, Format,
11 decode::NbyteReader,
12 io::{IoRead, RError},
13};
14use serde::{
15 Deserialize,
16 de::{self, IntoDeserializer},
17 forward_to_deserialize_any,
18};
19
20pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
22 use messagepack_core::io::SliceReader;
23 let reader = SliceReader::new(input);
24 from_trait(reader)
25}
26
27#[cfg(feature = "std")]
28pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
30where
31 R: std::io::Read,
32 T: for<'a> Deserialize<'a>,
33{
34 use messagepack_core::io::StdReader;
35 use std::io;
36 let reader = StdReader::new(reader);
37 let result = from_trait::<'_, StdReader<R>, T>(reader);
38 match result {
39 Ok(v) => Ok(v),
40 Err(err) => match err {
41 Error::Decode(err) => match err {
42 messagepack_core::decode::Error::InvalidData
43 | messagepack_core::decode::Error::UnexpectedFormat => {
44 Err(io::Error::new(io::ErrorKind::InvalidData, err))
45 }
46 messagepack_core::decode::Error::UnexpectedEof => {
47 Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
48 }
49 messagepack_core::decode::Error::Io(e) => Err(e),
50 },
51 _ => Err(io::Error::other(err)),
52 },
53 }
54}
55
56fn from_trait<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
57where
58 R: IoRead<'de>,
59 T: Deserialize<'de>,
60{
61 let mut deserializer = Deserializer::from_trait(reader);
62 T::deserialize(&mut deserializer)
63}
64
65const MAX_RECURSION_DEPTH: usize = 256;
66
67struct Deserializer<R> {
68 reader: R,
69 depth: usize,
70 format: Option<Format>,
71}
72
73impl<'de, R> Deserializer<R>
74where
75 R: IoRead<'de>,
76{
77 pub fn from_trait(reader: R) -> Self {
78 Deserializer {
79 reader,
80 depth: 0,
81 format: None,
82 }
83 }
84
85 fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
86 where
87 F: FnOnce(&mut Self) -> V,
88 {
89 if self.depth == MAX_RECURSION_DEPTH {
90 return Err(Error::RecursionLimitExceeded);
91 }
92 self.depth += 1;
93 let result = f(self);
94 self.depth -= 1;
95 Ok(result)
96 }
97
98 fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
99 match self.format.take() {
100 Some(v) => Ok(v),
101 None => {
102 let v = Format::decode(&mut self.reader)?;
103 Ok(v)
104 }
105 }
106 }
107
108 fn decode_seq_with_format<V>(
109 &mut self,
110 format: Format,
111 visitor: V,
112 ) -> Result<V::Value, Error<R::Error>>
113 where
114 V: de::Visitor<'de>,
115 {
116 let n = match format {
117 Format::FixArray(n) => n.into(),
118 Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
119 Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
120 _ => return Err(CoreError::UnexpectedFormat.into()),
121 };
122 self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
123 }
124
125 fn decode_map_with_format<V>(
126 &mut self,
127 format: Format,
128 visitor: V,
129 ) -> Result<V::Value, Error<R::Error>>
130 where
131 V: de::Visitor<'de>,
132 {
133 let n = match format {
134 Format::FixMap(n) => n.into(),
135 Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
136 Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
137 _ => return Err(CoreError::UnexpectedFormat.into()),
138 };
139 self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
140 }
141}
142
143impl<R> AsMut<Self> for Deserializer<R> {
144 fn as_mut(&mut self) -> &mut Self {
145 self
146 }
147}
148
149impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
150where
151 R: IoRead<'de>,
152{
153 type Error = Error<R::Error>;
154
155 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
156 where
157 V: de::Visitor<'de>,
158 {
159 let format = self.decode_format()?;
160 match format {
161 Format::Nil => visitor.visit_unit(),
162 Format::False => visitor.visit_bool(false),
163 Format::True => visitor.visit_bool(true),
164 Format::PositiveFixInt(v) => visitor.visit_u8(v),
165 Format::Uint8 => {
166 let v = u8::decode_with_format(format, &mut self.reader)?;
167 visitor.visit_u8(v)
168 }
169 Format::Uint16 => {
170 let v = u16::decode_with_format(format, &mut self.reader)?;
171 visitor.visit_u16(v)
172 }
173 Format::Uint32 => {
174 let v = u32::decode_with_format(format, &mut self.reader)?;
175 visitor.visit_u32(v)
176 }
177 Format::Uint64 => {
178 let v = u64::decode_with_format(format, &mut self.reader)?;
179 visitor.visit_u64(v)
180 }
181 Format::NegativeFixInt(v) => visitor.visit_i8(v),
182 Format::Int8 => {
183 let v = i8::decode_with_format(format, &mut self.reader)?;
184 visitor.visit_i8(v)
185 }
186 Format::Int16 => {
187 let v = i16::decode_with_format(format, &mut self.reader)?;
188 visitor.visit_i16(v)
189 }
190 Format::Int32 => {
191 let v = i32::decode_with_format(format, &mut self.reader)?;
192 visitor.visit_i32(v)
193 }
194 Format::Int64 => {
195 let v = i64::decode_with_format(format, &mut self.reader)?;
196 visitor.visit_i64(v)
197 }
198 Format::Float32 => {
199 let v = f32::decode_with_format(format, &mut self.reader)?;
200 visitor.visit_f32(v)
201 }
202 Format::Float64 => {
203 let v = f64::decode_with_format(format, &mut self.reader)?;
204 visitor.visit_f64(v)
205 }
206 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
207 use messagepack_core::decode::ReferenceStrDecoder;
208 let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
209 match data {
210 messagepack_core::decode::ReferenceStr::Borrowed(s) => {
211 visitor.visit_borrowed_str(s)
212 }
213 messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
214 }
215 }
216 Format::FixArray(_) | Format::Array16 | Format::Array32 => {
217 self.decode_seq_with_format(format, visitor)
218 }
219 Format::Bin8 | Format::Bin16 | Format::Bin32 => {
220 use messagepack_core::decode::ReferenceDecoder;
221 let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
222 match data {
223 messagepack_core::io::Reference::Borrowed(items) => {
224 visitor.visit_borrowed_bytes(items)
225 }
226 messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
227 }
228 }
229 Format::FixMap(_) | Format::Map16 | Format::Map32 => {
230 self.decode_map_with_format(format, visitor)
231 }
232 Format::Ext8
233 | Format::Ext16
234 | Format::Ext32
235 | Format::FixExt1
236 | Format::FixExt2
237 | Format::FixExt4
238 | Format::FixExt8
239 | Format::FixExt16 => {
240 let mut de_ext =
241 crate::extension::de::DeserializeExt::new(format, &mut self.reader)?;
242 let val = de::Deserializer::deserialize_newtype_struct(
243 &mut de_ext,
244 crate::extension::EXTENSION_STRUCT_NAME,
245 visitor,
246 )?;
247
248 Ok(val)
249 }
250 Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
251 }
252 }
253
254 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
255 where
256 V: de::Visitor<'de>,
257 {
258 let format = self.decode_format()?;
259 match format {
260 Format::Nil => visitor.visit_none(),
261 _ => {
262 self.format = Some(format);
263 visitor.visit_some(self.as_mut())
264 }
265 }
266 }
267
268 fn deserialize_enum<V>(
269 self,
270 _name: &'static str,
271 _variants: &'static [&'static str],
272 visitor: V,
273 ) -> Result<V::Value, Self::Error>
274 where
275 V: de::Visitor<'de>,
276 {
277 let format = self.decode_format()?;
278 match format {
279 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
280 let s = <&str>::decode_with_format(format, &mut self.reader)?;
281 visitor.visit_enum(s.into_deserializer())
282 }
283 Format::FixMap(_)
284 | Format::Map16
285 | Format::Map32
286 | Format::FixArray(_)
287 | Format::Array16
288 | Format::Array32 => {
289 let enum_access = enum_::Enum::new(self);
290 visitor.visit_enum(enum_access)
291 }
292 _ => Err(CoreError::UnexpectedFormat.into()),
293 }
294 }
295
296 forward_to_deserialize_any! {
297 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
298 bytes byte_buf unit unit_struct newtype_struct seq tuple
299 tuple_struct map struct identifier ignored_any
300 }
301
302 fn is_human_readable(&self) -> bool {
303 false
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use rstest::rstest;
310
311 use super::*;
312 use serde::de::IgnoredAny;
313
314 #[rstest]
315 #[case([0xc3],true)]
316 #[case([0xc2],false)]
317 fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
318 let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
319 assert_eq!(decoded, expected);
320 }
321
322 #[rstest]
323 #[case([0x05],5)]
324 #[case([0xcc, 0x80],128)]
325 fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
326 let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
327 assert_eq!(decoded, expected);
328 }
329
330 #[test]
331 fn decode_float_vec() {
332 let buf = [
334 0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
335 0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
336 0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
337 0x00, 0x00, 0x00, 0x00,
338 ];
339
340 let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
341 let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
342
343 assert_eq!(decoded, expected)
344 }
345
346 #[test]
347 fn decode_struct() {
348 #[derive(Deserialize)]
349 struct S {
350 compact: bool,
351 schema: u8,
352 }
353
354 let buf: &[u8] = &[
356 0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
357 0x65, 0x6d, 0x61, 0x00,
358 ];
359
360 let decoded = from_slice::<S>(buf).unwrap();
361 assert!(decoded.compact);
362 assert_eq!(decoded.schema, 0);
363 }
364
365 #[test]
366 fn decode_struct_from_array() {
367 #[derive(Deserialize, Debug, PartialEq)]
368 struct S {
369 compact: bool,
370 schema: u8,
371 }
372
373 let buf: &[u8] = &[0x92, 0xc3, 0x00];
375
376 let decoded = from_slice::<S>(buf).unwrap();
377 assert_eq!(
378 decoded,
379 S {
380 compact: true,
381 schema: 0
382 }
383 );
384 }
385
386 #[test]
387 fn option_consumes_nil_in_sequence() {
388 let buf: &[u8] = &[0x92, 0xc0, 0x05];
390
391 let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
392 assert_eq!(decoded, (None, 5));
393 }
394
395 #[test]
396 fn option_some_simple() {
397 let buf: &[u8] = &[0x05];
398 let decoded = from_slice::<Option<u8>>(buf).unwrap();
399 assert_eq!(decoded, Some(5));
400 }
401
402 #[test]
403 fn unit_from_nil() {
404 let buf: &[u8] = &[0xc0];
405 from_slice::<()>(buf).unwrap();
406 }
407
408 #[test]
409 fn unit_struct() {
410 #[derive(Debug, Deserialize, PartialEq)]
411 struct U;
412
413 let buf: &[u8] = &[0xc0];
414 let decoded = from_slice::<U>(buf).unwrap();
415 assert_eq!(decoded, U);
416 }
417
418 #[derive(Deserialize, PartialEq, Debug)]
419 enum E {
420 Unit,
421 Newtype(u8),
422 Tuple(u8, bool),
423 Struct { a: bool },
424 }
425 #[rstest]
426 #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
431 let decoded = from_slice::<E>(buf.as_ref()).unwrap();
432 assert_eq!(decoded, expected);
433 }
434
435 #[derive(Deserialize, PartialEq, Debug)]
436 #[serde(untagged)]
437 enum Untagged {
438 Bool(bool),
439 U8(u8),
440 Pair(u8, bool),
441 Struct { a: bool },
442 Nested(E),
443 }
444
445 #[rstest]
446 #[case([0xc3],Untagged::Bool(true))]
447 #[case([0x05],Untagged::U8(5))]
448 #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
449 #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
450 #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
452 let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
453 assert_eq!(decoded, expected);
454 }
455
456 #[test]
457 fn recursion_limit_ok_at_256() {
458 let mut buf = vec![0x91u8; 256];
460 buf.push(0xc0);
461
462 let _ = from_slice::<IgnoredAny>(&buf).unwrap();
463 }
464
465 #[test]
466 fn recursion_limit_err_over_256() {
467 let mut buf = vec![0x91u8; 257];
469 buf.push(0xc0);
470
471 let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
472 assert!(matches!(err, Error::RecursionLimitExceeded));
473 }
474
475 #[cfg(feature = "std")]
476 #[rstest]
477 #[case([0xc0],())]
479 #[case([0xc3],true)]
481 #[case([0xc2],false)]
482 #[case([0x2a],42u8)]
484 #[case([0xcc, 0x80],128u8)]
485 #[case([0xcd, 0x01, 0x00],256u16)]
486 #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
487 #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
488 #[case([0xff],-1i8)]
490 #[case([0xd0, 0x80],-128i8)]
491 #[case([0xd1, 0x80, 0x00],-32768i16)]
492 #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
493 #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
494 #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
496 #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
497 #[case([0xa1, 0x61],"a".to_string())]
499 #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
500 #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
502 #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
504 #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
506 let mut m = std::collections::BTreeMap::<String, u8>::new();
507 m.insert("a".to_string(), 1u8);
508 m.insert("b".to_string(), 2u8);
509 m
510 })]
511 fn decode_success_from_reader_when_owned<
512 Buf: AsRef<[u8]>,
513 T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
514 >(
515 #[case] buf: Buf,
516 #[case] expected: T,
517 ) {
518 use super::from_reader;
519 let mut reader = std::io::Cursor::new(buf.as_ref());
520 let val = from_reader::<_, T>(&mut reader).unwrap();
521 assert_eq!(val, expected)
522 }
523}