use std::fmt;
#[derive(Debug)]
pub enum CodecError {
Encode(String),
Decode(String),
}
impl CodecError {
pub fn encode(e: impl fmt::Display) -> Self {
CodecError::Encode(e.to_string())
}
pub fn decode(e: impl fmt::Display) -> Self {
CodecError::Decode(e.to_string())
}
}
impl fmt::Display for CodecError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
CodecError::Encode(msg) => write!(f, "codec encode: {msg}"),
CodecError::Decode(msg) => write!(f, "codec decode: {msg}"),
}
}
}
impl std::error::Error for CodecError {}
pub trait Codec: Send + Sized {
type Encoded: AsRef<[u8]>;
fn encode(&self) -> Result<Self::Encoded, CodecError>;
fn decode(bytes: &[u8]) -> Result<Self, CodecError>;
}
pub trait ZeroCopyCodec: Codec {
type Archived<'a>: 'a
where
Self: 'a;
fn access<'a>(bytes: &'a [u8]) -> Result<Self::Archived<'a>, CodecError>;
}
#[cfg(feature = "rkyv")]
pub mod rkyv_support {
pub use rkyv;
}
#[cfg(feature = "flatbuffers")]
pub mod flatbuf_support {
pub use flatbuffers;
pub fn root<'a, T: flatbuffers::Follow<'a> + flatbuffers::Verifiable + 'a>(
bytes: &'a [u8],
) -> Result<T::Inner, super::CodecError> {
flatbuffers::root::<T>(bytes)
.map_err(|e| super::CodecError::Decode(format!("flatbuffers: {e}")))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug, PartialEq)]
struct RawMessage(Vec<u8>);
impl Codec for RawMessage {
type Encoded = Vec<u8>;
fn encode(&self) -> Result<Vec<u8>, CodecError> {
Ok(self.0.clone())
}
fn decode(bytes: &[u8]) -> Result<Self, CodecError> {
Ok(RawMessage(bytes.to_vec()))
}
}
#[test]
fn test_raw_codec_roundtrip() {
let msg = RawMessage(vec![1, 2, 3, 4, 5]);
let bytes = msg.encode().unwrap();
let decoded = RawMessage::decode(bytes.as_ref()).unwrap();
assert_eq!(msg, decoded);
}
#[test]
fn test_codec_error_display() {
let err = CodecError::Encode("test error".to_string());
assert_eq!(format!("{err}"), "codec encode: test error");
let err = CodecError::Decode("bad data".to_string());
assert_eq!(format!("{err}"), "codec decode: bad data");
}
#[test]
fn test_codec_error_is_error() {
let err: Box<dyn std::error::Error> = Box::new(CodecError::Encode("test".to_string()));
assert!(err.to_string().contains("test"));
}
#[derive(Debug, PartialEq, Clone)]
struct SimplePair {
a: u64,
b: u64,
}
#[repr(C)]
struct ArchivedSimplePair {
a: u64,
b: u64,
}
impl Codec for SimplePair {
type Encoded = Vec<u8>;
fn encode(&self) -> Result<Vec<u8>, CodecError> {
let mut buf = Vec::with_capacity(16);
buf.extend_from_slice(&self.a.to_le_bytes());
buf.extend_from_slice(&self.b.to_le_bytes());
Ok(buf)
}
fn decode(bytes: &[u8]) -> Result<Self, CodecError> {
if bytes.len() < 16 {
return Err(CodecError::decode("too short"));
}
Ok(SimplePair {
a: u64::from_le_bytes(bytes[0..8].try_into().unwrap()),
b: u64::from_le_bytes(bytes[8..16].try_into().unwrap()),
})
}
}
impl ZeroCopyCodec for SimplePair {
type Archived<'a> = &'a ArchivedSimplePair;
fn access<'a>(bytes: &'a [u8]) -> Result<Self::Archived<'a>, CodecError> {
if bytes.len() < 16 {
return Err(CodecError::decode("too short"));
}
let ptr = bytes.as_ptr() as *const ArchivedSimplePair;
if (ptr as usize) % std::mem::align_of::<ArchivedSimplePair>() != 0 {
return Err(CodecError::decode("unaligned"));
}
Ok(unsafe { &*ptr })
}
}
#[test]
fn zc01_zero_copy_codec_trait_compiles() {
fn assert_zero_copy<T: ZeroCopyCodec>() {}
assert_zero_copy::<SimplePair>();
}
#[test]
fn zc02_access_returns_valid_reference() {
let pair = SimplePair { a: 42, b: 99 };
let encoded = pair.encode().unwrap();
let archived = SimplePair::access(&encoded).unwrap();
assert_eq!(archived.a, 42);
assert_eq!(archived.b, 99);
}
#[test]
fn zc04_access_pointer_is_inside_input_buffer() {
let pair = SimplePair { a: 1, b: 2 };
let encoded = pair.encode().unwrap();
let archived = SimplePair::access(&encoded).unwrap();
let archived_ptr = archived as *const ArchivedSimplePair as usize;
let buf_start = encoded.as_ptr() as usize;
let buf_end = buf_start + encoded.len();
assert!(archived_ptr >= buf_start && archived_ptr < buf_end,
"access() pointer {archived_ptr:#x} is NOT inside buffer [{buf_start:#x}, {buf_end:#x})");
}
#[test]
fn zc05_decode_pointer_is_not_inside_input_buffer() {
let pair = SimplePair { a: 1, b: 2 };
let encoded = pair.encode().unwrap();
let decoded = SimplePair::decode(&encoded).unwrap();
let decoded_ptr = &decoded as *const SimplePair as usize;
let buf_start = encoded.as_ptr() as usize;
let buf_end = buf_start + encoded.len();
assert!(
decoded_ptr < buf_start || decoded_ptr >= buf_end,
"decode() pointer {decoded_ptr:#x} IS inside buffer — should be a copy!"
);
}
}