1use serde::{
4 de::{Error as DeError, Unexpected, Visitor},
5 Deserializer, Serializer,
6};
7
8use alloc::{borrow::Cow, vec::Vec};
9use core::{convert::TryFrom, fmt, marker::PhantomData};
10
11#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
20pub trait Hex<T> {
21 type Error: fmt::Display;
23
24 fn create_bytes(value: &T) -> Cow<'_, [u8]>;
28
29 fn from_bytes(bytes: &[u8]) -> Result<T, Self::Error>;
36
37 fn serialize<S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error> {
46 let value = Self::create_bytes(value);
47 if serializer.is_human_readable() {
48 serializer.serialize_str(&hex::encode(value))
49 } else {
50 serializer.serialize_bytes(value.as_ref())
51 }
52 }
53
54 fn deserialize<'de, D>(deserializer: D) -> Result<T, D::Error>
61 where
62 D: Deserializer<'de>,
63 {
64 struct HexVisitor;
65
66 impl<'de> Visitor<'de> for HexVisitor {
67 type Value = Vec<u8>;
68
69 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
70 formatter.write_str("hex-encoded byte array")
71 }
72
73 fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
74 hex::decode(value).map_err(|_| E::invalid_type(Unexpected::Str(value), &self))
75 }
76
77 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
79 Ok(value.to_vec())
80 }
81 }
82
83 struct BytesVisitor;
84
85 impl<'de> Visitor<'de> for BytesVisitor {
86 type Value = Vec<u8>;
87
88 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
89 formatter.write_str("byte array")
90 }
91
92 fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
93 Ok(value.to_vec())
94 }
95
96 fn visit_byte_buf<E: DeError>(self, value: Vec<u8>) -> Result<Self::Value, E> {
97 Ok(value)
98 }
99 }
100
101 let maybe_bytes = if deserializer.is_human_readable() {
102 deserializer.deserialize_str(HexVisitor)
103 } else {
104 deserializer.deserialize_byte_buf(BytesVisitor)
105 };
106 maybe_bytes.and_then(|bytes| Self::from_bytes(&bytes).map_err(D::Error::custom))
107 }
108}
109
110#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
119#[derive(Debug)]
120pub struct HexForm<T>(PhantomData<T>);
121
122impl<T, E> Hex<T> for HexForm<T>
123where
124 T: AsRef<[u8]> + for<'a> TryFrom<&'a [u8], Error = E>,
125 E: fmt::Display,
126{
127 type Error = E;
128
129 fn create_bytes(buffer: &T) -> Cow<'_, [u8]> {
130 Cow::Borrowed(buffer.as_ref())
131 }
132
133 fn from_bytes(bytes: &[u8]) -> Result<T, Self::Error> {
134 T::try_from(bytes)
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141
142 use serde_derive::{Deserialize, Serialize};
143 use serde_json::json;
144
145 use alloc::{
146 borrow::ToOwned,
147 string::{String, ToString},
148 vec,
149 };
150 use core::array::TryFromSliceError;
151
152 #[derive(Debug, Serialize, Deserialize)]
153 struct Buffer([u8; 8]);
154
155 impl AsRef<[u8]> for Buffer {
156 fn as_ref(&self) -> &[u8] {
157 &self.0
158 }
159 }
160
161 impl TryFrom<&[u8]> for Buffer {
162 type Error = TryFromSliceError;
163
164 fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
165 <[u8; 8]>::try_from(slice).map(Buffer)
166 }
167 }
168
169 #[derive(Debug, Serialize, Deserialize)]
170 struct Test {
171 #[serde(with = "HexForm::<Buffer>")]
172 buffer: Buffer,
173 other_field: String,
174 }
175
176 #[test]
177 fn internal_type() {
178 let json = json!({ "buffer": "0001020304050607", "other_field": "abc" });
179 let value: Test = serde_json::from_value(json.clone()).unwrap();
180 assert!(value
181 .buffer
182 .0
183 .iter()
184 .enumerate()
185 .all(|(i, &byte)| i == usize::from(byte)));
186
187 let json_copy = serde_json::to_value(&value).unwrap();
188 assert_eq!(json, json_copy);
189 }
190
191 #[test]
192 fn error_reporting() {
193 let bogus_jsons = vec![
194 serde_json::json!({
195 "buffer": "bogus",
196 "other_field": "test",
197 }),
198 serde_json::json!({
199 "buffer": "c0ffe",
200 "other_field": "test",
201 }),
202 ];
203
204 for bogus_json in bogus_jsons {
205 let err = serde_json::from_value::<Test>(bogus_json)
206 .unwrap_err()
207 .to_string();
208 assert!(err.contains("expected hex-encoded byte array"), "{}", err);
209 }
210 }
211
212 #[test]
213 fn internal_type_with_derived_serde_code() {
214 #[derive(Serialize, Deserialize)]
216 struct OriginalTest {
217 buffer: Buffer,
218 other_field: String,
219 }
220
221 let test = Test {
222 buffer: Buffer([1; 8]),
223 other_field: "a".to_owned(),
224 };
225 assert_eq!(
226 serde_json::to_value(test).unwrap(),
227 json!({
228 "buffer": "0101010101010101",
229 "other_field": "a",
230 })
231 );
232
233 let test = OriginalTest {
234 buffer: Buffer([1; 8]),
235 other_field: "a".to_owned(),
236 };
237 assert_eq!(
238 serde_json::to_value(test).unwrap(),
239 json!({
240 "buffer": [1, 1, 1, 1, 1, 1, 1, 1],
241 "other_field": "a",
242 })
243 );
244 }
245
246 #[test]
247 fn external_type() {
248 #[derive(Debug, PartialEq, Eq)]
249 pub struct Buffer([u8; 8]);
250
251 struct BufferHex(());
252
253 impl Hex<Buffer> for BufferHex {
254 type Error = &'static str;
255
256 fn create_bytes(buffer: &Buffer) -> Cow<'_, [u8]> {
257 Cow::Borrowed(&buffer.0)
258 }
259
260 fn from_bytes(bytes: &[u8]) -> Result<Buffer, Self::Error> {
261 if bytes.len() == 8 {
262 let mut inner = [0; 8];
263 inner.copy_from_slice(bytes);
264 Ok(Buffer(inner))
265 } else {
266 Err("invalid buffer length")
267 }
268 }
269 }
270
271 #[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
272 struct Test {
273 #[serde(with = "BufferHex")]
274 buffer: Buffer,
275 other_field: String,
276 }
277
278 let json = json!({ "buffer": "0001020304050607", "other_field": "abc" });
279 let value: Test = serde_json::from_value(json.clone()).unwrap();
280 assert!(value
281 .buffer
282 .0
283 .iter()
284 .enumerate()
285 .all(|(i, &byte)| i == usize::from(byte)));
286
287 let json_copy = serde_json::to_value(&value).unwrap();
288 assert_eq!(json, json_copy);
289
290 let buffer = bincode::serialize(&value).unwrap();
292 let buffer_hex = hex::encode(&buffer);
294 let needle = "0001020304050607";
297 assert!(buffer_hex.contains(needle));
298
299 let value_copy: Test = bincode::deserialize(&buffer).unwrap();
300 assert_eq!(value_copy, value);
301 }
302
303 #[test]
304 fn deserializing_flattened_field() {
305 #[derive(Debug, PartialEq, Serialize, Deserialize)]
309 struct Inner {
310 #[serde(with = "HexForm")]
311 x: Vec<u8>,
312 #[serde(with = "HexForm")]
313 y: [u8; 16],
314 }
315
316 #[derive(Debug, PartialEq, Serialize, Deserialize)]
317 struct Outer {
318 #[serde(flatten)]
319 inner: Inner,
320 z: String,
321 }
322
323 let value = Outer {
324 inner: Inner {
325 x: vec![1; 8],
326 y: [0; 16],
327 },
328 z: "test".to_owned(),
329 };
330
331 let bytes = serde_cbor::to_vec(&value).unwrap();
332 let bytes_hex = hex::encode(&bytes);
333 assert!(bytes_hex.contains(&"01".repeat(8)));
335 assert!(bytes_hex.contains(&"00".repeat(16)));
336 let value_copy = serde_cbor::from_slice(&bytes).unwrap();
337 assert_eq!(value, value_copy);
338 }
339}