use crate::bitreader::BitReader;
use crate::bwt;
use crate::huffman::HuffmanTree;
use crate::mtf::MtfDecoder;
use std::cell::RefCell;
const MAX_BLOCKSIZE: usize = 900_000;
const MAX_SELECTORS: usize = 18002;
thread_local! {
static TT_BUF: RefCell<Vec<u32>> = const { RefCell::new(Vec::new()) };
}
fn take_tt_buffer(capacity: usize) -> Vec<u32> {
TT_BUF.with(|cell| {
let mut slot = cell.borrow_mut();
let mut buf = std::mem::take(&mut *slot);
buf.clear();
if buf.capacity() < capacity {
buf.reserve(capacity - buf.len());
}
buf
})
}
fn return_tt_buffer(buf: Vec<u32>) {
TT_BUF.with(|cell| {
*cell.borrow_mut() = buf;
});
}
#[derive(Debug)]
pub struct BlockError(pub &'static str);
impl std::fmt::Display for BlockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "bzip2 block error: {}", self.0)
}
}
impl std::error::Error for BlockError {}
pub fn decode_block(reader: &mut BitReader<'_>, max_blocksize: u32) -> Result<Vec<u8>, BlockError> {
let _crc = reader.read_u32(32)
.ok_or(BlockError("block CRC truncated"))?;
let randomised = reader.read_bit()
.ok_or(BlockError("randomised flag truncated"))?;
if randomised {
return Err(BlockError("randomised blocks not supported"));
}
let orig_ptr = reader.read_u32(24)
.ok_or(BlockError("orig_ptr truncated"))? as usize;
let mut used_bytes = [0u8; 256];
let mut n_used: usize = 0;
let mut ranges_present = [false; 16];
for range in &mut ranges_present {
*range = reader.read_bit()
.ok_or(BlockError("symbol range truncated"))?;
}
for (range_idx, &present) in ranges_present.iter().enumerate() {
if !present { continue; }
for sub in 0..16u8 {
if reader.read_bit().ok_or(BlockError("symbol bitmap truncated"))? {
used_bytes[n_used] = range_idx as u8 * 16 + sub;
n_used += 1;
}
}
}
if n_used == 0 {
return Err(BlockError("no symbols in block"));
}
let n_symbols = n_used + 2;
let n_groups = reader.read_u8(3)
.ok_or(BlockError("huffman groups truncated"))?;
if n_groups < 2 || n_groups > 6 {
return Err(BlockError("invalid number of huffman groups"));
}
let n_selectors = reader.read_u16(15)
.ok_or(BlockError("selectors_used truncated"))? as usize;
if n_selectors > MAX_SELECTORS {
return Err(BlockError("too many selectors"));
}
let mut selectors = [0u8; MAX_SELECTORS];
let mut sel_mtf = MtfDecoder::new();
for i in 0..n_selectors {
let mut trees = 0u8;
while reader.read_bit().ok_or(BlockError("selector bit truncated"))? {
trees += 1;
if trees >= n_groups {
return Err(BlockError("selector tree index too large"));
}
}
selectors[i] = sel_mtf.decode(trees);
}
let mut trees = [const { HuffmanTree::empty() }; 6];
let mut n_trees: usize = 0;
for _ in 0..n_groups {
let mut length = reader.read_u8(5)
.ok_or(BlockError("huffman start length truncated"))? as i32;
let mut lengths = [0u8; 258];
for j in 0..n_symbols {
loop {
if length < 1 || length > 20 {
return Err(BlockError("huffman code length out of range"));
}
if !reader.read_bit().ok_or(BlockError("length adjust bit1 truncated"))? {
break;
}
if reader.read_bit().ok_or(BlockError("length adjust bit2 truncated"))? {
length -= 1;
} else {
length += 1;
}
}
lengths[j] = length as u8;
}
trees[n_trees] = HuffmanTree::from_lengths(&lengths[..n_symbols])
.map_err(|_| BlockError("invalid huffman tree"))?;
n_trees += 1;
}
let mut tt: Vec<u32> = take_tt_buffer(max_blocksize as usize);
let mut c = [0u32; 256];
let mut byte_symbols = [0u8; 256];
byte_symbols[..n_used].copy_from_slice(&used_bytes[..n_used]);
let mut mtf = MtfDecoder::with_symbols(byte_symbols);
let mut sel_idx: usize = 0;
let mut current_tree = &trees[
selectors[0] as usize
];
let mut repeat: u32 = 0;
let mut repeat_power: u32 = 0;
let eob_symbol = (n_symbols - 1) as u16;
'outer: loop {
for _ in 0..50 {
let sym = current_tree.decode(reader)
.ok_or(BlockError("huffman bitstream truncated"))?;
if sym < 2 {
if repeat == 0 {
repeat_power = 1;
}
repeat += repeat_power << sym;
repeat_power <<= 1;
if repeat as usize > MAX_BLOCKSIZE {
return Err(BlockError("repeat count too large"));
}
continue;
}
if repeat > 0 {
let b = mtf.first();
if tt.len() + repeat as usize > max_blocksize as usize {
return Err(BlockError("data exceeds block size"));
}
let new_len = tt.len() + repeat as usize;
tt.resize(new_len, u32::from(b));
c[b as usize] += repeat;
repeat = 0;
}
if sym == eob_symbol {
break 'outer;
}
let b = mtf.decode((sym - 1) as u8);
if tt.len() >= max_blocksize as usize {
return Err(BlockError("data exceeds block size"));
}
tt.push(u32::from(b));
c[b as usize] += 1;
}
sel_idx += 1;
if sel_idx >= n_selectors {
return Err(BlockError("ran out of selectors"));
}
let sel = selectors[sel_idx] as usize;
if sel >= n_trees {
return Err(BlockError("selector out of range"));
}
current_tree = &trees[sel];
}
if orig_ptr >= tt.len() {
return Err(BlockError("orig_ptr out of bounds"));
}
let mut t_pos = bwt::inverse_bwt(&mut tt, orig_ptr, c);
let n = tt.len();
let out_cap = n + n / 4;
let mut output = Vec::<u8>::with_capacity(out_cap);
let mut out_len: usize = 0;
let mut last_byte: u8 = 0;
let mut has_last = false;
let mut byte_repeats: u8 = 0;
let tt_ptr = tt.as_ptr();
for _ in 0..n {
let entry = unsafe { *tt_ptr.add(t_pos as usize) };
let b = entry as u8;
t_pos = entry >> 8;
let next_entry = unsafe { *tt_ptr.add(t_pos as usize) };
#[cfg(target_arch = "x86_64")]
unsafe {
std::arch::x86_64::_mm_prefetch(
tt_ptr.add((next_entry >> 8) as usize) as *const i8,
std::arch::x86_64::_MM_HINT_T0,
);
}
if byte_repeats == 3 {
let count = b as usize;
if out_len + count > output.capacity() {
unsafe { output.set_len(out_len); }
output.reserve(count);
}
unsafe {
std::ptr::write_bytes(output.as_mut_ptr().add(out_len), last_byte, count);
}
out_len += count;
byte_repeats = 0;
has_last = false;
continue;
}
if has_last && last_byte == b {
byte_repeats += 1;
} else {
byte_repeats = 0;
}
last_byte = b;
has_last = true;
unsafe { *output.as_mut_ptr().add(out_len) = b; }
out_len += 1;
}
unsafe { output.set_len(out_len); }
return_tt_buffer(tt);
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_known_block() {
let compressed = include_bytes!("../test_data/hello.bz2");
let expected = b"Hello, World!\n";
assert_eq!(&compressed[..3], b"BZh");
let level = compressed[3] - b'0';
let max_blocksize = 100_000 * level as u32;
let mut reader = BitReader::from_bit_offset(compressed, 4 * 8);
let magic = reader.read_u64(48).unwrap();
assert_eq!(magic, crate::BLOCK_MAGIC, "expected block magic");
let output = decode_block(&mut reader, max_blocksize).unwrap();
assert_eq!(&output, expected);
}
}