pythnet_sdk/wire/
array.rs

1//! By default, serde does not know how to parse fixed length arrays of sizes
2//! that aren't common (I.E: 32) Here we provide a module that can be used to
3//! serialize arrays that relies on const generics.
4//!
5//! Usage:
6//!
7//! ```rust,ignore`
8//! #[derive(Serialize)]
9//! struct Example {
10//!     #[serde(with = "array")]
11//!     array: [u8; 55],
12//! }
13//! ```
14use {
15    serde::{Deserialize, Serialize, Serializer},
16    std::mem::MaybeUninit,
17};
18
19/// Serialize an array of size N using a const generic parameter to drive serialize_seq.
20pub fn serialize<S, T, const N: usize>(array: &[T; N], serializer: S) -> Result<S::Ok, S::Error>
21where
22    S: Serializer,
23    T: Serialize,
24{
25    use serde::ser::SerializeTuple;
26    let mut seq = serializer.serialize_tuple(N)?;
27    array.iter().try_for_each(|e| seq.serialize_element(e))?;
28    seq.end()
29}
30
31/// A visitor that carries type-level information about the length of the array we want to
32/// deserialize.
33struct ArrayVisitor<T, const N: usize> {
34    _marker: std::marker::PhantomData<T>,
35}
36
37/// Implement a Visitor over our ArrayVisitor that knows how many times to
38/// call next_element using the generic.
39impl<'de, T, const N: usize> serde::de::Visitor<'de> for ArrayVisitor<T, N>
40where
41    T: Deserialize<'de>,
42{
43    type Value = [T; N];
44
45    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
46        write!(formatter, "an array of length {N}")
47    }
48
49    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
50    where
51        A: serde::de::SeqAccess<'de>,
52    {
53        // We use MaybeUninit to allocate the right amount of memory
54        // because we do not know if `T` has a constructor or a default.
55        // Without this we would have to allocate a Vec.
56        let mut array = MaybeUninit::<[T; N]>::uninit();
57        let ptr = array.as_mut_ptr() as *mut T;
58        let mut pos = 0;
59        while pos < N {
60            let next = seq
61                .next_element()?
62                .ok_or_else(|| serde::de::Error::invalid_length(pos, &self))?;
63
64            unsafe {
65                std::ptr::write(ptr.add(pos), next);
66            }
67
68            pos += 1;
69        }
70
71        // We only succeed if we fully filled the array. This prevents
72        // accidentally returning garbage.
73        if pos == N {
74            return Ok(unsafe { array.assume_init() });
75        }
76
77        Err(serde::de::Error::invalid_length(pos, &self))
78    }
79}
80
81/// Deserialize an array with an ArrayVisitor aware of `N` during deserialize.
82pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
83where
84    D: serde::Deserializer<'de>,
85    T: serde::de::Deserialize<'de>,
86{
87    deserializer.deserialize_tuple(
88        N,
89        ArrayVisitor {
90            _marker: std::marker::PhantomData,
91        },
92    )
93}