use crate::parser::Decoder;
use std::io::Read;
pub trait ReadExt: Read {
fn read_bytes<T>(&mut self) -> Result<T, std::io::Error>
where
Self: Sized,
T: Decoder<()>,
{
T::decode(self, ())
}
fn read_bytes_with<T, P>(&mut self, params: P) -> Result<T, std::io::Error>
where
Self: Sized,
T: Decoder<P>,
{
T::decode(self, params)
}
fn skip_bytes<const SKIP: usize>(&mut self) -> Result<&mut Self, std::io::Error> {
if SKIP != 0 {
const CHUNK: usize = 256;
let mut buf = [0u8; CHUNK];
let mut remaining = SKIP;
while remaining > 0 {
let n = remaining.min(CHUNK);
self.read_exact(&mut buf[..n])?;
if buf[..n].iter().any(|&b| b != 0) {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Skipped bytes were expected to be zeroed.",
));
}
remaining -= n;
}
}
Ok(self)
}
}
impl<R> ReadExt for R where R: Read {}
#[cfg(test)]
mod read_ext_tests {
use super::*;
use std::io::{self, Read};
#[derive(Debug)]
struct MockReader {
data: Vec<u8>,
position: usize,
}
impl MockReader {
fn new(data: Vec<u8>) -> Self {
MockReader { data, position: 0 }
}
}
impl Read for MockReader {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let len = std::cmp::min(buf.len(), self.data.len() - self.position);
if len > 0 {
buf[..len].copy_from_slice(&self.data[self.position..self.position + len]);
self.position += len;
Ok(len)
} else {
Ok(0) }
}
}
#[test]
fn test_no_skip_valid_data() {
let data = vec![0x12, 0x34, 0x56, 0x78];
let mut reader = MockReader::new(data);
let result: Result<u32, _> = reader.read_bytes();
assert_eq!(result.unwrap(), 0x78563412);
}
#[test]
fn test_skip_valid_data() {
let data = vec![0, 0, 0, 0, 0x12, 0x34, 0x56, 0x78];
let mut reader = MockReader::new(data);
let result: Result<u32, _> = reader.skip_bytes::<4>().unwrap().read_bytes();
assert_eq!(result.unwrap(), 0x78563412);
}
#[test]
fn test_skip_invalid_data() {
let data = vec![0, 0, 1, 0, 0x12, 0x34, 0x56, 0x78];
let mut reader = MockReader::new(data);
let result = reader.skip_bytes::<4>();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::InvalidData);
}
#[test]
fn test_read_fails() {
let data = vec![0x12, 0x34];
let mut reader = MockReader::new(data);
let result: Result<u32, _> = reader.read_bytes();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
}
#[test]
fn test_zero_length_read() {
let data: Vec<u8> = vec![];
let mut reader = MockReader::new(data);
let result: Result<[u8; 0], _> = reader.read_bytes();
assert!(result.is_ok());
}
#[test]
fn test_empty_reader() {
let data: Vec<u8> = vec![];
let mut reader = MockReader::new(data);
let result: Result<u32, _> = reader.read_bytes();
assert!(result.is_err());
assert_eq!(result.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
}
}