1mod enum_;
4mod error;
5mod seq;
6use error::CoreError;
7pub use error::Error;
8
9use crate::value::extension::DeserializeExt;
10use messagepack_core::{Decode, Format, decode::NbyteReader};
11use serde::{
12 Deserialize,
13 de::{self, IntoDeserializer},
14 forward_to_deserialize_any,
15};
16
17pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
19 let mut deserializer = Deserializer::from_slice(input);
20 T::deserialize(&mut deserializer)
21}
22
23#[cfg(feature = "std")]
24pub fn from_reader<R, T>(reader: &mut R) -> std::io::Result<T>
26where
27 R: std::io::Read,
28 T: for<'a> Deserialize<'a>,
29{
30 let mut buf = Vec::new();
31 reader.read_to_end(&mut buf)?;
32
33 let mut deserializer = Deserializer::from_slice(&buf);
34 T::deserialize(&mut deserializer).map_err(std::io::Error::other)
35}
36
37const MAX_RECURSION_DEPTH: usize = 256;
38
39#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
40struct Deserializer<'de> {
41 input: &'de [u8],
42 depth: usize,
43}
44
45impl<'de> Deserializer<'de> {
46 pub fn from_slice(input: &'de [u8]) -> Self {
47 Deserializer { input, depth: 0 }
48 }
49
50 fn recurse<F, V>(&mut self, f: F) -> Result<V, Error>
51 where
52 F: FnOnce(&mut Self) -> V,
53 {
54 if self.depth == MAX_RECURSION_DEPTH {
55 return Err(Error::RecursionLimitExceeded);
56 }
57 self.depth += 1;
58 let result = f(self);
59 self.depth -= 1;
60 Ok(result)
61 }
62
63 fn decode<V: Decode<'de>>(&mut self) -> Result<V::Value, Error> {
64 let (decoded, rest) = V::decode(self.input)?;
65 self.input = rest;
66 Ok(decoded)
67 }
68
69 fn decode_with_format<V: Decode<'de>>(&mut self, format: Format) -> Result<V::Value, Error> {
70 let (decoded, rest) = V::decode_with_format(format, self.input)?;
71 self.input = rest;
72 Ok(decoded)
73 }
74
75 fn decode_seq_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
76 where
77 V: de::Visitor<'de>,
78 {
79 let n = match format {
80 Format::FixArray(n) => n.into(),
81 Format::Array16 => {
82 let (n, buf) = NbyteReader::<2>::read(self.input)?;
83 self.input = buf;
84 n
85 }
86 Format::Array32 => {
87 let (n, buf) = NbyteReader::<4>::read(self.input)?;
88 self.input = buf;
89 n
90 }
91 _ => return Err(CoreError::UnexpectedFormat.into()),
92 };
93 self.recurse(move |des| visitor.visit_seq(seq::FixLenAccess::new(des, n)))?
94 }
95
96 fn decode_map_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
97 where
98 V: de::Visitor<'de>,
99 {
100 let n = match format {
101 Format::FixMap(n) => n.into(),
102 Format::Map16 => {
103 let (n, buf) = NbyteReader::<2>::read(self.input)?;
104 self.input = buf;
105 n
106 }
107 Format::Map32 => {
108 let (n, buf) = NbyteReader::<4>::read(self.input)?;
109 self.input = buf;
110 n
111 }
112 _ => return Err(CoreError::UnexpectedFormat.into()),
113 };
114 self.recurse(move |des| visitor.visit_map(seq::FixLenAccess::new(des, n)))?
115 }
116}
117
118impl AsMut<Self> for Deserializer<'_> {
119 fn as_mut(&mut self) -> &mut Self {
120 self
121 }
122}
123
124impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
125 type Error = Error;
126
127 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
128 where
129 V: de::Visitor<'de>,
130 {
131 let format = self.decode::<Format>()?;
132 match format {
133 Format::Nil => visitor.visit_unit(),
134 Format::False => visitor.visit_bool(false),
135 Format::True => visitor.visit_bool(true),
136 Format::PositiveFixInt(v) => visitor.visit_u8(v),
137 Format::Uint8 => {
138 let v = self.decode_with_format::<u8>(format)?;
139 visitor.visit_u8(v)
140 }
141 Format::Uint16 => {
142 let v = self.decode_with_format::<u16>(format)?;
143 visitor.visit_u16(v)
144 }
145 Format::Uint32 => {
146 let v = self.decode_with_format::<u32>(format)?;
147 visitor.visit_u32(v)
148 }
149 Format::Uint64 => {
150 let v = self.decode_with_format::<u64>(format)?;
151 visitor.visit_u64(v)
152 }
153 Format::NegativeFixInt(v) => visitor.visit_i8(v),
154 Format::Int8 => {
155 let v = self.decode_with_format::<i8>(format)?;
156 visitor.visit_i8(v)
157 }
158 Format::Int16 => {
159 let v = self.decode_with_format::<i16>(format)?;
160 visitor.visit_i16(v)
161 }
162 Format::Int32 => {
163 let v = self.decode_with_format::<i32>(format)?;
164 visitor.visit_i32(v)
165 }
166 Format::Int64 => {
167 let v = self.decode_with_format::<i64>(format)?;
168 visitor.visit_i64(v)
169 }
170 Format::Float32 => {
171 let v = self.decode_with_format::<f32>(format)?;
172 visitor.visit_f32(v)
173 }
174 Format::Float64 => {
175 let v = self.decode_with_format::<f64>(format)?;
176 visitor.visit_f64(v)
177 }
178 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
179 let v = self.decode_with_format::<&str>(format)?;
180 visitor.visit_borrowed_str(v)
181 }
182 Format::FixArray(_) | Format::Array16 | Format::Array32 => {
183 self.decode_seq_with_format(format, visitor)
184 }
185 Format::Bin8 | Format::Bin16 | Format::Bin32 => {
186 let v = self.decode_with_format::<&[u8]>(format)?;
187 visitor.visit_borrowed_bytes(v)
188 }
189 Format::FixMap(_) | Format::Map16 | Format::Map32 => {
190 self.decode_map_with_format(format, visitor)
191 }
192 Format::Ext8
193 | Format::Ext16
194 | Format::Ext32
195 | Format::FixExt1
196 | Format::FixExt2
197 | Format::FixExt4
198 | Format::FixExt8
199 | Format::FixExt16 => {
200 let mut de_ext = DeserializeExt::new(format, self.input)?;
201 let val = (&mut de_ext).deserialize_newtype_struct(
202 crate::value::extension::EXTENSION_STRUCT_NAME,
203 visitor,
204 )?;
205 self.input = de_ext.input;
206
207 Ok(val)
208 }
209 Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
210 }
211 }
212
213 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
214 where
215 V: de::Visitor<'de>,
216 {
217 let (first, rest) = self.input.split_first().ok_or(CoreError::EofFormat)?;
218
219 let format = Format::from_byte(*first);
220 match format {
221 Format::Nil => {
222 self.input = rest;
223 visitor.visit_none()
224 }
225 _ => visitor.visit_some(self),
226 }
227 }
228
229 fn deserialize_enum<V>(
230 self,
231 _name: &'static str,
232 _variants: &'static [&'static str],
233 visitor: V,
234 ) -> Result<V::Value, Self::Error>
235 where
236 V: de::Visitor<'de>,
237 {
238 let ident = self.decode::<&str>();
239 match ident {
240 Ok(ident) => visitor.visit_enum(ident.into_deserializer()),
241 _ => {
242 let (format, rest) = Format::decode(self.input)?;
243 let mut des = Deserializer::from_slice(rest);
244 des.depth = self.depth;
246 let val = match format {
247 Format::FixMap(_)
248 | Format::Map16
249 | Format::Map32
250 | Format::FixArray(_)
251 | Format::Array16
252 | Format::Array32 => {
253 des.recurse(|d| visitor.visit_enum(enum_::Enum::new(d)))?
254 }
255 _ => Err(CoreError::UnexpectedFormat.into()),
256 }?;
257 self.input = des.input;
258
259 Ok(val)
260 }
261 }
262 }
263
264 forward_to_deserialize_any! {
265 bool i8 i16 i32 i64 i128 u8 u16 u32 u64 u128 f32 f64 char str string
266 bytes byte_buf unit unit_struct newtype_struct seq tuple
267 tuple_struct map struct identifier ignored_any
268 }
269
270 fn is_human_readable(&self) -> bool {
271 false
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use rstest::rstest;
278
279 use super::*;
280 use serde::de::IgnoredAny;
281
282 #[rstest]
283 #[case([0xc3],true)]
284 #[case([0xc2],false)]
285 fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
286 let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
287 assert_eq!(decoded, expected);
288 }
289
290 #[rstest]
291 #[case([0x05],5)]
292 #[case([0xcc, 0x80],128)]
293 fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
294 let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
295 assert_eq!(decoded, expected);
296 }
297
298 #[test]
299 fn decode_float_vec() {
300 let buf = [
302 0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
303 0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
304 0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
305 0x00, 0x00, 0x00, 0x00,
306 ];
307
308 let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
309 let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
310
311 assert_eq!(decoded, expected)
312 }
313
314 #[test]
315 fn decode_struct() {
316 #[derive(Deserialize)]
317 struct S {
318 compact: bool,
319 schema: u8,
320 }
321
322 let buf: &[u8] = &[
324 0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
325 0x65, 0x6d, 0x61, 0x00,
326 ];
327
328 let decoded = from_slice::<S>(buf).unwrap();
329 assert!(decoded.compact);
330 assert_eq!(decoded.schema, 0);
331 }
332
333 #[test]
334 fn option_consumes_nil_in_sequence() {
335 let buf: &[u8] = &[0x92, 0xc0, 0x05];
337
338 let decoded = from_slice::<(Option<u8>, u8)>(buf).unwrap();
339 assert_eq!(decoded, (None, 5));
340 }
341
342 #[test]
343 fn option_some_simple() {
344 let buf: &[u8] = &[0x05];
345 let decoded = from_slice::<Option<u8>>(buf).unwrap();
346 assert_eq!(decoded, Some(5));
347 }
348
349 #[test]
350 fn unit_from_nil() {
351 let buf: &[u8] = &[0xc0];
352 from_slice::<()>(buf).unwrap();
353 }
354
355 #[test]
356 fn unit_struct() {
357 #[derive(Debug, Deserialize, PartialEq)]
358 struct U;
359
360 let buf: &[u8] = &[0xc0];
361 let decoded = from_slice::<U>(buf).unwrap();
362 assert_eq!(decoded, U);
363 }
364
365 #[derive(Deserialize, PartialEq, Debug)]
366 enum E {
367 Unit,
368 Newtype(u8),
369 Tuple(u8, bool),
370 Struct { a: bool },
371 }
372 #[rstest]
373 #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
378 let decoded = from_slice::<E>(buf.as_ref()).unwrap();
379 assert_eq!(decoded, expected);
380 }
381
382 #[test]
383 fn recursion_limit_ok_at_256() {
384 let mut buf = vec![0x91u8; 256];
386 buf.push(0xc0);
387
388 let _ = from_slice::<IgnoredAny>(&buf).unwrap();
389 }
390
391 #[test]
392 fn recursion_limit_err_over_256() {
393 let mut buf = vec![0x91u8; 257];
395 buf.push(0xc0);
396
397 let err = from_slice::<IgnoredAny>(&buf).unwrap_err();
398 assert!(matches!(err, Error::RecursionLimitExceeded));
399 }
400}