use snafu::ensure;
use crate::stream::Reads;
use super::{encode::Encodable, CodecError, DataFormat, DataHeader, UnexpectedDataFormatSnafu};
pub const TEMP_BUFFER_SIZE: usize = 1024;
pub trait Decodable: Encodable {
fn decode(
&mut self,
reader: &mut (impl ReadsDecodable + ?Sized),
header: Option<DataHeader>,
) -> Result<(), CodecError>;
#[inline(always)]
fn ensure_header(
header: Option<DataHeader>,
supported_ordinals: &[u8],
) -> Result<DataHeader, CodecError> {
use super::UnsupportedDataFormatSnafu;
let header = header.ok_or_else(|| {
UnexpectedDataFormatSnafu {
expected: Self::FORMAT,
actual: header,
}
.build()
})?;
ensure!(
supported_ordinals.contains(&header.format.ordinal),
UnsupportedDataFormatSnafu {
ordinal: header.format.ordinal
}
);
Ok(header)
}
#[inline(always)]
fn ensure_no_header(header: Option<DataHeader>) -> Result<(), CodecError> {
ensure!(
header.is_none(),
UnexpectedDataFormatSnafu {
expected: Self::FORMAT,
actual: header,
}
);
Ok(())
}
}
pub trait ReadsDecodable {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, CodecError>;
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), CodecError> {
let mut read = 0;
while read < buf.len() {
let n = self.read(&mut buf[read..])?;
if n == 0 {
return Err(CodecError::UnexpectedEof);
}
read += n;
}
Ok(())
}
fn enter_scope(&mut self) -> Result<(), CodecError> {
Ok(())
}
fn exit_scope(&mut self) {}
fn read_data<T: Decodable + Default>(&mut self) -> Result<T, CodecError> {
let mut default = T::default();
self.read_data_into(&mut default)?;
Ok(default)
}
fn read_data_into<T: Decodable>(&mut self, data: &mut T) -> Result<(), CodecError> {
if T::FORMAT.is_structured() {
let mut guard = DecodingScope::enter(self)?;
let header: DataHeader = guard.read_data()?;
data.decode(&mut *guard, Some(header))?;
} else {
data.decode(self, None)?;
}
Ok(())
}
fn skip_blob(&mut self, length: usize) -> Result<(), CodecError> {
let mut skipped = 0;
let mut buf = [0; TEMP_BUFFER_SIZE];
while skipped < length {
let remaining = length - skipped;
let n = if remaining < TEMP_BUFFER_SIZE {
self.read(&mut buf[..remaining])?
} else {
self.read(&mut buf)?
};
if n == 0 {
return Err(CodecError::UnexpectedEof);
}
skipped += n;
}
Ok(())
}
fn skip_data(&mut self) -> Result<usize, CodecError> {
let mut guard = DecodingScope::enter(self)?;
let mut read = 0;
let header: DataHeader = guard.read_data()?;
read += DataHeader::FORMAT.as_data_format().blob_size as usize;
let data_format = header.format;
for _ in 0..header.count {
read += guard.skip_data_with_format(data_format)?;
}
Ok(read)
}
fn skip_data_with_format(&mut self, format: DataFormat) -> Result<usize, CodecError> {
let mut read = 0;
self.skip_blob(format.blob_size as usize)?;
read += format.blob_size as usize;
for _ in 0..format.data_fields {
read += self.skip_data()?;
}
Ok(read)
}
}
pub(crate) struct DecodingScope<'a, R: ReadsDecodable + ?Sized> {
reader: &'a mut R,
}
impl<'a, R: ReadsDecodable + ?Sized> DecodingScope<'a, R> {
pub(crate) fn enter(reader: &'a mut R) -> Result<Self, CodecError> {
reader.enter_scope()?;
Ok(Self { reader })
}
}
impl<R: ReadsDecodable + ?Sized> Drop for DecodingScope<'_, R> {
fn drop(&mut self) {
self.reader.exit_scope();
}
}
impl<R: ReadsDecodable + ?Sized> core::ops::Deref for DecodingScope<'_, R> {
type Target = R;
fn deref(&self) -> &R {
self.reader
}
}
impl<R: ReadsDecodable + ?Sized> core::ops::DerefMut for DecodingScope<'_, R> {
fn deref_mut(&mut self) -> &mut R {
self.reader
}
}
impl<R: Reads> ReadsDecodable for R {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, CodecError> {
Ok(Reads::read(self, buf)?)
}
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), CodecError> {
Ok(Reads::read_exact(self, buf)?)
}
fn read_data_into<T: Decodable>(&mut self, data: &mut T) -> Result<(), CodecError> {
LimitedReader::new(&mut *self).read_data_into(data)
}
fn skip_data(&mut self) -> Result<usize, CodecError> {
LimitedReader::new(&mut *self).skip_data()
}
}
pub struct LimitedReader<'a, R: Reads> {
reader: &'a mut R,
bytes_read: u64,
max_bytes: u64,
depth: u32,
max_depth: u32,
}
impl<'a, R: Reads> LimitedReader<'a, R> {
pub fn new(reader: &'a mut R) -> Self {
Self {
reader,
bytes_read: 0,
max_bytes: DEFAULT_MAX_BYTES,
depth: 0,
max_depth: DEFAULT_MAX_DEPTH,
}
}
pub fn unlimited(reader: &'a mut R) -> Self {
Self {
reader,
bytes_read: 0,
max_bytes: u64::MAX,
depth: 0,
max_depth: u32::MAX,
}
}
pub fn max_bytes(mut self, max: u64) -> Self {
self.max_bytes = max;
self
}
pub fn max_depth(mut self, max: u32) -> Self {
self.max_depth = max;
self
}
pub fn bytes_read(&self) -> u64 {
self.bytes_read
}
}
impl<R: Reads> ReadsDecodable for LimitedReader<'_, R> {
fn read(&mut self, buf: &mut [u8]) -> Result<usize, CodecError> {
let remaining = self.max_bytes.saturating_sub(self.bytes_read) as usize;
if remaining == 0 && !buf.is_empty() {
return Err(CodecError::ByteLimitExceeded);
}
let limit = buf.len().min(remaining);
let n = self.reader.read(&mut buf[..limit])?;
self.bytes_read += n as u64;
Ok(n)
}
fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), CodecError> {
let len = buf.len() as u64;
if self.bytes_read + len > self.max_bytes {
return Err(CodecError::ByteLimitExceeded);
}
self.reader.read_exact(buf)?;
self.bytes_read += len;
Ok(())
}
fn enter_scope(&mut self) -> Result<(), CodecError> {
self.depth += 1;
if self.depth > self.max_depth {
return Err(CodecError::DepthLimitExceeded);
}
Ok(())
}
fn exit_scope(&mut self) {
self.depth = self.depth.saturating_sub(1);
}
}
pub const DEFAULT_MAX_BYTES: u64 = 64 * 1024 * 1024;
pub const DEFAULT_MAX_DEPTH: u32 = 64;
#[cfg(test)]
mod tests {
use super::*;
use crate::{codec::tests::*, types::Text};
#[test]
fn decodes() -> Result<(), CodecError> {
let mut bytes = Vec::new();
encode_test_data(&mut bytes);
let mut bytes = bytes.as_slice();
let header: DataHeader = bytes.read_data()?;
assert_eq!(1, header.count);
assert_eq!(0, header.format.ordinal);
assert_eq!(12, header.format.blob_size);
assert_eq!(1, header.format.data_fields);
assert_eq!(TestData::default().num_a, bytes.read_data()?);
assert_eq!(TestData::default().num_b, bytes.read_data()?);
let mut text = Text::default();
bytes.read_data_into(&mut text)?;
assert_eq!(TestData::default().text, text);
Ok(())
}
#[test]
fn limited_reader_byte_limit() {
use crate::codec::WritesEncodable;
let text = Text::from("hello, limited world!");
let mut bytes = vec![];
bytes.write_data(&text).unwrap();
let total = bytes.len();
let mut slice = bytes.as_slice();
let result = LimitedReader::new(&mut slice)
.max_bytes(8) .read_data::<Text>();
assert!(
matches!(result, Err(CodecError::ByteLimitExceeded)),
"expected ByteLimitExceeded, got {result:?}"
);
let mut slice = bytes.as_slice();
let decoded = LimitedReader::new(&mut slice)
.max_bytes(total as u64)
.read_data::<Text>()
.expect("should decode within exact limit");
assert_eq!(text, decoded);
}
#[test]
fn limited_reader_depth_limit() {
use crate::codec::WritesEncodable;
let data: Vec<Vec<u32>> = vec![vec![1, 2], vec![3, 4]];
let mut bytes = vec![];
bytes.write_data(&data).unwrap();
let mut slice = bytes.as_slice();
let result = LimitedReader::new(&mut slice)
.max_depth(1)
.read_data::<Vec<Vec<u32>>>();
assert!(
matches!(result, Err(CodecError::DepthLimitExceeded)),
"expected DepthLimitExceeded, got {result:?}"
);
let mut slice = bytes.as_slice();
let decoded = LimitedReader::new(&mut slice)
.max_depth(2)
.read_data::<Vec<Vec<u32>>>()
.expect("should decode at depth 2");
assert_eq!(data, decoded);
}
#[test]
fn limited_reader_cumulative_bytes() {
use crate::codec::WritesEncodable;
let text_a = Text::from("hello world!!");
let text_b = Text::from("goodbye world!");
let mut payload = vec![];
payload.write_data(&text_a).unwrap();
payload.write_data(&text_b).unwrap();
let total = payload.len();
let first_field_size = {
let mut tmp = vec![];
tmp.write_data(&text_a).unwrap();
tmp.len()
};
let tight_limit = first_field_size as u64 + 4;
let mut slice = payload.as_slice();
let mut limited = LimitedReader::new(&mut slice).max_bytes(tight_limit);
let _a: Text = limited.read_data().unwrap(); let result_b: Result<Text, _> = limited.read_data(); assert!(
matches!(result_b, Err(CodecError::ByteLimitExceeded)),
"expected ByteLimitExceeded on second field, got {result_b:?}"
);
let mut slice = payload.as_slice();
let mut limited = LimitedReader::new(&mut slice).max_bytes(total as u64);
let a: Text = limited.read_data().unwrap();
let b: Text = limited.read_data().unwrap();
assert_eq!(text_a, a);
assert_eq!(text_b, b);
}
#[test]
fn limited_reader_auto_wrap_succeeds() -> Result<(), CodecError> {
use crate::codec::WritesEncodable;
let data: Vec<Vec<u32>> = vec![vec![1, 2, 3], vec![4, 5]];
let mut bytes = vec![];
bytes.write_data(&data).unwrap();
let decoded: Vec<Vec<u32>> = bytes.as_slice().read_data()?;
assert_eq!(data, decoded);
Ok(())
}
#[test]
fn limited_reader_struct_byte_accumulation() {
use crate::codec::WritesEncodable;
let data: Vec<Vec<u32>> = vec![vec![1, 2], vec![3, 4]];
let mut bytes = vec![];
bytes.write_data(&data).unwrap();
let total = bytes.len();
let mut slice = bytes.as_slice();
let result = LimitedReader::new(&mut slice)
.max_bytes(total as u64 - 1)
.read_data::<Vec<Vec<u32>>>();
assert!(
matches!(result, Err(CodecError::ByteLimitExceeded)),
"expected ByteLimitExceeded with budget {}, got {result:?}",
total - 1,
);
let mut slice = bytes.as_slice();
let decoded = LimitedReader::new(&mut slice)
.max_bytes(total as u64)
.read_data::<Vec<Vec<u32>>>()
.expect("exact budget should succeed");
assert_eq!(data, decoded);
}
#[test]
fn splits_off_group_sequences() -> Result<(), CodecError> {
let mut expected = vec![];
encode_test_data(&mut expected);
let mut bytes = vec![];
encode_test_data(&mut bytes);
encode_test_data(&mut bytes);
let original_bytes = bytes.as_slice();
let mut bytes = bytes.as_slice();
let data_one_length = bytes.skip_data()?;
let (data_one, original_bytes) = original_bytes.split_at(data_one_length);
assert_eq!(expected, data_one);
let data_two_length = bytes.skip_data()?;
let (data_two, _) = original_bytes.split_at(data_two_length);
assert_eq!(expected, data_two);
Ok(())
}
}