#[cfg(feature = "std")]
use alloc::string::ToString;
use alloc::{format, string::String, vec::Vec};
#[cfg(feature = "std")]
use core::cell::{Ref, RefCell};
#[cfg(feature = "std")]
use std::io::BufRead;
use crate::{Deserializable, DeserializationError};
pub trait ByteReader {
fn read_u8(&mut self) -> Result<u8, DeserializationError>;
fn peek_u8(&self) -> Result<u8, DeserializationError>;
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
fn has_more_bytes(&self) -> bool;
fn max_alloc(&self, _element_size: usize) -> usize {
usize::MAX
}
fn read_bool(&mut self) -> Result<bool, DeserializationError> {
let byte = self.read_u8()?;
match byte {
0 => Ok(false),
1 => Ok(true),
_ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
}
}
fn read_u16(&mut self) -> Result<u16, DeserializationError> {
let bytes = self.read_array::<2>()?;
Ok(u16::from_le_bytes(bytes))
}
fn read_u32(&mut self) -> Result<u32, DeserializationError> {
let bytes = self.read_array::<4>()?;
Ok(u32::from_le_bytes(bytes))
}
fn read_u64(&mut self) -> Result<u64, DeserializationError> {
let bytes = self.read_array::<8>()?;
Ok(u64::from_le_bytes(bytes))
}
fn read_u128(&mut self) -> Result<u128, DeserializationError> {
let bytes = self.read_array::<16>()?;
Ok(u128::from_le_bytes(bytes))
}
fn read_usize(&mut self) -> Result<usize, DeserializationError> {
let first_byte = self.peek_u8()?;
let length = first_byte.trailing_zeros() as usize + 1;
let result = if length == 9 {
self.read_u8()?;
let value = self.read_array::<8>()?;
u64::from_le_bytes(value)
} else {
let mut encoded = [0u8; 8];
let value = self.read_slice(length)?;
encoded[..length].copy_from_slice(value);
u64::from_le_bytes(encoded) >> length
};
if result > usize::MAX as u64 {
return Err(DeserializationError::InvalidValue(format!(
"Encoded value must be less than {}, but {} was provided",
usize::MAX,
result
)));
}
Ok(result as usize)
}
fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
let data = self.read_slice(len)?;
Ok(data.to_vec())
}
fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
let data = self.read_vec(num_bytes)?;
String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
}
fn read<D>(&mut self) -> Result<D, DeserializationError>
where
Self: Sized,
D: Deserializable,
{
D::read_from(self)
}
fn read_many_iter<D>(
&mut self,
num_elements: usize,
) -> Result<ReadManyIter<'_, Self, D>, DeserializationError>
where
Self: Sized,
D: Deserializable,
{
let max_elements = self.max_alloc(D::min_serialized_size());
if num_elements > max_elements {
return Err(DeserializationError::InvalidValue(format!(
"requested {num_elements} elements but reader can provide at most {max_elements}"
)));
}
Ok(ReadManyIter {
reader: self,
remaining: num_elements,
_item: core::marker::PhantomData,
})
}
}
pub struct ReadManyIter<'reader, R: ByteReader, D: Deserializable> {
reader: &'reader mut R,
remaining: usize,
_item: core::marker::PhantomData<D>,
}
impl<'reader, R: ByteReader, D: Deserializable> Iterator for ReadManyIter<'reader, R, D> {
type Item = Result<D, DeserializationError>;
fn next(&mut self) -> Option<Self::Item> {
if self.remaining > 0 {
self.remaining -= 1;
Some(D::read_from(self.reader))
} else {
None
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
(self.remaining, Some(self.remaining))
}
}
impl<'reader, R: ByteReader, D: Deserializable> ExactSizeIterator for ReadManyIter<'reader, R, D> {}
#[cfg(feature = "std")]
pub struct ReadAdapter<'a> {
reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
buf: Vec<u8>,
pos: usize,
guaranteed_eof: bool,
}
#[cfg(feature = "std")]
impl<'a> ReadAdapter<'a> {
pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
Self {
reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
buf: Default::default(),
pos: 0,
guaranteed_eof: false,
}
}
#[inline(always)]
fn buffer(&self) -> &[u8] {
self.buf.get(self.pos..).unwrap_or(&[])
}
#[inline(always)]
fn non_empty_buffer(&self) -> Option<&[u8]> {
self.buf.get(self.pos..).filter(|b| !b.is_empty())
}
#[inline(always)]
fn reader_buffer(&self) -> Ref<'_, [u8]> {
Ref::map(self.reader.borrow(), |r| r.buffer())
}
fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
use std::io::ErrorKind;
let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
e => DeserializationError::UnknownError(e.to_string()),
})?;
if buf.is_empty() {
self.guaranteed_eof = true;
Err(DeserializationError::UnexpectedEOF)
} else {
Ok(buf)
}
}
fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
use std::io::ErrorKind;
let mut reader = self.reader.borrow_mut();
let buf = reader.fill_buf().map_err(|e| match e.kind() {
ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
e => DeserializationError::UnknownError(e.to_string()),
})?;
if buf.is_empty() {
Err(DeserializationError::UnexpectedEOF)
} else {
drop(reader);
Ok(self.reader_buffer())
}
}
#[inline]
fn has_remaining_capacity(&self, n: usize) -> bool {
let remaining = self.buf.capacity() - self.buffer().len();
remaining >= n
}
fn pop(&mut self) -> Result<u8, DeserializationError> {
if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
self.pos += 1;
return Ok(byte);
}
let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
if result.is_ok() {
self.reader.get_mut().consume(1);
} else {
self.guaranteed_eof = true;
}
result
}
fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
let buf = self.buffer();
let mut output = [0; N];
match buf.len() {
0 => {
let buf = self.non_empty_reader_buffer_mut()?;
if buf.len() < N {
return Err(DeserializationError::UnexpectedEOF);
}
unsafe {
core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
}
self.reader.get_mut().consume(N);
},
n if n >= N => {
unsafe {
core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
}
self.pos += N;
},
n => {
self.non_empty_reader_buffer_mut()?;
let reader_buf = self.reader_buffer();
match reader_buf.len() {
#[cfg(debug_assertions)]
0 => unreachable!("expected reader buffer to be non-empty to reach here"),
#[cfg(not(debug_assertions))]
0 => unsafe { core::hint::unreachable_unchecked() },
m if m + n >= N => {
let needed = N - n;
let dst = output.as_mut_ptr();
unsafe {
core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
drop(reader_buf);
}
self.pos += n;
self.reader.get_mut().consume(needed);
},
m => {
let needed = N - (m + n);
drop(reader_buf);
self.buffer_at_least(needed)?;
debug_assert!(
self.buffer().len() >= N,
"expected buffer to be at least {N} bytes after call to buffer_at_least"
);
unsafe {
core::ptr::copy_nonoverlapping(
self.buffer().as_ptr(),
output.as_mut_ptr(),
N,
);
}
self.pos += N;
return Ok(output);
},
}
},
}
if self.buffer().is_empty() && self.pos > 0 {
unsafe {
self.buf.set_len(0);
}
}
Ok(output)
}
fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
loop {
if count == 0 || self.buffer().len() >= count {
break Ok(());
}
self.non_empty_reader_buffer_mut()?;
let reader = self.reader.get_mut();
let buf = reader.buffer();
let consumed = buf.len();
self.buf.extend_from_slice(buf);
reader.consume(consumed);
count = count.saturating_sub(consumed);
}
}
}
#[cfg(feature = "std")]
impl ByteReader for ReadAdapter<'_> {
#[inline(always)]
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.pop()
}
fn peek_u8(&self) -> Result<u8, DeserializationError> {
if let Some(byte) = self.buffer().first() {
return Ok(*byte);
}
self.non_empty_reader_buffer().map(|b| b[0])
}
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
if len == 0 {
return Ok(&[]);
}
let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
if should_optimize_storage {
let buf = self.buffer();
let src = buf.as_ptr();
let count = buf.len();
let dst = self.buf.as_mut_ptr();
unsafe {
core::ptr::copy(src, dst, count);
self.buf.set_len(count);
self.pos = 0;
}
}
self.buffer_at_least(len)?;
let slice = &self.buf[self.pos..(self.pos + len)];
self.pos += len;
Ok(slice)
}
#[inline]
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
if N == 0 {
return Ok([0; N]);
}
self.read_exact()
}
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
let buffer_len = self.buffer().len();
if buffer_len >= num_bytes {
return Ok(());
}
let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
let buffer_len = buffer_len + reader_buffer_len;
if buffer_len >= num_bytes {
return Ok(());
}
if self.guaranteed_eof {
return Err(DeserializationError::UnexpectedEOF);
}
Ok(())
}
#[inline]
fn has_more_bytes(&self) -> bool {
!self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
}
}
#[cfg(feature = "std")]
macro_rules! cursor_remaining_buf {
($cursor:ident) => {{
let buf = $cursor.get_ref().as_ref();
let start = $cursor.position().min(buf.len() as u64) as usize;
&buf[start..]
}};
}
#[cfg(feature = "std")]
impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
let buf = cursor_remaining_buf!(self);
if buf.is_empty() {
Err(DeserializationError::UnexpectedEOF)
} else {
let byte = buf[0];
self.set_position(self.position() + 1);
Ok(byte)
}
}
fn peek_u8(&self) -> Result<u8, DeserializationError> {
cursor_remaining_buf!(self)
.first()
.copied()
.ok_or(DeserializationError::UnexpectedEOF)
}
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
let pos = self.position();
let size = self.get_ref().as_ref().len() as u64;
if size.saturating_sub(pos) < len as u64 {
Err(DeserializationError::UnexpectedEOF)
} else {
self.set_position(pos + len as u64);
let start = pos.min(size) as usize;
Ok(&self.get_ref().as_ref()[start..(start + len)])
}
}
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
self.read_slice(N).map(|bytes| {
let mut result = [0u8; N];
result.copy_from_slice(bytes);
result
})
}
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
if cursor_remaining_buf!(self).len() >= num_bytes {
Ok(())
} else {
Err(DeserializationError::UnexpectedEOF)
}
}
#[inline]
fn has_more_bytes(&self) -> bool {
let pos = self.position();
let size = self.get_ref().as_ref().len() as u64;
pos < size
}
}
pub struct SliceReader<'a> {
source: &'a [u8],
pos: usize,
}
impl<'a> SliceReader<'a> {
pub fn new(source: &'a [u8]) -> Self {
SliceReader { source, pos: 0 }
}
}
impl ByteReader for SliceReader<'_> {
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.check_eor(1)?;
let result = self.source[self.pos];
self.pos += 1;
Ok(result)
}
fn peek_u8(&self) -> Result<u8, DeserializationError> {
self.check_eor(1)?;
Ok(self.source[self.pos])
}
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
self.check_eor(len)?;
let result = &self.source[self.pos..self.pos + len];
self.pos += len;
Ok(result)
}
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
self.check_eor(N)?;
let mut result = [0_u8; N];
result.copy_from_slice(&self.source[self.pos..self.pos + N]);
self.pos += N;
Ok(result)
}
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
self.pos
.checked_add(num_bytes)
.filter(|end| *end <= self.source.len())
.map(|_| ())
.ok_or(DeserializationError::UnexpectedEOF)
}
fn has_more_bytes(&self) -> bool {
self.pos < self.source.len()
}
}
pub struct BudgetedReader<R> {
inner: R,
remaining: usize,
}
impl<R> BudgetedReader<R> {
pub fn new(inner: R, budget: usize) -> Self {
Self { inner, remaining: budget }
}
pub fn remaining(&self) -> usize {
self.remaining
}
fn consume_budget(&mut self, n: usize) -> Result<(), DeserializationError> {
if n > self.remaining {
return Err(DeserializationError::InvalidValue(format!(
"budget exhausted: requested {n} bytes, {} remaining",
self.remaining
)));
}
self.remaining -= n;
Ok(())
}
}
impl<R: ByteReader> ByteReader for BudgetedReader<R> {
fn read_u8(&mut self) -> Result<u8, DeserializationError> {
self.consume_budget(1)?;
self.inner.read_u8()
}
fn peek_u8(&self) -> Result<u8, DeserializationError> {
self.inner.peek_u8()
}
fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
self.consume_budget(len)?;
self.inner.read_slice(len)
}
fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
self.consume_budget(N)?;
self.inner.read_array()
}
fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
if num_bytes > self.remaining {
return Err(DeserializationError::InvalidValue(format!(
"budget exhausted: requested {num_bytes} bytes, {} remaining",
self.remaining
)));
}
self.inner.check_eor(num_bytes)
}
fn has_more_bytes(&self) -> bool {
self.remaining > 0 && self.inner.has_more_bytes()
}
fn max_alloc(&self, element_size: usize) -> usize {
if element_size == 0 {
return usize::MAX; }
self.remaining / element_size
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use core::mem::size_of;
use std::io::Cursor;
use super::*;
use crate::ByteWriter;
#[test]
fn read_adapter_empty() -> Result<(), DeserializationError> {
let mut reader = std::io::empty();
let mut adapter = ReadAdapter::new(&mut reader);
assert!(!adapter.has_more_bytes());
assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
assert_eq!(adapter.read_array(), Ok([]));
assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
Ok(())
}
#[test]
fn read_adapter_passthrough() -> Result<(), DeserializationError> {
let mut reader = std::io::repeat(0b101);
let mut adapter = ReadAdapter::new(&mut reader);
assert!(adapter.has_more_bytes());
assert_eq!(adapter.check_eor(8), Ok(()));
assert_eq!(adapter.peek_u8(), Ok(0b101));
assert_eq!(adapter.read_u8(), Ok(0b101));
assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
assert_eq!(adapter.read_array(), Ok([]));
assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
Ok(())
}
#[test]
fn read_adapter_exact() {
const VALUE: usize = 2048;
let mut reader = Cursor::new(VALUE.to_le_bytes());
let mut adapter = ReadAdapter::new(&mut reader);
assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
assert!(!adapter.has_more_bytes());
assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
}
#[test]
fn read_adapter_roundtrip() {
const VALUE: usize = 2048;
let mut cursor = Cursor::new([0; size_of::<usize>()]);
cursor.write_usize(VALUE);
cursor.set_position(0);
let mut adapter = ReadAdapter::new(&mut cursor);
assert_eq!(adapter.read_usize(), Ok(VALUE));
}
#[test]
fn read_adapter_for_file() {
use std::fs::File;
use crate::ByteWriter;
let path = std::env::temp_dir().join("read_adapter_for_file.bin");
{
let mut buf = Vec::<u8>::with_capacity(256);
buf.write_bytes(b"MAGIC\0");
buf.write_bool(true);
buf.write_u32(0xbeef);
buf.write_usize(0xfeed);
buf.write_u16(0x5);
std::fs::write(&path, &buf).unwrap();
}
let mut file = File::open(&path).unwrap();
let mut reader = ReadAdapter::new(&mut file);
assert_eq!(reader.peek_u8().unwrap(), b'M');
assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
assert!(reader.read_bool().unwrap());
assert_eq!(reader.read_u32().unwrap(), 0xbeef);
assert_eq!(reader.read_usize().unwrap(), 0xfeed);
assert_eq!(reader.read_u16().unwrap(), 0x5);
assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
}
#[test]
fn read_adapter_issue_383() {
const STR_BYTES: &[u8] = b"just a string";
use std::fs::File;
use crate::ByteWriter;
let path = std::env::temp_dir().join("issue_383.bin");
{
let mut buf = vec![0u8; 1024];
unsafe {
buf.set_len(0);
}
buf.write_u128(2 * u64::MAX as u128);
unsafe {
buf.set_len(512);
}
buf.write_bytes(STR_BYTES);
buf.write_u32(0xbeef);
std::fs::write(&path, &buf).unwrap();
}
let mut file = File::open(&path).unwrap();
let mut reader = ReadAdapter::new(&mut file);
assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
assert_eq!(reader.buf.len(), 0);
assert_eq!(reader.pos, 0);
reader.read_slice(496).unwrap();
assert_eq!(reader.buf.len(), 496);
assert_eq!(reader.pos, 496);
assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + size_of::<u32>());
assert_eq!(reader.pos, 509);
assert_eq!(reader.read_u32().unwrap(), 0xbeef);
assert_eq!(reader.pos, 513);
assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
}
#[test]
fn budgeted_reader_basic() {
let data = [1u8, 2, 3, 4, 5, 6, 7, 8];
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 4);
assert_eq!(reader.remaining(), 4);
assert!(reader.has_more_bytes());
assert_eq!(reader.read_u32().unwrap(), 0x04030201);
assert_eq!(reader.remaining(), 0);
assert!(!reader.has_more_bytes());
assert!(reader.read_u8().is_err());
}
#[test]
fn budgeted_reader_peek_does_not_consume() {
let data = [42u8];
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 1);
assert_eq!(reader.peek_u8().unwrap(), 42);
assert_eq!(reader.peek_u8().unwrap(), 42);
assert_eq!(reader.remaining(), 1);
assert_eq!(reader.read_u8().unwrap(), 42);
assert_eq!(reader.remaining(), 0);
}
#[test]
fn budgeted_reader_check_eor_respects_budget() {
let data = [0u8; 100];
let inner = SliceReader::new(&data);
let reader = BudgetedReader::new(inner, 10);
assert!(reader.check_eor(10).is_ok());
assert!(reader.check_eor(11).is_err());
}
#[test]
fn budgeted_reader_read_slice() {
let data = [1u8, 2, 3, 4, 5];
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 3);
assert_eq!(reader.read_slice(3).unwrap(), &[1, 2, 3]);
assert_eq!(reader.remaining(), 0);
assert!(reader.read_slice(1).is_err());
}
#[test]
fn budgeted_reader_read_array() {
let data = [0xaau8, 0xbb, 0xcc, 0xdd];
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 2);
assert_eq!(reader.read_array::<2>().unwrap(), [0xaa, 0xbb]);
assert_eq!(reader.remaining(), 0);
assert!(reader.read_array::<2>().is_err());
}
#[test]
fn budgeted_reader_zero_budget() {
let data = [1u8];
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 0);
assert!(!reader.has_more_bytes());
assert!(reader.read_u8().is_err());
assert_eq!(reader.peek_u8().unwrap(), 1);
}
#[test]
fn budgeted_reader_max_alloc() {
let data = [0u8; 100];
let inner = SliceReader::new(&data);
let reader = BudgetedReader::new(inner, 64);
assert_eq!(reader.max_alloc(8), 8);
assert_eq!(reader.max_alloc(1), 64);
assert_eq!(reader.max_alloc(16), 4);
assert_eq!(reader.max_alloc(0), usize::MAX);
}
#[test]
fn unbounded_reader_max_alloc_returns_max() {
let data = [0u8; 100];
let reader = SliceReader::new(&data);
assert_eq!(reader.max_alloc(1), usize::MAX);
assert_eq!(reader.max_alloc(8), usize::MAX);
}
#[test]
fn slice_reader_rejects_overflowing_read_lengths() {
let data = [1u8];
let mut reader = SliceReader::new(&data);
assert_eq!(reader.read_u8().unwrap(), 1);
assert_eq!(reader.read_slice(usize::MAX), Err(DeserializationError::UnexpectedEOF));
assert_eq!(reader.check_eor(usize::MAX), Err(DeserializationError::UnexpectedEOF));
}
#[test]
fn slice_reader_accepts_fake_length_prefix() {
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&1000u64.to_le_bytes());
data.extend_from_slice(&42u64.to_le_bytes());
let mut reader = SliceReader::new(&data);
let _len = reader.read_usize().unwrap();
let iter_result = reader.read_many_iter::<u64>(1000);
assert!(iter_result.is_ok());
let collect_result: Result<Vec<u64>, _> = iter_result.unwrap().collect();
assert!(collect_result.is_err());
assert!(matches!(collect_result.unwrap_err(), DeserializationError::UnexpectedEOF));
}
#[test]
fn budgeted_reader_rejects_fake_length_upfront() {
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&1000u64.to_le_bytes());
data.extend_from_slice(&42u64.to_le_bytes());
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 64);
let _len = reader.read_usize().unwrap(); let iter_result = reader.read_many_iter::<u64>(1000);
match iter_result {
Err(DeserializationError::InvalidValue(_)) => {}, other => panic!("expected InvalidValue error, got {:?}", other.map(|_| "Ok")),
}
}
#[test]
fn budget_equals_input_length_is_safe() {
let original = vec![100u64, 200];
let mut data = Vec::new();
crate::Serializable::write_into(&original, &mut data);
let result = Vec::<u64>::read_from_bytes_with_budget(&data, data.len());
assert_eq!(result.unwrap(), vec![100, 200]);
let mut evil_data = Vec::new();
evil_data.push(0); evil_data.extend_from_slice(&1000u64.to_le_bytes());
evil_data.extend_from_slice(&42u64.to_le_bytes());
let result = Vec::<u64>::read_from_bytes_with_budget(&evil_data, evil_data.len());
assert!(result.is_err());
}
#[test]
fn min_serialized_size_bounds_flat_collections() {
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&1000u64.to_le_bytes()); data.extend_from_slice(&[0u8; 16]);
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 80);
let _len = reader.read_usize().unwrap();
let result = reader.read_many_iter::<u64>(1000);
assert!(result.is_err());
}
#[test]
fn min_serialized_size_override_for_nested_collections() {
assert_eq!(<Vec<u64>>::min_serialized_size(), 1);
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&100u64.to_le_bytes()); data.push(0b10);
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 110);
let _len = reader.read_usize().unwrap();
let result = reader.read_many_iter::<Vec<u64>>(100);
assert!(result.is_ok());
let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
assert!(collect_result.is_err());
}
#[test]
fn nested_collections_still_protected_by_budget() {
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&10u64.to_le_bytes()); for _ in 0..10 {
data.push(0); data.extend_from_slice(&1000u64.to_le_bytes());
}
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 100);
let _len = reader.read_usize().unwrap();
let result = reader.read_many_iter::<Vec<u64>>(10);
assert!(result.is_ok());
let collect_result: Result<Vec<Vec<u64>>, _> = result.unwrap().collect();
assert!(collect_result.is_err());
}
#[test]
fn tuple_min_serialized_size_excludes_padding() {
assert_eq!(<(u8, u64)>::min_serialized_size(), 9);
assert_eq!(size_of::<(u8, u64)>(), 16);
let mut data = Vec::new();
data.push(0); data.extend_from_slice(&4u64.to_le_bytes()); for i in 0u8..4 {
data.push(i); data.extend_from_slice(&(i as u64).to_le_bytes()); }
let inner = SliceReader::new(&data);
let mut reader = BudgetedReader::new(inner, 45);
let _len = reader.read_usize().unwrap();
let result = reader.read_many_iter::<(u8, u64)>(4);
assert!(result.is_ok());
let collect_result: Result<Vec<(u8, u64)>, _> = result.unwrap().collect();
assert!(collect_result.is_ok());
assert_eq!(collect_result.unwrap().len(), 4);
}
}