use crate::error::Error;
const BZ_MAX_ALPHA_SIZE: usize = 258;
const BZ_MAX_CODE_LEN: usize = 23;
const BZ_N_GROUPS: usize = 6;
const BZ_G_SIZE: usize = 50;
const BZ_MAX_SELECTORS: usize = 18002;
const MTFA_SIZE: usize = 4096;
const MTFL_SIZE: usize = 16;
const BLOCK_SIZE: usize = 900_000;
const BZ_RUNA: i32 = 0;
const BZ_RUNB: i32 = 1;
struct BitReader<'a> {
data: &'a [u8],
pos: usize,
buf: u32,
live: i32,
}
impl<'a> BitReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self {
data,
pos: 0,
buf: 0,
live: 0,
}
}
fn get_bits(&mut self, n: i32) -> Result<i32, Error> {
loop {
if self.live >= n {
let v = (self.buf >> (self.live - n)) & ((1 << n) - 1);
self.live -= n;
return Ok(v as i32);
}
if self.pos >= self.data.len() {
return Err(fail("unexpected end of input"));
}
self.buf = (self.buf << 8) | (self.data[self.pos] as u32);
self.live += 8;
self.pos += 1;
}
}
#[inline]
fn get_bit(&mut self) -> Result<i32, Error> {
self.get_bits(1)
}
#[inline]
fn get_u8(&mut self) -> Result<i32, Error> {
self.get_bits(8)
}
}
fn create_decode_tables(
limit: &mut [i32],
base: &mut [i32],
perm: &mut [i32],
length: &[u8],
min_len: i32,
max_len: i32,
alpha_size: usize,
) {
let mut pp = 0usize;
for i in min_len..=max_len {
for (j, &len_j) in length.iter().enumerate().take(alpha_size) {
if len_j as i32 == i {
perm[pp] = j as i32;
pp += 1;
}
}
}
for item in base.iter_mut().take(BZ_MAX_CODE_LEN) {
*item = 0;
}
for &len_j in length.iter().take(alpha_size) {
let idx = len_j as usize + 1;
if idx < BZ_MAX_CODE_LEN {
base[idx] += 1;
}
}
for i in 1..BZ_MAX_CODE_LEN {
base[i] += base[i - 1];
}
for item in limit.iter_mut().take(BZ_MAX_CODE_LEN) {
*item = 0;
}
let mut vec: i32 = 0;
for i in min_len..=max_len {
let iu = i as usize;
vec += base[iu + 1] - base[iu];
limit[iu] = vec - 1;
vec <<= 1;
}
for i in (min_len + 1)..=max_len {
let iu = i as usize;
base[iu] = ((limit[iu - 1] + 1) << 1) - base[iu];
}
}
#[inline]
fn bz_get_fast(tt: &[u32], t_pos: &mut u32) -> u8 {
*t_pos = tt[*t_pos as usize];
let ch = (*t_pos & 0xff) as u8;
*t_pos >>= 8;
ch
}
struct HuffmanTables {
selector: Vec<u8>,
min_lens: [i32; BZ_N_GROUPS],
limit: [[i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
perm: [[i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
base: [[i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
n_selectors: i32,
group_no: i32,
group_pos: i32,
}
pub fn decompress_bzip2(compressed: &[u8], max_output: usize) -> Result<Vec<u8>, Error> {
if compressed.is_empty() {
return Err(fail("empty input"));
}
let mut reader = BitReader::new(compressed);
let mut output = Vec::with_capacity(max_output.min(BLOCK_SIZE));
loop {
let header = reader.get_u8()?;
if header == 0x17 {
break;
}
if header != 0x31 {
return Err(fail(&format!(
"invalid block header 0x{:02X} (expected 0x31 or 0x17)",
header
)));
}
decompress_block(&mut reader, &mut output, max_output)?;
if output.len() >= max_output {
output.truncate(max_output);
break;
}
}
Ok(output)
}
fn decompress_block(
reader: &mut BitReader<'_>,
output: &mut Vec<u8>,
max_output: usize,
) -> Result<(), Error> {
let b0 = reader.get_u8()?;
let b1 = reader.get_u8()?;
let b2 = reader.get_u8()?;
let orig_ptr = (b0 << 16) | (b1 << 8) | b2;
if orig_ptr < 0 || orig_ptr > (10 + BLOCK_SIZE as i32) {
return Err(fail(&format!("origPtr out of range: {}", orig_ptr)));
}
let mut in_use16 = [false; 16];
for item in &mut in_use16 {
*item = reader.get_bit()? == 1;
}
let mut in_use = [false; 256];
for (i, &group_used) in in_use16.iter().enumerate() {
if group_used {
for j in 0..16 {
in_use[i * 16 + j] = reader.get_bit()? == 1;
}
}
}
let mut seq_to_unseq = [0u8; 256];
let mut n_in_use: usize = 0;
for (qi, &used) in in_use.iter().enumerate() {
if used {
seq_to_unseq[n_in_use] = qi as u8;
n_in_use += 1;
}
}
if n_in_use == 0 {
return Err(fail("no symbols in use"));
}
let alpha_size = n_in_use + 2;
let n_groups = reader.get_bits(3)?;
if !(2..=6).contains(&n_groups) {
return Err(fail(&format!("nGroups out of range: {}", n_groups)));
}
let n_groups = n_groups as usize;
let n_selectors = reader.get_bits(15)?;
if n_selectors < 1 {
return Err(fail("nSelectors < 1"));
}
let n_selectors = n_selectors as usize;
if n_selectors > BZ_MAX_SELECTORS {
return Err(fail(&format!("nSelectors too large: {}", n_selectors)));
}
let mut selector_mtf = vec![0u8; n_selectors];
for sel in selector_mtf.iter_mut() {
let mut j = 0;
loop {
let bit = reader.get_bit()?;
if bit == 0 {
break;
}
j += 1;
if j >= n_groups {
return Err(fail("selector MTF value >= nGroups"));
}
}
*sel = j as u8;
}
let mut selector = vec![0u8; n_selectors];
{
let mut pos = [0u8; BZ_N_GROUPS];
for (v, p) in pos.iter_mut().enumerate().take(n_groups) {
*p = v as u8;
}
for i in 0..n_selectors {
let v = selector_mtf[i] as usize;
let tmp = pos[v];
for k in (1..=v).rev() {
pos[k] = pos[k - 1];
}
pos[0] = tmp;
selector[i] = tmp;
}
}
let mut len = [[0u8; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS];
for table in len.iter_mut().take(n_groups) {
let mut curr = reader.get_bits(5)?;
for slot in table.iter_mut().take(alpha_size) {
loop {
if !(1..=20).contains(&curr) {
return Err(fail(&format!("code length out of range: {}", curr)));
}
let bit = reader.get_bit()?;
if bit == 0 {
break;
}
let bit2 = reader.get_bit()?;
if bit2 == 0 {
curr += 1;
} else {
curr -= 1;
}
}
*slot = curr as u8;
}
}
let mut huff = HuffmanTables {
selector,
min_lens: [0i32; BZ_N_GROUPS],
limit: [[0i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
perm: [[0i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
base: [[0i32; BZ_MAX_ALPHA_SIZE]; BZ_N_GROUPS],
n_selectors: n_selectors as i32,
group_no: -1,
group_pos: 0,
};
for (t, len_t) in len.iter().enumerate().take(n_groups) {
let mut min_len = 32i32;
let mut max_len = 0i32;
for &l in len_t.iter().take(alpha_size) {
let l = l as i32;
if l > max_len {
max_len = l;
}
if l < min_len {
min_len = l;
}
}
create_decode_tables(
&mut huff.limit[t],
&mut huff.base[t],
&mut huff.perm[t],
len_t,
min_len,
max_len,
alpha_size,
);
huff.min_lens[t] = min_len;
}
let eob = (n_in_use + 1) as i32;
let nblock_max = BLOCK_SIZE;
let mut unzftab = [0i32; 256];
let mut mtfa = [0u8; MTFA_SIZE];
let mut mtfbase = [0usize; 256 / MTFL_SIZE];
{
let mut kk = MTFA_SIZE - 1;
for ii in (0..(256 / MTFL_SIZE)).rev() {
for jj in (0..MTFL_SIZE).rev() {
mtfa[kk] = (ii * MTFL_SIZE + jj) as u8;
kk = kk.wrapping_sub(1);
}
mtfbase[ii] = kk.wrapping_add(1);
}
}
let mut tt = vec![0u32; nblock_max];
let mut nblock: usize = 0;
let mut next_sym = get_mtf_val(reader, &mut huff)?;
loop {
if next_sym == eob {
break;
}
if next_sym == BZ_RUNA || next_sym == BZ_RUNB {
let mut es: i32 = -1;
let mut n_power: i32 = 1;
while next_sym == BZ_RUNA || next_sym == BZ_RUNB {
if next_sym == BZ_RUNA {
es += n_power;
}
n_power <<= 1;
if next_sym == BZ_RUNB {
es += n_power;
}
next_sym = get_mtf_val(reader, &mut huff)?;
}
es += 1;
let uc = seq_to_unseq[mtfa[mtfbase[0]] as usize];
unzftab[uc as usize] += es;
let es = es as usize;
if nblock + es > nblock_max {
return Err(fail("block overflow during RLE expansion"));
}
for _ in 0..es {
tt[nblock] = uc as u32;
nblock += 1;
}
continue;
}
if nblock >= nblock_max {
return Err(fail("block overflow"));
}
let uc = mtf_decode(next_sym, &mut mtfa, &mut mtfbase)?;
let unseq = seq_to_unseq[uc as usize];
unzftab[unseq as usize] += 1;
tt[nblock] = unseq as u32;
nblock += 1;
next_sym = get_mtf_val(reader, &mut huff)?;
}
if orig_ptr < 0 || (orig_ptr as usize) >= nblock {
return Err(fail(&format!(
"origPtr {} out of range for nblock {}",
orig_ptr, nblock
)));
}
let mut cftab = [0i32; 257];
cftab[0] = 0;
for i in 1..=256 {
cftab[i] = unzftab[i - 1] + cftab[i - 1];
}
if cftab[256] != nblock as i32 {
return Err(fail(&format!(
"cftab inconsistency: cftab[256]={} but nblock={}",
cftab[256], nblock
)));
}
for i in 0..nblock {
let uc = (tt[i] & 0xff) as usize;
tt[cftab[uc] as usize] |= (i as u32) << 8;
cftab[uc] += 1;
}
let mut t_pos = tt[orig_ptr as usize] >> 8;
let mut nblock_used: usize = 0;
let mut k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
let mut state_out_len: i32 = 0;
let mut state_out_ch: u8 = 0;
while nblock_used <= nblock {
if output.len() >= max_output {
return Ok(());
}
if state_out_len > 0 {
let to_emit = state_out_len as usize;
let remaining = max_output - output.len();
let emit_count = to_emit.min(remaining);
for _ in 0..emit_count {
output.push(state_out_ch);
}
state_out_len -= emit_count as i32;
if state_out_len > 0 || output.len() >= max_output {
return Ok(());
}
continue;
}
state_out_ch = k0;
let mut count = 1;
if nblock_used < nblock {
k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
if k0 != state_out_ch {
output.push(state_out_ch);
continue;
}
count = 2;
if nblock_used < nblock {
k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
if k0 != state_out_ch {
output.push(state_out_ch);
if output.len() < max_output {
output.push(state_out_ch);
}
continue;
}
count = 3;
if nblock_used < nblock {
k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
if k0 != state_out_ch {
for _ in 0..3 {
if output.len() < max_output {
output.push(state_out_ch);
}
}
continue;
}
count = 4;
if nblock_used < nblock {
k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
state_out_len = k0 as i32 + count;
if nblock_used < nblock {
k0 = bz_get_fast(&tt, &mut t_pos);
nblock_used += 1;
}
continue;
}
}
}
}
for _ in 0..count {
if output.len() < max_output {
output.push(state_out_ch);
}
}
}
while state_out_len > 0 && output.len() < max_output {
output.push(state_out_ch);
state_out_len -= 1;
}
Ok(())
}
fn get_mtf_val(reader: &mut BitReader<'_>, huff: &mut HuffmanTables) -> Result<i32, Error> {
if huff.group_pos == 0 {
huff.group_no += 1;
if huff.group_no >= huff.n_selectors {
return Err(fail("ran out of selectors"));
}
huff.group_pos = BZ_G_SIZE as i32;
}
huff.group_pos -= 1;
let g_sel = huff.selector[huff.group_no as usize] as usize;
let g_min_len = huff.min_lens[g_sel];
let g_limit = &huff.limit[g_sel];
let g_perm = &huff.perm[g_sel];
let g_base = &huff.base[g_sel];
let mut zn = g_min_len;
let mut zvec = reader.get_bits(zn)?;
loop {
if zn > 20 {
return Err(fail("Huffman code length exceeds 20"));
}
if zvec <= g_limit[zn as usize] {
break;
}
zn += 1;
let zj = reader.get_bit()?;
zvec = (zvec << 1) | zj;
}
let idx = zvec - g_base[zn as usize];
if idx < 0 || idx >= BZ_MAX_ALPHA_SIZE as i32 {
return Err(fail("Huffman decoded index out of range"));
}
Ok(g_perm[idx as usize])
}
fn mtf_decode(
next_sym: i32,
mtfa: &mut [u8; MTFA_SIZE],
mtfbase: &mut [usize; 256 / MTFL_SIZE],
) -> Result<u8, Error> {
let nn = (next_sym - 1) as usize;
if nn < MTFL_SIZE {
let pp = mtfbase[0];
let uc = mtfa[pp + nn];
let mut pos = nn;
while pos > 0 {
mtfa[pp + pos] = mtfa[pp + pos - 1];
pos -= 1;
}
mtfa[pp] = uc;
Ok(uc)
} else {
let lno_init = nn / MTFL_SIZE;
let off = nn % MTFL_SIZE;
let mut pp = mtfbase[lno_init] + off;
let uc = mtfa[pp];
while pp > mtfbase[lno_init] {
mtfa[pp] = mtfa[pp - 1];
pp -= 1;
}
mtfbase[lno_init] += 1;
let mut lno = lno_init;
while lno > 0 {
mtfbase[lno] -= 1;
mtfa[mtfbase[lno]] = mtfa[mtfbase[lno - 1] + MTFL_SIZE - 1];
lno -= 1;
}
mtfbase[0] -= 1;
mtfa[mtfbase[0]] = uc;
if mtfbase[0] == 0 {
let mut kk = MTFA_SIZE - 1;
for ii in (0..(256 / MTFL_SIZE)).rev() {
for jj in (0..MTFL_SIZE).rev() {
mtfa[kk] = mtfa[mtfbase[ii] + jj];
kk = kk.wrapping_sub(1);
}
mtfbase[ii] = kk.wrapping_add(1);
}
}
Ok(uc)
}
}
fn fail(detail: &str) -> Error {
Error::DecompressionFailed {
method: "bzip2",
detail: detail.to_string(),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn empty_input_fails() {
let result = decompress_bzip2(&[], 1024);
assert!(result.is_err());
}
#[test]
fn invalid_block_header_fails() {
let result = decompress_bzip2(&[0xFF], 1024);
assert!(result.is_err());
let err = result.unwrap_err();
match err {
Error::DecompressionFailed { method, detail } => {
assert_eq!(method, "bzip2");
assert!(detail.contains("invalid block header"));
}
_ => panic!("expected DecompressionFailed"),
}
}
#[test]
fn end_of_stream_produces_empty() {
let result = decompress_bzip2(&[0x17], 1024);
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn create_decode_tables_basic() {
let length = [2u8, 1, 2];
let mut limit_arr = [0i32; BZ_MAX_ALPHA_SIZE];
let mut base_arr = [0i32; BZ_MAX_ALPHA_SIZE];
let mut perm_arr = [0i32; BZ_MAX_ALPHA_SIZE];
create_decode_tables(
&mut limit_arr,
&mut base_arr,
&mut perm_arr,
&length,
1,
2,
3,
);
assert_eq!(perm_arr[0], 1);
assert_eq!(perm_arr[1], 0);
assert_eq!(perm_arr[2], 2);
}
#[test]
fn bit_reader_reads_bits() {
let data = [0b10110000, 0b01010000];
let mut r = BitReader::new(&data);
assert_eq!(r.get_bits(4).unwrap(), 0b1011);
assert_eq!(r.get_bits(4).unwrap(), 0b0000);
assert_eq!(r.get_bits(1).unwrap(), 0);
assert_eq!(r.get_bits(1).unwrap(), 1);
assert_eq!(r.get_bits(1).unwrap(), 0);
assert_eq!(r.get_bits(1).unwrap(), 1);
}
#[test]
fn bit_reader_eof() {
let data = [0xFF];
let mut r = BitReader::new(&data);
assert!(r.get_bits(8).is_ok());
assert!(r.get_bits(1).is_err());
}
}