pythnet_sdk/wire/
de.rs

1//! A module defining serde dserialize for the format described in ser.rs
2//!
3//! TL;DR: How to Use
4//! ================================================================================
5//!
6//! ```rust,ignore
7//! #[derive(Deserialize)]
8//! struct ExampleStruct {
9//!     a: (),
10//!     b: bool,
11//!     c: u8,
12//!     ...,
13//! }
14//!
15//! let bytes = ...;
16//! let s: ExampleStruct = pythnet_sdk::de::from_slice::<LittleEndian, _>(&bytes)?;
17//! ```
18//!
19//! The deserialization mechanism is a bit more complex than the serialization mechanism as it
20//! employs a visitor pattern. Rather than describe it here, the serde documentation on how to
21//! implement a deserializer can be found here:
22//!
23//! https://serde.rs/impl-deserializer.html
24
25use {
26    crate::require,
27    byteorder::{
28        ByteOrder,
29        ReadBytesExt,
30    },
31    serde::{
32        de::{
33            EnumAccess,
34            MapAccess,
35            SeqAccess,
36            VariantAccess,
37        },
38        Deserialize,
39    },
40    std::{
41        io::{
42            Cursor,
43            Seek,
44            SeekFrom,
45        },
46        mem::size_of,
47    },
48    thiserror::Error,
49};
50
51/// Deserialize a Pyth wire-format buffer into a type.
52///
53/// Note that this method will not consume left-over bytes ore report errors. This is due to the
54/// fact that the Pyth wire formatted is intended to allow for appending of new fields without
55/// breaking backwards compatibility. As such, it is possible that a newer version of the format
56/// will contain fields that are not known to the deserializer. This is not an error, and the
57/// deserializer will simply ignore these fields.
58pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result<T, DeserializerError>
59where
60    T: Deserialize<'de>,
61    B: ByteOrder,
62{
63    let mut deserializer = Deserializer::<B>::new(bytes);
64    T::deserialize(&mut deserializer)
65}
66
67#[derive(Debug, Error)]
68pub enum DeserializerError {
69    #[error("io error: {0}")]
70    Io(#[from] std::io::Error),
71
72    #[error("invalid utf8: {0}")]
73    Utf8(#[from] std::str::Utf8Error),
74
75    #[error("this type is not supported")]
76    Unsupported,
77
78    #[error("sequence too large ({0} elements), max supported is 255")]
79    SequenceTooLarge(usize),
80
81    #[error("message: {0}")]
82    Message(Box<str>),
83
84    #[error("invalid enum variant, higher than expected variant range")]
85    InvalidEnumVariant,
86
87    #[error("eof")]
88    Eof,
89}
90
91pub struct Deserializer<'de, B>
92where
93    B: ByteOrder,
94{
95    cursor: Cursor<&'de [u8]>,
96    endian: std::marker::PhantomData<B>,
97}
98
99impl serde::de::Error for DeserializerError {
100    fn custom<T: std::fmt::Display>(msg: T) -> Self {
101        DeserializerError::Message(msg.to_string().into_boxed_str())
102    }
103}
104
105impl<'de, B> Deserializer<'de, B>
106where
107    B: ByteOrder,
108{
109    pub fn new(buffer: &'de [u8]) -> Self {
110        Self {
111            cursor: Cursor::new(buffer),
112            endian: std::marker::PhantomData,
113        }
114    }
115}
116
117impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B>
118where
119    B: ByteOrder,
120{
121    type Error = DeserializerError;
122
123    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
124    where
125        V: serde::de::Visitor<'de>,
126    {
127        Err(DeserializerError::Unsupported)
128    }
129
130    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
131    where
132        V: serde::de::Visitor<'de>,
133    {
134        Err(DeserializerError::Unsupported)
135    }
136
137    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
138    where
139        V: serde::de::Visitor<'de>,
140    {
141        let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
142        visitor.visit_bool(value != 0)
143    }
144
145    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
146    where
147        V: serde::de::Visitor<'de>,
148    {
149        let value = self.cursor.read_i8().map_err(DeserializerError::from)?;
150        visitor.visit_i8(value)
151    }
152
153    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154    where
155        V: serde::de::Visitor<'de>,
156    {
157        let value = self
158            .cursor
159            .read_i16::<B>()
160            .map_err(DeserializerError::from)?;
161
162        visitor.visit_i16(value)
163    }
164
165    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166    where
167        V: serde::de::Visitor<'de>,
168    {
169        let value = self
170            .cursor
171            .read_i32::<B>()
172            .map_err(DeserializerError::from)?;
173
174        visitor.visit_i32(value)
175    }
176
177    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
178    where
179        V: serde::de::Visitor<'de>,
180    {
181        let value = self
182            .cursor
183            .read_i64::<B>()
184            .map_err(DeserializerError::from)?;
185
186        visitor.visit_i64(value)
187    }
188
189    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190    where
191        V: serde::de::Visitor<'de>,
192    {
193        let value = self
194            .cursor
195            .read_i128::<B>()
196            .map_err(DeserializerError::from)?;
197
198        visitor.visit_i128(value)
199    }
200
201    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
202    where
203        V: serde::de::Visitor<'de>,
204    {
205        let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
206        visitor.visit_u8(value)
207    }
208
209    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: serde::de::Visitor<'de>,
212    {
213        let value = self
214            .cursor
215            .read_u16::<B>()
216            .map_err(DeserializerError::from)?;
217
218        visitor.visit_u16(value)
219    }
220
221    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222    where
223        V: serde::de::Visitor<'de>,
224    {
225        let value = self
226            .cursor
227            .read_u32::<B>()
228            .map_err(DeserializerError::from)?;
229
230        visitor.visit_u32(value)
231    }
232
233    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
234    where
235        V: serde::de::Visitor<'de>,
236    {
237        let value = self
238            .cursor
239            .read_u64::<B>()
240            .map_err(DeserializerError::from)?;
241
242        visitor.visit_u64(value)
243    }
244
245    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
246    where
247        V: serde::de::Visitor<'de>,
248    {
249        let value = self
250            .cursor
251            .read_u128::<B>()
252            .map_err(DeserializerError::from)?;
253
254        visitor.visit_u128(value)
255    }
256
257    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
258    where
259        V: serde::de::Visitor<'de>,
260    {
261        Err(DeserializerError::Unsupported)
262    }
263
264    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: serde::de::Visitor<'de>,
267    {
268        Err(DeserializerError::Unsupported)
269    }
270
271    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
272    where
273        V: serde::de::Visitor<'de>,
274    {
275        Err(DeserializerError::Unsupported)
276    }
277
278    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
279    where
280        V: serde::de::Visitor<'de>,
281    {
282        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
283
284        // We cannot use cursor read methods as they copy the data out of the internal buffer,
285        // where we actually want a pointer into that buffer. So instead, we take the internal
286        // representation (the underlying &[u8]) and slice it to get the data we want. We then
287        // advance the cursor to simulate the read.
288        //
289        // Note that we do the advance first because otherwise we run into a immutable->mutable
290        // borrow issue, but the reverse is fine.
291        self.cursor
292            .seek(SeekFrom::Current(len as i64))
293            .map_err(DeserializerError::from)?;
294
295        let buf = {
296            let buf = self.cursor.get_ref();
297            buf[(self.cursor.position() - len) as usize..]
298                .get(..len as usize)
299                .ok_or(DeserializerError::Eof)?
300        };
301
302        visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializerError::from)?)
303    }
304
305    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
306    where
307        V: serde::de::Visitor<'de>,
308    {
309        self.deserialize_str(visitor)
310    }
311
312    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
313    where
314        V: serde::de::Visitor<'de>,
315    {
316        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
317
318        // See comment in deserialize_str for an explanation of this code.
319        self.cursor
320            .seek(SeekFrom::Current(len as i64))
321            .map_err(DeserializerError::from)?;
322
323        let buf = {
324            let buf = self.cursor.get_ref();
325            buf[(self.cursor.position() - len) as usize..]
326                .get(..len as usize)
327                .ok_or(DeserializerError::Eof)?
328        };
329
330        visitor.visit_borrowed_bytes(buf)
331    }
332
333    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
334    where
335        V: serde::de::Visitor<'de>,
336    {
337        self.deserialize_bytes(visitor)
338    }
339
340    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
341    where
342        V: serde::de::Visitor<'de>,
343    {
344        Err(DeserializerError::Unsupported)
345    }
346
347    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
348    where
349        V: serde::de::Visitor<'de>,
350    {
351        visitor.visit_unit()
352    }
353
354    fn deserialize_unit_struct<V>(
355        self,
356        _name: &'static str,
357        visitor: V,
358    ) -> Result<V::Value, Self::Error>
359    where
360        V: serde::de::Visitor<'de>,
361    {
362        visitor.visit_unit()
363    }
364
365    fn deserialize_newtype_struct<V>(
366        self,
367        _name: &'static str,
368        visitor: V,
369    ) -> Result<V::Value, Self::Error>
370    where
371        V: serde::de::Visitor<'de>,
372    {
373        visitor.visit_newtype_struct(self)
374    }
375
376    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
377    where
378        V: serde::de::Visitor<'de>,
379    {
380        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
381        visitor.visit_seq(SequenceIterator::new(self, len))
382    }
383
384    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
385    where
386        V: serde::de::Visitor<'de>,
387    {
388        visitor.visit_seq(SequenceIterator::new(self, len))
389    }
390
391    fn deserialize_tuple_struct<V>(
392        self,
393        _name: &'static str,
394        len: usize,
395        visitor: V,
396    ) -> Result<V::Value, Self::Error>
397    where
398        V: serde::de::Visitor<'de>,
399    {
400        visitor.visit_seq(SequenceIterator::new(self, len))
401    }
402
403    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
404    where
405        V: serde::de::Visitor<'de>,
406    {
407        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
408        visitor.visit_map(SequenceIterator::new(self, len))
409    }
410
411    fn deserialize_struct<V>(
412        self,
413        _name: &'static str,
414        fields: &'static [&'static str],
415        visitor: V,
416    ) -> Result<V::Value, Self::Error>
417    where
418        V: serde::de::Visitor<'de>,
419    {
420        visitor.visit_seq(SequenceIterator::new(self, fields.len()))
421    }
422
423    fn deserialize_enum<V>(
424        self,
425        _name: &'static str,
426        variants: &'static [&'static str],
427        visitor: V,
428    ) -> Result<V::Value, Self::Error>
429    where
430        V: serde::de::Visitor<'de>,
431    {
432        // We read the discriminator here so that we can make the expected enum variant available
433        // to the `visit_enum` call.
434        let variant = self.cursor.read_u8().map_err(DeserializerError::from)?;
435        if variant >= variants.len() as u8 {
436            return Err(DeserializerError::InvalidEnumVariant);
437        }
438
439        visitor.visit_enum(Enum { de: self, variant })
440    }
441
442    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
443    where
444        V: serde::de::Visitor<'de>,
445    {
446        Err(DeserializerError::Unsupported)
447    }
448}
449
450impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> {
451    type Error = DeserializerError;
452
453    fn unit_variant(self) -> Result<(), Self::Error> {
454        Ok(())
455    }
456
457    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
458    where
459        T: serde::de::DeserializeSeed<'de>,
460    {
461        seed.deserialize(self)
462    }
463
464    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
465    where
466        V: serde::de::Visitor<'de>,
467    {
468        visitor.visit_seq(SequenceIterator::new(self, len))
469    }
470
471    fn struct_variant<V>(
472        self,
473        fields: &'static [&'static str],
474        visitor: V,
475    ) -> Result<V::Value, Self::Error>
476    where
477        V: serde::de::Visitor<'de>,
478    {
479        visitor.visit_seq(SequenceIterator::new(self, fields.len()))
480    }
481}
482
483struct SequenceIterator<'de, 'a, B: ByteOrder> {
484    de:  &'a mut Deserializer<'de, B>,
485    len: usize,
486}
487
488impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> {
489    fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self {
490        Self { de, len }
491    }
492}
493
494impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> {
495    type Error = DeserializerError;
496
497    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
498    where
499        T: serde::de::DeserializeSeed<'de>,
500    {
501        if self.len == 0 {
502            return Ok(None);
503        }
504
505        self.len -= 1;
506        seed.deserialize(&mut *self.de).map(Some)
507    }
508
509    fn size_hint(&self) -> Option<usize> {
510        Some(self.len)
511    }
512}
513
514impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> {
515    type Error = DeserializerError;
516
517    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
518    where
519        K: serde::de::DeserializeSeed<'de>,
520    {
521        if self.len == 0 {
522            return Ok(None);
523        }
524
525        self.len -= 1;
526        seed.deserialize(&mut *self.de).map(Some)
527    }
528
529    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
530    where
531        V: serde::de::DeserializeSeed<'de>,
532    {
533        seed.deserialize(&mut *self.de)
534    }
535
536    fn size_hint(&self) -> Option<usize> {
537        Some(self.len)
538    }
539}
540
541struct Enum<'de, 'a, B: ByteOrder> {
542    de:      &'a mut Deserializer<'de, B>,
543    variant: u8,
544}
545
546impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
547    type Error = DeserializerError;
548    type Variant = &'a mut Deserializer<'de, B>;
549
550    fn variant_seed<V>(self, _: V) -> Result<(V::Value, Self::Variant), Self::Error>
551    where
552        V: serde::de::DeserializeSeed<'de>,
553    {
554        // When serializing/deserializing, serde passes a variant_index into the handlers. We
555        // currently write these as u8's and have already parsed' them during deserialize_enum
556        // before we reach this point.
557        //
558        // Normally, when deserializing enum tags from a wire format that does not match the
559        // expected size, we would use a u*.into_deserializer() to feed the already parsed
560        // result into the visit_u64 visitor method during `__Field` deserialize.
561        //
562        // The problem with this however is during `visit_u64`, there is a possibility the
563        // enum variant is not valid, which triggers Serde to return an `Unexpected` error.
564        // These errors have the unfortunate side effect of triggering Rust including float
565        // operations in the resulting binary, which breaks WASM environments.
566        //
567        // To work around this, we rely on the following facts:
568        //
569        // - variant_index in Serde is always 0 indexed and contiguous
570        // - transmute_copy into a 0 sized type is safe
571        // - transmute_copy is safe to cast into __Field as long as u8 >= size_of::<__Field>()
572        //
573        // This behaviour relies on serde not changing its enum deserializer generation, but
574        // this would be a major backwards compatibility break for them so we should be safe.
575        require!(
576            size_of::<u8>() >= size_of::<V::Value>(),
577            DeserializerError::InvalidEnumVariant
578        );
579
580        Ok((
581            unsafe { std::mem::transmute_copy::<u8, V::Value>(&self.variant) },
582            self.de,
583        ))
584    }
585}