use std::{
fmt::{self, Debug, Formatter},
io::Read,
};
#[derive(Debug, Copy, Clone, PartialEq, Eq, thiserror::Error)]
#[error("Expected at least {expected} bytes, but only {actual} bytes were found")]
pub(crate) struct InvalidSize {
pub(crate) expected: usize,
pub(crate) actual: usize,
}
#[derive(Clone, PartialEq, Eq)]
pub(crate) struct Scanner<'buf> {
rest: &'buf [u8],
current_position: usize,
}
impl<'buf> Scanner<'buf> {
pub(crate) fn new(bytes: &'buf [u8]) -> Self {
Scanner {
rest: bytes,
current_position: 0,
}
}
pub(crate) fn with_current_position(self, current_position: usize) -> Self {
Scanner {
current_position,
..self
}
}
pub(crate) fn current_position(&self) -> usize {
self.current_position
}
pub(crate) fn rest(&self) -> &'buf [u8] {
self.rest
}
pub(crate) fn is_empty(&self) -> bool {
self.rest().is_empty()
}
pub(crate) fn take(&mut self, len: usize) -> Result<&'buf [u8], InvalidSize> {
if self.rest.len() < len {
Err(InvalidSize {
expected: self.current_position + len,
actual: self.current_position + self.rest.len(),
})
} else {
let (bytes, rest) = self.rest.split_at(len);
self.rest = rest;
self.current_position += len;
Ok(bytes)
}
}
pub(crate) fn split_off(&mut self, len: usize) -> Result<Self, InvalidSize> {
let current_position = self.current_position();
if len > self.rest().len() {
return Err(InvalidSize {
expected: current_position + len,
actual: current_position + self.rest().len(),
});
}
let (head, tail) = self.rest().split_at(len);
*self = Scanner {
rest: tail,
current_position: current_position + len,
};
Ok(Scanner {
rest: head,
current_position,
})
}
pub(crate) fn truncated(&self, len: usize) -> Result<Self, InvalidSize> {
self.clone().split_off(len)
}
pub(crate) fn read<const LEN: usize>(&mut self) -> Result<[u8; LEN], InvalidSize>
where
[u8; LEN]: Copy,
{
self.read_ref().copied()
}
pub(crate) fn read_usize(&mut self) -> Result<usize, InvalidSize> {
let bytes = self.read()?;
Ok(u64::from_le_bytes(bytes).try_into().unwrap())
}
pub(crate) fn read_ref<const LEN: usize>(&mut self) -> Result<&'buf [u8; LEN], InvalidSize> {
self.take(LEN)
.map(|bytes| bytes.try_into().expect("Already checked"))
}
}
impl Read for Scanner<'_> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let rest = self.rest();
let bytes_read = std::cmp::min(rest.len(), buf.len());
let buffer = self.take(bytes_read).expect("unreachable");
buf.copy_from_slice(buffer);
Ok(bytes_read)
}
}
impl Debug for Scanner<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let Scanner {
rest,
current_position,
} = self;
f.debug_struct("Scanner")
.field(
"rest",
&TruncatedBuffer {
buffer: rest,
length: 32,
},
)
.field("current_position", current_position)
.finish()
}
}
struct TruncatedBuffer<'a> {
buffer: &'a [u8],
length: usize,
}
impl Debug for TruncatedBuffer<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
let TruncatedBuffer { buffer, length } = *self;
match buffer.get(..length) {
Some(truncated) => write!(f, "{truncated:?}..."),
None => write!(f, "{buffer:?}"),
}
}
}