use serde::{
de::{Error as DeError, Unexpected, Visitor},
Deserializer, Serializer,
};
use core::{array::TryFromSliceError, convert::TryFrom, fmt, marker::PhantomData, mem, slice, str};
#[cfg_attr(docsrs, doc(cfg(feature = "const_len")))]
pub trait ConstHex<T, const N: usize> {
type Error: fmt::Display;
fn create_bytes(value: &T) -> [u8; N];
fn from_bytes(bytes: [u8; N]) -> Result<T, Self::Error>;
fn serialize<S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error> {
fn as_u8_slice(slice: &mut [u16]) -> &mut [u8] {
if slice.is_empty() {
&mut []
} else {
let byte_len = slice.len() * mem::size_of::<u16>();
let data = (slice as *mut [u16]).cast::<u8>();
unsafe {
slice::from_raw_parts_mut(data, byte_len)
}
}
}
let value = Self::create_bytes(value);
if serializer.is_human_readable() {
let mut hex_slice = [0_u16; N];
let hex_slice = as_u8_slice(&mut hex_slice);
hex::encode_to_slice(value, hex_slice).unwrap();
serializer.serialize_str(unsafe {
str::from_utf8_unchecked(hex_slice)
})
} else {
serializer.serialize_bytes(value.as_ref())
}
}
fn deserialize<'de, D>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Default)]
struct HexVisitor<const M: usize>;
impl<'de, const M: usize> Visitor<'de> for HexVisitor<M> {
type Value = [u8; M];
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "hex-encoded byte array of length {}", M)
}
fn visit_str<E: DeError>(self, value: &str) -> Result<Self::Value, E> {
let mut decoded = [0_u8; M];
hex::decode_to_slice(value, &mut decoded)
.map_err(|_| E::invalid_type(Unexpected::Str(value), &self))?;
Ok(decoded)
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
<[u8; M]>::try_from(value).map_err(|_| E::invalid_length(value.len(), &self))
}
}
#[derive(Default)]
struct BytesVisitor<const M: usize>;
impl<'de, const M: usize> Visitor<'de> for BytesVisitor<M> {
type Value = [u8; M];
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "byte array of length {}", M)
}
fn visit_bytes<E: DeError>(self, value: &[u8]) -> Result<Self::Value, E> {
<[u8; M]>::try_from(value).map_err(|_| E::invalid_length(value.len(), &self))
}
}
let maybe_bytes = if deserializer.is_human_readable() {
deserializer.deserialize_str(HexVisitor::default())
} else {
deserializer.deserialize_bytes(BytesVisitor::default())
};
maybe_bytes.and_then(|bytes| Self::from_bytes(bytes).map_err(D::Error::custom))
}
}
#[cfg_attr(docsrs, doc(cfg(feature = "const_len")))]
#[derive(Debug)]
pub struct ConstHexForm<T>(PhantomData<T>);
impl<const N: usize> ConstHex<[u8; N], N> for ConstHexForm<[u8; N]> {
type Error = TryFromSliceError;
fn create_bytes(buffer: &[u8; N]) -> [u8; N] {
*buffer
}
fn from_bytes(bytes: [u8; N]) -> Result<[u8; N], Self::Error> {
Ok(bytes)
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
use serde_derive::{Deserialize, Serialize};
#[derive(Debug, PartialEq, Serialize, Deserialize)]
struct Arrays {
#[serde(with = "ConstHexForm")]
array: [u8; 16],
#[serde(with = "ConstHexForm")]
longer_array: [u8; 32],
}
#[test]
fn serializing_arrays() {
let arrays = Arrays {
array: [11; 16],
longer_array: [240; 32],
};
let json = serde_json::to_string(&arrays).unwrap();
assert!(json.contains(&"0b".repeat(16)));
let arrays_copy: Arrays = serde_json::from_str(&json).unwrap();
assert_eq!(arrays_copy, arrays);
}
#[test]
fn deserializing_array_with_incorrect_length() {
let json = serde_json::json!({
"array": "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
"longer_array": "0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b0b",
});
let err = serde_json::from_value::<Arrays>(json)
.unwrap_err()
.to_string();
assert!(err.contains("invalid type"), "{}", err);
assert!(err.contains("expected hex-encoded byte array"), "{}", err);
}
#[test]
fn deserializing_array_with_incorrect_length_from_binary_format() {
#[derive(Debug, Serialize, Deserialize)]
struct ArrayHolder<const N: usize>(#[serde(with = "ConstHexForm")] [u8; N]);
let buffer = bincode::serialize(&ArrayHolder([5; 6])).unwrap();
let err = bincode::deserialize::<ArrayHolder<4>>(&buffer).unwrap_err();
assert_eq!(
err.to_string(),
"invalid length 6, expected byte array of length 4"
);
}
#[test]
fn custom_type() {
use ed25519_compact::PublicKey;
struct PublicKeyHex(());
impl ConstHex<PublicKey, 32> for PublicKeyHex {
type Error = ed25519_compact::Error;
fn create_bytes(pk: &PublicKey) -> [u8; 32] {
**pk
}
fn from_bytes(bytes: [u8; 32]) -> Result<PublicKey, Self::Error> {
PublicKey::from_slice(&bytes)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct Holder {
#[serde(with = "PublicKeyHex")]
public_key: PublicKey,
}
let json = serde_json::json!({
"public_key": "06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f61115075",
});
let holder: Holder = serde_json::from_value(json).unwrap();
assert_eq!(holder.public_key[0], 6);
let bogus_json = serde_json::json!({
"public_key": "06fac1f22240cffd637ead6647188429fafda9c9cb7eae43386ac17f6111507",
});
let err = serde_json::from_value::<Holder>(bogus_json).unwrap_err();
assert!(err
.to_string()
.contains("expected hex-encoded byte array of length 32"));
}
}