p3_util/
array_serialization.rs

1use alloc::vec::Vec;
2use core::marker::PhantomData;
3
4use serde::de::{SeqAccess, Visitor};
5use serde::ser::SerializeTuple;
6use serde::{Deserialize, Deserializer, Serialize, Serializer};
7
8pub fn serialize<S: Serializer, T: Serialize, const N: usize>(
9    data: &[T; N],
10    ser: S,
11) -> Result<S::Ok, S::Error> {
12    let mut s = ser.serialize_tuple(N)?;
13    for item in data {
14        s.serialize_element(item)?;
15    }
16    s.end()
17}
18
19struct ArrayVisitor<T, const N: usize>(PhantomData<T>);
20
21impl<'de, T, const N: usize> Visitor<'de> for ArrayVisitor<T, N>
22where
23    T: Deserialize<'de>,
24{
25    type Value = [T; N];
26
27    fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
28        formatter.write_fmt(format_args!("an array of length {}", N))
29    }
30
31    #[inline]
32    fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
33    where
34        A: SeqAccess<'de>,
35    {
36        let mut data = Vec::with_capacity(N);
37        for _ in 0..N {
38            match seq.next_element()? {
39                Some(val) => data.push(val),
40                None => return Err(serde::de::Error::invalid_length(N, &self)),
41            }
42        }
43        match data.try_into() {
44            Ok(arr) => Ok(arr),
45            Err(_) => unreachable!(),
46        }
47    }
48}
49pub fn deserialize<'de, D, T, const N: usize>(deserializer: D) -> Result<[T; N], D::Error>
50where
51    D: Deserializer<'de>,
52    T: Deserialize<'de>,
53{
54    deserializer.deserialize_tuple(N, ArrayVisitor::<T, N>(PhantomData))
55}
56
57#[cfg(test)]
58mod tests {
59    use serde::{Deserialize, Serialize};
60    use serde_json;
61
62    use super::*;
63
64    /// A helper wrapper struct to use serialize/deserialize hooks on arrays.
65    #[derive(Serialize, Deserialize, Debug, PartialEq)]
66    #[serde(bound(serialize = "", deserialize = ""))]
67    struct Wrapper<const N: usize> {
68        #[serde(serialize_with = "serialize", deserialize_with = "deserialize")]
69        arr: [u32; N],
70    }
71
72    #[test]
73    fn test_array_serde_roundtrip() {
74        let original = Wrapper::<3> { arr: [10, 20, 30] };
75
76        let json = serde_json::to_string(&original).unwrap();
77        assert_eq!(json, r#"{"arr":[10,20,30]}"#);
78
79        let deserialized: Wrapper<3> = serde_json::from_str(&json).unwrap();
80        assert_eq!(deserialized, original);
81
82        let parsed: Wrapper<3> = serde_json::from_str(r#"{"arr":[10,20,30]}"#).unwrap();
83        assert_eq!(parsed.arr, [10, 20, 30]);
84    }
85
86    #[test]
87    fn test_deserialize_wrong_length() {
88        let json = r#"{"arr":[1,2]}"#;
89
90        let result: Result<Wrapper<3>, _> = serde_json::from_str(json);
91        assert!(result.is_err());
92    }
93
94    #[test]
95    fn test_empty_array() {
96        let data = Wrapper::<0> { arr: [] };
97
98        let json = serde_json::to_string(&data).unwrap();
99        assert_eq!(json, r#"{"arr":[]}"#);
100
101        let parsed: Wrapper<0> = serde_json::from_str(&json).unwrap();
102        assert_eq!(parsed, data);
103    }
104}