use core::fmt;
pub trait Sink {
fn put(&mut self, bytes: &[u8]);
}
impl Sink for blake3::Hasher {
fn put(&mut self, bytes: &[u8]) {
self.update(bytes);
}
}
impl Sink for Vec<u8> {
fn put(&mut self, bytes: &[u8]) {
self.extend_from_slice(bytes);
}
}
pub fn write_u8<S: Sink>(out: &mut S, n: u8) {
out.put(&[n]);
}
pub fn write_u16<S: Sink>(out: &mut S, n: u16) {
out.put(&n.to_le_bytes());
}
pub fn write_u32<S: Sink>(out: &mut S, n: u32) {
out.put(&n.to_le_bytes());
}
pub fn write_u64<S: Sink>(out: &mut S, n: u64) {
out.put(&n.to_le_bytes());
}
pub fn write_u128<S: Sink>(out: &mut S, n: u128) {
out.put(&n.to_le_bytes());
}
pub fn write_i8<S: Sink>(out: &mut S, n: i8) {
out.put(&n.to_le_bytes());
}
pub fn write_i16<S: Sink>(out: &mut S, n: i16) {
out.put(&n.to_le_bytes());
}
pub fn write_i32<S: Sink>(out: &mut S, n: i32) {
out.put(&n.to_le_bytes());
}
pub fn write_i64<S: Sink>(out: &mut S, n: i64) {
out.put(&n.to_le_bytes());
}
pub fn write_i128<S: Sink>(out: &mut S, n: i128) {
out.put(&n.to_le_bytes());
}
pub fn write_f32<S: Sink>(out: &mut S, n: f32) {
out.put(&n.to_le_bytes());
}
pub fn write_f64<S: Sink>(out: &mut S, n: f64) {
out.put(&n.to_le_bytes());
}
pub fn write_bool<S: Sink>(out: &mut S, b: bool) {
write_u8(out, u8::from(b));
}
pub fn write_str<S: Sink>(out: &mut S, s: &str) {
write_u32(out, s.len() as u32);
out.put(s.as_bytes());
}
pub fn write_bytes<S: Sink>(out: &mut S, b: &[u8]) {
write_u32(out, b.len() as u32);
out.put(b);
}
pub fn pad_to(out: &mut Vec<u8>, n: usize) {
while !out.len().is_multiple_of(n) {
out.push(0);
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
#[non_exhaustive]
pub enum DecodeError {
UnexpectedEof { needed: usize, remaining: usize },
UnknownTag(u8),
InvalidBool(u8),
InvalidUtf8,
InvalidChar(u32),
LengthTooLarge { count: u64, remaining: usize },
DepthExceeded,
DuplicateKey,
DuplicateElement,
UnexpectedTag { expected: &'static str, got: u8 },
UnknownVariant(String),
WriterOnlyVariant(u32),
BadVariantIndex(u32),
Malformed(&'static str),
TrailingBytes(usize),
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DecodeError::UnexpectedEof { needed, remaining } => {
write!(
f,
"unexpected end of input: need {needed}, have {remaining}"
)
}
DecodeError::UnknownTag(t) => write!(f, "unknown tag {t:#04x}"),
DecodeError::InvalidBool(b) => write!(f, "invalid bool byte {b:#04x}"),
DecodeError::InvalidUtf8 => write!(f, "invalid UTF-8 in string"),
DecodeError::InvalidChar(c) => write!(f, "invalid Unicode scalar {c:#x}"),
DecodeError::LengthTooLarge { count, remaining } => {
write!(f, "length {count} exceeds {remaining} bytes remaining")
}
DecodeError::DepthExceeded => write!(f, "maximum nesting depth exceeded"),
DecodeError::DuplicateKey => write!(f, "duplicate map key"),
DecodeError::DuplicateElement => write!(f, "duplicate set element"),
DecodeError::UnexpectedTag { expected, got } => {
write!(f, "expected {expected}, got tag {got:#04x}")
}
DecodeError::UnknownVariant(name) => write!(f, "unknown variant {name:?}"),
DecodeError::WriterOnlyVariant(i) => {
write!(
f,
"received enum variant {i} the reader schema does not have"
)
}
DecodeError::BadVariantIndex(i) => write!(f, "enum variant index {i} out of range"),
DecodeError::Malformed(what) => write!(f, "malformed value: {what}"),
DecodeError::TrailingBytes(n) => write!(f, "{n} trailing bytes after value"),
}
}
}
impl std::error::Error for DecodeError {}
pub struct Reader<'a> {
buf: &'a [u8],
pos: usize,
}
impl<'a> Reader<'a> {
#[must_use]
pub fn new(buf: &'a [u8]) -> Self {
Reader { buf, pos: 0 }
}
#[must_use]
pub fn remaining(&self) -> usize {
self.buf.len() - self.pos
}
#[must_use]
pub fn position(&self) -> usize {
self.pos
}
fn take(&mut self, n: usize) -> Result<&'a [u8], DecodeError> {
if self.remaining() < n {
return Err(DecodeError::UnexpectedEof {
needed: n,
remaining: self.remaining(),
});
}
let slice = &self.buf[self.pos..self.pos + n];
self.pos += n;
Ok(slice)
}
pub fn read_slice(&mut self, n: usize) -> Result<&'a [u8], DecodeError> {
self.take(n)
}
pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
Ok(self.take(1)?[0])
}
pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
Ok(u16::from_le_bytes(self.take(2)?.try_into().unwrap()))
}
pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
Ok(u32::from_le_bytes(self.take(4)?.try_into().unwrap()))
}
pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
Ok(u64::from_le_bytes(self.take(8)?.try_into().unwrap()))
}
pub fn read_u128(&mut self) -> Result<u128, DecodeError> {
Ok(u128::from_le_bytes(self.take(16)?.try_into().unwrap()))
}
pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
Ok(i8::from_le_bytes(self.take(1)?.try_into().unwrap()))
}
pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
Ok(i16::from_le_bytes(self.take(2)?.try_into().unwrap()))
}
pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
Ok(i32::from_le_bytes(self.take(4)?.try_into().unwrap()))
}
pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
Ok(i64::from_le_bytes(self.take(8)?.try_into().unwrap()))
}
pub fn read_i128(&mut self) -> Result<i128, DecodeError> {
Ok(i128::from_le_bytes(self.take(16)?.try_into().unwrap()))
}
pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
Ok(f32::from_le_bytes(self.take(4)?.try_into().unwrap()))
}
pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
Ok(f64::from_le_bytes(self.take(8)?.try_into().unwrap()))
}
pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
match self.read_u8()? {
0 => Ok(false),
1 => Ok(true),
b => Err(DecodeError::InvalidBool(b)),
}
}
pub fn read_char(&mut self) -> Result<char, DecodeError> {
let n = self.read_u32()?;
char::from_u32(n).ok_or(DecodeError::InvalidChar(n))
}
pub fn read_str(&mut self) -> Result<&'a str, DecodeError> {
let len = self.read_len(1)?;
let bytes = self.take(len)?;
core::str::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
}
pub fn read_bytes(&mut self) -> Result<&'a [u8], DecodeError> {
let len = self.read_len(1)?;
self.take(len)
}
pub fn read_len(&mut self, min_elem_size: usize) -> Result<usize, DecodeError> {
let count = self.read_u32()? as usize;
let remaining = self.remaining();
let max = remaining
.checked_div(min_elem_size)
.unwrap_or(ZST_COUNT_CAP);
if count > max {
return Err(DecodeError::LengthTooLarge {
count: count as u64,
remaining,
});
}
Ok(count)
}
}
pub const ZST_COUNT_CAP: usize = 1 << 24;
pub fn skip_pad(r: &mut Reader, n: usize) -> Result<(), DecodeError> {
while !r.position().is_multiple_of(n) {
r.read_u8()?;
}
Ok(())
}