1mod enum_;
2mod error;
3mod seq;
4
5pub use error::{CoreError, Error};
6
7use crate::value::extension::DeserializeExt;
8use messagepack_core::{
9 Decode, Format,
10 decode::{NbyteReader, NilDecoder},
11};
12use serde::{
13 Deserialize,
14 de::{self, IntoDeserializer},
15};
16
17#[derive(Debug, Clone, PartialOrd, Ord, PartialEq, Eq)]
18pub struct Deserializer<'de> {
19 input: &'de [u8],
20}
21
22impl<'de> Deserializer<'de> {
23 pub fn from_slice(input: &'de [u8]) -> Self {
24 Deserializer { input }
25 }
26
27 fn decode<V: Decode<'de>>(&mut self) -> Result<V::Value, Error> {
28 let (decoded, rest) = V::decode(self.input)?;
29 self.input = rest;
30 Ok(decoded)
31 }
32
33 fn decode_with_format<V: Decode<'de>>(&mut self, format: Format) -> Result<V::Value, Error> {
34 let (decoded, rest) = V::decode_with_format(format, self.input)?;
35 self.input = rest;
36 Ok(decoded)
37 }
38
39 fn decode_seq_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
40 where
41 V: de::Visitor<'de>,
42 {
43 let n = match format {
44 Format::FixArray(n) => n.into(),
45 Format::Array16 => {
46 let (n, buf) = NbyteReader::<2>::read(self.input)?;
47 self.input = buf;
48 n
49 }
50 Format::Array32 => {
51 let (n, buf) = NbyteReader::<4>::read(self.input)?;
52 self.input = buf;
53 n
54 }
55 _ => return Err(CoreError::UnexpectedFormat.into()),
56 };
57 visitor.visit_seq(seq::FixLenAccess::new(self, n))
58 }
59
60 fn decode_map_with_format<V>(&mut self, format: Format, visitor: V) -> Result<V::Value, Error>
61 where
62 V: de::Visitor<'de>,
63 {
64 let n = match format {
65 Format::FixMap(n) => n.into(),
66 Format::Map16 => {
67 let (n, buf) = NbyteReader::<2>::read(self.input)?;
68 self.input = buf;
69 n
70 }
71 Format::Map32 => {
72 let (n, buf) = NbyteReader::<4>::read(self.input)?;
73 self.input = buf;
74 n
75 }
76 _ => return Err(CoreError::UnexpectedFormat.into()),
77 };
78 visitor.visit_map(seq::FixLenAccess::new(self, n))
79 }
80}
81
82impl AsMut<Self> for Deserializer<'_> {
83 fn as_mut(&mut self) -> &mut Self {
84 self
85 }
86}
87
88pub fn from_slice<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
89 from_slice_with_config(input)
90}
91
92pub fn from_slice_with_config<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
93 let mut deserializer = Deserializer::from_slice(input);
94 T::deserialize(&mut deserializer)
95}
96
97#[cfg(feature = "std")]
98pub fn from_reader<R, T>(reader: &mut R) -> std::io::Result<T>
99where
100 R: std::io::Read,
101 T: for<'a> Deserialize<'a>,
102{
103 from_reader_with_config(reader)
104}
105
106#[cfg(feature = "std")]
107pub fn from_reader_with_config<R, T>(reader: &mut R) -> std::io::Result<T>
108where
109 R: std::io::Read,
110 T: for<'a> Deserialize<'a>,
111{
112 let mut buf = Vec::new();
113 reader.read_to_end(&mut buf)?;
114
115 let mut deserializer = Deserializer::from_slice(&buf);
116 T::deserialize(&mut deserializer).map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
117}
118
119impl<'de> de::Deserializer<'de> for &mut Deserializer<'de> {
120 type Error = Error;
121
122 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
123 where
124 V: de::Visitor<'de>,
125 {
126 let format = self.decode::<Format>()?;
127 match format {
128 Format::Nil => visitor.visit_none(),
129 Format::False => visitor.visit_bool(false),
130 Format::True => visitor.visit_bool(true),
131 Format::PositiveFixInt(v) => visitor.visit_u8(v),
132 Format::Uint8 => {
133 let v = self.decode_with_format::<u8>(format)?;
134 visitor.visit_u8(v)
135 }
136 Format::Uint16 => {
137 let v = self.decode_with_format::<u16>(format)?;
138 visitor.visit_u16(v)
139 }
140 Format::Uint32 => {
141 let v = self.decode_with_format::<u32>(format)?;
142 visitor.visit_u32(v)
143 }
144 Format::Uint64 => {
145 let v = self.decode_with_format::<u64>(format)?;
146 visitor.visit_u64(v)
147 }
148 Format::NegativeFixInt(v) => visitor.visit_i8(v),
149 Format::Int8 => {
150 let v = self.decode_with_format::<i8>(format)?;
151 visitor.visit_i8(v)
152 }
153 Format::Int16 => {
154 let v = self.decode_with_format::<i16>(format)?;
155 visitor.visit_i16(v)
156 }
157 Format::Int32 => {
158 let v = self.decode_with_format::<i32>(format)?;
159 visitor.visit_i32(v)
160 }
161 Format::Int64 => {
162 let v = self.decode_with_format::<i64>(format)?;
163 visitor.visit_i64(v)
164 }
165 Format::Float32 => {
166 let v = self.decode_with_format::<f32>(format)?;
167 visitor.visit_f32(v)
168 }
169 Format::Float64 => {
170 let v = self.decode_with_format::<f64>(format)?;
171 visitor.visit_f64(v)
172 }
173 Format::FixStr(_) | Format::Str8 | Format::Str16 | Format::Str32 => {
174 let v = self.decode_with_format::<&str>(format)?;
175 visitor.visit_borrowed_str(v)
176 }
177 Format::FixArray(_) | Format::Array16 | Format::Array32 => {
178 self.decode_seq_with_format(format, visitor)
179 }
180 Format::Bin8 | Format::Bin16 | Format::Bin32 => {
181 let v = self.decode_with_format::<&[u8]>(format)?;
182 visitor.visit_borrowed_bytes(v)
183 }
184 Format::FixMap(_) | Format::Map16 | Format::Map32 => {
185 self.decode_map_with_format(format, visitor)
186 }
187 Format::Ext8
188 | Format::Ext16
189 | Format::Ext32
190 | Format::FixExt1
191 | Format::FixExt2
192 | Format::FixExt4
193 | Format::FixExt8
194 | Format::FixExt16 => {
195 let mut de_ext = DeserializeExt::new(format, self.input)?;
196 let val = (&mut de_ext).deserialize_newtype_struct(
197 crate::value::extension::EXTENSION_STRUCT_NAME,
198 visitor,
199 )?;
200 self.input = de_ext.input;
201
202 Ok(val)
203 }
204 Format::NeverUsed => Err(CoreError::UnexpectedFormat.into()),
205 }
206 }
207
208 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209 where
210 V: de::Visitor<'de>,
211 {
212 let decoded = self.decode::<bool>()?;
213 visitor.visit_bool(decoded)
214 }
215
216 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
217 where
218 V: de::Visitor<'de>,
219 {
220 let decoded = self.decode::<i8>()?;
221 visitor.visit_i8(decoded)
222 }
223
224 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
225 where
226 V: de::Visitor<'de>,
227 {
228 let decoded = self.decode::<i16>()?;
229 visitor.visit_i16(decoded)
230 }
231
232 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
233 where
234 V: de::Visitor<'de>,
235 {
236 let decoded = self.decode::<i32>()?;
237 visitor.visit_i32(decoded)
238 }
239
240 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
241 where
242 V: de::Visitor<'de>,
243 {
244 let decoded = self.decode::<i64>()?;
245 visitor.visit_i64(decoded)
246 }
247
248 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
249 where
250 V: de::Visitor<'de>,
251 {
252 let decoded = self.decode::<u8>()?;
253 visitor.visit_u8(decoded)
254 }
255
256 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
257 where
258 V: de::Visitor<'de>,
259 {
260 let decoded = self.decode::<u16>()?;
261 visitor.visit_u16(decoded)
262 }
263
264 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
265 where
266 V: de::Visitor<'de>,
267 {
268 let decoded = self.decode::<u32>()?;
269 visitor.visit_u32(decoded)
270 }
271
272 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
273 where
274 V: de::Visitor<'de>,
275 {
276 let decoded = self.decode::<u64>()?;
277 visitor.visit_u64(decoded)
278 }
279
280 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
281 where
282 V: de::Visitor<'de>,
283 {
284 let decoded = self.decode::<f32>()?;
285 visitor.visit_f32(decoded)
286 }
287
288 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
289 where
290 V: de::Visitor<'de>,
291 {
292 let decoded = self.decode::<f64>()?;
293 visitor.visit_f64(decoded)
294 }
295
296 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
297 where
298 V: de::Visitor<'de>,
299 {
300 self.deserialize_str(visitor)
301 }
302
303 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
304 where
305 V: de::Visitor<'de>,
306 {
307 let decoded = self.decode::<&str>()?;
308 visitor.visit_borrowed_str(decoded)
309 }
310
311 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
312 where
313 V: de::Visitor<'de>,
314 {
315 self.deserialize_str(visitor)
316 }
317
318 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
319 where
320 V: de::Visitor<'de>,
321 {
322 let decoded = self.decode::<&[u8]>()?;
323 visitor.visit_borrowed_bytes(decoded)
324 }
325
326 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
327 where
328 V: de::Visitor<'de>,
329 {
330 self.deserialize_bytes(visitor)
331 }
332
333 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
334 where
335 V: de::Visitor<'de>,
336 {
337 let is_null = NilDecoder::decode(self.input).is_ok();
338 if is_null {
339 visitor.visit_none()
340 } else {
341 visitor.visit_some(self)
342 }
343 }
344
345 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
346 where
347 V: de::Visitor<'de>,
348 {
349 self.decode::<()>()?;
350 visitor.visit_unit()
351 }
352
353 fn deserialize_unit_struct<V>(
354 self,
355 _name: &'static str,
356 visitor: V,
357 ) -> Result<V::Value, Self::Error>
358 where
359 V: de::Visitor<'de>,
360 {
361 self.deserialize_unit(visitor)
362 }
363
364 fn deserialize_newtype_struct<V>(
365 self,
366 _name: &'static str,
367 visitor: V,
368 ) -> Result<V::Value, Self::Error>
369 where
370 V: de::Visitor<'de>,
371 {
372 visitor.visit_newtype_struct(self)
373 }
374
375 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
376 where
377 V: de::Visitor<'de>,
378 {
379 let (format, rest) = Format::decode(self.input)?;
380
381 let mut des = Deserializer::from_slice(rest);
382 let val = des.decode_seq_with_format(format, visitor)?;
383 self.input = des.input;
384
385 Ok(val)
386 }
387
388 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
389 where
390 V: de::Visitor<'de>,
391 {
392 self.deserialize_seq(visitor)
393 }
394
395 fn deserialize_tuple_struct<V>(
396 self,
397 _name: &'static str,
398 _len: usize,
399 visitor: V,
400 ) -> Result<V::Value, Self::Error>
401 where
402 V: de::Visitor<'de>,
403 {
404 self.deserialize_seq(visitor)
405 }
406
407 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
408 where
409 V: de::Visitor<'de>,
410 {
411 let (format, rest) = Format::decode(self.input)?;
412
413 let mut des = Deserializer::from_slice(rest);
414 let val = des.decode_map_with_format(format, visitor)?;
415 self.input = des.input;
416
417 Ok(val)
418 }
419
420 fn deserialize_struct<V>(
421 self,
422 _name: &'static str,
423 _fields: &'static [&'static str],
424 visitor: V,
425 ) -> Result<V::Value, Self::Error>
426 where
427 V: de::Visitor<'de>,
428 {
429 self.deserialize_map(visitor)
430 }
431
432 fn deserialize_enum<V>(
433 self,
434 _name: &'static str,
435 _variants: &'static [&'static str],
436 visitor: V,
437 ) -> Result<V::Value, Self::Error>
438 where
439 V: de::Visitor<'de>,
440 {
441 let ident = self.decode::<&str>();
442 match ident {
443 Ok(ident) => visitor.visit_enum(ident.into_deserializer()),
444 _ => {
445 let (format, rest) = Format::decode(self.input)?;
446
447 let mut des = Deserializer::from_slice(rest);
448 let val = match format {
449 Format::FixMap(_)
450 | Format::Map16
451 | Format::Map32
452 | Format::FixArray(_)
453 | Format::Array16
454 | Format::Array32 => visitor.visit_enum(enum_::Enum::new(&mut des)),
455 _ => Err(CoreError::UnexpectedFormat.into()),
456 }?;
457
458 self.input = des.input;
459
460 Ok(val)
461 }
462 }
463 }
464
465 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
466 where
467 V: de::Visitor<'de>,
468 {
469 self.deserialize_str(visitor)
470 }
471
472 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
473 where
474 V: de::Visitor<'de>,
475 {
476 self.deserialize_any(visitor)
477 }
478
479 fn is_human_readable(&self) -> bool {
480 false
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use rstest::rstest;
487
488 use super::*;
489
490 #[rstest]
491 #[case([0xc3],true)]
492 #[case([0xc2],false)]
493 fn decode_bool<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: bool) {
494 let decoded = from_slice::<bool>(buf.as_ref()).unwrap();
495 assert_eq!(decoded, expected);
496 }
497
498 #[rstest]
499 #[case([0x05],5)]
500 #[case([0xcc, 0x80],128)]
501 fn decode_uint8<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: u8) {
502 let decoded = from_slice::<u8>(buf.as_ref()).unwrap();
503 assert_eq!(decoded, expected);
504 }
505
506 #[test]
507 fn decode_float_vec() {
508 let buf = [
510 0x95, 0xcb, 0x3f, 0xf1, 0x99, 0x99, 0x99, 0x99, 0x99, 0x9a, 0xcb, 0x3f, 0xf3, 0x33,
511 0x33, 0x33, 0x33, 0x33, 0x33, 0xcb, 0x3f, 0xf4, 0xcc, 0xcc, 0xcc, 0xcc, 0xcc, 0xcd,
512 0xcb, 0x3f, 0xf6, 0x66, 0x66, 0x66, 0x66, 0x66, 0x66, 0xcb, 0x3f, 0xf8, 0x00, 0x00,
513 0x00, 0x00, 0x00, 0x00,
514 ];
515
516 let decoded = from_slice::<Vec<f64>>(&buf).unwrap();
517 let expected = [1.1, 1.2, 1.3, 1.4, 1.5];
518
519 assert_eq!(decoded, expected)
520 }
521
522 #[test]
523 fn decode_struct() {
524 #[derive(Deserialize)]
525 struct S {
526 compact: bool,
527 schema: u8,
528 }
529
530 let buf: &[u8] = &[
532 0x82, 0xa7, 0x63, 0x6f, 0x6d, 0x70, 0x61, 0x63, 0x74, 0xc3, 0xa6, 0x73, 0x63, 0x68,
533 0x65, 0x6d, 0x61, 0x00,
534 ];
535
536 let decoded = from_slice::<S>(buf).unwrap();
537 assert!(decoded.compact);
538 assert_eq!(decoded.schema, 0);
539 }
540
541 #[derive(Deserialize, PartialEq, Debug)]
542 enum E {
543 Unit,
544 Newtype(u8),
545 Tuple(u8, bool),
546 Struct { a: bool },
547 }
548 #[rstest]
549 #[case([0xa4, 0x55, 0x6e, 0x69, 0x74],E::Unit)] #[case([0x81, 0xa7, 0x4e, 0x65, 0x77, 0x74, 0x79, 0x70, 0x65, 0x1b], E::Newtype(27))] #[case([0x81, 0xa5, 0x54, 0x75, 0x70, 0x6c, 0x65, 0x92, 0x03, 0xc3], E::Tuple(3, true))] #[case([0x81, 0xa6, 0x53, 0x74, 0x72, 0x75, 0x63, 0x74, 0x81, 0xa1, 0x61, 0xc2],E::Struct { a: false })] fn decode_enum<Buf: AsRef<[u8]>>(#[case] buf: Buf, #[case] expected: E) {
554 let decoded = from_slice::<E>(buf.as_ref()).unwrap();
555 assert_eq!(decoded, expected);
556 }
557
558 #[rstest]
559 fn decode_extension() {
560 use crate::value::extension::ExtensionRef;
561
562 let buf: &[u8] = &[0xd4, 0x7b, 0x12];
563
564 let ext = from_slice::<ExtensionRef<'_>>(buf).unwrap();
565 assert_eq!(ext.kind, 123);
566 assert_eq!(ext.data, [0x12_u8])
567 }
568}