fuel_types/
canonical.rs

1//! Canonical serialization and deserialization of Fuel types.
2//!
3//! This module provides the `Serialize` and `Deserialize` traits, which
4//! allow for automatic serialization and deserialization of Fuel types.
5
6#![allow(unsafe_code)]
7
8#[cfg(feature = "alloc")]
9use alloc::{
10    vec,
11    vec::Vec,
12};
13use core::fmt;
14
15use core::mem::MaybeUninit;
16pub use fuel_derive::{
17    Deserialize,
18    Serialize,
19};
20
21/// Error when serializing or deserializing.
22#[derive(Debug, Eq, PartialEq)]
23#[non_exhaustive]
24pub enum Error {
25    /// The buffer is to short for writing or reading.
26    BufferIsTooShort,
27    /// Got unknown enum's discriminant.
28    UnknownDiscriminant,
29    /// Struct prefix (set with `#[canonical(prefix = ...)]`) was invalid.
30    InvalidPrefix,
31    /// Allocation too large to be correct.
32    AllocationLimit,
33    /// Unknown error.
34    Unknown(&'static str),
35}
36
37impl Error {
38    pub(crate) fn as_str(&self) -> &'static str {
39        match self {
40            Error::BufferIsTooShort => "buffer is too short",
41            Error::UnknownDiscriminant => "unknown discriminant",
42            Error::InvalidPrefix => {
43                "prefix set with #[canonical(prefix = ...)] was invalid"
44            }
45            Error::AllocationLimit => "allocation too large",
46            Error::Unknown(str) => str,
47        }
48    }
49}
50
51impl fmt::Display for Error {
52    /// Shows a human-readable description of the `Error`.
53    fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
54        fmt.write_str(self.as_str())
55    }
56}
57
58/// Allows writing of data.
59pub trait Output {
60    /// Write bytes to the output buffer.
61    fn write(&mut self, bytes: &[u8]) -> Result<(), Error>;
62
63    /// Write a single byte to the output buffer.
64    fn push_byte(&mut self, byte: u8) -> Result<(), Error> {
65        self.write(&[byte])
66    }
67}
68
69/// Allows serialize the type into the `Output`.
70/// https://github.com/FuelLabs/fuel-specs/blob/master/specs/protocol/tx_format.md#transaction
71pub trait Serialize {
72    /// !INTERNAL USAGE ONLY!
73    /// Array of bytes that are now aligned by themselves.
74    #[doc(hidden)]
75    const UNALIGNED_BYTES: bool = false;
76
77    /// Size of the static part of the serialized object, in bytes.
78    /// Saturates to usize::MAX on overflow.
79    fn size_static(&self) -> usize;
80
81    /// Size of the dynamic part, in bytes.
82    /// Saturates to usize::MAX on overflow.
83    fn size_dynamic(&self) -> usize;
84
85    /// Total size of the serialized object, in bytes.
86    /// Saturates to usize::MAX on overflow.
87    fn size(&self) -> usize {
88        self.size_static().saturating_add(self.size_dynamic())
89    }
90
91    /// Encodes `Self` into the `buffer`.
92    ///
93    /// It is better to not implement this function directly, instead implement
94    /// `encode_static` and `encode_dynamic`.
95    fn encode<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
96        self.encode_static(buffer)?;
97        self.encode_dynamic(buffer)
98    }
99
100    /// Encodes staticly-sized part of `Self`.
101    fn encode_static<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error>;
102
103    /// Encodes dynamically-sized part of `Self`.
104    /// The default implementation does nothing. Dynamically-sized contains should
105    /// override this.
106    fn encode_dynamic<O: Output + ?Sized>(&self, _buffer: &mut O) -> Result<(), Error> {
107        Ok(())
108    }
109
110    /// Encodes `Self` into bytes vector. Required known size.
111    #[cfg(feature = "alloc")]
112    fn to_bytes(&self) -> Vec<u8> {
113        let mut vec = Vec::with_capacity(self.size());
114        self.encode(&mut vec).expect("Unable to encode self");
115        vec
116    }
117}
118
119/// Allows reading of data into a slice.
120pub trait Input {
121    /// Returns the remaining length of the input data.
122    fn remaining(&mut self) -> usize;
123
124    /// Peek the exact number of bytes required to fill the given buffer.
125    fn peek(&self, buf: &mut [u8]) -> Result<(), Error>;
126
127    /// Read the exact number of bytes required to fill the given buffer.
128    fn read(&mut self, buf: &mut [u8]) -> Result<(), Error>;
129
130    /// Peek a single byte from the input.
131    fn peek_byte(&mut self) -> Result<u8, Error> {
132        let mut buf = [0u8];
133        self.peek(&mut buf[..])?;
134        Ok(buf[0])
135    }
136
137    /// Read a single byte from the input.
138    fn read_byte(&mut self) -> Result<u8, Error> {
139        let mut buf = [0u8];
140        self.read(&mut buf[..])?;
141        Ok(buf[0])
142    }
143
144    /// Skips next `n` bytes.
145    fn skip(&mut self, n: usize) -> Result<(), Error>;
146}
147
148/// Allows deserialize the type from the `Input`.
149/// https://github.com/FuelLabs/fuel-specs/blob/master/specs/protocol/tx_format.md#transaction
150pub trait Deserialize: Sized {
151    /// !INTERNAL USAGE ONLY!
152    /// Array of bytes that are now aligned by themselves.
153    #[doc(hidden)]
154    const UNALIGNED_BYTES: bool = false;
155
156    /// Decodes `Self` from the `buffer`.
157    ///
158    /// It is better to not implement this function directly, instead implement
159    /// `decode_static` and `decode_dynamic`.
160    fn decode<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error> {
161        let mut object = Self::decode_static(buffer)?;
162        object.decode_dynamic(buffer)?;
163        Ok(object)
164    }
165
166    /// Decodes static part of `Self` from the `buffer`.
167    fn decode_static<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error>;
168
169    /// Decodes dynamic part of the information from the `buffer` to fill `Self`.
170    /// The default implementation does nothing. Dynamically-sized contains should
171    /// override this.
172    fn decode_dynamic<I: Input + ?Sized>(
173        &mut self,
174        _buffer: &mut I,
175    ) -> Result<(), Error> {
176        Ok(())
177    }
178
179    /// Helper method for deserializing `Self` from bytes.
180    fn from_bytes(mut buffer: &[u8]) -> Result<Self, Error> {
181        Self::decode(&mut buffer)
182    }
183}
184
185/// The data of each field should be aligned to 64 bits.
186pub const ALIGN: usize = 8;
187
188/// The number of padding bytes required to align the given length correctly.
189#[allow(clippy::arithmetic_side_effects)] // Safety: (a % b) < b
190const fn alignment_bytes(len: usize) -> usize {
191    let modulo = len % ALIGN;
192    if modulo == 0 { 0 } else { ALIGN - modulo }
193}
194
195/// Size after alignment. Saturates on overflow.
196pub const fn aligned_size(len: usize) -> usize {
197    len.saturating_add(alignment_bytes(len))
198}
199
200macro_rules! impl_for_primitives {
201    ($t:ident, $unpadded:literal) => {
202        impl Serialize for $t {
203            const UNALIGNED_BYTES: bool = $unpadded;
204
205            #[inline(always)]
206            fn size_static(&self) -> usize {
207                aligned_size(::core::mem::size_of::<$t>())
208            }
209
210            #[inline(always)]
211            fn size_dynamic(&self) -> usize {
212                0
213            }
214
215            #[inline(always)]
216            fn encode_static<O: Output + ?Sized>(
217                &self,
218                buffer: &mut O,
219            ) -> Result<(), Error> {
220                // Primitive types are zero-padded on left side to a 8-byte boundary.
221                // The resulting value is always well-aligned.
222                let bytes = <$t>::to_be_bytes(*self);
223                for _ in 0..alignment_bytes(bytes.len()) {
224                    // Zero-pad
225                    buffer.push_byte(0)?;
226                }
227                buffer.write(bytes.as_ref())?;
228                Ok(())
229            }
230        }
231
232        impl Deserialize for $t {
233            const UNALIGNED_BYTES: bool = $unpadded;
234
235            fn decode_static<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error> {
236                let mut asset = [0u8; ::core::mem::size_of::<$t>()];
237                buffer.skip(alignment_bytes(asset.len()))?; // Skip zero-padding
238                buffer.read(asset.as_mut())?;
239                Ok(<$t>::from_be_bytes(asset))
240            }
241        }
242    };
243}
244
245impl_for_primitives!(u8, true);
246impl_for_primitives!(u16, false);
247impl_for_primitives!(u32, false);
248impl_for_primitives!(usize, false);
249impl_for_primitives!(u64, false);
250impl_for_primitives!(u128, false);
251
252// Empty tuple `()`, i.e. the unit type takes up no space.
253impl Serialize for () {
254    fn size_static(&self) -> usize {
255        0
256    }
257
258    #[inline(always)]
259    fn size_dynamic(&self) -> usize {
260        0
261    }
262
263    #[inline(always)]
264    fn encode_static<O: Output + ?Sized>(&self, _buffer: &mut O) -> Result<(), Error> {
265        Ok(())
266    }
267}
268
269impl Deserialize for () {
270    fn decode_static<I: Input + ?Sized>(_buffer: &mut I) -> Result<Self, Error> {
271        Ok(())
272    }
273}
274
275/// To protect against malicious large inputs, vector size is limited when decoding.
276pub const VEC_DECODE_LIMIT: usize = 100 * (1 << 20); // 100 MiB
277
278#[cfg(feature = "alloc")]
279impl<T: Serialize> Serialize for Vec<T> {
280    fn size_static(&self) -> usize {
281        8
282    }
283
284    #[inline(always)]
285    fn size_dynamic(&self) -> usize {
286        if T::UNALIGNED_BYTES {
287            aligned_size(self.len())
288        } else {
289            aligned_size(
290                self.iter()
291                    .map(|e| e.size())
292                    .reduce(usize::saturating_add)
293                    .unwrap_or_default(),
294            )
295        }
296    }
297
298    #[inline(always)]
299    // Encode only the size of the vector. Elements will be encoded in the
300    // `encode_dynamic` method.
301    fn encode_static<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
302        if self.len() > VEC_DECODE_LIMIT {
303            return Err(Error::AllocationLimit)
304        }
305        let len: u64 = self.len().try_into().expect("msg.len() > u64::MAX");
306        len.encode(buffer)
307    }
308
309    fn encode_dynamic<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
310        // Bytes - Vec<u8> it a separate case without padding for each element.
311        // It should padded at the end if is not % ALIGN
312        if T::UNALIGNED_BYTES {
313            // SAFETY: `UNALIGNED_BYTES` only set for `u8`.
314            let bytes = unsafe { ::core::mem::transmute::<&Vec<T>, &Vec<u8>>(self) };
315            buffer.write(bytes.as_slice())?;
316            for _ in 0..alignment_bytes(self.len()) {
317                buffer.push_byte(0)?;
318            }
319        } else {
320            for e in self.iter() {
321                e.encode(buffer)?;
322            }
323        }
324        Ok(())
325    }
326}
327
328#[cfg(feature = "alloc")]
329impl<T: Deserialize> Deserialize for Vec<T> {
330    // Decode only the capacity of the vector. Elements will be decoded in the
331    // `decode_dynamic` method. The capacity is needed for iteration there.
332    fn decode_static<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error> {
333        let cap = u64::decode(buffer)?;
334        let cap: usize = cap.try_into().map_err(|_| Error::AllocationLimit)?;
335        if cap > VEC_DECODE_LIMIT {
336            return Err(Error::AllocationLimit)
337        }
338
339        if T::UNALIGNED_BYTES {
340            // SAFETY: `UNALIGNED_BYTES` only set for `u8`.
341            let vec = unsafe {
342                let vec = vec![0u8; cap];
343                ::core::mem::transmute::<Vec<u8>, Vec<T>>(vec)
344            };
345
346            Ok(vec)
347        } else {
348            Ok(Vec::with_capacity(cap))
349        }
350    }
351
352    fn decode_dynamic<I: Input + ?Sized>(&mut self, buffer: &mut I) -> Result<(), Error> {
353        // Bytes - Vec<u8> it a separate case without unpadding for each element.
354        // It should unpadded at the end if is not % ALIGN
355        if T::UNALIGNED_BYTES {
356            // SAFETY: `UNALIGNED_BYTES` implemented set for `u8`.
357            let _self =
358                unsafe { ::core::mem::transmute::<&mut Vec<T>, &mut Vec<u8>>(self) };
359            buffer.read(_self.as_mut())?;
360        } else {
361            for _ in 0..self.capacity() {
362                self.push(T::decode(buffer)?);
363            }
364        }
365
366        if T::UNALIGNED_BYTES {
367            buffer.skip(alignment_bytes(self.capacity()))?;
368        }
369
370        Ok(())
371    }
372}
373
374impl<const N: usize, T: Serialize> Serialize for [T; N] {
375    fn size_static(&self) -> usize {
376        if T::UNALIGNED_BYTES {
377            aligned_size(N)
378        } else {
379            aligned_size(
380                self.iter()
381                    .map(|e| e.size_static())
382                    .reduce(usize::saturating_add)
383                    .unwrap_or_default(),
384            )
385        }
386    }
387
388    #[inline(always)]
389    fn size_dynamic(&self) -> usize {
390        if T::UNALIGNED_BYTES {
391            0
392        } else {
393            aligned_size(
394                self.iter()
395                    .map(|e| e.size_dynamic())
396                    .reduce(usize::saturating_add)
397                    .unwrap_or_default(),
398            )
399        }
400    }
401
402    #[inline(always)]
403    fn encode_static<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
404        // Bytes - [u8; N] it a separate case without padding for each element.
405        // It should padded at the end if is not % ALIGN
406        if T::UNALIGNED_BYTES {
407            // SAFETY: `Type::U8` implemented only for `u8`.
408            let bytes = unsafe { ::core::mem::transmute::<&[T; N], &[u8; N]>(self) };
409            buffer.write(bytes.as_slice())?;
410            for _ in 0..alignment_bytes(N) {
411                buffer.push_byte(0)?;
412            }
413        } else {
414            for e in self.iter() {
415                e.encode_static(buffer)?;
416            }
417        }
418        Ok(())
419    }
420
421    fn encode_dynamic<O: Output + ?Sized>(&self, buffer: &mut O) -> Result<(), Error> {
422        if !T::UNALIGNED_BYTES {
423            for e in self.iter() {
424                e.encode_dynamic(buffer)?;
425            }
426        }
427
428        Ok(())
429    }
430}
431
432impl<const N: usize, T: Deserialize> Deserialize for [T; N] {
433    fn decode_static<I: Input + ?Sized>(buffer: &mut I) -> Result<Self, Error> {
434        if T::UNALIGNED_BYTES {
435            let mut bytes: [u8; N] = [0; N];
436            buffer.read(bytes.as_mut())?;
437            buffer.skip(alignment_bytes(N))?;
438            let ref_typed: &[T; N] = unsafe { core::mem::transmute(&bytes) };
439            let typed: [T; N] = unsafe { core::ptr::read(ref_typed) };
440            Ok(typed)
441        } else {
442            // Spec doesn't say how to deserialize arrays with unaligned
443            // primitives(as `u16`, `u32`, `usize`), so unpad them.
444            // SAFETY: `uninit`` is an array of `MaybUninit`, which do not require
445            // initialization
446            let mut uninit: [MaybeUninit<T>; N] =
447                unsafe { MaybeUninit::uninit().assume_init() };
448            // The following line coerces the pointer to the array to a pointer
449            // to the first array element which is equivalent.
450            for i in 0..N {
451                match T::decode_static(buffer) {
452                    Err(e) => {
453                        for item in uninit.iter_mut().take(i) {
454                            // SAFETY: all elements up to index i (excluded have been
455                            // initialised)
456                            unsafe {
457                                item.assume_init_drop();
458                            }
459                        }
460                        return Err(e)
461                    }
462                    Ok(decoded) => {
463                        // SAFETY: `uninit[i]` is a MaybeUninit which can be
464                        // safely overwritten.
465                        uninit[i].write(decoded);
466
467                        // SAFETY: Point to the next element after every iteration.
468                        // 		 We do this N times therefore this is safe.
469                    }
470                }
471            }
472
473            // SAFETY: All array elements have been initialized above.
474            let init = uninit.map(|v| unsafe { v.assume_init() });
475            Ok(init)
476        }
477    }
478
479    fn decode_dynamic<I: Input + ?Sized>(&mut self, buffer: &mut I) -> Result<(), Error> {
480        if !T::UNALIGNED_BYTES {
481            for e in self.iter_mut() {
482                e.decode_dynamic(buffer)?;
483            }
484        }
485
486        Ok(())
487    }
488}
489
490#[cfg(feature = "alloc")]
491impl Output for Vec<u8> {
492    fn write(&mut self, bytes: &[u8]) -> Result<(), Error> {
493        self.extend_from_slice(bytes);
494        Ok(())
495    }
496}
497
498impl Output for &'_ mut [u8] {
499    fn write(&mut self, from: &[u8]) -> Result<(), Error> {
500        if from.len() > self.len() {
501            return Err(Error::BufferIsTooShort)
502        }
503        let len = from.len();
504        self[..len].copy_from_slice(from);
505        // We need to reduce the inner slice by `len`, because we already filled them.
506        let reduced = &mut self[len..];
507
508        // Compiler is not clever enough to allow it.
509        // https://stackoverflow.com/questions/25730586/how-can-i-create-my-own-data-structure-with-an-iterator-that-returns-mutable-ref
510        *self = unsafe { &mut *(reduced as *mut [u8]) };
511        Ok(())
512    }
513}
514
515impl Input for &'_ [u8] {
516    fn remaining(&mut self) -> usize {
517        self.len()
518    }
519
520    fn peek(&self, into: &mut [u8]) -> Result<(), Error> {
521        if into.len() > self.len() {
522            return Err(Error::BufferIsTooShort)
523        }
524
525        let len = into.len();
526        into.copy_from_slice(&self[..len]);
527        Ok(())
528    }
529
530    fn read(&mut self, into: &mut [u8]) -> Result<(), Error> {
531        if into.len() > self.len() {
532            return Err(Error::BufferIsTooShort)
533        }
534
535        let len = into.len();
536        into.copy_from_slice(&self[..len]);
537        *self = &self[len..];
538        Ok(())
539    }
540
541    fn skip(&mut self, n: usize) -> Result<(), Error> {
542        if n > self.len() {
543            return Err(Error::BufferIsTooShort)
544        }
545
546        *self = &self[n..];
547        Ok(())
548    }
549}
550
551#[cfg(test)]
552mod tests {
553    use super::*;
554
555    fn validate<T: Serialize + Deserialize + Eq + core::fmt::Debug>(t: T) {
556        let bytes = t.to_bytes();
557        let t2 = T::from_bytes(&bytes).expect("Roundtrip failed");
558        assert_eq!(t, t2);
559        assert_eq!(t.to_bytes(), t2.to_bytes());
560
561        let mut vec = Vec::new();
562        t.encode_static(&mut vec).expect("Encode failed");
563        assert_eq!(vec.len(), t.size_static());
564    }
565
566    fn validate_enum<T: Serialize + Deserialize + Eq + fmt::Debug>(t: T) {
567        let bytes = t.to_bytes();
568        let t2 = T::from_bytes(&bytes).expect("Roundtrip failed");
569        assert_eq!(t, t2);
570        assert_eq!(t.to_bytes(), t2.to_bytes());
571
572        let mut vec = Vec::new();
573        t.encode_static(&mut vec).expect("Encode failed");
574        assert_eq!(vec.len(), t.size_static());
575        t.encode_dynamic(&mut vec).expect("Encode failed");
576        assert_eq!(vec.len(), t.size());
577
578        let mut vec2 = Vec::new();
579        t.encode_dynamic(&mut vec2).expect("Encode failed");
580        assert_eq!(vec2.len(), t.size_dynamic());
581    }
582
583    #[test]
584    fn test_canonical_encode_decode() {
585        validate(());
586        validate(123u8);
587        validate(u8::MAX);
588        validate(123u16);
589        validate(u16::MAX);
590        validate(123u32);
591        validate(u32::MAX);
592        validate(123u64);
593        validate(u64::MAX);
594        validate(123u128);
595        validate(u128::MAX);
596        validate(Vec::<u8>::new());
597        validate(Vec::<u16>::new());
598        validate(Vec::<u32>::new());
599        validate(Vec::<u64>::new());
600        validate(Vec::<u128>::new());
601        validate(vec![1u8]);
602        validate(vec![1u16]);
603        validate(vec![1u32]);
604        validate(vec![1u64]);
605        validate(vec![1u128]);
606        validate(vec![1u8, 2u8]);
607        validate(vec![1u16, 2u16]);
608        validate(vec![1u32, 2u32]);
609        validate(vec![1u64, 2u64]);
610        validate(vec![1u128, 2u128]);
611
612        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
613        struct TestStruct1 {
614            a: u8,
615            b: u16,
616        }
617
618        let t = TestStruct1 { a: 123, b: 456 };
619        assert_eq!(t.size_static(), 16);
620        assert_eq!(t.size(), 16);
621        validate(t);
622
623        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
624        struct TestStruct2 {
625            a: u8,
626            v: Vec<u8>,
627            b: u16,
628            arr0: [u8; 0],
629            arr1: [u8; 2],
630            arr2: [u16; 3],
631            arr3: [u64; 4],
632        }
633
634        validate(TestStruct2 {
635            a: 123,
636            v: vec![1, 2, 3],
637            b: 456,
638            arr0: [],
639            arr1: [1, 2],
640            arr2: [1, 2, u16::MAX],
641            arr3: [0, 3, 1111, u64::MAX],
642        });
643
644        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
645        #[repr(transparent)]
646        struct TestStruct3([u8; 64]);
647
648        let t = TestStruct3([1; 64]);
649        assert_eq!(t.size_static(), 64);
650        assert_eq!(t.size(), 64);
651        validate(t);
652
653        #[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
654        #[canonical(prefix = 1u64)]
655        struct Prefixed1 {
656            a: [u8; 3],
657            b: Vec<u8>,
658        }
659        validate(Prefixed1 {
660            a: [1, 2, 3],
661            b: vec![4, 5, 6],
662        });
663
664        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
665        #[repr(u8)]
666        enum TestEnum1 {
667            A,
668            B,
669            C = 0x13,
670            D,
671        }
672
673        validate(TestEnum1::A);
674        validate(TestEnum1::B);
675        validate(TestEnum1::C);
676        validate(TestEnum1::D);
677
678        #[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
679        enum TestEnum2 {
680            A(u8),
681            B([u8; 3]),
682            C(Vec<u8>),
683        }
684
685        validate_enum(TestEnum2::A(2));
686        validate_enum(TestEnum2::B([1, 2, 3]));
687        validate_enum(TestEnum2::C(vec![1, 2, 3]));
688
689        #[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq)]
690        #[canonical(prefix = 2u64)]
691        struct Prefixed2(u16);
692        validate(Prefixed2(u16::MAX));
693
694        assert_eq!(
695            &Prefixed1 {
696                a: [1, 2, 3],
697                b: vec![4, 5]
698            }
699            .to_bytes()[..8],
700            &[0u8, 0, 0, 0, 0, 0, 0, 1]
701        );
702        assert_eq!(
703            Prefixed2(u16::MAX).to_bytes(),
704            [0u8, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0xff, 0xff]
705        );
706    }
707}