use std::io::{self};
use crate::{Command, Error};
pub struct Reader<T: std::io::Read> {
inner: T,
repeated: u8,
op: Command,
max_output: Option<u64>,
position: u64,
}
impl<T: io::Read> Reader<T> {
pub fn new(inner: T) -> Self {
Self {
inner,
repeated: 0,
op: Command::Literal(0),
max_output: None,
position: 0,
}
}
pub fn with_max_output(inner: T, max_output: u64) -> Self {
Self {
inner,
repeated: 0,
op: Command::Literal(0),
max_output: Some(max_output),
position: 0,
}
}
#[inline]
fn produce_next_byte(&mut self) -> Result<Option<u8>, Error> {
let mut buf = [0u8];
match self.op {
Command::Literal(0) | Command::Repeat(0) => match self.inner.read(&mut buf) {
Ok(0) => Ok(None),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => Ok(None),
Err(e) => Err(e)?,
Ok(_) => {
self.op = buf[0].into();
if let Command::Repeat(_) = self.op {
match self.inner.read(&mut buf) {
Ok(0) => return Err(Error::NotEnoughInputData),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
return Err(Error::NotEnoughInputData);
}
Err(e) => return Err(e)?,
Ok(_) => {
self.repeated = buf[0];
}
}
}
self.produce_next_byte()
}
},
Command::Escape => {
self.op = Command::Literal(0);
match self.inner.read(&mut buf) {
Ok(0) => Err(Error::NotEnoughInputData),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
Err(Error::NotEnoughInputData)
}
Err(e) => Err(e)?,
Ok(_) => Ok(Some(buf[0])),
}
}
Command::Literal(c) => {
self.op = Command::Literal(c - 1);
match self.inner.read(&mut buf) {
Ok(0) => Err(Error::NotEnoughInputData),
Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => {
Err(Error::NotEnoughInputData)
}
Err(e) => Err(e)?,
Ok(_) => Ok(Some(buf[0])),
}
}
Command::Repeat(count) => {
self.op = Command::Repeat(count - 1);
Ok(Some(self.repeated))
}
}
}
pub fn into_inner(self) -> T {
self.inner
}
}
impl<T: io::Read + io::Seek> Reader<T> {
#[inline]
pub fn is_empty(&mut self) -> io::Result<bool> {
Ok(matches!(self.op, Command::Literal(0) | Command::Repeat(0))
&& self.inner.stream_position()? >= self.inner.stream_len()?)
}
}
impl<T: io::Read> io::Read for Reader<T> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
let max_output = self.max_output.unwrap_or(u64::MAX);
for (idx, byte) in buf.iter_mut().enumerate() {
if self.position >= max_output {
return Ok(idx);
}
match self.produce_next_byte()? {
None => return Ok(idx),
Some(value) => {
self.position += 1;
*byte = value
}
}
}
Ok(buf.len())
}
}
impl<T: io::Read> io::Seek for Reader<T> {
fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
todo!()
}
#[inline]
fn stream_len(&mut self) -> io::Result<u64> {
if let Some(max_output) = self.max_output {
Ok(max_output)
} else {
Err(io::Error::other(
"Cannot determine stream length without max_output",
))
}
}
#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.position)
}
}
#[cfg(test)]
mod test {
use std::io::{self, Read};
#[test]
fn reading_literals() {
let mut reader = crate::Reader::new(io::Cursor::new(b"\x01\xAB\xCD"));
let mut output = vec![0u8; 2];
let result = reader.read(&mut output);
assert!(matches!(result, Ok(2)));
assert_eq!(output, b"\xAB\xCD");
}
#[test]
fn reading_longer_values() {
let input = b"\xFE\xAA\x02\x80\x00\x2A\xFD\xAA\x03\x80\x00\x2A\x22\xF7\xAA";
let expectation = b"\xAA\xAA\xAA\x80\x00\x2A\xAA\xAA\xAA\xAA\x80\x00\x2A\x22\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA\xAA";
let inner = io::Cursor::new(input);
let mut reader = crate::Reader::new(inner);
let mut output = vec![0u8; 24];
let result = reader.read(&mut output);
assert!(matches!(result, Ok(24)));
assert_eq!(output, expectation);
}
}