pythnet_sdk/wire/
array.rs

1//! By default, serde does not know how to parse fixed length arrays of sizes
2//! that aren't common (I.E: 32) Here we provide a module that can be used to
3//! serialize arrays that relies on const generics.
4//!
5//! Usage:
6//!
7//! ```rust,ignore`
8//! #[derive(Serialize)]
9//! struct Example {
10//!     #[serde(with = "array")]
11//!     array: [u8; 55],
12//! }
13//! ```
14use {
15    serde::{
16        Deserialize,
17        Serialize,
18        Serializer,
19    },
20    std::mem::MaybeUninit,
21};
22
23/// Serialize an array of size N using a const generic parameter to drive serialize_seq.
24pub fn serialize<S, T, const N: usize>(array: &[T; N], serializer: S) -> Result<S::Ok, S::Error>
25where
26    S: Serializer,
27    T: Serialize,
28{
29    use serde::ser::SerializeTuple;
30    let mut seq = serializer.serialize_tuple(N)?;
31    array.iter().try_for_each(|e| seq.serialize_element(e))?;
32    seq.end()
33}
34
35/// A visitor that carries type-level information about the length of the array we want to
36/// deserialize.
37struct ArrayVisitor<T, const N: usize> {
38    _marker: std::marker::PhantomData<T>,
39}
40
41/// Implement a Visitor over our ArrayVisitor that knows how many times to
42/// call next_element using the generic.
43impl<'de, T, const N: usize> serde::de::Visitor<'de> for ArrayVisitor<T, N>
44where
45    T: Deserialize<'de>,
46{
47    type Value = [T; N];
48
49    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
50        write!(formatter, "an array of length {N}")
51    }
52
53    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
54    where
55        A: serde::de::SeqAccess<'de>,
56    {
57        // We use MaybeUninit to allocate the right amount of memory
58        // because we do not know if `T` has a constructor or a default.
59        // Without this we would have to allocate a Vec.
60        let mut array = MaybeUninit::<[T; N]>::uninit();
61        let ptr = array.as_mut_ptr() as *mut T;
62        let mut pos = 0;
63        while pos < N {
64            let next = seq
65                .next_element()?
66                .ok_or_else(|| serde::de::Error::invalid_length(pos, &self))?;
67
68            unsafe {
69                std::ptr::write(ptr.add(pos), next);
70            }
71
72            pos += 1;
73        }
74
75        // We only succeed if we fully filled the array. This prevents
76        // accidentally returning garbage.
77        if pos == N {
78            return Ok(unsafe { array.assume_init() });
79        }
80
81        Err(serde::de::Error::invalid_length(pos, &self))
82    }
83}
84
85/// Deserialize an array with an ArrayVisitor aware of `N` during deserialize.
86pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
87where
88    D: serde::Deserializer<'de>,
89    T: serde::de::Deserialize<'de>,
90{
91    deserializer.deserialize_tuple(
92        N,
93        ArrayVisitor {
94            _marker: std::marker::PhantomData,
95        },
96    )
97}