use crate::cacheable::Cacheable;
use crate::error::WireFormatError;
use serde::{Serialize, de::DeserializeOwned};
pub const WIRE_FORMAT_MAJOR: u16 = 1;
const MAGIC: &[u8; 8] = b"SASSI\0W\0";
pub(crate) const KIND_VALUE: u8 = 0x01;
pub(crate) const KIND_FILE_ENTRY: u8 = 0x02;
pub(crate) const KIND_PUNNU_ENTRIES: u8 = 0x03;
pub(crate) const KIND_PUNNU_ENTRIES_WITH_HINTS: u8 = 0x04;
const HEADER_FIXED_LEN: usize = 14;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum WireKind {
Value,
FileEntry,
PunnuEntries,
}
impl WireKind {
pub(crate) fn as_u8(self) -> u8 {
match self {
Self::Value => KIND_VALUE,
Self::FileEntry => KIND_FILE_ENTRY,
Self::PunnuEntries => KIND_PUNNU_ENTRIES,
}
}
}
pub(crate) fn encode_header<T: Cacheable>(
kind: WireKind,
out: &mut Vec<u8>,
) -> Result<(), WireFormatError> {
let type_name = T::cache_type_name().as_bytes();
let len: u16 = type_name.len().try_into().map_err(|_| {
WireFormatError::MalformedHeader("cache type name exceeds u16 length".into())
})?;
out.extend_from_slice(MAGIC);
out.extend_from_slice(&WIRE_FORMAT_MAJOR.to_le_bytes());
out.push(kind.as_u8());
out.push(0);
out.extend_from_slice(&len.to_le_bytes());
out.extend_from_slice(type_name);
Ok(())
}
pub(crate) fn decode_header<T: Cacheable>(
bytes: &[u8],
expected: WireKind,
) -> Result<&[u8], WireFormatError> {
if bytes.first() == Some(&b'{') {
return Err(WireFormatError::VersionMismatch {
got: 0,
expected: WIRE_FORMAT_MAJOR,
});
}
if bytes.len() < HEADER_FIXED_LEN {
return Err(WireFormatError::MalformedHeader("header too short".into()));
}
if &bytes[..8] != MAGIC {
return Err(WireFormatError::InvalidMagic);
}
let major = u16::from_le_bytes([bytes[8], bytes[9]]);
if major != WIRE_FORMAT_MAJOR {
return Err(WireFormatError::VersionMismatch {
got: major,
expected: WIRE_FORMAT_MAJOR,
});
}
let kind = bytes[10];
if kind >= KIND_PUNNU_ENTRIES_WITH_HINTS {
return Err(WireFormatError::UnsupportedKind { kind });
}
if kind != expected.as_u8() {
return Err(WireFormatError::KindMismatch {
got: kind,
expected: expected.as_u8(),
});
}
let flags = bytes[11];
if flags != 0 {
return Err(WireFormatError::UnsupportedFlags { flags });
}
let name_len = u16::from_le_bytes([bytes[12], bytes[13]]) as usize;
let name_start = HEADER_FIXED_LEN;
let name_end = name_start + name_len;
if bytes.len() < name_end {
return Err(WireFormatError::MalformedHeader(
"type name extends past input".into(),
));
}
let got = std::str::from_utf8(&bytes[name_start..name_end])
.map_err(|err| WireFormatError::MalformedHeader(err.to_string()))?;
let expected_name = T::cache_type_name();
if got != expected_name {
return Err(WireFormatError::TypeNameMismatch {
got: got.to_owned(),
expected: expected_name,
});
}
Ok(&bytes[name_end..])
}
pub(crate) fn decode_postcard_exact<T>(body: &[u8]) -> Result<T, WireFormatError>
where
T: DeserializeOwned,
{
let (value, trailing) =
postcard::take_from_bytes(body).map_err(|err| WireFormatError::Codec(err.to_string()))?;
if !trailing.is_empty() {
return Err(WireFormatError::Codec(
"trailing bytes after postcard body".into(),
));
}
Ok(value)
}
pub fn to_vec<T>(payload: &T) -> Result<Vec<u8>, WireFormatError>
where
T: Cacheable + Serialize,
{
let mut out = Vec::new();
encode_header::<T>(WireKind::Value, &mut out)?;
append_postcard(payload, &mut out)?;
Ok(out)
}
pub(crate) fn append_postcard<T>(payload: &T, out: &mut Vec<u8>) -> Result<(), WireFormatError>
where
T: Serialize + ?Sized,
{
let body =
postcard::to_allocvec(payload).map_err(|err| WireFormatError::Codec(err.to_string()))?;
out.extend_from_slice(&body);
Ok(())
}
pub fn from_slice<T>(bytes: &[u8]) -> Result<T, WireFormatError>
where
T: Cacheable + DeserializeOwned,
{
let body = decode_header::<T>(bytes, WireKind::Value)?;
decode_postcard_exact(body)
}
pub(crate) fn encode_punnu_entries<T>(entries: &[&T]) -> Result<Vec<u8>, WireFormatError>
where
T: Cacheable + Serialize,
{
let mut out = Vec::new();
encode_header::<T>(WireKind::PunnuEntries, &mut out)?;
let count = u32::try_from(entries.len())
.map_err(|_| WireFormatError::Codec("too many punnu entries".into()))?;
out.extend_from_slice(&count.to_le_bytes());
for entry in entries {
append_postcard(*entry, &mut out)?;
}
Ok(out)
}
pub(crate) fn decode_punnu_entries_len<T>(bytes: &[u8]) -> Result<(usize, &[u8]), WireFormatError>
where
T: Cacheable,
{
let body = decode_header::<T>(bytes, WireKind::PunnuEntries)?;
if body.len() < 4 {
return Err(WireFormatError::MalformedHeader(
"punnu entries body missing count".into(),
));
}
let count = u32::from_le_bytes(body[..4].try_into().expect("slice length checked")) as usize;
Ok((count, &body[4..]))
}
pub(crate) fn decode_punnu_entries_body<T>(
mut body: &[u8],
count: usize,
) -> Result<Vec<T>, WireFormatError>
where
T: Cacheable + DeserializeOwned,
{
let mut entries: Vec<T> = Vec::new();
entries.try_reserve_exact(count).map_err(|err| {
WireFormatError::Codec(format!(
"could not reserve capacity for {count} punnu entries: {err}"
))
})?;
for _ in 0..count {
let (entry, rest) = postcard::take_from_bytes(body)
.map_err(|err| WireFormatError::Codec(err.to_string()))?;
entries.push(entry);
body = rest;
}
if !body.is_empty() {
return Err(WireFormatError::Codec(
"trailing bytes after punnu entries body".into(),
));
}
Ok(entries)
}