use serde::{
de::{Error as DeError, Unexpected, Visitor},
Deserializer, Serializer,
};
use alloc::{borrow::Cow, vec::Vec};
use core::{convert::TryFrom, fmt, marker::PhantomData};
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
pub trait Hex<T> {
type Error: fmt::Display;
fn create_bytes(value: &T) -> Cow<'_, [u8]>;
fn from_bytes(bytes: &[u8]) -> Result<T, Self::Error>;
fn serialize<S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error> {
let value = Self::create_bytes(value);
if serializer.is_human_readable() {
serializer.serialize_str(&hex::encode(value))
} else {
serializer.serialize_bytes(value.as_ref())
}
}
fn deserialize<'de, D>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
{
struct HexVisitor;
impl<'de> Visitor<'de> for HexVisitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("hex-encoded byte array")
}
fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
hex::decode(value).map_err(|_| E::invalid_type(Unexpected::Str(value), &self))
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
Ok(value.to_vec())
}
}
struct BytesVisitor;
impl<'de> Visitor<'de> for BytesVisitor {
type Value = Vec<u8>;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
formatter.write_str("byte array")
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
Ok(value.to_vec())
}
}
let maybe_bytes = if deserializer.is_human_readable() {
deserializer.deserialize_str(HexVisitor)
} else {
deserializer.deserialize_bytes(BytesVisitor)
};
maybe_bytes.and_then(|bytes| Self::from_bytes(&bytes).map_err(D::Error::custom))
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "alloc")))]
#[derive(Debug)]
pub struct HexForm<T>(PhantomData<T>);
impl<T, E> Hex<T> for HexForm<T>
where
T: AsRef<[u8]> + for<'a> TryFrom<&'a [u8], Error = E>,
E: fmt::Display,
{
type Error = E;
fn create_bytes(buffer: &T) -> Cow<'_, [u8]> {
Cow::Borrowed(buffer.as_ref())
}
fn from_bytes(bytes: &[u8]) -> Result<T, Self::Error> {
T::try_from(bytes)
}
}
#[cfg(test)]
#[allow(renamed_and_removed_lints, clippy::unknown_clippy_lints)]
mod tests {
use super::*;
use serde_derive::{Deserialize, Serialize};
use serde_json::json;
use alloc::{
borrow::ToOwned,
string::{String, ToString},
vec,
};
use core::array::TryFromSliceError;
#[derive(Debug, Serialize, Deserialize)]
struct Buffer([u8; 8]);
impl AsRef<[u8]> for Buffer {
fn as_ref(&self) -> &[u8] {
&self.0
}
}
impl TryFrom<&[u8]> for Buffer {
type Error = TryFromSliceError;
fn try_from(slice: &[u8]) -> Result<Self, Self::Error> {
<[u8; 8]>::try_from(slice).map(Buffer)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Test {
#[serde(with = "HexForm::<Buffer>")]
buffer: Buffer,
other_field: String,
}
#[test]
fn internal_type() {
let json = json!({ "buffer": "0001020304050607", "other_field": "abc" });
let value: Test = serde_json::from_value(json.clone()).unwrap();
assert!(value
.buffer
.0
.iter()
.enumerate()
.all(|(i, &byte)| i == usize::from(byte)));
let json_copy = serde_json::to_value(&value).unwrap();
assert_eq!(json, json_copy);
}
#[test]
fn error_reporting() {
let bogus_jsons = vec![
serde_json::json!({
"buffer": "bogus",
"other_field": "test",
}),
serde_json::json!({
"buffer": "c0ffe",
"other_field": "test",
}),
];
for bogus_json in bogus_jsons {
let err = serde_json::from_value::<Test>(bogus_json)
.unwrap_err()
.to_string();
assert!(err.contains("expected hex-encoded byte array"), "{}", err);
}
}
#[test]
fn internal_type_with_derived_serde_code() {
#[derive(Serialize, Deserialize)]
struct OriginalTest {
buffer: Buffer,
other_field: String,
}
let test = Test {
buffer: Buffer([1; 8]),
other_field: "a".to_owned(),
};
assert_eq!(
serde_json::to_value(test).unwrap(),
json!({
"buffer": "0101010101010101",
"other_field": "a",
})
);
let test = OriginalTest {
buffer: Buffer([1; 8]),
other_field: "a".to_owned(),
};
assert_eq!(
serde_json::to_value(test).unwrap(),
json!({
"buffer": [1, 1, 1, 1, 1, 1, 1, 1],
"other_field": "a",
})
);
}
#[test]
fn external_type() {
#[derive(Debug, PartialEq, Eq)]
pub struct Buffer([u8; 8]);
struct BufferHex(());
impl Hex<Buffer> for BufferHex {
type Error = &'static str;
fn create_bytes(buffer: &Buffer) -> Cow<'_, [u8]> {
Cow::Borrowed(&buffer.0)
}
fn from_bytes(bytes: &[u8]) -> Result<Buffer, Self::Error> {
if bytes.len() == 8 {
let mut inner = [0; 8];
inner.copy_from_slice(bytes);
Ok(Buffer(inner))
} else {
Err("invalid buffer length")
}
}
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Test {
#[serde(with = "BufferHex")]
buffer: Buffer,
other_field: String,
}
let json = json!({ "buffer": "0001020304050607", "other_field": "abc" });
let value: Test = serde_json::from_value(json.clone()).unwrap();
assert!(value
.buffer
.0
.iter()
.enumerate()
.all(|(i, &byte)| i == usize::from(byte)));
let json_copy = serde_json::to_value(&value).unwrap();
assert_eq!(json, json_copy);
let buffer = bincode::serialize(&value).unwrap();
let buffer_hex = hex::encode(&buffer);
let needle = "0001020304050607";
assert!(buffer_hex.contains(needle));
let value_copy: Test = bincode::deserialize(&buffer).unwrap();
assert_eq!(value_copy, value);
}
#[test]
fn deserializing_flattened_field() {
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Inner {
#[serde(with = "HexForm")]
x: Vec<u8>,
#[serde(with = "HexForm")]
y: [u8; 16],
}
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Outer {
#[serde(flatten)]
inner: Inner,
z: String,
}
let value = Outer {
inner: Inner {
x: vec![1; 8],
y: [0; 16],
},
z: "test".to_owned(),
};
let bytes = serde_cbor::to_vec(&value).unwrap();
let bytes_hex = hex::encode(&bytes);
assert!(bytes_hex.contains(&"01".repeat(8)));
assert!(bytes_hex.contains(&"00".repeat(16)));
let value_copy = serde_cbor::from_slice(&bytes).unwrap();
assert_eq!(value, value_copy);
}
}