use std::{
cmp::Ordering,
io::{self, Read},
};
use bitstream_io::BitRead2;
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error(transparent)]
Io(#[from] io::Error),
#[error("Invalid LZW code encountered")]
InvalidCode,
#[error("Too many LZW codes encountered")]
TooManyCodes,
}
impl From<Error> for io::Error {
fn from(val: Error) -> Self {
match val {
Error::Io(error) => error,
Error::InvalidCode => io::Error::other("Invalid LZW code encountered"),
Error::TooManyCodes => io::Error::other("Too many LZW codes encountered"),
}
}
}
pub struct LzwTree<const N: usize> {
symbol_count: usize,
previous_symbol: usize,
symbol_size: u32,
parents: [usize; N],
values: [u8; N],
buffer: Vec<u8>,
}
impl<const N: usize> Default for LzwTree<N> {
fn default() -> Self {
if N < 256 {
panic!("Invalid LZWTree configuration");
}
let mut symbols = [0u8; N];
symbols
.iter_mut()
.take(256)
.enumerate()
.for_each(|(i, symbol)| *symbol = i as u8);
Self {
symbol_count: 256 + 1,
symbol_size: 9,
previous_symbol: usize::MAX,
parents: [usize::MAX; N],
buffer: Vec::with_capacity(1024),
values: symbols,
}
}
}
impl<const N: usize> LzwTree<N> {
fn reset(&mut self) {
self.symbol_count = 256 + 1;
self.previous_symbol = usize::MAX;
self.symbol_size = 9;
}
fn advance(&mut self, symbol: usize) -> Result<&[u8], Error> {
if self.previous_symbol == usize::MAX {
if symbol >= self.symbol_count {
return Err(Error::InvalidCode);
}
self.previous_symbol = symbol;
} else {
let value = match symbol.cmp(&self.symbol_count) {
Ordering::Less => self.find_first_byte(symbol),
Ordering::Equal => self.find_first_byte(self.previous_symbol),
Ordering::Greater => return Err(Error::InvalidCode),
};
let parent = self.previous_symbol;
self.previous_symbol = symbol;
if !self.full() {
self.parents[self.symbol_count] = parent;
self.values[self.symbol_count] = value;
self.symbol_count += 1;
if !self.full() && (self.symbol_count & (self.symbol_count - 1)) == 0 {
self.symbol_size += 1;
}
} else {
log::warn!("Ignore overflowing code table, hopefully the block ends soon…");
}
}
let n = self.output_len();
if n > self.buffer.len() {
self.buffer = vec![0u8; n];
}
let mut i = n;
let mut symbol = self.previous_symbol;
loop {
match symbol {
usize::MAX => return Ok(&self.buffer[0..n]),
_ => {
self.buffer[i - 1] = self.values[symbol];
symbol = self.parents[symbol];
i -= 1;
}
}
}
}
fn find_first_byte(&mut self, mut symbol: usize) -> u8 {
assert_ne!(symbol, usize::MAX);
loop {
match self.parents[symbol] {
usize::MAX => return self.values[symbol],
_ => symbol = self.parents[symbol],
}
}
}
fn full(&self) -> bool {
self.symbol_count == N
}
fn output_len(&self) -> usize {
let mut n = 0;
let mut symbol = self.previous_symbol;
loop {
match symbol {
usize::MAX => return n,
_ => {
n += 1;
symbol = self.parents[symbol]
}
}
}
}
}
pub struct LzwReader<R: io::Read> {
initialized: bool,
inner: bitstream_io::BitReader<R, bitstream_io::LittleEndian>,
tree: LzwTree<0x4000>,
symbol_counter: u32,
buffer: Vec<u8>,
buffer_pos: usize,
position: u64,
uncompressed_size: u64,
}
impl<R: io::Read> LzwReader<R> {
pub fn new(inner: R, uncompressed_size: u64) -> Self {
Self {
initialized: false,
inner: bitstream_io::BitReader::<_, bitstream_io::LittleEndian>::new(inner),
tree: Default::default(),
symbol_counter: 0,
buffer: Vec::new(),
buffer_pos: 0,
position: 0,
uncompressed_size,
}
}
pub fn into_inner(self) -> R {
self.inner.into_reader()
}
fn decode_chunk(&mut self) -> Result<&[u8], Error> {
loop {
self.symbol_counter += 1;
match self.inner.read(self.tree.symbol_size)? {
256u16 => {
log::info!("End of block found");
if !self.symbol_counter.is_multiple_of(8) {
self.inner
.skip(self.tree.symbol_size * (8 - (self.symbol_counter % 8)))?;
}
self.tree.reset();
self.symbol_counter = 0;
}
symbol => return self.tree.advance(symbol as usize),
}
}
}
}
impl<R: io::Read> io::Read for LzwReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.position >= self.uncompressed_size {
return Ok(0);
}
if !self.initialized {
self.initialized = true;
self.buffer = self.decode_chunk()?.to_vec();
self.buffer_pos = 0;
}
for (idx, byte) in buf.iter_mut().enumerate() {
if self.buffer_pos < self.buffer.len() {
*byte = self.buffer[self.buffer_pos];
self.buffer_pos += 1;
continue;
}
self.buffer = match self.decode_chunk() {
Ok(buf) => buf.to_vec(),
Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(idx),
Err(e) => return Err(e.into()),
};
self.buffer_pos = 0;
if self.buffer_pos < self.buffer.len() {
*byte = self.buffer[self.buffer_pos];
self.buffer_pos += 1;
continue;
}
self.position += idx as u64;
return Ok(idx);
}
Ok(buf.len())
}
}
impl<R: io::Read> io::Seek for LzwReader<R> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
Ok(match pos {
io::SeekFrom::Current(0) => todo!(),
io::SeekFrom::Current(n) if n < 0 => todo!(),
io::SeekFrom::Current(x) => {
let mut buf = vec![0u8; x as usize];
self.read(&mut buf)? as u64
}
io::SeekFrom::End(_) => todo!(),
io::SeekFrom::Start(n) if n > self.position => {
self.seek(io::SeekFrom::Current(n as i64 - self.position as i64))?
}
_ => todo!(),
})
}
#[inline]
fn stream_position(&mut self) -> io::Result<u64> {
Ok(self.position)
}
#[inline]
fn stream_len(&mut self) -> io::Result<u64> {
Ok(self.uncompressed_size)
}
}