use std::{io, iter::repeat_n, ops::Range};
use bitstream_io::{BitRead, BitReader, LE};
pub(crate) struct HuffTree {
sym_count: usize,
data: Vec<u16>,
}
impl HuffTree {
pub fn read_code<R: io::Read + io::Seek>(
&self,
i: usize,
reader: &mut BitReader<R, LE>,
) -> Result<u16, super::Error> {
let mut i = i as u32;
let count = (self.sym_count as u32) * 2;
while i < count {
let idx = i as usize + reader.read_var::<u8>(1)? as usize;
i = self.data[idx] as u32;
}
Ok((i - count) as u16)
}
pub fn read_from<R: io::Read + io::Seek>(
reader: &mut BitReader<R, LE>,
sym_count: usize,
) -> Result<Self, super::Error> {
let k = reader.read_var::<u8>(1)?;
let j = reader.read_var::<u32>(2)? + 2;
let o = reader.read_var::<u32>(3)? + 1;
let m = (1 << j) - 1;
let count = if k != 0 {
m - 1
} else {
(-1i32).cast_unsigned()
};
let symbols = match reader.read_var::<u8>(2)? {
0b00 => read_plain_symbols(reader, sym_count, count, o, m, j),
0b01 => read_compressed_symbols(reader, sym_count, 1 << j, count, m, o),
_ => Err(super::Error::TreeEncodingUnknown),
}?;
reader.byte_align();
let paths = collect_paths(&symbols);
let data = compile_tree(&symbols, &paths);
Ok(Self { sym_count, data })
}
}
fn read_plain_symbols<R: io::Read + io::Seek>(
reader: &mut BitReader<R, LE>,
sym_count: usize,
count: u32,
o: u32,
m: u32,
j: u32,
) -> Result<Vec<u8>, super::Error> {
let mut syms = Vec::with_capacity(sym_count);
loop {
if syms.len() == sym_count {
return Ok(syms);
}
match reader.read_var::<u32>(j)? {
l if count == l => syms.push(0),
l if l != m => syms.push((l + o) as u8),
_ => {
let count = reader.read_var::<u32>(j)? as usize + 3;
if syms.is_empty() || syms.len() + count > sym_count {
return Err(super::Error::InvalidTree);
}
let symbol = syms[syms.len() - 1];
syms.extend(repeat_n(symbol, count));
}
}
}
}
fn read_compressed_symbols<R: io::Read + io::Seek>(
reader: &mut BitReader<R, LE>,
sym_count: usize,
meta_len: usize,
count: u32,
m: u32,
o: u32,
) -> Result<Vec<u8>, super::Error> {
let mut syms = Vec::with_capacity(sym_count);
let meta = HuffTree::read_from(reader, meta_len)?;
loop {
if syms.len() == sym_count {
return Ok(syms);
}
match meta.read_code(0, reader)? as u32 {
l if count == l => syms.push(0),
l if l != m => syms.push((l + o) as u8),
_ => {
let count = meta.read_code(0, reader)? as usize + 3;
if syms.is_empty() || syms.len() + count > sym_count {
return Err(super::Error::InvalidTree);
}
let symbol = syms[syms.len() - 1];
syms.extend(repeat_n(symbol, count));
}
}
}
}
fn collect_paths(syms: &[u8]) -> Vec<u32> {
let (indices, sorted_syms) = sorted_indices(syms);
let mut i = sorted_syms
.iter()
.position(|v| *v != 0)
.unwrap_or(syms.len());
let mut paths = vec![0u32; syms.len()];
let mut j = 0;
let mut l: u32;
let mut count;
while i < syms.len() {
if i != 0 {
j <<= sorted_syms[i] - sorted_syms[i - 1];
}
count = sorted_syms[i] as u32;
let mut m = 0;
l = j;
for _ in 0..count {
m = (m << 1) | (l & 1);
l >>= 1;
}
paths[indices[i] as usize] = m;
i += 1;
j += 1;
}
paths
}
fn compile_tree(syms: &[u8], paths: &[u32]) -> Vec<u16> {
let mut tree_data = vec![0u16; syms.len() * 2];
let mut tree_ptr = 2;
for i in 0..syms.len() {
let mut leaf = 0;
let mut path = paths[i] as usize;
let mut bits: u8 = 0;
while bits < syms[i] {
leaf += path & 1;
let is_leaf = syms[i] - 1 > bits;
if !is_leaf {
tree_data[leaf] = (syms.len() * 2 + i) as u16;
} else {
if tree_data[leaf] == 0 {
tree_data[leaf] = tree_ptr;
tree_ptr += 2;
}
leaf = tree_data[leaf] as usize;
}
path >>= 1;
bits += 1;
}
}
tree_data
}
fn sorted_indices(list: &[u8]) -> (Vec<u16>, Vec<u8>) {
let mut indices: Vec<u16> = std::iter::repeat_n((), list.len())
.enumerate()
.map(|(i, _)| i as u16)
.collect();
let mut copy = list.to_vec();
custom_sorted_indices(&mut indices, &mut copy);
(indices, copy)
}
fn custom_sorted_indices(indices: &mut [u16], list: &mut [u8]) {
let mut first = 0;
let mut last = list.len();
while first < last {
let mut i = first;
let mut j = last;
loop {
loop {
i += 1;
if i >= last {
break;
}
if list[first] <= list[i] {
break;
}
}
loop {
j -= 1;
if j <= first {
break;
}
if list[first] >= list[j] {
break;
}
}
if j > i {
list.swap(i, j);
indices.swap(i, j);
}
if j <= i {
break;
}
}
if first != j {
list.swap(first, j);
indices.swap(first, j);
i = j + 1;
let range: Range<usize>;
if last - i <= j - first {
range = i..last;
last = j;
} else {
range = first..j;
first = i;
};
custom_sorted_indices(&mut indices[range.clone()], &mut list[range]);
} else {
first += 1;
}
}
}