sit-algos 0.3.0

Implementation of decompression algorithms used by StuffIt Expander and related applications
Documentation
use std::{io, iter::repeat_n, ops::Range};

use bitstream_io::{BitRead, BitReader, LE};

pub(crate) struct HuffTree {
    sym_count: usize,
    data: Vec<u16>,
}

impl HuffTree {
    pub fn read_code<R: io::Read + io::Seek>(
        &self,
        i: usize,
        reader: &mut BitReader<R, LE>,
    ) -> Result<u16, super::Error> {
        let mut i = i as u32;
        let count = (self.sym_count as u32) * 2;
        while i < count {
            let idx = i as usize + reader.read_var::<u8>(1)? as usize;
            i = self.data[idx] as u32;
        }
        Ok((i - count) as u16)
    }

    pub fn read_from<R: io::Read + io::Seek>(
        reader: &mut BitReader<R, LE>,
        sym_count: usize,
    ) -> Result<Self, super::Error> {
        let k = reader.read_var::<u8>(1)?;
        let j = reader.read_var::<u32>(2)? + 2;
        let o = reader.read_var::<u32>(3)? + 1;

        let m = (1 << j) - 1;
        let count = if k != 0 {
            m - 1
        } else {
            (-1i32).cast_unsigned()
        };

        let symbols = match reader.read_var::<u8>(2)? {
            0b00 => read_plain_symbols(reader, sym_count, count, o, m, j),
            0b01 => read_compressed_symbols(reader, sym_count, 1 << j, count, m, o),
            _ => Err(super::Error::TreeEncodingUnknown),
        }?;
        reader.byte_align();

        let paths = collect_paths(&symbols);
        let data = compile_tree(&symbols, &paths);

        Ok(Self { sym_count, data })
    }
}

fn read_plain_symbols<R: io::Read + io::Seek>(
    reader: &mut BitReader<R, LE>,
    sym_count: usize,
    count: u32,
    o: u32,
    m: u32,
    j: u32,
) -> Result<Vec<u8>, super::Error> {
    let mut syms = Vec::with_capacity(sym_count);

    loop {
        if syms.len() == sym_count {
            return Ok(syms);
        }

        match reader.read_var::<u32>(j)? {
            l if count == l => syms.push(0),
            l if l != m => syms.push((l + o) as u8),
            _ => {
                let count = reader.read_var::<u32>(j)? as usize + 3;
                if syms.is_empty() || syms.len() + count > sym_count {
                    return Err(super::Error::InvalidTree);
                }
                let symbol = syms[syms.len() - 1];
                syms.extend(repeat_n(symbol, count));
            }
        }
    }
}

fn read_compressed_symbols<R: io::Read + io::Seek>(
    reader: &mut BitReader<R, LE>,
    sym_count: usize,
    meta_len: usize,
    count: u32,
    m: u32,
    o: u32,
) -> Result<Vec<u8>, super::Error> {
    let mut syms = Vec::with_capacity(sym_count);
    // TODO: This recursion on tree reading will probably crash the program on malicious input
    let meta = HuffTree::read_from(reader, meta_len)?;

    loop {
        if syms.len() == sym_count {
            return Ok(syms);
        }

        match meta.read_code(0, reader)? as u32 {
            l if count == l => syms.push(0),
            l if l != m => syms.push((l + o) as u8),
            _ => {
                let count = meta.read_code(0, reader)? as usize + 3;
                if syms.is_empty() || syms.len() + count > sym_count {
                    return Err(super::Error::InvalidTree);
                }
                let symbol = syms[syms.len() - 1];
                syms.extend(repeat_n(symbol, count));
            }
        }
    }
}

fn collect_paths(syms: &[u8]) -> Vec<u32> {
    let (indices, sorted_syms) = sorted_indices(syms);
    let mut i = sorted_syms
        .iter()
        .position(|v| *v != 0)
        .unwrap_or(syms.len());

    let mut paths = vec![0u32; syms.len()];

    let mut j = 0;
    let mut l: u32;
    let mut count;
    while i < syms.len() {
        if i != 0 {
            j <<= sorted_syms[i] - sorted_syms[i - 1];
        }

        count = sorted_syms[i] as u32;
        let mut m = 0;

        l = j;
        for _ in 0..count {
            m = (m << 1) | (l & 1);
            l >>= 1;
        }
        paths[indices[i] as usize] = m;

        i += 1;
        j += 1;
    }

    paths
}

fn compile_tree(syms: &[u8], paths: &[u32]) -> Vec<u16> {
    let mut tree_data = vec![0u16; syms.len() * 2];
    let mut tree_ptr = 2;
    for i in 0..syms.len() {
        let mut leaf = 0;
        let mut path = paths[i] as usize;

        let mut bits: u8 = 0;
        while bits < syms[i] {
            leaf += path & 1;

            let is_leaf = syms[i] - 1 > bits;
            if !is_leaf {
                tree_data[leaf] = (syms.len() * 2 + i) as u16;
            } else {
                if tree_data[leaf] == 0 {
                    tree_data[leaf] = tree_ptr;
                    tree_ptr += 2;
                }
                leaf = tree_data[leaf] as usize;
            }
            path >>= 1;
            bits += 1;
        }
    }

    tree_data
}

fn sorted_indices(list: &[u8]) -> (Vec<u16>, Vec<u8>) {
    let mut indices: Vec<u16> = std::iter::repeat_n((), list.len())
        .enumerate()
        .map(|(i, _)| i as u16)
        .collect();

    let mut copy = list.to_vec();
    custom_sorted_indices(&mut indices, &mut copy);
    (indices, copy)
}

fn custom_sorted_indices(indices: &mut [u16], list: &mut [u8]) {
    let mut first = 0;
    let mut last = list.len();

    while first < last {
        let mut i = first;
        let mut j = last;

        loop {
            loop {
                i += 1;
                if i >= last {
                    break;
                }
                if list[first] <= list[i] {
                    break;
                }
            }

            loop {
                j -= 1;
                if j <= first {
                    break;
                }
                if list[first] >= list[j] {
                    break;
                }
            }

            if j > i {
                list.swap(i, j);
                indices.swap(i, j);
            }

            if j <= i {
                break;
            }
        }

        if first != j {
            list.swap(first, j);
            indices.swap(first, j);

            i = j + 1;

            let range: Range<usize>;
            if last - i <= j - first {
                range = i..last;
                last = j;
            } else {
                range = first..j;
                first = i;
            };
            custom_sorted_indices(&mut indices[range.clone()], &mut list[range]);
        } else {
            first += 1;
        }
    }
}