hex_buffer_serde/
const_len.rs1use serde::{
4 de::{Error as DeError, Unexpected, Visitor},
5 Deserializer, Serializer,
6};
7
8use core::{array::TryFromSliceError, convert::TryFrom, fmt, marker::PhantomData, mem, slice, str};
9
10#[cfg_attr(docsrs, doc(cfg(feature = "const_len")))]
75pub trait ConstHex<T, const N: usize> {
76 type Error: fmt::Display;
78
79 fn create_bytes(value: &T) -> [u8; N];
81
82 fn from_bytes(bytes: [u8; N]) -> Result<T, Self::Error>;
89
90 fn serialize<S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error> {
98 fn as_u8_slice(slice: &mut [u16]) -> &mut [u8] {
101 if slice.is_empty() {
102 &mut []
105 } else {
106 let byte_len = slice.len() * mem::size_of::<u16>();
107 let data = (slice as *mut [u16]).cast::<u8>();
108 unsafe {
109 slice::from_raw_parts_mut(data, byte_len)
112 }
113 }
114 }
115
116 let value = Self::create_bytes(value);
117 if serializer.is_human_readable() {
118 let mut hex_slice = [0_u16; N];
119 let hex_slice = as_u8_slice(&mut hex_slice);
120
121 hex::encode_to_slice(value, hex_slice).unwrap();
122 serializer.serialize_str(unsafe {
124 str::from_utf8_unchecked(hex_slice)
126 })
127 } else {
128 serializer.serialize_bytes(value.as_ref())
129 }
130 }
131
132 fn deserialize<'de, D>(deserializer: D) -> Result<T, D::Error>
139 where
140 D: Deserializer<'de>,
141 {
142 #[derive(Default)]
143 struct HexVisitor<const M: usize>;
144
145 impl<'de, const M: usize> Visitor<'de> for HexVisitor<M> {
146 type Value = [u8; M];
147
148 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
149 write!(formatter, "hex-encoded byte array of length {}", M)
150 }
151
152 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
153 let mut decoded = [0_u8; M];
154 hex::decode_to_slice(value, &mut decoded)
155 .map_err(|_| E::invalid_type(Unexpected::Str(value), &self))?;
156 Ok(decoded)
157 }
158
159 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
160 <[u8; M]>::try_from(value).map_err(|_| E::invalid_length(value.len(), &self))
161 }
162 }
163
164 #[derive(Default)]
165 struct BytesVisitor<const M: usize>;
166
167 impl<'de, const M: usize> Visitor<'de> for BytesVisitor<M> {
168 type Value = [u8; M];
169
170 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
171 write!(formatter, "byte array of length {}", M)
172 }
173
174 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
175 <[u8; M]>::try_from(value).map_err(|_| E::invalid_length(value.len(), &self))
176 }
177 }
178
179 let maybe_bytes = if deserializer.is_human_readable() {
180 deserializer.deserialize_str(HexVisitor::default())
181 } else {
182 deserializer.deserialize_bytes(BytesVisitor::default())
183 };
184 maybe_bytes.and_then(|bytes| Self::from_bytes(bytes).map_err(D::Error::custom))
185 }
186}
187
188#[cfg_attr(docsrs, doc(cfg(feature = "const_len")))]
191#[derive(Debug)]
192pub struct ConstHexForm<T>(PhantomData<T>);
193
194impl<const N: usize> ConstHex<[u8; N], N> for ConstHexForm<[u8; N]> {
195 type Error = TryFromSliceError;
196
197 fn create_bytes(buffer: &[u8; N]) -> [u8; N] {
198 *buffer
199 }
200
201 fn from_bytes(bytes: [u8; N]) -> Result<[u8; N], Self::Error> {
202 Ok(bytes)
203 }
204}
205
206#[cfg(test)]
207mod tests {
208 use super::*;
209
210 use alloc::string::ToString;
211 use serde_derive::{Deserialize, Serialize};
212
213 #[derive(Debug, PartialEq, Serialize, Deserialize)]
214 struct Arrays {
215 #[serde(with = "ConstHexForm")]
216 array: [u8; 16],
217 #[serde(with = "ConstHexForm")]
218 longer_array: [u8; 32],
219 }
220
221 #[test]
222 fn serializing_arrays() {
223 let arrays = Arrays {
224 array: [11; 16],
225 longer_array: [240; 32],
226 };
227 let json = serde_json::to_string(&arrays).unwrap();
228 assert!(json.contains(&"0b".repeat(16)));
229
230 let arrays_copy: Arrays = serde_json::from_str(&json).unwrap();
231 assert_eq!(arrays_copy, arrays);
232 }
233
234 #[test]
235 fn deserializing_array_with_incorrect_length() {
236 let json = serde_json::json!({
237 "array": "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
238 "longer_array": "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
239 });
240 let err = serde_json::from_value::<Arrays>(json)
241 .unwrap_err()
242 .to_string();
243
244 assert!(err.contains("invalid type"), "{}", err);
245 assert!(err.contains("expected hex-encoded byte array"), "{}", err);
246 }
247
248 #[test]
249 fn deserializing_array_with_incorrect_length_from_binary_format() {
250 #[derive(Debug, Serialize, Deserialize)]
251 struct ArrayHolder<const N: usize>(#[serde(with = "ConstHexForm")] [u8; N]);
252
253 let buffer = bincode::serialize(&ArrayHolder([5; 6])).unwrap();
254 let err = bincode::deserialize::<ArrayHolder<4>>(&buffer).unwrap_err();
255
256 assert_eq!(
257 err.to_string(),
258 "invalid length 6, expected byte array of length 4"
259 );
260 }
261
262 #[test]
263 fn custom_type() {
264 use ed25519_compact::PublicKey;
265
266 struct PublicKeyHex(());
267 impl ConstHex<PublicKey, 32> for PublicKeyHex {
268 type Error = ed25519_compact::Error;
269
270 fn create_bytes(pk: &PublicKey) -> [u8; 32] {
271 **pk
272 }
273
274 fn from_bytes(bytes: [u8; 32]) -> Result<PublicKey, Self::Error> {
275 PublicKey::from_slice(&bytes)
276 }
277 }
278
279 #[derive(Debug, Serialize, Deserialize)]
280 struct Holder {
281 #[serde(with = "PublicKeyHex")]
282 public_key: PublicKey,
283 }
284
285 let json = serde_json::json!({
286 "public_key": "06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f61115075",
287 });
288 let holder: Holder = serde_json::from_value(json).unwrap();
289 assert_eq!(holder.public_key[0], 6);
290
291 let bogus_json = serde_json::json!({
292 "public_key": "06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f6111507",
293 });
294 let err = serde_json::from_value::<Holder>(bogus_json).unwrap_err();
295 assert!(err
296 .to_string()
297 .contains("expected hex-encoded byte array of length 32"));
298 }
299}