1use crate::{constants::*, Error, Result};
2use serde::{
3 de::{
4 self, DeserializeSeed, EnumAccess, IntoDeserializer, MapAccess, SeqAccess, VariantAccess,
5 Visitor,
6 },
7 Deserialize,
8};
9use std::{convert::TryFrom, io, io::Read, marker::PhantomData, slice};
10
11pub fn from_slice<'de, T>(slice: &'de [u8]) -> Result<T>
13where
14 T: Deserialize<'de>,
15{
16 from_reader(slice)
17}
18
19pub fn from_reader<'de, R, T>(reader: R) -> Result<T>
21where
22 T: Deserialize<'de>,
23 R: Read,
24{
25 let mut de = Deserializer::new(reader);
26 let value = Deserialize::deserialize(&mut de)?;
27 de.end()?;
28 Ok(value)
29}
30
31pub struct Deserializer<R>
33where
34 R: Read,
35{
36 reader: R,
37 last_discriminator: Option<(u8, u8)>,
38}
39
40impl<R> Deserializer<R>
41where
42 R: Read,
43{
44 pub fn new(reader: R) -> Self {
46 Self {
47 reader,
48 last_discriminator: None,
49 }
50 }
51
52 pub fn end(&mut self) -> Result<()> {
56 match self.read_discriminator() {
57 Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(()),
58 _ => Err(Error::TrailingBytes),
59 }
60 }
61
62 #[allow(clippy::should_implement_trait)]
64 pub fn into_iter<'de, T>(self) -> StreamDeserializer<'de, R, T>
65 where
66 T: Deserialize<'de>,
67 {
68 StreamDeserializer {
69 de: self,
70 failed: false,
71 output: PhantomData,
72 lifetime: PhantomData,
73 }
74 }
75
76 fn read_discriminator(&mut self) -> Result<(u8, u8)> {
77 let mut d = 0;
78 self.reader.read_exact(slice::from_mut(&mut d))?;
79 Ok((d & TYPE_MASK, d & !TYPE_MASK))
80 }
81
82 fn peek_discriminator(&mut self) -> Result<(u8, u8)> {
83 if self.last_discriminator.is_none() {
84 self.last_discriminator = Some(self.read_discriminator()?);
85 }
86 Ok(self.last_discriminator.unwrap())
87 }
88
89 fn consume_discriminator(&mut self) -> Result<(u8, u8)> {
90 self.last_discriminator
91 .take()
92 .map(Result::Ok)
93 .unwrap_or_else(|| self.read_discriminator())
94 }
95
96 fn read_i64(&mut self, len: usize) -> Result<i64> {
97 let mut buf = [0u8; 8];
98 let start = 8 - len;
99 self.reader.read_exact(&mut buf[start..])?;
100 if buf[start] & 0x80 != 0 {
101 buf[0..start].fill(0xFF);
103 }
104 Ok(i64::from_be_bytes(buf))
105 }
106
107 fn read_null(&mut self) -> Result<()> {
108 let (typ, bits) = self.consume_discriminator()?;
109 if typ != TYPE_NULL {
110 return Err(Error::WrongType);
111 }
112 if bits != 0 {
113 return Err(Error::InvalidValue);
114 }
115 Ok(())
116 }
117
118 fn read_boolean(&mut self) -> Result<bool> {
119 let (typ, bits) = self.consume_discriminator()?;
120 if typ != TYPE_BOOLEAN {
121 return Err(Error::WrongType);
122 }
123 if bits > 1 {
124 return Err(Error::InvalidValue);
125 }
126 Ok(bits == 1)
127 }
128
129 fn read_integer(&mut self) -> Result<i64> {
130 let (typ, len) = self.consume_discriminator()?;
131 if typ != TYPE_INTEGER {
132 return Err(Error::WrongType);
133 }
134 if !len.is_power_of_two() {
135 return Err(Error::InvalidLength);
136 }
137 self.read_i64(len as usize)
138 }
139
140 fn read_float(&mut self) -> Result<f64> {
141 let (typ, len) = self.consume_discriminator()?;
142 if typ != TYPE_FLOAT {
143 return Err(Error::WrongType);
144 }
145 if len != 8 {
146 return Err(Error::InvalidLength);
147 }
148 let mut buf = [0u8; 8];
149 self.reader.read_exact(&mut buf)?;
150 Ok(f64::from_be_bytes(buf))
151 }
152
153 fn read_string(&mut self) -> Result<String> {
154 let (typ, llen) = self.consume_discriminator()?;
155 if typ != TYPE_STRING {
156 return Err(Error::WrongType);
157 }
158 if !llen.is_power_of_two() {
159 return Err(Error::InvalidLengthOfLength);
160 }
161 let len = self.read_i64(llen as usize)?;
162 if len < 0 {
163 return Err(Error::InvalidLength);
164 }
165 let mut s = String::with_capacity(len as usize);
166 let read = (&mut self.reader).take(len as u64).read_to_string(&mut s)?;
167 if read != len as usize {
168 return Err(Error::eof());
169 }
170 Ok(s)
171 }
172
173 fn read_raw(&mut self) -> Result<Vec<u8>> {
174 let (typ, llen) = self.consume_discriminator()?;
175 if typ != TYPE_RAW {
176 return Err(Error::WrongType);
177 }
178 if !llen.is_power_of_two() {
179 return Err(Error::InvalidLengthOfLength);
180 }
181 let len = self.read_i64(llen as usize)?;
182 if len < 0 {
183 return Err(Error::InvalidLength);
184 }
185 let mut v = Vec::with_capacity(len as usize);
186 let read = (&mut self.reader).take(len as u64).read_to_end(&mut v)?;
187 if read != len as usize {
188 return Err(Error::eof());
189 }
190 Ok(v)
191 }
192
193 fn read_list_start(&mut self) -> Result<()> {
194 let (typ, bits) = self.consume_discriminator()?;
195 if typ != TYPE_LIST {
196 return Err(Error::WrongType);
197 }
198 if bits != 0 {
199 return Err(Error::InvalidValue);
200 }
201 Ok(())
202 }
203
204 fn read_dictionary_start(&mut self) -> Result<()> {
205 let (typ, bits) = self.consume_discriminator()?;
206 if typ != TYPE_DICTIONARY {
207 return Err(Error::WrongType);
208 }
209 if bits != 0 {
210 return Err(Error::InvalidValue);
211 }
212 Ok(())
213 }
214
215 fn peek_end(&mut self) -> Result<bool> {
216 let (typ, bits) = self.peek_discriminator()?;
217 if typ != TYPE_END {
218 return Ok(false);
219 }
220 if bits != 0 {
221 return Err(Error::InvalidValue);
222 }
223 Ok(true)
224 }
225
226 fn read_end(&mut self) -> Result<()> {
227 let (typ, bits) = self.consume_discriminator()?;
228 if typ != TYPE_END {
229 return Err(Error::WrongType);
230 }
231 if bits != 0 {
232 return Err(Error::InvalidValue);
233 }
234 Ok(())
235 }
236}
237
238impl<'de, 'a, R> de::Deserializer<'de> for &'a mut Deserializer<R>
239where
240 R: Read,
241{
242 type Error = Error;
243
244 fn deserialize_any<V>(self, visitor: V) -> Result<V::Value>
245 where
246 V: Visitor<'de>,
247 {
248 let (typ, _) = self.peek_discriminator()?;
249
250 match typ {
251 TYPE_NULL => self.deserialize_unit(visitor),
252 TYPE_BOOLEAN => self.deserialize_bool(visitor),
253 TYPE_INTEGER => self.deserialize_i64(visitor),
254 TYPE_FLOAT => self.deserialize_f64(visitor),
255 TYPE_STRING => self.deserialize_str(visitor),
256 TYPE_RAW => self.deserialize_bytes(visitor),
257 TYPE_LIST => self.deserialize_seq(visitor),
258 TYPE_DICTIONARY => self.deserialize_map(visitor),
259 _ => Err(Error::InvalidType),
260 }
261 }
262
263 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
264 where
265 V: Visitor<'de>,
266 {
267 let value = self.read_boolean()?;
268 visitor.visit_bool(value)
269 }
270
271 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
272 where
273 V: Visitor<'de>,
274 {
275 let value = self.read_integer()?;
276 visitor.visit_i8(i8::try_from(value)?)
277 }
278
279 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
280 where
281 V: Visitor<'de>,
282 {
283 let value = self.read_integer()?;
284 visitor.visit_i16(i16::try_from(value)?)
285 }
286
287 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
288 where
289 V: Visitor<'de>,
290 {
291 let value = self.read_integer()?;
292 visitor.visit_i32(i32::try_from(value)?)
293 }
294
295 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
296 where
297 V: Visitor<'de>,
298 {
299 let value = self.read_integer()?;
300 visitor.visit_i64(value)
301 }
302
303 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
304 where
305 V: Visitor<'de>,
306 {
307 let value = self.read_integer()?;
308 visitor.visit_u8(u8::try_from(value)?)
309 }
310
311 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
312 where
313 V: Visitor<'de>,
314 {
315 let value = self.read_integer()?;
316 visitor.visit_u16(u16::try_from(value)?)
317 }
318
319 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
320 where
321 V: Visitor<'de>,
322 {
323 let value = self.read_integer()?;
324 visitor.visit_u32(u32::try_from(value)?)
325 }
326
327 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
328 where
329 V: Visitor<'de>,
330 {
331 let value = self.read_integer()?;
332 visitor.visit_u64(u64::try_from(value)?)
333 }
334
335 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
336 where
337 V: Visitor<'de>,
338 {
339 let value = self.read_float()?;
340 visitor.visit_f32(value as f32)
341 }
342
343 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
344 where
345 V: Visitor<'de>,
346 {
347 let value = self.read_float()?;
348 visitor.visit_f64(value)
349 }
350
351 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
352 where
353 V: Visitor<'de>,
354 {
355 let value = self.read_integer()?;
356 visitor.visit_char(char::try_from(u32::try_from(value)?)?)
357 }
358
359 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
360 where
361 V: Visitor<'de>,
362 {
363 self.deserialize_string(visitor)
364 }
365
366 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
367 where
368 V: Visitor<'de>,
369 {
370 let value = self.read_string()?;
371 visitor.visit_string(value)
372 }
373
374 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
375 where
376 V: Visitor<'de>,
377 {
378 self.deserialize_byte_buf(visitor)
379 }
380
381 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
382 where
383 V: Visitor<'de>,
384 {
385 let value = self.read_raw()?;
386 visitor.visit_byte_buf(value)
387 }
388
389 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
390 where
391 V: Visitor<'de>,
392 {
393 let (typ, _) = self.peek_discriminator()?;
394 match typ {
395 TYPE_NULL => visitor.visit_none(),
396 TYPE_BOOLEAN | TYPE_INTEGER | TYPE_FLOAT | TYPE_STRING | TYPE_RAW | TYPE_LIST
397 | TYPE_DICTIONARY => visitor.visit_some(self),
398 _ => Err(Error::WrongType),
399 }
400 }
401
402 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
403 where
404 V: Visitor<'de>,
405 {
406 self.read_null()?;
407 visitor.visit_unit()
408 }
409
410 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
411 where
412 V: Visitor<'de>,
413 {
414 self.deserialize_unit(visitor)
415 }
416
417 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
418 where
419 V: Visitor<'de>,
420 {
421 visitor.visit_newtype_struct(self)
422 }
423
424 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
425 where
426 V: Visitor<'de>,
427 {
428 self.read_list_start()?;
429 let value = visitor.visit_seq(&mut *self)?;
430 self.read_end()?;
431 Ok(value)
432 }
433
434 fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value>
435 where
436 V: Visitor<'de>,
437 {
438 self.deserialize_seq(visitor)
439 }
440
441 fn deserialize_tuple_struct<V>(
442 self,
443 _name: &'static str,
444 _len: usize,
445 visitor: V,
446 ) -> Result<V::Value>
447 where
448 V: Visitor<'de>,
449 {
450 self.deserialize_seq(visitor)
451 }
452
453 fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
454 where
455 V: Visitor<'de>,
456 {
457 self.read_dictionary_start()?;
458 let value = visitor.visit_map(&mut *self)?;
459 self.read_end()?;
460 Ok(value)
461 }
462
463 fn deserialize_struct<V>(
464 self,
465 _name: &'static str,
466 fields: &'static [&'static str],
467 visitor: V,
468 ) -> Result<V::Value>
469 where
470 V: Visitor<'de>,
471 {
472 self.deserialize_tuple(fields.len(), visitor)
473 }
474
475 fn deserialize_enum<V>(
476 self,
477 _name: &'static str,
478 _variants: &'static [&'static str],
479 visitor: V,
480 ) -> Result<V::Value>
481 where
482 V: Visitor<'de>,
483 {
484 visitor.visit_enum(self)
485 }
486
487 fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value>
488 where
489 V: Visitor<'de>,
490 {
491 self.deserialize_str(visitor)
492 }
493
494 fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
495 where
496 V: Visitor<'de>,
497 {
498 self.deserialize_any(visitor)
499 }
500}
501
502impl<'de, 'a, R> SeqAccess<'de> for Deserializer<R>
503where
504 R: Read,
505{
506 type Error = Error;
507
508 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
509 where
510 T: DeserializeSeed<'de>,
511 {
512 if self.peek_end()? {
513 return Ok(None);
514 }
515
516 seed.deserialize(self).map(Some)
517 }
518}
519
520impl<'de, 'a, R> MapAccess<'de> for Deserializer<R>
521where
522 R: Read,
523{
524 type Error = Error;
525
526 fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
527 where
528 K: DeserializeSeed<'de>,
529 {
530 if self.peek_end()? {
531 return Ok(None);
532 }
533
534 seed.deserialize(self).map(Some)
535 }
536
537 fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
538 where
539 V: DeserializeSeed<'de>,
540 {
541 seed.deserialize(self)
542 }
543}
544
545impl<'de, 'a, R> EnumAccess<'de> for &'a mut Deserializer<R>
546where
547 R: Read,
548{
549 type Error = Error;
550 type Variant = Self;
551
552 fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
553 where
554 V: DeserializeSeed<'de>,
555 {
556 let (typ, _) = self.peek_discriminator()?;
557 match typ {
558 TYPE_INTEGER => {}
559 TYPE_LIST => self.read_list_start()?,
560 _ => return Err(Error::WrongType),
561 }
562 let variant_index = u32::try_from(self.read_integer()?)?;
563 let value: Result<_> = seed.deserialize(variant_index.into_deserializer());
564 Ok((value?, self))
565 }
566}
567
568impl<'de, 'a, R> VariantAccess<'de> for &'a mut Deserializer<R>
569where
570 R: Read,
571{
572 type Error = Error;
573
574 fn unit_variant(self) -> Result<()> {
575 Ok(())
576 }
577
578 fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
579 where
580 T: DeserializeSeed<'de>,
581 {
582 let value = seed.deserialize(&mut *self)?;
583 self.read_end()?;
584 Ok(value)
585 }
586
587 fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value>
588 where
589 V: Visitor<'de>,
590 {
591 self.read_list_start()?;
592 let value = visitor.visit_seq(&mut *self)?;
593 self.read_end()?;
594 self.read_end()?;
595 Ok(value)
596 }
597
598 fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value>
599 where
600 V: Visitor<'de>,
601 {
602 self.read_list_start()?;
603 let value = visitor.visit_seq(&mut *self)?;
604 self.read_end()?;
605 self.read_end()?;
606 Ok(value)
607 }
608}
609
610pub struct StreamDeserializer<'de, R, T>
612where
613 R: Read,
614 T: Deserialize<'de>,
615{
616 de: Deserializer<R>,
617 failed: bool,
618 output: PhantomData<T>,
619 lifetime: PhantomData<&'de ()>,
620}
621
622impl<'de, R, T> Iterator for StreamDeserializer<'de, R, T>
623where
624 R: Read,
625 T: Deserialize<'de>,
626{
627 type Item = Result<T>;
628
629 fn next(&mut self) -> Option<Result<T>> {
630 if self.failed {
631 return None;
632 }
633
634 match Deserialize::deserialize(&mut self.de) {
635 Err(e) => {
636 self.failed = true;
637 if e.is_eof() {
638 None
639 } else {
640 Some(Err(e))
641 }
642 }
643 ok => Some(ok),
644 }
645 }
646}
647
648#[cfg(test)]
649mod test {
650 use super::*;
651 use hex_literal::hex;
652 use serde::Deserialize;
653 use std::collections::HashMap;
654
655 #[test]
656 fn from_slice_maps() {
657 let buf = hex!("70 4103626172 2201C8 4103666F6F 217B 80");
658 let map = from_slice(&buf).unwrap();
659
660 let mut expected = HashMap::new();
661 expected.insert("foo".to_string(), 123u32);
662 expected.insert("bar".to_string(), 456u32);
663
664 assert_eq!(expected, map);
665 }
666
667 #[test]
668 fn from_slice_structs() {
669 #[derive(Deserialize, Debug, PartialEq, Eq)]
670 struct Test {
671 x: bool,
672 y: u32,
673 z: Vec<String>,
674 }
675 let buf = hex!("60 11 2111 60 4103666F6F 4103626172 80 80");
676 let s = from_slice(&buf).unwrap();
677 let expected = Test {
678 x: true,
679 y: 17,
680 z: vec!["foo".into(), "bar".into()],
681 };
682 assert_eq!(expected, s);
683 }
684
685 #[test]
686 fn from_slice_enums() {
687 #[derive(Deserialize, Debug, PartialEq, Eq)]
688 enum Test {
689 UnitVariant,
690 NewTypeVariant(u32),
691 TupleVariant(bool, u32),
692 StructVariant { x: bool, y: u32 },
693 }
694 let buf = hex!("2100");
695 let e = from_slice(&buf).unwrap();
696 let expected = Test::UnitVariant;
697 assert_eq!(expected, e);
698
699 let buf = hex!("60 2101 2111 80");
700 let e = from_slice(&buf).unwrap();
701 let expected = Test::NewTypeVariant(17);
702 assert_eq!(expected, e);
703
704 let buf = hex!("60 2102 60 11 2111 80 80");
705 let e = from_slice(&buf).unwrap();
706 let expected = Test::TupleVariant(true, 17);
707 assert_eq!(expected, e);
708
709 let buf = hex!("60 2103 60 11 2111 80 80");
710 let e = from_slice(&buf).unwrap();
711 let expected = Test::StructVariant { x: true, y: 17 };
712 assert_eq!(expected, e);
713 }
714
715 #[test]
716 fn from_slice_options() {
717 let buf = hex!("00");
718 let o = from_slice(&buf).unwrap();
719 let expected: Option<u32> = None;
720 assert_eq!(expected, o);
721
722 let buf = hex!("2111");
723 let o = from_slice(&buf).unwrap();
724 let expected = Some(17);
725 assert_eq!(expected, o);
726 }
727
728 #[test]
729 fn stream_deserializer() {
730 let buf = hex!("2100 2101 80 2103");
731 let vec = Deserializer::new(buf.as_ref())
732 .into_iter()
733 .collect::<Vec<Result<u64>>>();
734 assert_eq!(vec.len(), 3);
735 assert_eq!(vec[0].as_ref().unwrap(), &0);
736 assert_eq!(vec[1].as_ref().unwrap(), &1);
737 assert!(matches!(vec[2].as_ref().unwrap_err(), Error::WrongType));
738
739 let buf = hex!("2100 2101 2102 2103");
740 let vec = Deserializer::new(buf.as_ref())
741 .into_iter()
742 .collect::<Result<Vec<u64>>>()
743 .unwrap();
744 assert_eq!(vec![0, 1, 2, 3], vec);
745 }
746}
747
748