bcs_link/
de.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::error::{Error, Result};
5use serde::de::{self, Deserialize, DeserializeSeed, IntoDeserializer, Visitor};
6use std::convert::TryFrom;
7
8/// Deserializes a `&[u8]` into a type.
9///
10/// This function will attempt to interpret `bytes` as the BCS serialized form of `T` and
11/// deserialize `T` from `bytes`.
12///
13/// # Examples
14///
15/// ```
16/// use bcs::from_bytes;
17/// use serde::Deserialize;
18///
19/// #[derive(Deserialize)]
20/// struct Ip([u8; 4]);
21///
22/// #[derive(Deserialize)]
23/// struct Port(u16);
24///
25/// #[derive(Deserialize)]
26/// struct SocketAddr {
27///     ip: Ip,
28///     port: Port,
29/// }
30///
31/// let bytes = vec![0x7f, 0x00, 0x00, 0x01, 0x41, 0x1f];
32/// let socket_addr: SocketAddr = from_bytes(&bytes).unwrap();
33///
34/// assert_eq!(socket_addr.ip.0, [127, 0, 0, 1]);
35/// assert_eq!(socket_addr.port.0, 8001);
36/// ```
37pub fn from_bytes<'a, T>(bytes: &'a [u8]) -> Result<T>
38where
39    T: Deserialize<'a>,
40{
41    let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH);
42    let t = T::deserialize(&mut deserializer)?;
43    deserializer.end().map(move |_| t)
44}
45
46/// Same as `from_bytes` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH`
47pub fn from_bytes_with_limit<'a, T>(bytes: &'a [u8], limit: usize) -> Result<T>
48where
49    T: Deserialize<'a>,
50{
51    if limit > crate::MAX_CONTAINER_DEPTH {
52        return Err(Error::NotSupported(
53            "limit exceeds the max allowed depth 500",
54        ));
55    }
56    let mut deserializer = Deserializer::new(bytes, limit);
57    let t = T::deserialize(&mut deserializer)?;
58    deserializer.end().map(move |_| t)
59}
60
61/// Perform a stateful deserialization from a `&[u8]` using the provided `seed`.
62pub fn from_bytes_seed<'a, T>(seed: T, bytes: &'a [u8]) -> Result<T::Value>
63where
64    T: DeserializeSeed<'a>,
65{
66    let mut deserializer = Deserializer::new(bytes, crate::MAX_CONTAINER_DEPTH);
67    let t = seed.deserialize(&mut deserializer)?;
68    deserializer.end().map(move |_| t)
69}
70
71/// Same as `from_bytes_seed` but use `limit` as max container depth instead of MAX_CONTAINER_DEPTH`
72pub fn from_bytes_seed_with_limit<'a, T>(seed: T, bytes: &'a [u8], limit: usize) -> Result<T::Value>
73where
74    T: DeserializeSeed<'a>,
75{
76    if limit > crate::MAX_CONTAINER_DEPTH {
77        return Err(Error::NotSupported(
78            "limit exceeds the max allowed depth 500",
79        ));
80    }
81    let mut deserializer = Deserializer::new(bytes, limit);
82    let t = seed.deserialize(&mut deserializer)?;
83    deserializer.end().map(move |_| t)
84}
85
86/// Deserialization implementation for BCS
87struct Deserializer<'de> {
88    input: &'de [u8],
89    max_remaining_depth: usize,
90}
91
92impl<'de> Deserializer<'de> {
93    /// Creates a new `Deserializer` which will be deserializing the provided
94    /// input.
95    fn new(input: &'de [u8], max_remaining_depth: usize) -> Self {
96        Deserializer {
97            input,
98            max_remaining_depth,
99        }
100    }
101
102    /// The `Deserializer::end` method should be called after a type has been
103    /// fully deserialized. This allows the `Deserializer` to validate that
104    /// the there are no more bytes remaining in the input stream.
105    fn end(&mut self) -> Result<()> {
106        if self.input.is_empty() {
107            Ok(())
108        } else {
109            Err(Error::RemainingInput)
110        }
111    }
112}
113
114impl<'de> Deserializer<'de> {
115    fn peek(&mut self) -> Result<u8> {
116        self.input.first().copied().ok_or(Error::Eof)
117    }
118
119    fn next(&mut self) -> Result<u8> {
120        let byte = self.peek()?;
121        self.input = &self.input[1..];
122        Ok(byte)
123    }
124
125    fn parse_bool(&mut self) -> Result<bool> {
126        let byte = self.next()?;
127
128        match byte {
129            0 => Ok(false),
130            1 => Ok(true),
131            _ => Err(Error::ExpectedBoolean),
132        }
133    }
134
135    fn fill_slice(&mut self, slice: &mut [u8]) -> Result<()> {
136        for byte in slice {
137            *byte = self.next()?;
138        }
139        Ok(())
140    }
141
142    fn parse_u8(&mut self) -> Result<u8> {
143        self.next()
144    }
145
146    fn parse_u16(&mut self) -> Result<u16> {
147        let mut le_bytes = [0; 2];
148        self.fill_slice(&mut le_bytes)?;
149        Ok(u16::from_le_bytes(le_bytes))
150    }
151
152    fn parse_u32(&mut self) -> Result<u32> {
153        let mut le_bytes = [0; 4];
154        self.fill_slice(&mut le_bytes)?;
155        Ok(u32::from_le_bytes(le_bytes))
156    }
157
158    fn parse_u64(&mut self) -> Result<u64> {
159        let mut le_bytes = [0; 8];
160        self.fill_slice(&mut le_bytes)?;
161        Ok(u64::from_le_bytes(le_bytes))
162    }
163
164    fn parse_u128(&mut self) -> Result<u128> {
165        let mut le_bytes = [0; 16];
166        self.fill_slice(&mut le_bytes)?;
167        Ok(u128::from_le_bytes(le_bytes))
168    }
169
170    #[allow(clippy::integer_arithmetic)]
171    fn parse_u32_from_uleb128(&mut self) -> Result<u32> {
172        let mut value: u64 = 0;
173        for shift in (0..32).step_by(7) {
174            let byte = self.next()?;
175            let digit = byte & 0x7f;
176            value |= u64::from(digit) << shift;
177            // If the highest bit of `byte` is 0, return the final value.
178            if digit == byte {
179                if shift > 0 && digit == 0 {
180                    // We only accept canonical ULEB128 encodings, therefore the
181                    // heaviest (and last) base-128 digit must be non-zero.
182                    return Err(Error::NonCanonicalUleb128Encoding);
183                }
184                // Decoded integer must not overflow.
185                return u32::try_from(value)
186                    .map_err(|_| Error::IntegerOverflowDuringUleb128Decoding);
187            }
188        }
189        // Decoded integer must not overflow.
190        Err(Error::IntegerOverflowDuringUleb128Decoding)
191    }
192
193    fn parse_length(&mut self) -> Result<usize> {
194        let len = self.parse_u32_from_uleb128()? as usize;
195        if len > crate::MAX_SEQUENCE_LENGTH {
196            return Err(Error::ExceededMaxLen(len));
197        }
198        Ok(len)
199    }
200
201    fn parse_bytes(&mut self) -> Result<&'de [u8]> {
202        let len = self.parse_length()?;
203        let slice = self.input.get(..len).ok_or(Error::Eof)?;
204        self.input = &self.input[len..];
205        Ok(slice)
206    }
207
208    fn parse_string(&mut self) -> Result<&'de str> {
209        let slice = self.parse_bytes()?;
210        std::str::from_utf8(slice).map_err(|_| Error::Utf8)
211    }
212
213    fn enter_named_container(&mut self, name: &'static str) -> Result<()> {
214        if self.max_remaining_depth == 0 {
215            return Err(Error::ExceededContainerDepthLimit(name));
216        }
217        self.max_remaining_depth -= 1;
218        Ok(())
219    }
220
221    fn leave_named_container(&mut self) {
222        self.max_remaining_depth += 1;
223    }
224}
225
226impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
227    type Error = Error;
228
229    // BCS is not a self-describing format so we can't implement `deserialize_any`
230    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
231    where
232        V: Visitor<'de>,
233    {
234        Err(Error::NotSupported("deserialize_any"))
235    }
236
237    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
238    where
239        V: Visitor<'de>,
240    {
241        visitor.visit_bool(self.parse_bool()?)
242    }
243
244    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
245    where
246        V: Visitor<'de>,
247    {
248        visitor.visit_i8(self.parse_u8()? as i8)
249    }
250
251    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
252    where
253        V: Visitor<'de>,
254    {
255        visitor.visit_i16(self.parse_u16()? as i16)
256    }
257
258    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
259    where
260        V: Visitor<'de>,
261    {
262        visitor.visit_i32(self.parse_u32()? as i32)
263    }
264
265    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
266    where
267        V: Visitor<'de>,
268    {
269        visitor.visit_i64(self.parse_u64()? as i64)
270    }
271
272    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value>
273    where
274        V: Visitor<'de>,
275    {
276        visitor.visit_i128(self.parse_u128()? as i128)
277    }
278
279    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
280    where
281        V: Visitor<'de>,
282    {
283        visitor.visit_u8(self.parse_u8()?)
284    }
285
286    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
287    where
288        V: Visitor<'de>,
289    {
290        visitor.visit_u16(self.parse_u16()?)
291    }
292
293    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
294    where
295        V: Visitor<'de>,
296    {
297        visitor.visit_u32(self.parse_u32()?)
298    }
299
300    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
301    where
302        V: Visitor<'de>,
303    {
304        visitor.visit_u64(self.parse_u64()?)
305    }
306
307    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value>
308    where
309        V: Visitor<'de>,
310    {
311        visitor.visit_u128(self.parse_u128()?)
312    }
313
314    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value>
315    where
316        V: Visitor<'de>,
317    {
318        Err(Error::NotSupported("deserialize_f32"))
319    }
320
321    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value>
322    where
323        V: Visitor<'de>,
324    {
325        Err(Error::NotSupported("deserialize_f64"))
326    }
327
328    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value>
329    where
330        V: Visitor<'de>,
331    {
332        Err(Error::NotSupported("deserialize_char"))
333    }
334
335    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
336    where
337        V: Visitor<'de>,
338    {
339        visitor.visit_borrowed_str(self.parse_string()?)
340    }
341
342    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
343    where
344        V: Visitor<'de>,
345    {
346        self.deserialize_str(visitor)
347    }
348
349    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
350    where
351        V: Visitor<'de>,
352    {
353        visitor.visit_borrowed_bytes(self.parse_bytes()?)
354    }
355
356    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
357    where
358        V: Visitor<'de>,
359    {
360        self.deserialize_bytes(visitor)
361    }
362
363    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
364    where
365        V: Visitor<'de>,
366    {
367        let byte = self.next()?;
368
369        match byte {
370            0 => visitor.visit_none(),
371            1 => visitor.visit_some(self),
372            _ => Err(Error::ExpectedOption),
373        }
374    }
375
376    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
377    where
378        V: Visitor<'de>,
379    {
380        visitor.visit_unit()
381    }
382
383    fn deserialize_unit_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
384    where
385        V: Visitor<'de>,
386    {
387        self.enter_named_container(name)?;
388        let r = self.deserialize_unit(visitor);
389        self.leave_named_container();
390        r
391    }
392
393    fn deserialize_newtype_struct<V>(self, name: &'static str, visitor: V) -> Result<V::Value>
394    where
395        V: Visitor<'de>,
396    {
397        self.enter_named_container(name)?;
398        let r = visitor.visit_newtype_struct(&mut *self);
399        self.leave_named_container();
400        r
401    }
402
403    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value>
404    where
405        V: Visitor<'de>,
406    {
407        let len = self.parse_length()?;
408        visitor.visit_seq(SeqDeserializer::new(&mut self, len))
409    }
410
411    fn deserialize_tuple<V>(mut self, len: usize, visitor: V) -> Result<V::Value>
412    where
413        V: Visitor<'de>,
414    {
415        visitor.visit_seq(SeqDeserializer::new(&mut self, len))
416    }
417
418    fn deserialize_tuple_struct<V>(
419        mut self,
420        name: &'static str,
421        len: usize,
422        visitor: V,
423    ) -> Result<V::Value>
424    where
425        V: Visitor<'de>,
426    {
427        self.enter_named_container(name)?;
428        let r = visitor.visit_seq(SeqDeserializer::new(&mut self, len));
429        self.leave_named_container();
430        r
431    }
432
433    fn deserialize_map<V>(mut self, visitor: V) -> Result<V::Value>
434    where
435        V: Visitor<'de>,
436    {
437        let len = self.parse_length()?;
438        visitor.visit_map(MapDeserializer::new(&mut self, len))
439    }
440
441    fn deserialize_struct<V>(
442        mut self,
443        name: &'static str,
444        fields: &'static [&'static str],
445        visitor: V,
446    ) -> Result<V::Value>
447    where
448        V: Visitor<'de>,
449    {
450        self.enter_named_container(name)?;
451        let r = visitor.visit_seq(SeqDeserializer::new(&mut self, fields.len()));
452        self.leave_named_container();
453        r
454    }
455
456    fn deserialize_enum<V>(
457        self,
458        name: &'static str,
459        _variants: &'static [&'static str],
460        visitor: V,
461    ) -> Result<V::Value>
462    where
463        V: Visitor<'de>,
464    {
465        self.enter_named_container(name)?;
466        let r = visitor.visit_enum(&mut *self);
467        self.leave_named_container();
468        r
469    }
470
471    // BCS does not utilize identifiers, so throw them away
472    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
473    where
474        V: Visitor<'de>,
475    {
476        self.deserialize_bytes(_visitor)
477    }
478
479    // BCS is not a self-describing format so we can't implement `deserialize_ignored_any`
480    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
481    where
482        V: Visitor<'de>,
483    {
484        Err(Error::NotSupported("deserialize_ignored_any"))
485    }
486
487    // BCS is not a human readable format
488    fn is_human_readable(&self) -> bool {
489        false
490    }
491}
492
493struct SeqDeserializer<'a, 'de: 'a> {
494    de: &'a mut Deserializer<'de>,
495    remaining: usize,
496}
497
498impl<'a, 'de> SeqDeserializer<'a, 'de> {
499    fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self {
500        Self { de, remaining }
501    }
502}
503
504impl<'de, 'a> de::SeqAccess<'de> for SeqDeserializer<'a, 'de> {
505    type Error = Error;
506
507    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
508    where
509        T: DeserializeSeed<'de>,
510    {
511        if self.remaining == 0 {
512            Ok(None)
513        } else {
514            self.remaining -= 1;
515            seed.deserialize(&mut *self.de).map(Some)
516        }
517    }
518
519    fn size_hint(&self) -> Option<usize> {
520        Some(self.remaining)
521    }
522}
523
524struct MapDeserializer<'a, 'de: 'a> {
525    de: &'a mut Deserializer<'de>,
526    remaining: usize,
527    previous_key_bytes: Option<&'a [u8]>,
528}
529
530impl<'a, 'de> MapDeserializer<'a, 'de> {
531    fn new(de: &'a mut Deserializer<'de>, remaining: usize) -> Self {
532        Self {
533            de,
534            remaining,
535            previous_key_bytes: None,
536        }
537    }
538}
539
540impl<'de, 'a> de::MapAccess<'de> for MapDeserializer<'a, 'de> {
541    type Error = Error;
542
543    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
544    where
545        K: DeserializeSeed<'de>,
546    {
547        match self.remaining.checked_sub(1) {
548            None => Ok(None),
549            Some(remaining) => {
550                let previous_input_slice = self.de.input;
551                let key_value = seed.deserialize(&mut *self.de)?;
552                let key_len = previous_input_slice
553                    .len()
554                    .saturating_sub(self.de.input.len());
555                let key_bytes = &previous_input_slice[..key_len];
556                if let Some(previous_key_bytes) = self.previous_key_bytes {
557                    if previous_key_bytes >= key_bytes {
558                        return Err(Error::NonCanonicalMap);
559                    }
560                }
561                self.remaining = remaining;
562                self.previous_key_bytes = Some(key_bytes);
563                Ok(Some(key_value))
564            }
565        }
566    }
567
568    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
569    where
570        V: DeserializeSeed<'de>,
571    {
572        seed.deserialize(&mut *self.de)
573    }
574
575    fn size_hint(&self) -> Option<usize> {
576        Some(self.remaining)
577    }
578}
579
580impl<'de, 'a> de::EnumAccess<'de> for &'a mut Deserializer<'de> {
581    type Error = Error;
582    type Variant = Self;
583
584    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant)>
585    where
586        V: DeserializeSeed<'de>,
587    {
588        let variant_index = self.parse_u32_from_uleb128()?;
589        let result: Result<V::Value> = seed.deserialize(variant_index.into_deserializer());
590        Ok((result?, self))
591    }
592}
593
594impl<'de, 'a> de::VariantAccess<'de> for &'a mut Deserializer<'de> {
595    type Error = Error;
596
597    fn unit_variant(self) -> Result<()> {
598        Ok(())
599    }
600
601    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value>
602    where
603        T: DeserializeSeed<'de>,
604    {
605        seed.deserialize(self)
606    }
607
608    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value>
609    where
610        V: Visitor<'de>,
611    {
612        de::Deserializer::deserialize_tuple(self, len, visitor)
613    }
614
615    fn struct_variant<V>(self, fields: &'static [&'static str], visitor: V) -> Result<V::Value>
616    where
617        V: Visitor<'de>,
618    {
619        de::Deserializer::deserialize_tuple(self, fields.len(), visitor)
620    }
621}