use std::error::Error;
use std::fmt;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub enum EnvelopeError {
Truncated {
needed: usize,
got: usize,
},
BadMagic {
expected: [u8; 4],
found: [u8; 4],
},
VersionMismatch {
expected: u32,
found: u32,
},
SectionTooLarge {
len: usize,
max: usize,
},
}
impl fmt::Display for EnvelopeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Truncated { needed, got } => write!(
f,
"wire envelope truncated: needed {needed} bytes, got {got}. \
Fix: regenerate the cache."
),
Self::BadMagic { expected, found } => write!(
f,
"wire envelope magic mismatch: expected {expected:?}, found {found:?}. \
Fix: this blob was not produced by the matching consumer."
),
Self::VersionMismatch { expected, found } => write!(
f,
"wire envelope version {found} does not match runtime {expected}. \
Fix: discard the cache and rebuild from source."
),
Self::SectionTooLarge { len, max } => write!(
f,
"wire envelope section length {len} exceeds maximum {max}. \
Fix: split the payload into smaller sections."
),
}
}
}
impl Error for EnvelopeError {}
#[derive(Debug)]
pub struct WireWriter {
out: Vec<u8>,
}
impl WireWriter {
#[must_use]
pub fn new(magic: &[u8; 4], version: u32) -> Self {
let mut out = Vec::with_capacity(8);
out.extend_from_slice(magic);
out.extend_from_slice(&version.to_le_bytes());
Self { out }
}
pub fn write_section(&mut self, bytes: &[u8]) -> Result<(), EnvelopeError> {
let len = u32::try_from(bytes.len()).map_err(|_| EnvelopeError::SectionTooLarge {
len: bytes.len(),
max: u32::MAX as usize,
})?;
self.out.extend_from_slice(&len.to_le_bytes());
self.out.extend_from_slice(bytes);
Ok(())
}
pub fn write_words(&mut self, words: &[u32]) -> Result<(), EnvelopeError> {
let len = u32::try_from(words.len()).map_err(|_| EnvelopeError::SectionTooLarge {
len: words.len(),
max: u32::MAX as usize,
})?;
self.out.extend_from_slice(&len.to_le_bytes());
for w in words {
self.out.extend_from_slice(&w.to_le_bytes());
}
Ok(())
}
pub fn write_u32(&mut self, value: u32) {
self.out.extend_from_slice(&value.to_le_bytes());
}
#[must_use]
pub fn into_bytes(self) -> Vec<u8> {
self.out
}
}
#[derive(Debug)]
pub struct WireReader<'a> {
src: &'a [u8],
cursor: usize,
}
impl<'a> WireReader<'a> {
pub fn new(
bytes: &'a [u8],
expected_magic: &[u8; 4],
expected_version: u32,
) -> Result<Self, EnvelopeError> {
if bytes.len() < 8 {
return Err(EnvelopeError::Truncated {
needed: 8,
got: bytes.len(),
});
}
let mut found_magic = [0u8; 4];
found_magic.copy_from_slice(&bytes[0..4]);
if &found_magic != expected_magic {
return Err(EnvelopeError::BadMagic {
expected: *expected_magic,
found: found_magic,
});
}
let version = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
if version != expected_version {
return Err(EnvelopeError::VersionMismatch {
expected: expected_version,
found: version,
});
}
Ok(Self {
src: bytes,
cursor: 8,
})
}
pub fn read_section(&mut self) -> Result<&'a [u8], EnvelopeError> {
let n = self.read_u32()? as usize;
if self.src.len() < self.cursor + n {
return Err(EnvelopeError::Truncated {
needed: self.cursor + n,
got: self.src.len(),
});
}
let slice = &self.src[self.cursor..self.cursor + n];
self.cursor += n;
Ok(slice)
}
pub fn read_words(&mut self) -> Result<Vec<u32>, EnvelopeError> {
let n_words = self.read_u32()? as usize;
let bytes_needed = n_words * 4;
if self.src.len() < self.cursor + bytes_needed {
return Err(EnvelopeError::Truncated {
needed: self.cursor + bytes_needed,
got: self.src.len(),
});
}
let mut v = Vec::with_capacity(n_words);
for _ in 0..n_words {
let w = u32::from_le_bytes([
self.src[self.cursor],
self.src[self.cursor + 1],
self.src[self.cursor + 2],
self.src[self.cursor + 3],
]);
v.push(w);
self.cursor += 4;
}
Ok(v)
}
pub fn read_u32(&mut self) -> Result<u32, EnvelopeError> {
if self.src.len() < self.cursor + 4 {
return Err(EnvelopeError::Truncated {
needed: self.cursor + 4,
got: self.src.len(),
});
}
let n = u32::from_le_bytes([
self.src[self.cursor],
self.src[self.cursor + 1],
self.src[self.cursor + 2],
self.src[self.cursor + 3],
]);
self.cursor += 4;
Ok(n)
}
}
pub mod test_helpers {
use super::{EnvelopeError, WireWriter};
pub trait WireRoundTrip: Sized {
const MAGIC: [u8; 4];
const VERSION: u32;
type EncodeError: std::fmt::Debug;
type DecodeError: std::fmt::Debug;
fn to_bytes(&self) -> Result<Vec<u8>, Self::EncodeError>;
fn from_bytes(bytes: &[u8]) -> Result<Self, Self::DecodeError>;
fn structurally_eq(&self, other: &Self) -> bool;
}
pub fn assert_envelope_roundtrip<T>(sample: &T)
where
T: WireRoundTrip + std::fmt::Debug,
{
let bytes = sample
.to_bytes()
.expect("Fix: encode sample; restore this invariant before continuing.");
assert!(
bytes.len() >= 8,
"wire blob must include at least the 8-byte header"
);
assert_eq!(
&bytes[0..4],
T::MAGIC.as_slice(),
"magic mismatch in encoded blob"
);
let version_field = u32::from_le_bytes(bytes[4..8].try_into().unwrap());
let expected_version = T::VERSION;
assert!(
version_field == expected_version,
"version mismatch in encoded blob: got {version_field}, expected {expected_version}"
);
let back = T::from_bytes(&bytes)
.expect("Fix: decode round trip; restore this invariant before continuing.");
assert!(
sample.structurally_eq(&back),
"round-tripped value diverges from original"
);
let mut mutated = bytes.clone();
mutated[0] ^= 0xFF;
assert!(
T::from_bytes(&mutated).is_err(),
"mutated magic must surface as a typed error"
);
let mut mutated = bytes.clone();
let bumped = T::VERSION.wrapping_add(1);
mutated[4..8].copy_from_slice(&bumped.to_le_bytes());
assert!(
T::from_bytes(&mutated).is_err(),
"mutated version must surface as a typed error"
);
if bytes.len() > 8 {
let truncated = &bytes[..bytes.len() - 1];
assert!(
T::from_bytes(truncated).is_err(),
"truncated trailing byte must surface as a typed error"
);
}
}
#[must_use]
pub fn header_only(magic: &[u8; 4], version: u32) -> Vec<u8> {
WireWriter::new(magic, version).into_bytes()
}
pub fn assert_envelope_error_kind(err: &EnvelopeError, kind: ExpectedEnvelopeError) {
let matches = matches!(
(err, kind),
(
EnvelopeError::Truncated { .. },
ExpectedEnvelopeError::Truncated
) | (
EnvelopeError::BadMagic { .. },
ExpectedEnvelopeError::BadMagic
) | (
EnvelopeError::VersionMismatch { .. },
ExpectedEnvelopeError::VersionMismatch
) | (
EnvelopeError::SectionTooLarge { .. },
ExpectedEnvelopeError::SectionTooLarge
)
);
assert!(
matches,
"expected envelope error kind {kind:?}, got {err:?}"
);
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ExpectedEnvelopeError {
Truncated,
BadMagic,
VersionMismatch,
SectionTooLarge,
}
}