use crate::bitreader::BitReader;
use crate::bwt;
use crate::huffman::HuffmanTree;
use crate::mtf::MtfDecoder;
const MAX_BLOCKSIZE: usize = 900_000;
#[derive(Debug)]
pub struct BlockError(pub &'static str);
impl std::fmt::Display for BlockError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "bzip2 block error: {}", self.0)
}
}
impl std::error::Error for BlockError {}
pub fn decode_block(reader: &mut BitReader<'_>, max_blocksize: u32) -> Result<Vec<u8>, BlockError> {
let _crc = reader.read_u32(32)
.ok_or(BlockError("block CRC truncated"))?;
let randomised = reader.read_bit()
.ok_or(BlockError("randomised flag truncated"))?;
if randomised {
return Err(BlockError("randomised blocks not supported"));
}
let orig_ptr = reader.read_u32(24)
.ok_or(BlockError("orig_ptr truncated"))? as usize;
let mut used_bytes: Vec<u8> = Vec::new();
let mut ranges_present = [false; 16];
for range in &mut ranges_present {
*range = reader.read_bit()
.ok_or(BlockError("symbol range truncated"))?;
}
for (range_idx, &present) in ranges_present.iter().enumerate() {
if !present { continue; }
for sub in 0..16u8 {
if reader.read_bit().ok_or(BlockError("symbol bitmap truncated"))? {
used_bytes.push(range_idx as u8 * 16 + sub);
}
}
}
if used_bytes.is_empty() {
return Err(BlockError("no symbols in block"));
}
let n_symbols = used_bytes.len() + 2;
let n_groups = reader.read_u8(3)
.ok_or(BlockError("huffman groups truncated"))?;
if n_groups < 2 || n_groups > 6 {
return Err(BlockError("invalid number of huffman groups"));
}
let n_selectors = reader.read_u16(15)
.ok_or(BlockError("selectors_used truncated"))? as usize;
let mut selectors = Vec::with_capacity(n_selectors);
let mut sel_mtf = MtfDecoder::new();
for _ in 0..n_selectors {
let mut trees = 0u8;
while reader.read_bit().ok_or(BlockError("selector bit truncated"))? {
trees += 1;
if trees >= n_groups {
return Err(BlockError("selector tree index too large"));
}
}
selectors.push(sel_mtf.decode(trees));
}
let mut trees: Vec<HuffmanTree> = Vec::with_capacity(n_groups as usize);
for _ in 0..n_groups {
let mut length = reader.read_u8(5)
.ok_or(BlockError("huffman start length truncated"))? as i32;
let mut lengths = Vec::with_capacity(n_symbols);
for _ in 0..n_symbols {
loop {
if length < 1 || length > 20 {
return Err(BlockError("huffman code length out of range"));
}
if !reader.read_bit().ok_or(BlockError("length adjust bit1 truncated"))? {
break;
}
if reader.read_bit().ok_or(BlockError("length adjust bit2 truncated"))? {
length -= 1;
} else {
length += 1;
}
}
lengths.push(length as u8);
}
trees.push(HuffmanTree::from_lengths(&lengths)
.map_err(|_| BlockError("invalid huffman tree"))?);
}
let mut tt: Vec<u32> = Vec::with_capacity(max_blocksize as usize);
let mut c = [0u32; 256];
let mut byte_symbols = [0u8; 256];
byte_symbols[..used_bytes.len()].copy_from_slice(&used_bytes);
let mut mtf = MtfDecoder::with_symbols(byte_symbols);
let mut sel_idx: usize = 0;
let mut decoded_in_group: usize = 0;
let mut current_tree = trees.get(
*selectors.first().ok_or(BlockError("no selectors"))? as usize
).ok_or(BlockError("selector out of range"))?;
let mut repeat: u32 = 0;
let mut repeat_power: u32 = 0;
let eob_symbol = (n_symbols - 1) as u16;
loop {
if decoded_in_group == 50 {
sel_idx += 1;
let sel = *selectors.get(sel_idx)
.ok_or(BlockError("ran out of selectors"))? as usize;
current_tree = trees.get(sel)
.ok_or(BlockError("selector out of range"))?;
decoded_in_group = 0;
}
let sym = current_tree.decode(reader)
.ok_or(BlockError("huffman bitstream truncated"))?;
decoded_in_group += 1;
if sym < 2 {
if repeat == 0 {
repeat_power = 1;
}
repeat += repeat_power << sym;
repeat_power <<= 1;
if repeat as usize > MAX_BLOCKSIZE {
return Err(BlockError("repeat count too large"));
}
continue;
}
if repeat > 0 {
let b = mtf.first();
if tt.len() + repeat as usize > max_blocksize as usize {
return Err(BlockError("data exceeds block size"));
}
let new_len = tt.len() + repeat as usize;
tt.resize(new_len, u32::from(b));
c[b as usize] += repeat;
repeat = 0;
}
if sym == eob_symbol {
break;
}
let b = mtf.decode((sym - 1) as u8);
if tt.len() >= max_blocksize as usize {
return Err(BlockError("data exceeds block size"));
}
tt.push(u32::from(b));
c[b as usize] += 1;
}
if orig_ptr >= tt.len() {
return Err(BlockError("orig_ptr out of bounds"));
}
let mut t_pos = bwt::inverse_bwt(&mut tt, orig_ptr, c);
let mut output = Vec::with_capacity(tt.len());
let mut last_byte: i16 = -1;
let mut byte_repeats: u8 = 0;
let n = tt.len();
let tt_ptr = tt.as_ptr();
for _ in 0..n {
let entry = unsafe { *tt_ptr.add(t_pos as usize) };
let b = entry as u8;
t_pos = entry >> 8;
if byte_repeats == 3 {
let count = b as usize;
let lb = last_byte as u8;
for _ in 0..count {
output.push(lb);
}
byte_repeats = 0;
last_byte = -1;
continue;
}
if last_byte == i16::from(b) {
byte_repeats += 1;
} else {
byte_repeats = 0;
}
last_byte = i16::from(b);
output.push(b);
}
Ok(output)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn decode_known_block() {
let compressed = include_bytes!("../test_data/hello.bz2");
let expected = b"Hello, World!\n";
assert_eq!(&compressed[..3], b"BZh");
let level = compressed[3] - b'0';
let max_blocksize = 100_000 * level as u32;
let mut reader = BitReader::from_bit_offset(compressed, 4 * 8);
let magic = reader.read_u64(48).unwrap();
assert_eq!(magic, crate::BLOCK_MAGIC, "expected block magic");
let output = decode_block(&mut reader, max_blocksize).unwrap();
assert_eq!(&output, expected);
}
}