1mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use crate::value::extension::DeserializeExt;
10use messagepack_core::{
11 Decode, Format,
12 decode::NbyteReader,
13 io::{IoRead, RError},
14};
15use serde::{
16 Deserialize,
17 de::{self, IntoDeserializer},
18 forward_to_deserialize_any,
19};
20
21pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error<RError>> {
23 use messagepack_core::io::SliceReader;
24 let reader = SliceReader::new(input);
25 from_trait(reader)
26}
27
28#[cfg(feature = "std")]
29pub fn from_reader<R, T>(reader: R) -> std::io::Result<T>
31where
32 R: std::io::Read,
33 T: for<'a> Deserialize<'a>,
34{
35 use messagepack_core::io::StdReader;
36 use std::io;
37 let reader = StdReader::new(reader);
38 let result = from_trait::<'_, StdReader<R>, T>(reader);
39 match result {
40 Ok(v) => Ok(v),
41 Err(err) => match err {
42 Error::Decode(err) => match err {
43 messagepack_core::decode::Error::InvalidData
44 | messagepack_core::decode::Error::UnexpectedFormat => {
45 Err(io::Error::new(io::ErrorKind::InvalidData, err))
46 }
47 messagepack_core::decode::Error::UnexpectedEof => {
48 Err(io::Error::new(io::ErrorKind::UnexpectedEof, err))
49 }
50 messagepack_core::decode::Error::Io(e) => Err(e),
51 },
52 _ => Err(io::Error::other(err)),
53 },
54 }
55}
56
57fn from_trait<'de, R, T>(reader: R) -> Result<T, Error<R::Error>>
58where
59 R: IoRead<'de>,
60 T: Deserialize<'de>,
61{
62 let mut deserializer = Deserializer::from_trait(reader);
63 T::deserialize(&mut deserializer)
64}
65
66const MAX_RECURSION_DEPTH: usize = 256;
67
68struct Deserializer<R> {
69 reader: R,
70 depth: usize,
71 format: Option<Format>,
72}
73
74impl<'de, R> Deserializer<R>
75where
76 R: IoRead<'de>,
77{
78 pub fn from_trait(reader: R) -> Self {
79 Deserializer {
80 reader,
81 depth: 0,
82 format: None,
83 }
84 }
85
86 fn recurse<F, V>(&mut self, f: F) -> Result<V, Error<R::Error>>
87 where
88 F: FnOnce(&mut Self) -> V,
89 {
90 if self.depth == MAX_RECURSION_DEPTH {
91 return Err(Error::RecursionLimitExceeded);
92 }
93 self.depth += 1;
94 let result = f(self);
95 self.depth -= 1;
96 Ok(result)
97 }
98
99 fn decode_format(&mut self) -> Result<Format, Error<R::Error>> {
100 match self.format.take() {
101 Some(v) => Ok(v),
102 None => {
103 let v = Format::decode(&mut self.reader)?;
104 Ok(v)
105 }
106 }
107 }
108
109 fn decode_seq_with_format<V>(
110 &mut self,
111 format: Format,
112 visitor: V,
113 ) -> Result<V::Value, Error<R::Error>>
114 where
115 V: de::Visitor<'de>,
116 {
117 let n = match format {
118 Format::FixArray(n) => n.into(),
119 Format::Array16 => NbyteReader::<2>::read(&mut self.reader)?,
120 Format::Array32 => NbyteReader::<4>::read(&mut self.reader)?,
121 _ => return Err(CoreError::UnexpectedFormat.into()),
122 };
123 self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
124 }
125
126 fn decode_map_with_format<V>(
127 &mut self,
128 format: Format,
129 visitor: V,
130 ) -> Result<V::Value, Error<R::Error>>
131 where
132 V: de::Visitor<'de>,
133 {
134 let n = match format {
135 Format::FixMap(n) => n.into(),
136 Format::Map16 => NbyteReader::<2>::read(&mut self.reader)?,
137 Format::Map32 => NbyteReader::<4>::read(&mut self.reader)?,
138 _ => return Err(CoreError::UnexpectedFormat.into()),
139 };
140 self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
141 }
142}
143
144impl<R> AsMut<Self> for Deserializer<R> {
145 fn as_mut(&mut self) -> &mut Self {
146 self
147 }
148}
149
150impl<'de, R> de::Deserializer<'de> for &mut Deserializer<R>
151where
152 R: IoRead<'de>,
153{
154 type Error = Error<R::Error>;
155
156 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
157 where
158 V: de::Visitor<'de>,
159 {
160 let format = self.decode_format()?;
161 match format {
162 Format::Nil => visitor.visit_unit(),
163 Format::False => visitor.visit_bool(false),
164 Format::True => visitor.visit_bool(true),
165 Format::PositiveFixInt(v) => visitor.visit_u8(v),
166 Format::Uint8 => {
167 let v = u8::decode_with_format(format, &mut self.reader)?;
168 visitor.visit_u8(v)
169 }
170 Format::Uint16 => {
171 let v = u16::decode_with_format(format, &mut self.reader)?;
172 visitor.visit_u16(v)
173 }
174 Format::Uint32 => {
175 let v = u32::decode_with_format(format, &mut self.reader)?;
176 visitor.visit_u32(v)
177 }
178 Format::Uint64 => {
179 let v = u64::decode_with_format(format, &mut self.reader)?;
180 visitor.visit_u64(v)
181 }
182 Format::NegativeFixInt(v) => visitor.visit_i8(v),
183 Format::Int8 => {
184 let v = i8::decode_with_format(format, &mut self.reader)?;
185 visitor.visit_i8(v)
186 }
187 Format::Int16 => {
188 let v = i16::decode_with_format(format, &mut self.reader)?;
189 visitor.visit_i16(v)
190 }
191 Format::Int32 => {
192 let v = i32::decode_with_format(format, &mut self.reader)?;
193 visitor.visit_i32(v)
194 }
195 Format::Int64 => {
196 let v = i64::decode_with_format(format, &mut self.reader)?;
197 visitor.visit_i64(v)
198 }
199 Format::Float32 => {
200 let v = f32::decode_with_format(format, &mut self.reader)?;
201 visitor.visit_f32(v)
202 }
203 Format::Float64 => {
204 let v = f64::decode_with_format(format, &mut self.reader)?;
205 visitor.visit_f64(v)
206 }
207 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
208 use messagepack_core::decode::ReferenceStrDecoder;
209 let data = ReferenceStrDecoder::decode_with_format(format, &mut self.reader)?;
210 match data {
211 messagepack_core::decode::ReferenceStr::Borrowed(s) => {
212 visitor.visit_borrowed_str(s)
213 }
214 messagepack_core::decode::ReferenceStr::Copied(s) => visitor.visit_str(s),
215 }
216 }
217 Format::FixArray(_) | Format::Array16 | Format::Array32 => {
218 self.decode_seq_with_format(format, visitor)
219 }
220 Format::Bin8 | Format::Bin16 | Format::Bin32 => {
221 use messagepack_core::decode::ReferenceDecoder;
222 let data = ReferenceDecoder::decode_with_format(format, &mut self.reader)?;
223 match data {
224 messagepack_core::io::Reference::Borrowed(items) => {
225 visitor.visit_borrowed_bytes(items)
226 }
227 messagepack_core::io::Reference::Copied(items) => visitor.visit_bytes(items),
228 }
229 }
230 Format::FixMap(_) | Format::Map16 | Format::Map32 => {
231 self.decode_map_with_format(format, visitor)
232 }
233 Format::Ext8
234 | Format::Ext16
235 | Format::Ext32
236 | Format::FixExt1
237 | Format::FixExt2
238 | Format::FixExt4
239 | Format::FixExt8
240 | Format::FixExt16 => {
241 let mut de_ext = DeserializeExt::new(format, &mut self.reader)?;
242 let val = de::Deserializer::deserialize_newtype_struct(
243 &mut de_ext,
244 crate::value::extension::EXTENSION_STRUCT_NAME,
245 visitor,
246 )?;
247
248 Ok(val)
249 }
250 Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
251 }
252 }
253
254 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
255 where
256 V: de::Visitor<'de>,
257 {
258 let format = self.decode_format()?;
259 match format {
260 Format::Nil => visitor.visit_none(),
261 _ => {
262 self.format = Some(format);
263 visitor.visit_some(self.as_mut())
264 }
265 }
266 }
267
268 fn deserialize_enum<V>(
269 self,
270 _name: &'static str,
271 _variants: &'static [&'static str],
272 visitor: V,
273 ) -> Result<V::Value, Self::Error>
274 where
275 V: de::Visitor<'de>,
276 {
277 let format = self.decode_format()?;
278 match format {
279 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
280 let s = <&str>::decode_with_format(format, &mut self.reader)?;
281 visitor.visit_enum(s.into_deserializer())
282 }
283 Format::FixMap(_)
284 | Format::Map16
285 | Format::Map32
286 | Format::FixArray(_)
287 | Format::Array16
288 | Format::Array32 => {
289 let enum_access = enum_::Enum::new(self);
290 visitor.visit_enum(enum_access)
291 }
292 _ => Err(CoreError::UnexpectedFormat.into()),
293 }
294 }
295
296 forward_to_deserialize_any! {
297 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
298 bytes byte_buf unit unit_struct newtype_struct seq tuple
299 tuple_struct map struct identifier ignored_any
300 }
301
302 fn is_human_readable(&self) -> bool {
303 false
304 }
305}
306
307#[cfg(test)]
308mod tests {
309 use rstest::rstest;
310
311 use super::*;
312 use serde::de::IgnoredAny;
313
314 #[rstest]
315 #[case([0xc3],true)]
316 #[case([0xc2],false)]
317 fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
318 let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
319 assert_eq!(decoded, expected);
320 }
321
322 #[rstest]
323 #[case([0x05],5)]
324 #[case([0xcc, 0x80],128)]
325 fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
326 let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
327 assert_eq!(decoded, expected);
328 }
329
330 #[test]
331 fn decode_float_vec() {
332 let buf = [
334 0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
335 0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
336 0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
337 0x00, 0x00, 0x00, 0x00,
338 ];
339
340 let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
341 let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
342
343 assert_eq!(decoded, expected)
344 }
345
346 #[test]
347 fn decode_struct() {
348 #[derive(Deserialize)]
349 struct S {
350 compact: bool,
351 schema: u8,
352 }
353
354 let buf: &[u8] = &[
356 0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
357 0x65, 0x6d, 0x61, 0x00,
358 ];
359
360 let decoded = from_slice::<S>(buf).unwrap();
361 assert!(decoded.compact);
362 assert_eq!(decoded.schema, 0);
363 }
364
365 #[test]
366 fn decode_struct_from_array() {
367 #[derive(Deserialize, Debug, PartialEq)]
368 struct S {
369 compact: bool,
370 schema: u8,
371 }
372
373 let buf: &[u8] = &[0x92, 0xc3, 0x00];
375
376 let decoded = from_slice::<S>(buf).unwrap();
377 assert_eq!(
378 decoded,
379 S {
380 compact: true,
381 schema: 0
382 }
383 );
384 }
385
386 #[test]
387 fn option_consumes_nil_in_sequence() {
388 let buf: &[u8] = &[0x92, 0xc0, 0x05];
390
391 let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
392 assert_eq!(decoded, (None, 5));
393 }
394
395 #[test]
396 fn option_some_simple() {
397 let buf: &[u8] = &[0x05];
398 let decoded = from_slice::<Option<u8>>(buf).unwrap();
399 assert_eq!(decoded, Some(5));
400 }
401
402 #[test]
403 fn unit_from_nil() {
404 let buf: &[u8] = &[0xc0];
405 from_slice::<()>(buf).unwrap();
406 }
407
408 #[test]
409 fn unit_struct() {
410 #[derive(Debug, Deserialize, PartialEq)]
411 struct U;
412
413 let buf: &[u8] = &[0xc0];
414 let decoded = from_slice::<U>(buf).unwrap();
415 assert_eq!(decoded, U);
416 }
417
418 #[derive(Deserialize, PartialEq, Debug)]
419 enum E {
420 Unit,
421 Newtype(u8),
422 Tuple(u8, bool),
423 Struct { a: bool },
424 }
425 #[rstest]
426 #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
431 let decoded = from_slice::<E>(buf.as_ref()).unwrap();
432 assert_eq!(decoded, expected);
433 }
434
435 #[derive(Deserialize, PartialEq, Debug)]
436 #[serde(untagged)]
437 enum Untagged {
438 Bool(bool),
439 U8(u8),
440 Pair(u8, bool),
441 Struct { a: bool },
442 Nested(E),
443 }
444
445 #[rstest]
446 #[case([0xc3],Untagged::Bool(true))]
447 #[case([0x05],Untagged::U8(5))]
448 #[case([0x92, 0x02, 0xc3],Untagged::Pair(2,true))]
449 #[case([0x81, 0xa1, 0x61, 0xc2],Untagged::Struct { a: false })]
450 #[case([0xa4,0x55,0x6e,0x69,0x74],Untagged::Nested(E::Unit))] fn decode_untagged_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: Untagged) {
452 let decoded = from_slice::<Untagged>(buf.as_ref()).unwrap();
453 assert_eq!(decoded, expected);
454 }
455
456 #[test]
457 fn recursion_limit_ok_at_256() {
458 let mut buf = vec![0x91u8; 256];
460 buf.push(0xc0);
461
462 let _ = from_slice::<IgnoredAny>(&buf).unwrap();
463 }
464
465 #[test]
466 fn recursion_limit_err_over_256() {
467 let mut buf = vec![0x91u8; 257];
469 buf.push(0xc0);
470
471 let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
472 assert!(matches!(err, Error::RecursionLimitExceeded));
473 }
474
475 #[cfg(feature = "std")]
476 #[rstest]
477 #[case([0xc0],())]
479 #[case([0xc3],true)]
481 #[case([0xc2],false)]
482 #[case([0x2a],42u8)]
484 #[case([0xcc, 0x80],128u8)]
485 #[case([0xcd, 0x01, 0x00],256u16)]
486 #[case([0xce, 0x00, 0x01, 0x00, 0x00],65536u32)]
487 #[case([0xcf, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00],4294967296u64)]
488 #[case([0xff],-1i8)]
490 #[case([0xd0, 0x80],-128i8)]
491 #[case([0xd1, 0x80, 0x00],-32768i16)]
492 #[case([0xd2, 0x80, 0x00, 0x00, 0x00],-2147483648i32)]
493 #[case([0xd3, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],i64::MIN)]
494 #[case([0xca, 0x41, 0x45, 0x70, 0xa4],12.34f32)]
496 #[case([0xcb, 0x3f, 0xf0, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00],1.0f64)]
497 #[case([0xa1, 0x61],"a".to_string())]
499 #[case([0xd9, 0x05, 0x68, 0x65, 0x6c, 0x6c, 0x6f],"hello".to_string())]
500 #[case([0xc4, 0x03, 0x01, 0x02, 0x03],serde_bytes::ByteBuf::from(vec![1u8, 2, 3]))]
502 #[case([0x93, 0x01, 0x02, 0x03],vec![1u8, 2, 3])]
504 #[case([0x82, 0xa1, 0x61, 0x01, 0xa1, 0x62, 0x02],{
506 let mut m = std::collections::BTreeMap::<String, u8>::new();
507 m.insert("a".to_string(), 1u8);
508 m.insert("b".to_string(), 2u8);
509 m
510 })]
511 fn decode_success_from_reader_when_owned<
512 Buf: AsRef<[u8]>,
513 T: serde::de::DeserializeOwned + core::fmt::Debug + PartialEq,
514 >(
515 #[case] buf: Buf,
516 #[case] expected: T,
517 ) {
518 use super::from_reader;
519 let mut reader = std::io::Cursor::new(buf.as_ref());
520 let val = from_reader::<_, T>(&mut reader).unwrap();
521 assert_eq!(val, expected)
522 }
523}