use std::hash::{BuildHasher, BuildHasherDefault, Hash};
use twox_hash::XxHash64;
pub trait Key: Send + Sync + 'static + Hash + Eq {}
pub trait Value: Send + Sync + 'static {}
impl<T: Send + Sync + 'static + std::hash::Hash + Eq> Key for T {}
impl<T: Send + Sync + 'static> Value for T {}
pub trait HashBuilder: BuildHasher + Send + Sync + 'static {}
impl<T> HashBuilder for T where T: BuildHasher + Send + Sync + 'static {}
pub type DefaultHasher = BuildHasherDefault<XxHash64>;
#[derive(Debug, thiserror::Error)]
pub enum CodeError {
#[error("exceed size limit")]
SizeLimit,
#[error("io error: {0}")]
Io(std::io::Error),
#[cfg(feature = "serde")]
#[error("bincode error: {0}")]
Bincode(bincode::Error),
#[error("unrecognized data: {0:?}")]
Unrecognized(Vec<u8>),
#[error("other error: {0}")]
Other(#[from] Box<dyn std::error::Error + Send + Sync>),
}
pub type CodeResult<T> = std::result::Result<T, CodeError>;
impl From<std::io::Error> for CodeError {
fn from(err: std::io::Error) -> Self {
match err.kind() {
std::io::ErrorKind::WriteZero => Self::SizeLimit,
_ => Self::Io(err),
}
}
}
#[cfg(feature = "serde")]
impl From<bincode::Error> for CodeError {
fn from(err: bincode::Error) -> Self {
match *err {
bincode::ErrorKind::SizeLimit => Self::SizeLimit,
bincode::ErrorKind::Io(e) => e.into(),
_ => Self::Bincode(err),
}
}
}
pub trait StorageKey: Key + Code {}
impl<T> StorageKey for T where T: Key + Code {}
pub trait StorageValue: Value + 'static + Code {}
impl<T> StorageValue for T where T: Value + Code {}
pub trait Code {
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError>;
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
where
Self: Sized;
fn estimated_size(&self) -> usize;
}
#[cfg(feature = "serde")]
impl<T> Code for T
where
T: serde::Serialize + serde::de::DeserializeOwned,
{
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
bincode::serialize_into(writer, self).map_err(CodeError::from)
}
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError> {
bincode::deserialize_from(reader).map_err(CodeError::from)
}
fn estimated_size(&self) -> usize {
bincode::serialized_size(self).unwrap() as usize
}
}
macro_rules! impl_serde_for_numeric_types {
($($t:ty),*) => {
$(
#[cfg(not(feature = "serde"))]
impl Code for $t {
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
writer.write_all(&self.to_le_bytes()).map_err(CodeError::from)
}
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError> {
let mut buf = [0u8; std::mem::size_of::<$t>()];
reader.read_exact(&mut buf).map_err(CodeError::from)?;
Ok(<$t>::from_le_bytes(buf))
}
fn estimated_size(&self) -> usize {
std::mem::size_of::<$t>()
}
}
)*
};
}
macro_rules! for_all_numeric_types {
($macro:ident) => {
$macro! { u8, u16, u32, u64, u128, usize, i8, i16, i32, i64, i128, isize, f32, f64}
};
}
for_all_numeric_types! { impl_serde_for_numeric_types }
#[cfg(not(feature = "serde"))]
impl Code for bool {
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
writer
.write_all(if *self { &[1u8] } else { &[0u8] })
.map_err(CodeError::from)
}
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
where
Self: Sized,
{
let mut buf = [0u8; 1];
reader.read_exact(&mut buf).map_err(CodeError::from)?;
match buf[0] {
0 => Ok(false),
1 => Ok(true),
_ => Err(CodeError::Unrecognized(buf.to_vec())),
}
}
fn estimated_size(&self) -> usize {
1
}
}
#[cfg(not(feature = "serde"))]
impl Code for Vec<u8> {
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
self.len().encode(writer)?;
writer.write_all(self).map_err(CodeError::from)
}
#[expect(clippy::uninit_vec)]
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
where
Self: Sized,
{
let len = usize::decode(reader)?;
let mut v = Vec::with_capacity(len);
unsafe {
v.set_len(len);
}
reader.read_exact(&mut v).map_err(CodeError::from)?;
Ok(v)
}
fn estimated_size(&self) -> usize {
std::mem::size_of::<usize>() + self.len()
}
}
#[cfg(not(feature = "serde"))]
impl Code for String {
fn encode(&self, writer: &mut impl std::io::Write) -> std::result::Result<(), CodeError> {
self.len().encode(writer)?;
writer.write_all(self.as_bytes()).map_err(CodeError::from)
}
#[expect(clippy::uninit_vec)]
fn decode(reader: &mut impl std::io::Read) -> std::result::Result<Self, CodeError>
where
Self: Sized,
{
let len = usize::decode(reader)?;
let mut v = Vec::with_capacity(len);
unsafe {
v.set_len(len);
}
reader.read_exact(&mut v).map_err(CodeError::from)?;
String::from_utf8(v).map_err(|e| CodeError::Unrecognized(e.into_bytes()))
}
fn estimated_size(&self) -> usize {
std::mem::size_of::<usize>() + self.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(feature = "serde")]
mod serde {
use super::*;
#[test]
fn test_encode_overflow() {
let mut buf = [0u8; 4];
assert!(matches! {1u64.encode(&mut buf.as_mut()), Err(CodeError::SizeLimit)});
}
}
#[cfg(not(feature = "serde"))]
mod non_serde {
use super::*;
#[test]
fn test_encode_overflow() {
let mut buf = [0u8; 4];
assert!(matches! {1u64.encode(&mut buf.as_mut()), Err(CodeError::SizeLimit)});
}
macro_rules! impl_serde_test_for_numeric_types {
($($t:ty),*) => {
paste::paste! {
$(
#[test]
fn [<test_ $t _serde>]() {
for a in [0 as $t, <$t>::MIN, <$t>::MAX] {
let mut buf = vec![0xffu8; a.estimated_size()];
a.encode(&mut buf.as_mut_slice()).unwrap();
let b = <$t>::decode(&mut buf.as_slice()).unwrap();
assert_eq!(a, b);
}
}
)*
}
};
}
for_all_numeric_types! { impl_serde_test_for_numeric_types }
#[test]
fn test_bool_serde() {
let a = true;
let mut buf = vec![0xffu8; a.estimated_size()];
a.encode(&mut buf.as_mut_slice()).unwrap();
let b = bool::decode(&mut buf.as_slice()).unwrap();
assert_eq!(a, b);
}
#[test]
fn test_vec_u8_serde() {
let mut a = vec![0u8; 42];
rand::fill(&mut a[..]);
let mut buf = vec![0xffu8; a.estimated_size()];
a.encode(&mut buf.as_mut_slice()).unwrap();
let b = Vec::<u8>::decode(&mut buf.as_slice()).unwrap();
assert_eq!(a, b);
}
#[test]
fn test_string_serde() {
let a = "hello world".to_string();
let mut buf = vec![0xffu8; a.estimated_size()];
a.encode(&mut buf.as_mut_slice()).unwrap();
let b = String::decode(&mut buf.as_slice()).unwrap();
assert_eq!(a, b);
}
}
}