serde_scale/
de.rs

1// Copyright (C) 2020 Stephane Raux. Distributed under the zlib license.
2
3use core::convert::TryFrom;
4use crate::{Bytes, EndOfInput, Error, Read};
5use serde::{
6    de::{DeserializeSeed, Visitor},
7    Deserialize, Deserializer as _,
8};
9
10/// Deserializes a value encoded with SCALE
11pub fn from_slice<'a, T>(v: &'a [u8]) -> Result<T, Error<EndOfInput>>
12where
13    T: Deserialize<'a>,
14{
15    T::deserialize(&mut Deserializer(v))
16}
17
18/// Deserializer for the SCALE encoding
19pub struct Deserializer<R>(R);
20
21impl<'de, R: Read<'de>> Deserializer<R> {
22    /// Returns a deserializer using the given reader
23    pub fn new(r: R) -> Self {
24        Self(r)
25    }
26
27    /// Returns the underlying reader
28    pub fn into_inner(self) -> R {
29        self.0
30    }
31
32    fn read_compact(&mut self) -> Result<u64, Error<R::Error>> {
33        let mut head = 0;
34        self.0.read_exact(core::slice::from_mut(&mut head))?;
35        match head & 0x3 {
36            0x0 => Ok((head >> 2) as u64),
37            0x1 => {
38                let low = (head >> 2) as u64;
39                let high = self.read_u8()? as u64;
40                Ok(low | high << 6)
41            }
42            0x2 => {
43                let low = (head >> 2) as u64;
44                let mut high = [0; 4];
45                self.0.read_exact(&mut high[..3])?;
46                let high = u32::from_le_bytes(high) as u64;
47                Ok(low | high << 6)
48            }
49            0x3 => {
50                let len = (head >> 2) as usize + 4;
51                if len > 8 {
52                    return Err(Error::CollectionTooLargeToDeserialize);
53                }
54                let mut buf = [0; 8];
55                self.0.read_exact(&mut buf[..len])?;
56                let n = u64::from_le_bytes(buf);
57                Ok(n)
58            }
59            _ => unreachable!(),
60        }
61    }
62
63    fn read_u8(&mut self) -> Result<u8, Error<R::Error>> {
64        let mut v = 0;
65        self.0.read_exact(core::slice::from_mut(&mut v))?;
66        Ok(v)
67    }
68
69    fn read_u32(&mut self) -> Result<u32, Error<R::Error>> {
70        let mut v = [0; 4];
71        self.0.read_exact(&mut v)?;
72        Ok(u32::from_le_bytes(v))
73    }
74}
75
76impl<'de, R: Read<'de>> serde::Deserializer<'de> for &mut Deserializer<R> {
77    type Error = Error<R::Error>;
78
79    fn deserialize_any<V>(self, _: V) -> Result<V::Value, Self::Error>
80    where
81        V: Visitor<'de>,
82    {
83        Err(Error::TypeMustBeKnown)
84    }
85
86    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
87    where
88        V: Visitor<'de>,
89    {
90        match self.read_u8()? {
91            0 => visitor.visit_bool(false),
92            1 => visitor.visit_bool(true),
93            found => Err(Error::ExpectedBoolean { found }),
94        }
95    }
96
97    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
98    where
99        V: Visitor<'de>,
100    {
101        let mut found = [0];
102        self.0.read_exact(&mut found)?;
103        visitor.visit_i8(i8::from_le_bytes(found))
104    }
105
106    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
107    where
108        V: Visitor<'de>,
109    {
110        let mut found = [0; 2];
111        self.0.read_exact(&mut found)?;
112        visitor.visit_i16(i16::from_le_bytes(found))
113    }
114
115    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
116    where
117        V: Visitor<'de>,
118    {
119        let mut found = [0; 4];
120        self.0.read_exact(&mut found)?;
121        visitor.visit_i32(i32::from_le_bytes(found))
122    }
123
124    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
125    where
126        V: Visitor<'de>,
127    {
128        let mut found = [0; 8];
129        self.0.read_exact(&mut found)?;
130        visitor.visit_i64(i64::from_le_bytes(found))
131
132    }
133
134    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
135    where
136        V: Visitor<'de>,
137    {
138        visitor.visit_u8(self.read_u8()?)
139    }
140
141    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
142    where
143        V: Visitor<'de>,
144    {
145        let mut found = [0; 2];
146        self.0.read_exact(&mut found)?;
147        visitor.visit_u16(u16::from_le_bytes(found))
148    }
149
150    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
151    where
152        V: Visitor<'de>,
153    {
154        visitor.visit_u32(self.read_u32()?)
155    }
156
157    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
158    where
159        V: Visitor<'de>,
160    {
161        let mut found = [0; 8];
162        self.0.read_exact(&mut found)?;
163        visitor.visit_u64(u64::from_le_bytes(found))
164    }
165
166    fn deserialize_f32<V>(self, _: V) -> Result<V::Value, Self::Error>
167    where
168        V: Visitor<'de>,
169    {
170        Err(Error::FloatingPointUnsupported)
171    }
172
173    fn deserialize_f64<V>(self, _: V) -> Result<V::Value, Self::Error>
174    where
175        V: Visitor<'de>,
176    {
177        Err(Error::FloatingPointUnsupported)
178    }
179
180    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
181    where
182        V: Visitor<'de>,
183    {
184        let found = self.read_u32()?;
185        let c = core::char::from_u32(found).ok_or(Error::InvalidCharacter { found })?;
186        visitor.visit_char(c)
187    }
188
189    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190    where
191        V: Visitor<'de>,
192    {
193        let n = self.read_compact()?;
194        let n = usize::try_from(n).map_err(|_| Error::CollectionTooLargeToDeserialize)?;
195        self.0.read_map(n, |bytes| {
196            match bytes {
197                Bytes::Persistent(b) => {
198                    let s = core::str::from_utf8(b).map_err(Error::InvalidUnicode)?;
199                    visitor.visit_borrowed_str(s)
200                }
201                Bytes::Temporary(b) => {
202                    let s = core::str::from_utf8(b).map_err(Error::InvalidUnicode)?;
203                    visitor.visit_str(s)
204                }
205            }
206        })?
207    }
208
209    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: Visitor<'de>,
212    {
213        self.deserialize_str(visitor)
214    }
215
216    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
217    where
218        V: Visitor<'de>,
219    {
220        let n = self.read_compact()?;
221        let n = usize::try_from(n).map_err(|_| Error::CollectionTooLargeToDeserialize)?;
222        self.0.read_map(n, |bytes| {
223            match bytes {
224                Bytes::Persistent(b) => visitor.visit_borrowed_bytes(b),
225                Bytes::Temporary(b) => visitor.visit_bytes(b),
226            }
227        })?
228    }
229
230    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
231    where
232        V: Visitor<'de>,
233    {
234        self.deserialize_bytes(visitor)
235    }
236
237    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
238    where
239        V: Visitor<'de>,
240    {
241        match self.read_u8()? {
242            0 => visitor.visit_none(),
243            1 => visitor.visit_some(OptionalBoolDeserializer::discriminant_1(self)),
244            2 => visitor.visit_some(OptionalBoolDeserializer::discriminant_2(self)),
245            found_discriminant => Err(Error::InvalidOption { found_discriminant }),
246        }
247    }
248
249    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
250    where
251        V: Visitor<'de>,
252    {
253        visitor.visit_unit()
254    }
255
256    fn deserialize_unit_struct<V>(
257        self,
258        _: &'static str,
259        visitor: V,
260    ) -> Result<V::Value, Self::Error>
261    where
262        V: Visitor<'de>,
263    {
264        visitor.visit_unit()
265    }
266
267    fn deserialize_newtype_struct<V>(
268        self,
269        _: &'static str,
270        visitor: V,
271    ) -> Result<V::Value, Self::Error>
272    where
273        V: Visitor<'de>,
274    {
275        visitor.visit_newtype_struct(self)
276    }
277
278    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
279    where
280        V: Visitor<'de>,
281    {
282        let len = self.read_compact()?;
283        let len = usize::try_from(len).map_err(|_| Error::CollectionTooLargeToDeserialize)?;
284        self.deserialize_tuple(len, visitor)
285    }
286
287    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
288    where
289        V: Visitor<'de>,
290    {
291        visitor.visit_seq(Sequence {
292            deserializer: self,
293            remaining: len,
294        })
295    }
296
297    fn deserialize_tuple_struct<V>(
298        self,
299        _: &'static str,
300        len: usize,
301        visitor: V,
302    ) -> Result<V::Value, Self::Error>
303    where
304        V: Visitor<'de>,
305    {
306        self.deserialize_tuple(len, visitor)
307    }
308
309    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
310    where
311        V: Visitor<'de>,
312    {
313        let len = self.read_compact()?;
314        let len = usize::try_from(len).map_err(|_| Error::CollectionTooLargeToDeserialize)?;
315        visitor.visit_map(Map {
316            deserializer: self,
317            remaining: len,
318        })
319    }
320
321    fn deserialize_struct<V>(
322        self,
323        _: &'static str,
324        fields: &'static [&'static str],
325        visitor: V,
326    ) -> Result<V::Value, Self::Error>
327    where
328        V: Visitor<'de>,
329    {
330        self.deserialize_tuple(fields.len(), visitor)
331    }
332
333    fn deserialize_enum<V>(
334        self,
335        _: &'static str,
336        _: &'static [&'static str],
337        visitor: V,
338    ) -> Result<V::Value, Self::Error>
339    where
340        V: Visitor<'de>,
341    {
342        visitor.visit_enum(Enum {
343            deserializer: self,
344        })
345    }
346
347    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
348    where
349        V: Visitor<'de>,
350    {
351        visitor.visit_u8(self.read_u8()?)
352    }
353
354    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
355    where
356        V: Visitor<'de>,
357    {
358        self.deserialize_any(visitor)
359    }
360}
361
362struct Sequence<'a, R> {
363    deserializer: &'a mut Deserializer<R>,
364    remaining: usize,
365}
366
367impl<'a, 'de, R: Read<'de>> serde::de::SeqAccess<'de> for Sequence<'a, R> {
368    type Error = Error<R::Error>;
369
370    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
371    where
372        T: DeserializeSeed<'de>,
373    {
374        if self.remaining == 0 {
375            return Ok(None);
376        }
377        self.remaining -= 1;
378        seed.deserialize(&mut *self.deserializer).map(Some)
379    }
380
381    fn size_hint(&self) -> Option<usize> {
382        Some(self.remaining)
383    }
384}
385
386struct Map<'a, R> {
387    deserializer: &'a mut Deserializer<R>,
388    remaining: usize,
389}
390
391impl<'a, 'de, R: Read<'de>> serde::de::MapAccess<'de> for Map<'a, R> {
392    type Error = Error<R::Error>;
393
394    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
395    where
396        K: DeserializeSeed<'de>,
397    {
398        if self.remaining == 0 {
399            return Ok(None);
400        }
401        self.remaining -= 1;
402        seed.deserialize(&mut *self.deserializer).map(Some)
403    }
404
405    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
406    where
407        V: DeserializeSeed<'de>,
408    {
409        seed.deserialize(&mut *self.deserializer)
410    }
411}
412
413struct Enum<'a, R> {
414    deserializer: &'a mut Deserializer<R>,
415}
416
417impl<'a, 'de, R: Read<'de>> serde::de::EnumAccess<'de> for Enum<'a, R> {
418    type Error = Error<R::Error>;
419    type Variant = Self;
420
421    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
422    where
423        V: DeserializeSeed<'de>,
424    {
425        Ok((seed.deserialize(&mut *self.deserializer)?, self))
426    }
427}
428
429impl<'a, 'de, R: Read<'de>> serde::de::VariantAccess<'de> for Enum<'a, R> {
430    type Error = Error<R::Error>;
431
432    fn unit_variant(self) -> Result<(), Self::Error> {
433        Ok(())
434    }
435
436    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
437    where
438        T: DeserializeSeed<'de>,
439    {
440        seed.deserialize(self.deserializer)
441    }
442
443    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
444    where
445        V: Visitor<'de>,
446    {
447        self.deserializer.deserialize_tuple(len, visitor)
448    }
449
450    fn struct_variant<V>(
451        self,
452        _: &'static [&'static str],
453        visitor: V,
454    ) -> Result<V::Value, Self::Error>
455    where
456        V: Visitor<'de>,
457    {
458        self.deserializer.deserialize_seq(visitor)
459    }
460}
461
462struct OptionalBoolDeserializer<'a, R> {
463    inner: &'a mut Deserializer<R>,
464    discriminant_is_1: bool,
465}
466
467impl<'a, 'de, R: Read<'de>> OptionalBoolDeserializer<'a, R> {
468    fn discriminant_1(inner: &'a mut Deserializer<R>) -> Self {
469        Self {
470            inner,
471            discriminant_is_1: true,
472        }
473    }
474
475    fn discriminant_2(inner: &'a mut Deserializer<R>) -> Self {
476        Self {
477            inner,
478            discriminant_is_1: false,
479        }
480    }
481
482    fn check_bad_discriminant(&self) -> Result<(), Error<R::Error>> {
483        if self.discriminant_is_1 {
484            Ok(())
485        } else {
486            Err(Error::InvalidOption { found_discriminant: 2 })
487        }
488    }
489}
490
491impl<'de, R: Read<'de>> serde::Deserializer<'de> for OptionalBoolDeserializer<'_, R> {
492    type Error = Error<R::Error>;
493
494    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
495    where
496        V: Visitor<'de>,
497    {
498        self.check_bad_discriminant()?;
499        self.inner.deserialize_any(visitor)
500    }
501
502    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
503    where
504        V: Visitor<'de>,
505    {
506        visitor.visit_bool(self.discriminant_is_1)
507    }
508
509    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
510    where
511        V: Visitor<'de>,
512    {
513        self.check_bad_discriminant()?;
514        self.inner.deserialize_i8(visitor)
515    }
516
517    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
518    where
519        V: Visitor<'de>,
520    {
521        self.check_bad_discriminant()?;
522        self.inner.deserialize_i16(visitor)
523    }
524
525    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
526    where
527        V: Visitor<'de>,
528    {
529        self.check_bad_discriminant()?;
530        self.inner.deserialize_i32(visitor)
531    }
532
533    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
534    where
535        V: Visitor<'de>,
536    {
537        self.check_bad_discriminant()?;
538        self.inner.deserialize_i64(visitor)
539    }
540
541    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
542    where
543        V: Visitor<'de>,
544    {
545        self.check_bad_discriminant()?;
546        self.inner.deserialize_u8(visitor)
547    }
548
549    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
550    where
551        V: Visitor<'de>,
552    {
553        self.check_bad_discriminant()?;
554        self.inner.deserialize_u16(visitor)
555    }
556
557    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
558    where
559        V: Visitor<'de>,
560    {
561        self.check_bad_discriminant()?;
562        self.inner.deserialize_u32(visitor)
563    }
564
565    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
566    where
567        V: Visitor<'de>,
568    {
569        self.check_bad_discriminant()?;
570        self.inner.deserialize_u64(visitor)
571    }
572
573    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
574    where
575        V: Visitor<'de>,
576    {
577        self.check_bad_discriminant()?;
578        self.inner.deserialize_f32(visitor)
579    }
580
581    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
582    where
583        V: Visitor<'de>,
584    {
585        self.check_bad_discriminant()?;
586        self.inner.deserialize_f64(visitor)
587    }
588
589    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value, Self::Error>
590    where
591        V: Visitor<'de>,
592    {
593        self.check_bad_discriminant()?;
594        self.inner.deserialize_char(visitor)
595    }
596
597    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
598    where
599        V: Visitor<'de>,
600    {
601        self.check_bad_discriminant()?;
602        self.inner.deserialize_str(visitor)
603    }
604
605    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
606    where
607        V: Visitor<'de>,
608    {
609        self.check_bad_discriminant()?;
610        self.inner.deserialize_string(visitor)
611    }
612
613    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
614    where
615        V: Visitor<'de>,
616    {
617        self.check_bad_discriminant()?;
618        self.inner.deserialize_bytes(visitor)
619    }
620
621    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
622    where
623        V: Visitor<'de>,
624    {
625        self.check_bad_discriminant()?;
626        self.inner.deserialize_byte_buf(visitor)
627    }
628
629    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
630    where
631        V: Visitor<'de>,
632    {
633        self.check_bad_discriminant()?;
634        self.inner.deserialize_option(visitor)
635    }
636
637    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
638    where
639        V: Visitor<'de>,
640    {
641        self.check_bad_discriminant()?;
642        self.inner.deserialize_unit(visitor)
643    }
644
645    fn deserialize_unit_struct<V>(
646        self,
647        name: &'static str,
648        visitor: V,
649    ) -> Result<V::Value, Self::Error>
650    where
651        V: Visitor<'de>,
652    {
653        self.check_bad_discriminant()?;
654        self.inner.deserialize_unit_struct(name, visitor)
655    }
656
657    fn deserialize_newtype_struct<V>(
658        self,
659        name: &'static str,
660        visitor: V,
661    ) -> Result<V::Value, Self::Error>
662    where
663        V: Visitor<'de>,
664    {
665        self.check_bad_discriminant()?;
666        self.inner.deserialize_newtype_struct(name, visitor)
667    }
668
669    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
670    where
671        V: Visitor<'de>,
672    {
673        self.check_bad_discriminant()?;
674        self.inner.deserialize_seq(visitor)
675    }
676
677    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
678    where
679        V: Visitor<'de>,
680    {
681        self.check_bad_discriminant()?;
682        self.inner.deserialize_tuple(len, visitor)
683    }
684
685    fn deserialize_tuple_struct<V>(
686        self,
687        name: &'static str,
688        len: usize,
689        visitor: V,
690    ) -> Result<V::Value, Self::Error>
691    where
692        V: Visitor<'de>,
693    {
694        self.check_bad_discriminant()?;
695        self.inner.deserialize_tuple_struct(name, len, visitor)
696    }
697
698    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
699    where
700        V: Visitor<'de>,
701    {
702        self.check_bad_discriminant()?;
703        self.inner.deserialize_map(visitor)
704    }
705
706    fn deserialize_struct<V>(
707        self,
708        name: &'static str,
709        fields: &'static [&'static str],
710        visitor: V,
711    ) -> Result<V::Value, Self::Error>
712    where
713        V: Visitor<'de>,
714    {
715        self.check_bad_discriminant()?;
716        self.inner.deserialize_struct(name, fields, visitor)
717    }
718
719    fn deserialize_enum<V>(
720        self,
721        name: &'static str,
722        variants: &'static [&'static str],
723        visitor: V,
724    ) -> Result<V::Value, Self::Error>
725    where
726        V: Visitor<'de>,
727    {
728        self.check_bad_discriminant()?;
729        self.inner.deserialize_enum(name, variants, visitor)
730    }
731
732    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
733    where
734        V: Visitor<'de>,
735    {
736        self.check_bad_discriminant()?;
737        self.inner.deserialize_identifier(visitor)
738    }
739
740    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
741    where
742        V: Visitor<'de>,
743    {
744        self.check_bad_discriminant()?;
745        self.inner.deserialize_ignored_any(visitor)
746    }
747}
748
749#[cfg(test)]
750mod tests {
751    use crate::from_slice;
752
753    #[test]
754    fn none_bool_deserializes_from_0() {
755        assert_eq!(from_slice::<Option<bool>>(&[0]).unwrap(), None);
756    }
757
758    #[test]
759    fn some_true_deserializes_from_1() {
760        assert_eq!(from_slice::<Option<bool>>(&[1]).unwrap(), Some(true));
761    }
762
763    #[test]
764    fn some_false_deserializes_from_2() {
765        assert_eq!(from_slice::<Option<bool>>(&[2]).unwrap(), Some(false));
766    }
767}