bytekind/array/
serde.rs

1use core::marker::PhantomData;
2
3use serde::{de::Visitor, Deserialize, Serialize};
4use unarray::UnarrayArrayExt;
5
6use crate::{ByteArray, Plain};
7
8impl<const N: usize> Serialize for ByteArray<Plain, N> {
9    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10    where
11        S: serde::Serializer,
12    {
13        self.inner.serialize(serializer)
14    }
15}
16
17impl<'de, const N: usize> Deserialize<'de> for ByteArray<Plain, N> {
18    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
19    where
20        D: serde::Deserializer<'de>,
21    {
22        struct V<const N: usize>;
23
24        impl<'de, const N: usize> Visitor<'de> for V<N> {
25            type Value = [u8; N];
26
27            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
28                write!(formatter, "a byte array of length {N}")
29            }
30
31            fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
32            where
33                E: serde::de::Error,
34            {
35                v.try_into().map_err(E::custom)
36            }
37
38            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
39            where
40                A: serde::de::SeqAccess<'de>,
41            {
42                let err_fn = |s: &'static str| -> Result<Self::Value, A::Error> {
43                    Err(<A::Error as serde::de::Error>::custom(s))
44                };
45
46                // this is not the most efficient algorithm but it works
47                let mut result = [None; N];
48
49                for slot in result.iter_mut() {
50                    match seq.next_element()? {
51                        Some(elem) => *slot = Some(elem),
52                        None => return err_fn("not enough elements"),
53                    }
54                }
55
56                match seq.next_element::<u8>() {
57                    Ok(None) => {}
58                    Ok(Some(_)) => return err_fn("too many elements"),
59                    Err(_) => return err_fn("too many elements"),
60                }
61
62                // unwrap is fine here because all elements should be `Some`
63                Ok(result.map_option(|i| i).unwrap())
64            }
65        }
66
67        deserializer.deserialize_bytes(V::<N>).map(|inner| Self {
68            inner,
69            _marker: PhantomData,
70        })
71    }
72}
73
74#[cfg(feature = "hex")]
75mod hex_impl {
76    use crate::HexString;
77
78    use super::*;
79
80    impl<const N: usize> Serialize for ByteArray<HexString, N> {
81        fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82        where
83            S: serde::Serializer,
84        {
85            serializer.serialize_str(&hex::encode(self.inner))
86        }
87    }
88
89    impl<'de, const N: usize> Deserialize<'de> for ByteArray<HexString, N> {
90        fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
91        where
92            D: serde::Deserializer<'de>,
93        {
94            struct V<const N: usize>;
95
96            impl<const N: usize> Visitor<'_> for V<N> {
97                type Value = [u8; N];
98
99                fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
100                    write!(formatter, "a hex string representing a byte array of length {N} (i.e. a hex string with length {})", N * 2)
101                }
102
103                fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
104                where
105                    E: serde::de::Error,
106                {
107                    let mut buf = [0; N];
108                    hex::decode_to_slice(v.trim_start_matches("0x"), &mut buf).map_err(E::custom)?;
109                    Ok(buf)
110                }
111            }
112
113            deserializer.deserialize_str(V::<N>).map(|inner| Self {
114                inner,
115                _marker: PhantomData,
116            })
117        }
118    }
119}
120
121#[cfg(all(test, feature = "hex"))]
122mod tests {
123    use serde_json::{from_value, json, to_value};
124
125    use crate::HexString;
126
127    use super::*;
128
129    #[derive(Debug, Deserialize, Serialize)]
130    struct Foo {
131        plain: ByteArray<Plain, 4>,
132        hex: ByteArray<HexString, 4>,
133    }
134
135    #[test]
136    fn serialize_deserialize_sanity() {
137        let value = json!({
138            "plain": [1, 2, 3, 4],
139            "hex": "01020304",
140        });
141
142        let Foo { plain, hex } = from_value(value.clone()).unwrap();
143        let value_again = to_value(&Foo { plain, hex }).unwrap();
144
145        assert_eq!(value, value_again);
146    }
147
148    #[test]
149    fn fails_if_wrong_length() {
150        let plain_too_long = json!({
151            "plain": [1, 2, 3, 4, 5],
152            "hex": "01020304",
153        });
154        from_value::<Foo>(plain_too_long).unwrap_err();
155
156        let plain_too_short = json!({
157            "plain": [1, 2, 3, 4, 5],
158            "hex": "01020304",
159        });
160        from_value::<Foo>(plain_too_short).unwrap_err();
161
162        let hex_too_long = json!({
163            "plain": [1, 2, 3, 4],
164            "hex": "0102030405",
165        });
166        from_value::<Foo>(hex_too_long).unwrap_err();
167
168        let hex_too_short = json!({
169            "plain": [1, 2, 3, 4],
170            "hex": "010203",
171        });
172        from_value::<Foo>(hex_too_short).unwrap_err();
173    }
174}