Skip to main content

lib_q_stark_util/
array_serialization.rs

1use alloc::vec::Vec;
2use core::marker::PhantomData;
3
4use serde::de::{
5    SeqAccess,
6    Visitor,
7};
8use serde::ser::SerializeTuple;
9use serde::{
10    Deserialize,
11    Deserializer,
12    Serialize,
13    Serializer,
14};
15
16pub fn serialize<S: Serializer, T: Serialize, const N: usize>(
17    data: &[T; N],
18    ser: S,
19) -> Result<S::Ok, S::Error> {
20    let mut s = ser.serialize_tuple(N)?;
21    for item in data {
22        s.serialize_element(item)?;
23    }
24    s.end()
25}
26
27struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
28
29impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
30where
31    T: Deserialize<'de>,
32{
33    type Value = [T; N];
34
35    fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
36        formatter.write_fmt(format_args!("an array of length {N}"))
37    }
38
39    #[inline]
40    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
41    where
42        A: SeqAccess<'de>,
43    {
44        let mut data = Vec::with_capacity(N);
45        for _ in 0..N {
46            match seq.next_element()? {
47                Some(val) => data.push(val),
48                None => return Err(serde::de::Error::invalid_length(N, &self)),
49            }
50        }
51        data.try_into().map_or_else(|_| unreachable!(), Ok)
52    }
53}
54pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
55where
56    D: Deserializer<'de>,
57    T: Deserialize<'de>,
58{
59    deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
60}
61
62#[cfg(test)]
63mod tests {
64    use serde::{
65        Deserialize,
66        Serialize,
67    };
68    use serde_json;
69
70    use super::*;
71
72    /// A helper wrapper struct to use serialize/deserialize hooks on arrays.
73    #[derive(Serialize, Deserialize, Debug, PartialEq)]
74    #[serde(bound(serialize = "", deserialize = ""))]
75    struct Wrapper<const N: usize> {
76        #[serde(serialize_with = "serialize", deserialize_with = "deserialize")]
77        arr: [u32; N],
78    }
79
80    #[test]
81    fn test_array_serde_roundtrip() {
82        let original = Wrapper::<3> { arr: [10, 20, 30] };
83
84        let json = serde_json::to_string(&original).unwrap();
85        assert_eq!(json, r#"{"arr":[10,20,30]}"#);
86
87        let deserialized: Wrapper<3> = serde_json::from_str(&json).unwrap();
88        assert_eq!(deserialized, original);
89
90        let parsed: Wrapper<3> = serde_json::from_str(r#"{"arr":[10,20,30]}"#).unwrap();
91        assert_eq!(parsed.arr, [10, 20, 30]);
92    }
93
94    #[test]
95    fn test_deserialize_wrong_length() {
96        let json = r#"{"arr":[1,2]}"#;
97
98        let result: Result<Wrapper<3>, _> = serde_json::from_str(json);
99        assert!(result.is_err());
100    }
101
102    #[test]
103    fn test_empty_array() {
104        let data = Wrapper::<0> { arr: [] };
105
106        let json = serde_json::to_string(&data).unwrap();
107        assert_eq!(json, r#"{"arr":[]}"#);
108
109        let parsed: Wrapper<0> = serde_json::from_str(&json).unwrap();
110        assert_eq!(parsed, data);
111    }
112}