use std::io;
pub struct RleReader<R: io::Read> {
inner: R,
count: u8,
last_byte: u8,
uncompressed_size: usize,
produced_bytes: usize,
}
impl<R: io::Read> RleReader<R> {
pub fn new(inner: R, uncompressed_size: usize) -> Self {
Self {
inner,
count: 0,
last_byte: 0,
uncompressed_size,
produced_bytes: 0,
}
}
#[inline]
fn read_next_byte(&mut self) -> io::Result<Option<u8>> {
if self.produced_bytes >= self.uncompressed_size {
return Ok(None);
}
if self.count > 0 {
self.count -= 1;
self.produced_bytes += 1;
return Ok(Some(self.last_byte));
}
match self.inner.read_byte()? {
None => Ok(None),
Some(0x81) => match self.inner.read_byte()? {
None => Err(io::Error::other("Unexpected end of file after escape code")),
Some(0x82) => match self.inner.read_byte()? {
None => Err(io::Error::other("Unexpected end of file after escape code")),
Some(0) => {
self.last_byte = 0x82;
self.count = 1;
self.produced_bytes += 1;
Ok(Some(0x81))
}
Some(1) => self.read_next_byte(),
Some(byte) => {
self.count = byte - 2;
self.produced_bytes += 1;
Ok(Some(self.last_byte))
}
},
Some(c) => {
self.count = 1;
self.last_byte = c;
self.produced_bytes += 1;
Ok(Some(0x81))
}
},
Some(b) => {
self.last_byte = b;
self.produced_bytes += 1;
Ok(Some(self.last_byte))
}
}
}
}
impl<R: io::Read> io::Read for RleReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
for (idx, byte) in buf.iter_mut().enumerate() {
match self.read_next_byte() {
Err(e) => return Err(e),
Ok(None) => return Ok(idx),
Ok(Some(b)) => *byte = b,
}
}
Ok(buf.len())
}
}
impl<R: io::Read> io::Seek for RleReader<R> {
fn seek(&mut self, _: io::SeekFrom) -> io::Result<u64> {
todo!();
}
#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.produced_bytes as u64)
}
#[inline]
fn stream_len(&mut self) -> io::Result<u64> {
Ok(self.uncompressed_size as u64)
}
}
trait ReadByte {
fn read_byte(&mut self) -> io::Result<Option<u8>>;
}
impl<R: io::Read> ReadByte for R {
#[inline]
fn read_byte(&mut self) -> io::Result<Option<u8>> {
let mut buf = [0u8];
match self.read(&mut buf) {
Ok(0) => Ok(None),
Ok(_) => Ok(Some(buf[0])),
Err(err) => Err(err),
}
}
}
#[cfg(test)]
mod test {
use std::io::{self, Read};
use crate::rle::RleReader;
#[test]
fn literal_escape() {
verify_expansion(b"\xAB\x81\x82\x00", b"\xAB\x81\x82");
verify_expansion(b"\xAB\x81\x82\x01", b"\xAB");
verify_expansion(b"\xAB\x81\x82\x02", b"\xAB\xAB");
verify_expansion(b"\xAB\x81\x82\x03", b"\xAB\xAB\xAB");
verify_expansion(b"\xAB\x81\x82\x04", b"\xAB\xAB\xAB\xAB");
verify_expansion(b"\xAB\x81\x82\x05", b"\xAB\xAB\xAB\xAB\xAB");
}
fn verify_expansion(input: &[u8], expected_output: &[u8]) {
let mut output: Vec<u8> = Vec::new();
let input = io::Cursor::new(input);
let mut reader = RleReader::new(input, expected_output.len());
reader.read_to_end(&mut output).unwrap();
assert_eq!(output, expected_output);
}
}