1use core::marker::PhantomData;
2
3use serde::{de::Visitor, Deserialize, Serialize};
4use unarray::UnarrayArrayExt;
5
6use crate::{ByteArray, Plain};
7
8impl<const N: usize> Serialize for ByteArray<Plain, N> {
9 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
10 where
11 S: serde::Serializer,
12 {
13 self.inner.serialize(serializer)
14 }
15}
16
17impl<'de, const N: usize> Deserialize<'de> for ByteArray<Plain, N> {
18 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
19 where
20 D: serde::Deserializer<'de>,
21 {
22 struct V<const N: usize>;
23
24 impl<'de, const N: usize> Visitor<'de> for V<N> {
25 type Value = [u8; N];
26
27 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
28 write!(formatter, "a byte array of length {N}")
29 }
30
31 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
32 where
33 E: serde::de::Error,
34 {
35 v.try_into().map_err(E::custom)
36 }
37
38 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
39 where
40 A: serde::de::SeqAccess<'de>,
41 {
42 let err_fn = |s: &'static str| -> Result<Self::Value, A::Error> {
43 Err(<A::Error as serde::de::Error>::custom(s))
44 };
45
46 let mut result = [None; N];
48
49 for slot in result.iter_mut() {
50 match seq.next_element()? {
51 Some(elem) => *slot = Some(elem),
52 None => return err_fn("not enough elements"),
53 }
54 }
55
56 match seq.next_element::<u8>() {
57 Ok(None) => {}
58 Ok(Some(_)) => return err_fn("too many elements"),
59 Err(_) => return err_fn("too many elements"),
60 }
61
62 Ok(result.map_option(|i| i).unwrap())
64 }
65 }
66
67 deserializer.deserialize_bytes(V::<N>).map(|inner| Self {
68 inner,
69 _marker: PhantomData,
70 })
71 }
72}
73
74#[cfg(feature = "hex")]
75mod hex_impl {
76 use crate::HexString;
77
78 use super::*;
79
80 impl<const N: usize> Serialize for ByteArray<HexString, N> {
81 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
82 where
83 S: serde::Serializer,
84 {
85 serializer.serialize_str(&hex::encode(self.inner))
86 }
87 }
88
89 impl<'de, const N: usize> Deserialize<'de> for ByteArray<HexString, N> {
90 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
91 where
92 D: serde::Deserializer<'de>,
93 {
94 struct V<const N: usize>;
95
96 impl<const N: usize> Visitor<'_> for V<N> {
97 type Value = [u8; N];
98
99 fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
100 write!(formatter, "a hex string representing a byte array of length {N} (i.e. a hex string with length {})", N * 2)
101 }
102
103 fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
104 where
105 E: serde::de::Error,
106 {
107 let mut buf = [0; N];
108 hex::decode_to_slice(v.trim_start_matches("0x"), &mut buf).map_err(E::custom)?;
109 Ok(buf)
110 }
111 }
112
113 deserializer.deserialize_str(V::<N>).map(|inner| Self {
114 inner,
115 _marker: PhantomData,
116 })
117 }
118 }
119}
120
121#[cfg(all(test, feature = "hex"))]
122mod tests {
123 use serde_json::{from_value, json, to_value};
124
125 use crate::HexString;
126
127 use super::*;
128
129 #[derive(Debug, Deserialize, Serialize)]
130 struct Foo {
131 plain: ByteArray<Plain, 4>,
132 hex: ByteArray<HexString, 4>,
133 }
134
135 #[test]
136 fn serialize_deserialize_sanity() {
137 let value = json!({
138 "plain": [1, 2, 3, 4],
139 "hex": "01020304",
140 });
141
142 let Foo { plain, hex } = from_value(value.clone()).unwrap();
143 let value_again = to_value(&Foo { plain, hex }).unwrap();
144
145 assert_eq!(value, value_again);
146 }
147
148 #[test]
149 fn fails_if_wrong_length() {
150 let plain_too_long = json!({
151 "plain": [1, 2, 3, 4, 5],
152 "hex": "01020304",
153 });
154 from_value::<Foo>(plain_too_long).unwrap_err();
155
156 let plain_too_short = json!({
157 "plain": [1, 2, 3, 4, 5],
158 "hex": "01020304",
159 });
160 from_value::<Foo>(plain_too_short).unwrap_err();
161
162 let hex_too_long = json!({
163 "plain": [1, 2, 3, 4],
164 "hex": "0102030405",
165 });
166 from_value::<Foo>(hex_too_long).unwrap_err();
167
168 let hex_too_short = json!({
169 "plain": [1, 2, 3, 4],
170 "hex": "010203",
171 });
172 from_value::<Foo>(hex_too_short).unwrap_err();
173 }
174}