pythnet_sdk/wire/
de.rs

1//! A module defining serde dserialize for the format described in ser.rs
2//!
3//! TL;DR: How to Use
4//! ================================================================================
5//!
6//! ```rust,ignore
7//! #[derive(Deserialize)]
8//! struct ExampleStruct {
9//!     a: (),
10//!     b: bool,
11//!     c: u8,
12//!     ...,
13//! }
14//!
15//! let bytes = ...;
16//! let s: ExampleStruct = pythnet_sdk::de::from_slice::<LittleEndian, _>(&bytes)?;
17//! ```
18//!
19//! The deserialization mechanism is a bit more complex than the serialization mechanism as it
20//! employs a visitor pattern. Rather than describe it here, the serde documentation on how to
21//! implement a deserializer can be found here:
22//!
23//! https://serde.rs/impl-deserializer.html
24
25use {
26    crate::require,
27    byteorder::{ByteOrder, ReadBytesExt},
28    serde::{
29        de::{EnumAccess, MapAccess, SeqAccess, VariantAccess},
30        Deserialize,
31    },
32    std::{
33        io::{Cursor, Seek, SeekFrom},
34        mem::size_of,
35    },
36    thiserror::Error,
37};
38
39/// Deserialize a Pyth wire-format buffer into a type.
40///
41/// Note that this method will not consume left-over bytes ore report errors. This is due to the
42/// fact that the Pyth wire formatted is intended to allow for appending of new fields without
43/// breaking backwards compatibility. As such, it is possible that a newer version of the format
44/// will contain fields that are not known to the deserializer. This is not an error, and the
45/// deserializer will simply ignore these fields.
46pub fn from_slice<'de, B, T>(bytes: &'de [u8]) -> Result<T, DeserializerError>
47where
48    T: Deserialize<'de>,
49    B: ByteOrder,
50{
51    let mut deserializer = Deserializer::<B>::new(bytes);
52    T::deserialize(&mut deserializer)
53}
54
55#[derive(Debug, Error)]
56pub enum DeserializerError {
57    #[error("io error: {0}")]
58    Io(#[from] std::io::Error),
59
60    #[error("invalid utf8: {0}")]
61    Utf8(#[from] std::str::Utf8Error),
62
63    #[error("this type is not supported")]
64    Unsupported,
65
66    #[error("sequence too large ({0} elements), max supported is 255")]
67    SequenceTooLarge(usize),
68
69    #[error("message: {0}")]
70    Message(Box<str>),
71
72    #[error("invalid enum variant, higher than expected variant range")]
73    InvalidEnumVariant,
74
75    #[error("eof")]
76    Eof,
77}
78
79pub struct Deserializer<'de, B>
80where
81    B: ByteOrder,
82{
83    cursor: Cursor<&'de [u8]>,
84    endian: std::marker::PhantomData<B>,
85}
86
87impl serde::de::Error for DeserializerError {
88    fn custom<T: std::fmt::Display>(msg: T) -> Self {
89        DeserializerError::Message(msg.to_string().into_boxed_str())
90    }
91}
92
93impl<'de, B> Deserializer<'de, B>
94where
95    B: ByteOrder,
96{
97    pub fn new(buffer: &'de [u8]) -> Self {
98        Self {
99            cursor: Cursor::new(buffer),
100            endian: std::marker::PhantomData,
101        }
102    }
103}
104
105impl<'de, B> serde::de::Deserializer<'de> for &'_ mut Deserializer<'de, B>
106where
107    B: ByteOrder,
108{
109    type Error = DeserializerError;
110
111    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
112    where
113        V: serde::de::Visitor<'de>,
114    {
115        Err(DeserializerError::Unsupported)
116    }
117
118    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
119    where
120        V: serde::de::Visitor<'de>,
121    {
122        Err(DeserializerError::Unsupported)
123    }
124
125    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126    where
127        V: serde::de::Visitor<'de>,
128    {
129        let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
130        visitor.visit_bool(value != 0)
131    }
132
133    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
134    where
135        V: serde::de::Visitor<'de>,
136    {
137        let value = self.cursor.read_i8().map_err(DeserializerError::from)?;
138        visitor.visit_i8(value)
139    }
140
141    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
142    where
143        V: serde::de::Visitor<'de>,
144    {
145        let value = self
146            .cursor
147            .read_i16::<B>()
148            .map_err(DeserializerError::from)?;
149
150        visitor.visit_i16(value)
151    }
152
153    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
154    where
155        V: serde::de::Visitor<'de>,
156    {
157        let value = self
158            .cursor
159            .read_i32::<B>()
160            .map_err(DeserializerError::from)?;
161
162        visitor.visit_i32(value)
163    }
164
165    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
166    where
167        V: serde::de::Visitor<'de>,
168    {
169        let value = self
170            .cursor
171            .read_i64::<B>()
172            .map_err(DeserializerError::from)?;
173
174        visitor.visit_i64(value)
175    }
176
177    fn deserialize_i128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
178    where
179        V: serde::de::Visitor<'de>,
180    {
181        let value = self
182            .cursor
183            .read_i128::<B>()
184            .map_err(DeserializerError::from)?;
185
186        visitor.visit_i128(value)
187    }
188
189    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
190    where
191        V: serde::de::Visitor<'de>,
192    {
193        let value = self.cursor.read_u8().map_err(DeserializerError::from)?;
194        visitor.visit_u8(value)
195    }
196
197    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
198    where
199        V: serde::de::Visitor<'de>,
200    {
201        let value = self
202            .cursor
203            .read_u16::<B>()
204            .map_err(DeserializerError::from)?;
205
206        visitor.visit_u16(value)
207    }
208
209    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
210    where
211        V: serde::de::Visitor<'de>,
212    {
213        let value = self
214            .cursor
215            .read_u32::<B>()
216            .map_err(DeserializerError::from)?;
217
218        visitor.visit_u32(value)
219    }
220
221    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
222    where
223        V: serde::de::Visitor<'de>,
224    {
225        let value = self
226            .cursor
227            .read_u64::<B>()
228            .map_err(DeserializerError::from)?;
229
230        visitor.visit_u64(value)
231    }
232
233    fn deserialize_u128<V>(self, visitor: V) -> Result<V::Value, Self::Error>
234    where
235        V: serde::de::Visitor<'de>,
236    {
237        let value = self
238            .cursor
239            .read_u128::<B>()
240            .map_err(DeserializerError::from)?;
241
242        visitor.visit_u128(value)
243    }
244
245    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
246    where
247        V: serde::de::Visitor<'de>,
248    {
249        Err(DeserializerError::Unsupported)
250    }
251
252    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
253    where
254        V: serde::de::Visitor<'de>,
255    {
256        Err(DeserializerError::Unsupported)
257    }
258
259    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
260    where
261        V: serde::de::Visitor<'de>,
262    {
263        Err(DeserializerError::Unsupported)
264    }
265
266    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
267    where
268        V: serde::de::Visitor<'de>,
269    {
270        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
271
272        // We cannot use cursor read methods as they copy the data out of the internal buffer,
273        // where we actually want a pointer into that buffer. So instead, we take the internal
274        // representation (the underlying &[u8]) and slice it to get the data we want. We then
275        // advance the cursor to simulate the read.
276        //
277        // Note that we do the advance first because otherwise we run into a immutable->mutable
278        // borrow issue, but the reverse is fine.
279        self.cursor
280            .seek(SeekFrom::Current(len as i64))
281            .map_err(DeserializerError::from)?;
282
283        let buf = {
284            let buf = self.cursor.get_ref();
285            buf[(self.cursor.position() - len) as usize..]
286                .get(..len as usize)
287                .ok_or(DeserializerError::Eof)?
288        };
289
290        visitor.visit_borrowed_str(std::str::from_utf8(buf).map_err(DeserializerError::from)?)
291    }
292
293    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
294    where
295        V: serde::de::Visitor<'de>,
296    {
297        self.deserialize_str(visitor)
298    }
299
300    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
301    where
302        V: serde::de::Visitor<'de>,
303    {
304        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as u64;
305
306        // See comment in deserialize_str for an explanation of this code.
307        self.cursor
308            .seek(SeekFrom::Current(len as i64))
309            .map_err(DeserializerError::from)?;
310
311        let buf = {
312            let buf = self.cursor.get_ref();
313            buf[(self.cursor.position() - len) as usize..]
314                .get(..len as usize)
315                .ok_or(DeserializerError::Eof)?
316        };
317
318        visitor.visit_borrowed_bytes(buf)
319    }
320
321    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
322    where
323        V: serde::de::Visitor<'de>,
324    {
325        self.deserialize_bytes(visitor)
326    }
327
328    fn deserialize_option<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
329    where
330        V: serde::de::Visitor<'de>,
331    {
332        Err(DeserializerError::Unsupported)
333    }
334
335    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
336    where
337        V: serde::de::Visitor<'de>,
338    {
339        visitor.visit_unit()
340    }
341
342    fn deserialize_unit_struct<V>(
343        self,
344        _name: &'static str,
345        visitor: V,
346    ) -> Result<V::Value, Self::Error>
347    where
348        V: serde::de::Visitor<'de>,
349    {
350        visitor.visit_unit()
351    }
352
353    fn deserialize_newtype_struct<V>(
354        self,
355        _name: &'static str,
356        visitor: V,
357    ) -> Result<V::Value, Self::Error>
358    where
359        V: serde::de::Visitor<'de>,
360    {
361        visitor.visit_newtype_struct(self)
362    }
363
364    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
365    where
366        V: serde::de::Visitor<'de>,
367    {
368        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
369        visitor.visit_seq(SequenceIterator::new(self, len))
370    }
371
372    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
373    where
374        V: serde::de::Visitor<'de>,
375    {
376        visitor.visit_seq(SequenceIterator::new(self, len))
377    }
378
379    fn deserialize_tuple_struct<V>(
380        self,
381        _name: &'static str,
382        len: usize,
383        visitor: V,
384    ) -> Result<V::Value, Self::Error>
385    where
386        V: serde::de::Visitor<'de>,
387    {
388        visitor.visit_seq(SequenceIterator::new(self, len))
389    }
390
391    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
392    where
393        V: serde::de::Visitor<'de>,
394    {
395        let len = self.cursor.read_u8().map_err(DeserializerError::from)? as usize;
396        visitor.visit_map(SequenceIterator::new(self, len))
397    }
398
399    fn deserialize_struct<V>(
400        self,
401        _name: &'static str,
402        fields: &'static [&'static str],
403        visitor: V,
404    ) -> Result<V::Value, Self::Error>
405    where
406        V: serde::de::Visitor<'de>,
407    {
408        visitor.visit_seq(SequenceIterator::new(self, fields.len()))
409    }
410
411    fn deserialize_enum<V>(
412        self,
413        _name: &'static str,
414        variants: &'static [&'static str],
415        visitor: V,
416    ) -> Result<V::Value, Self::Error>
417    where
418        V: serde::de::Visitor<'de>,
419    {
420        // We read the discriminator here so that we can make the expected enum variant available
421        // to the `visit_enum` call.
422        let variant = self.cursor.read_u8().map_err(DeserializerError::from)?;
423        if variant >= variants.len() as u8 {
424            return Err(DeserializerError::InvalidEnumVariant);
425        }
426
427        visitor.visit_enum(Enum { de: self, variant })
428    }
429
430    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
431    where
432        V: serde::de::Visitor<'de>,
433    {
434        Err(DeserializerError::Unsupported)
435    }
436}
437
438impl<'de, 'a, B: ByteOrder> VariantAccess<'de> for &'a mut Deserializer<'de, B> {
439    type Error = DeserializerError;
440
441    fn unit_variant(self) -> Result<(), Self::Error> {
442        Ok(())
443    }
444
445    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
446    where
447        T: serde::de::DeserializeSeed<'de>,
448    {
449        seed.deserialize(self)
450    }
451
452    fn tuple_variant<V>(self, len: usize, visitor: V) -> Result<V::Value, Self::Error>
453    where
454        V: serde::de::Visitor<'de>,
455    {
456        visitor.visit_seq(SequenceIterator::new(self, len))
457    }
458
459    fn struct_variant<V>(
460        self,
461        fields: &'static [&'static str],
462        visitor: V,
463    ) -> Result<V::Value, Self::Error>
464    where
465        V: serde::de::Visitor<'de>,
466    {
467        visitor.visit_seq(SequenceIterator::new(self, fields.len()))
468    }
469}
470
471struct SequenceIterator<'de, 'a, B: ByteOrder> {
472    de: &'a mut Deserializer<'de, B>,
473    len: usize,
474}
475
476impl<'de, 'a, B: ByteOrder> SequenceIterator<'de, 'a, B> {
477    fn new(de: &'a mut Deserializer<'de, B>, len: usize) -> Self {
478        Self { de, len }
479    }
480}
481
482impl<'de, 'a, B: ByteOrder> SeqAccess<'de> for SequenceIterator<'de, 'a, B> {
483    type Error = DeserializerError;
484
485    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
486    where
487        T: serde::de::DeserializeSeed<'de>,
488    {
489        if self.len == 0 {
490            return Ok(None);
491        }
492
493        self.len -= 1;
494        seed.deserialize(&mut *self.de).map(Some)
495    }
496
497    fn size_hint(&self) -> Option<usize> {
498        Some(self.len)
499    }
500}
501
502impl<'de, 'a, B: ByteOrder> MapAccess<'de> for SequenceIterator<'de, 'a, B> {
503    type Error = DeserializerError;
504
505    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
506    where
507        K: serde::de::DeserializeSeed<'de>,
508    {
509        if self.len == 0 {
510            return Ok(None);
511        }
512
513        self.len -= 1;
514        seed.deserialize(&mut *self.de).map(Some)
515    }
516
517    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
518    where
519        V: serde::de::DeserializeSeed<'de>,
520    {
521        seed.deserialize(&mut *self.de)
522    }
523
524    fn size_hint(&self) -> Option<usize> {
525        Some(self.len)
526    }
527}
528
529struct Enum<'de, 'a, B: ByteOrder> {
530    de: &'a mut Deserializer<'de, B>,
531    variant: u8,
532}
533
534impl<'de, 'a, B: ByteOrder> EnumAccess<'de> for Enum<'de, 'a, B> {
535    type Error = DeserializerError;
536    type Variant = &'a mut Deserializer<'de, B>;
537
538    fn variant_seed<V>(self, _: V) -> Result<(V::Value, Self::Variant), Self::Error>
539    where
540        V: serde::de::DeserializeSeed<'de>,
541    {
542        // When serializing/deserializing, serde passes a variant_index into the handlers. We
543        // currently write these as u8's and have already parsed' them during deserialize_enum
544        // before we reach this point.
545        //
546        // Normally, when deserializing enum tags from a wire format that does not match the
547        // expected size, we would use a u*.into_deserializer() to feed the already parsed
548        // result into the visit_u64 visitor method during `__Field` deserialize.
549        //
550        // The problem with this however is during `visit_u64`, there is a possibility the
551        // enum variant is not valid, which triggers Serde to return an `Unexpected` error.
552        // These errors have the unfortunate side effect of triggering Rust including float
553        // operations in the resulting binary, which breaks WASM environments.
554        //
555        // To work around this, we rely on the following facts:
556        //
557        // - variant_index in Serde is always 0 indexed and contiguous
558        // - transmute_copy into a 0 sized type is safe
559        // - transmute_copy is safe to cast into __Field as long as u8 >= size_of::<__Field>()
560        //
561        // This behaviour relies on serde not changing its enum deserializer generation, but
562        // this would be a major backwards compatibility break for them so we should be safe.
563        require!(
564            size_of::<u8>() >= size_of::<V::Value>(),
565            DeserializerError::InvalidEnumVariant
566        );
567
568        Ok((
569            unsafe { std::mem::transmute_copy::<u8, V::Value>(&self.variant) },
570            self.de,
571        ))
572    }
573}