use alloc::boxed::Box;
use alloc::format;
use core::any::type_name;
use core::fmt;
use core::marker::PhantomData;
use base64::{engine::general_purpose, Engine as _};
use serde::{de, Deserializer, Serializer};
pub(crate) enum Encoding {
Base64,
Hex,
}
struct B64Visitor<T>(PhantomData<T>);
impl<'de, T> de::Visitor<'de> for B64Visitor<T>
where
T: TryFromBytes,
{
type Value = T;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "b64-encoded {} bytes", type_name::<T>())
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
let bytes = general_purpose::STANDARD_NO_PAD
.decode(v)
.map_err(de::Error::custom)?;
T::try_from_bytes(&bytes).map_err(de::Error::custom)
}
}
struct HexVisitor<T>(PhantomData<T>);
impl<'de, T> de::Visitor<'de> for HexVisitor<T>
where
T: TryFromBytes,
{
type Value = T;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "0x-prefixed hex-encoded bytes of {}", type_name::<T>())
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: de::Error,
{
if v.len() < 2 {
return Err(de::Error::invalid_length(
v.len(),
&"0x-prefixed hex-encoded bytes",
));
}
if &v[..2] != "0x" {
return Err(de::Error::invalid_value(
de::Unexpected::Str(v),
&"0x-prefixed hex-encoded bytes",
));
}
let bytes = hex::decode(&v[2..]).map_err(de::Error::custom)?;
T::try_from_bytes(&bytes).map_err(de::Error::custom)
}
}
struct BytesVisitor<T>(PhantomData<T>);
impl<'de, T> de::Visitor<'de> for BytesVisitor<T>
where
T: TryFromBytes,
{
type Value = T;
fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} bytes", type_name::<T>())
}
fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
where
E: de::Error,
{
T::try_from_bytes(v).map_err(de::Error::custom)
}
}
pub(crate) fn serialize_with_encoding<T, S>(
obj: &T,
serializer: S,
encoding: Encoding,
) -> Result<S::Ok, S::Error>
where
T: AsRef<[u8]>,
S: Serializer,
{
if serializer.is_human_readable() {
let encoded = match encoding {
Encoding::Base64 => general_purpose::STANDARD_NO_PAD.encode(obj.as_ref()),
Encoding::Hex => format!("0x{}", hex::encode(obj.as_ref())),
};
serializer.serialize_str(&encoded)
} else {
serializer.serialize_bytes(obj.as_ref())
}
}
pub(crate) fn deserialize_with_encoding<'de, T, D>(
deserializer: D,
encoding: Encoding,
) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: TryFromBytes,
{
if deserializer.is_human_readable() {
match encoding {
Encoding::Base64 => deserializer.deserialize_str(B64Visitor::<T>(PhantomData)),
Encoding::Hex => deserializer.deserialize_str(HexVisitor::<T>(PhantomData)),
}
} else {
deserializer.deserialize_bytes(BytesVisitor::<T>(PhantomData))
}
}
pub mod as_hex {
use super::*;
pub fn serialize<T, S>(obj: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: AsRef<[u8]>,
S: Serializer,
{
serialize_with_encoding(obj, serializer, Encoding::Hex)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: TryFromBytes,
{
deserialize_with_encoding(deserializer, Encoding::Hex)
}
}
pub mod as_base64 {
use super::*;
pub fn serialize<T, S>(obj: &T, serializer: S) -> Result<S::Ok, S::Error>
where
T: AsRef<[u8]>,
S: Serializer,
{
serialize_with_encoding(obj, serializer, Encoding::Base64)
}
pub fn deserialize<'de, T, D>(deserializer: D) -> Result<T, D::Error>
where
D: Deserializer<'de>,
T: TryFromBytes,
{
deserialize_with_encoding(deserializer, Encoding::Base64)
}
}
pub trait TryFromBytes: Sized {
type Error: fmt::Display;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error>;
}
impl<const N: usize> TryFromBytes for [u8; N] {
type Error = core::array::TryFromSliceError;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
Self::try_from(bytes)
}
}
impl TryFromBytes for Box<[u8]> {
type Error = core::convert::Infallible;
fn try_from_bytes(bytes: &[u8]) -> Result<Self, Self::Error> {
Ok(bytes.into())
}
}
#[cfg(test)]
pub(crate) mod tests {
use core::fmt;
use serde::de::DeserializeOwned;
use serde::Serialize;
pub(crate) fn check_serialization_roundtrip<T>(obj: &T)
where
T: fmt::Debug + PartialEq + Serialize + DeserializeOwned,
{
let serialized = serde_json::to_string(obj).unwrap();
let deserialized: T = serde_json::from_str(&serialized).unwrap();
assert_eq!(obj, &deserialized);
let serialized = rmp_serde::to_vec(obj).unwrap();
let deserialized: T = rmp_serde::from_slice(&serialized).unwrap();
assert_eq!(obj, &deserialized);
}
}