sit-algos 0.3.0

Implementation of decompression algorithms used by StuffIt Expander and related applications
Documentation
use std::{
    cmp::Ordering,
    io::{self, Read},
};

use bitstream_io::BitRead2;

#[derive(Debug, thiserror::Error)]
pub enum Error {
    #[error(transparent)]
    Io(#[from] io::Error),
    #[error("Invalid LZW code encountered")]
    InvalidCode,
    #[error("Too many LZW codes encountered")]
    TooManyCodes,
}

impl From<Error> for io::Error {
    fn from(val: Error) -> Self {
        match val {
            Error::Io(error) => error,
            Error::InvalidCode => io::Error::other("Invalid LZW code encountered"),
            Error::TooManyCodes => io::Error::other("Too many LZW codes encountered"),
        }
    }
}

pub struct LzwTree<const N: usize> {
    symbol_count: usize,
    previous_symbol: usize,
    symbol_size: u32,

    parents: [usize; N],
    values: [u8; N],
    buffer: Vec<u8>,
}

impl<const N: usize> Default for LzwTree<N> {
    fn default() -> Self {
        if N < 256 {
            panic!("Invalid LZWTree configuration");
        }

        let mut symbols = [0u8; N];
        symbols
            .iter_mut()
            .take(256)
            .enumerate()
            .for_each(|(i, symbol)| *symbol = i as u8);

        Self {
            symbol_count: 256 + 1,
            symbol_size: 9,
            previous_symbol: usize::MAX,
            parents: [usize::MAX; N],
            buffer: Vec::with_capacity(1024),
            values: symbols,
        }
    }
}

impl<const N: usize> LzwTree<N> {
    fn reset(&mut self) {
        self.symbol_count = 256 + 1;
        self.previous_symbol = usize::MAX;
        self.symbol_size = 9;
    }

    fn advance(&mut self, symbol: usize) -> Result<&[u8], Error> {
        if self.previous_symbol == usize::MAX {
            if symbol >= self.symbol_count {
                return Err(Error::InvalidCode);
            }

            self.previous_symbol = symbol;
        } else {
            let value = match symbol.cmp(&self.symbol_count) {
                Ordering::Less => self.find_first_byte(symbol),
                Ordering::Equal => self.find_first_byte(self.previous_symbol),
                Ordering::Greater => return Err(Error::InvalidCode),
            };

            let parent = self.previous_symbol;
            self.previous_symbol = symbol;

            if !self.full() {
                self.parents[self.symbol_count] = parent;
                self.values[self.symbol_count] = value;
                self.symbol_count += 1;

                if !self.full() && (self.symbol_count & (self.symbol_count - 1)) == 0 {
                    self.symbol_size += 1;
                }
            } else {
                log::warn!("Ignore overflowing code table, hopefully the block ends soon…");
            }
        }

        let n = self.output_len();
        if n > self.buffer.len() {
            self.buffer = vec![0u8; n];
        }

        let mut i = n;
        let mut symbol = self.previous_symbol;
        loop {
            match symbol {
                usize::MAX => return Ok(&self.buffer[0..n]),
                _ => {
                    self.buffer[i - 1] = self.values[symbol];
                    symbol = self.parents[symbol];
                    i -= 1;
                }
            }
        }
    }

    fn find_first_byte(&mut self, mut symbol: usize) -> u8 {
        assert_ne!(symbol, usize::MAX);
        loop {
            match self.parents[symbol] {
                usize::MAX => return self.values[symbol],
                _ => symbol = self.parents[symbol],
            }
        }
    }

    fn full(&self) -> bool {
        self.symbol_count == N
    }

    fn output_len(&self) -> usize {
        let mut n = 0;
        let mut symbol = self.previous_symbol;
        loop {
            match symbol {
                usize::MAX => return n,
                _ => {
                    n += 1;
                    symbol = self.parents[symbol]
                }
            }
        }
    }
}

pub struct LzwReader<R: io::Read> {
    initialized: bool,
    inner: bitstream_io::BitReader<R, bitstream_io::LittleEndian>,
    tree: LzwTree<0x4000>,
    symbol_counter: u32,
    buffer: Vec<u8>,
    buffer_pos: usize,

    position: u64,
    uncompressed_size: u64,
}

impl<R: io::Read> LzwReader<R> {
    pub fn new(inner: R, uncompressed_size: u64) -> Self {
        Self {
            initialized: false,
            inner: bitstream_io::BitReader::<_, bitstream_io::LittleEndian>::new(inner),
            tree: Default::default(),
            symbol_counter: 0,
            buffer: Vec::new(),
            buffer_pos: 0,

            position: 0,
            uncompressed_size,
        }
    }

    pub fn into_inner(self) -> R {
        self.inner.into_reader()
    }

    fn decode_chunk(&mut self) -> Result<&[u8], Error> {
        loop {
            self.symbol_counter += 1;
            match self.inner.read(self.tree.symbol_size)? {
                256u16 => {
                    log::info!("End of block found");
                    if !self.symbol_counter.is_multiple_of(8) {
                        self.inner
                            .skip(self.tree.symbol_size * (8 - (self.symbol_counter % 8)))?;
                    }
                    self.tree.reset();
                    self.symbol_counter = 0;
                }
                symbol => return self.tree.advance(symbol as usize),
            }
        }
    }
}

impl<R: io::Read> io::Read for LzwReader<R> {
    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
        if self.position >= self.uncompressed_size {
            return Ok(0);
        }

        if !self.initialized {
            self.initialized = true;
            self.buffer = self.decode_chunk()?.to_vec();
            self.buffer_pos = 0;
        }

        for (idx, byte) in buf.iter_mut().enumerate() {
            // Copy already decoded data
            if self.buffer_pos < self.buffer.len() {
                *byte = self.buffer[self.buffer_pos];
                self.buffer_pos += 1;
                continue;
            }

            // Decode another chunk
            self.buffer = match self.decode_chunk() {
                Ok(buf) => buf.to_vec(),
                Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(idx),
                Err(e) => return Err(e.into()),
            };
            self.buffer_pos = 0;

            // Copy more data if we have it, otherwise report number of copied bytes thus far
            if self.buffer_pos < self.buffer.len() {
                *byte = self.buffer[self.buffer_pos];
                self.buffer_pos += 1;
                continue;
            }

            self.position += idx as u64;
            return Ok(idx);
        }

        Ok(buf.len())
    }
}

impl<R: io::Read> io::Seek for LzwReader<R> {
    fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
        Ok(match pos {
            io::SeekFrom::Current(0) => todo!(),
            io::SeekFrom::Current(n) if n < 0 => todo!(),
            io::SeekFrom::Current(x) => {
                let mut buf = vec![0u8; x as usize];
                self.read(&mut buf)? as u64
            }
            io::SeekFrom::End(_) => todo!(),
            io::SeekFrom::Start(n) if n > self.position => {
                self.seek(io::SeekFrom::Current(n as i64 - self.position as i64))?
            }
            _ => todo!(),
        })
    }

    #[inline]
    fn stream_position(&mut self) -> io::Result<u64> {
        Ok(self.position)
    }

    #[inline]
    fn stream_len(&mut self) -> io::Result<u64> {
        Ok(self.uncompressed_size)
    }
}