extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::bits::BitReader;
use super::bwt::bwt_inverse;
use super::crc::Crc32;
use super::huffman::{DecodeTable, MAX_CODE_LEN};
use super::mtf::mtf_inverse_reduced;
use super::rle::rle1_inverse;
const BLOCK_MAGIC: u64 = 0x3141_5926_5359;
const STREAM_END_MAGIC: u64 = 0x1772_4538_5090;
#[derive(Clone, Copy, PartialEq, Eq)]
enum Phase {
Header,
BlockOrEnd,
DrainDecoded,
StreamCrc,
Done,
}
pub struct Decoder {
in_buf: Vec<u8>,
in_committed_bytes: usize,
bit_pos: usize,
decoded: Vec<u8>,
decoded_idx: usize,
combined_crc: u32,
level: u8,
phase: Phase,
poisoned: bool,
}
impl Decoder {
pub fn new() -> Self {
Self {
in_buf: Vec::new(),
in_committed_bytes: 0,
bit_pos: 0,
decoded: Vec::new(),
decoded_idx: 0,
combined_crc: 0,
level: 0,
phase: Phase::Header,
poisoned: false,
}
}
fn poison(&mut self, e: Error) -> Error {
self.poisoned = true;
e
}
fn step(&mut self) -> Result<bool, Error> {
match self.phase {
Phase::Header => self.try_header(),
Phase::BlockOrEnd => self.try_block_or_end(),
Phase::DrainDecoded => Ok(false), Phase::StreamCrc => self.try_stream_crc(),
Phase::Done => Ok(false),
}
}
fn try_header(&mut self) -> Result<bool, Error> {
let buffered = self.in_buf.len() - self.in_committed_bytes;
if buffered < 4 {
return Ok(false);
}
let off = self.in_committed_bytes;
if &self.in_buf[off..off + 3] != b"BZh" {
return Err(self.poison(Error::BadHeader));
}
let lvl = self.in_buf[off + 3];
if !(b'1'..=b'9').contains(&lvl) {
return Err(self.poison(Error::BadHeader));
}
self.level = lvl - b'0';
self.in_committed_bytes += 4;
self.bit_pos = self.in_committed_bytes * 8;
self.phase = Phase::BlockOrEnd;
Ok(true)
}
fn try_block_or_end(&mut self) -> Result<bool, Error> {
let available_bits = self.in_buf.len() * 8 - self.bit_pos;
if available_bits < 48 {
return Ok(false);
}
let snapshot = self.bit_pos;
let mut br = BitReader::new_at(&self.in_buf, self.bit_pos);
let magic = br.read_bits_48()?;
if magic == BLOCK_MAGIC {
self.bit_pos = br.position();
match self.decode_block_payload() {
Ok(()) => {
self.phase = Phase::DrainDecoded;
Ok(true)
}
Err(Error::UnexpectedEnd) => {
self.bit_pos = snapshot;
Ok(false)
}
Err(e) => Err(self.poison(e)),
}
} else if magic == STREAM_END_MAGIC {
self.bit_pos = br.position();
self.phase = Phase::StreamCrc;
Ok(true)
} else {
Err(self.poison(Error::BadHeader))
}
}
fn try_stream_crc(&mut self) -> Result<bool, Error> {
let available_bits = self.in_buf.len() * 8 - self.bit_pos;
if available_bits < 32 {
return Ok(false);
}
let mut br = BitReader::new_at(&self.in_buf, self.bit_pos);
let expected = br.read_bits(32)?;
if expected != self.combined_crc {
return Err(self.poison(Error::ChecksumMismatch));
}
let mut p = br.position();
let rem = p & 7;
if rem != 0 {
p += 8 - rem;
}
self.bit_pos = p;
self.in_committed_bytes = self.bit_pos / 8;
self.phase = Phase::Done;
Ok(true)
}
fn decode_block_payload(&mut self) -> Result<(), Error> {
let mut br = BitReader::new_at(&self.in_buf, self.bit_pos);
let stored_crc = br.read_bits(32)?;
let randomized = br.read_bit()?;
if randomized != 0 {
return Err(Error::Unsupported);
}
let origin = br.read_bits(24)?;
let stripe_top = br.read_bits(16)?;
let mut alphabet: Vec<u8> = Vec::with_capacity(64);
for stripe in 0..16 {
let stripe_used = stripe_top & (1 << (15 - stripe)) != 0;
if !stripe_used {
continue;
}
let mask = br.read_bits(16)?;
for byte in 0..16 {
if mask & (1 << (15 - byte)) != 0 {
alphabet.push(((stripe << 4) | byte) as u8);
}
}
}
if alphabet.is_empty() {
return Err(Error::Corrupt);
}
let num_used = alphabet.len();
let alpha_size = num_used + 2;
let num_tables = br.read_bits(3)? as usize;
if !(2..=6).contains(&num_tables) {
return Err(Error::Corrupt);
}
let num_selectors = br.read_bits(15)? as usize;
if num_selectors == 0 || num_selectors > 18002 {
return Err(Error::Corrupt);
}
let mut mtf_list: Vec<u8> = (0..num_tables as u8).collect();
let mut selectors: Vec<u8> = Vec::with_capacity(num_selectors);
for _ in 0..num_selectors {
let mut pos = 0;
loop {
if pos >= num_tables {
return Err(Error::Corrupt);
}
let bit = br.read_bit()?;
if bit == 0 {
break;
}
pos += 1;
}
let v = mtf_list.remove(pos);
selectors.push(v);
mtf_list.insert(0, v);
}
let mut tables: Vec<DecodeTable> = Vec::with_capacity(num_tables);
for _ in 0..num_tables {
let mut cur = br.read_bits(5)? as i32;
if !(1..=(MAX_CODE_LEN as i32)).contains(&cur) {
return Err(Error::Corrupt);
}
let mut lens = vec![0u8; alpha_size];
for symbol_len in lens.iter_mut().take(alpha_size) {
loop {
let b = br.read_bit()?;
if b == 0 {
break;
}
let dir = br.read_bit()?;
if dir == 0 {
cur += 1;
} else {
cur -= 1;
}
if !(1..=(MAX_CODE_LEN as i32)).contains(&cur) {
return Err(Error::Corrupt);
}
}
*symbol_len = cur as u8;
}
tables.push(DecodeTable::from_lengths(&lens)?);
}
let eob = (alpha_size - 1) as u16;
let mut mtf_indices: Vec<u8> = Vec::new();
let mut group_idx = 0usize;
let mut symbols_in_group = 0usize;
let mut zero_run: u32 = 0;
let mut zero_weight: u32 = 1;
loop {
if symbols_in_group == 0 {
if group_idx >= num_selectors {
return Err(Error::Corrupt);
}
symbols_in_group = 50;
}
let sel = selectors[group_idx] as usize;
if sel >= num_tables {
return Err(Error::Corrupt);
}
let tbl = &tables[sel];
let s = tbl.decode_symbol(&mut br)?;
symbols_in_group -= 1;
if symbols_in_group == 0 {
group_idx += 1;
}
if s == eob {
break;
}
if s <= 1 {
let contrib = if s == 0 { 1 } else { 2 };
zero_run = zero_run.saturating_add(contrib * zero_weight);
zero_weight = zero_weight.saturating_mul(2);
if zero_run as usize > 900_000 * 9 + 1024 {
return Err(Error::Corrupt);
}
} else {
if zero_run > 0 {
mtf_indices.extend(core::iter::repeat_n(0u8, zero_run as usize));
zero_run = 0;
zero_weight = 1;
}
let idx = (s - 1) as usize;
if idx >= num_used {
return Err(Error::Corrupt);
}
mtf_indices.push(idx as u8);
}
}
if zero_run > 0 {
mtf_indices.extend(core::iter::repeat_n(0u8, zero_run as usize));
}
let l_column = mtf_inverse_reduced(&mtf_indices, &alphabet);
if origin as usize >= l_column.len() {
return Err(Error::Corrupt);
}
let bwt = bwt_inverse(&l_column, origin);
let raw = rle1_inverse(&bwt);
let mut crc = Crc32::new();
crc.update(&raw);
if crc.value() != stored_crc {
return Err(Error::ChecksumMismatch);
}
self.combined_crc = self.combined_crc.rotate_left(1) ^ stored_crc;
self.decoded.extend_from_slice(&raw);
self.bit_pos = br.position();
self.in_committed_bytes = self.bit_pos / 8;
Ok(())
}
fn drain(&mut self, output: &mut [u8], written: &mut usize) {
let avail = self.decoded.len() - self.decoded_idx;
let space = output.len() - *written;
let n = avail.min(space);
if n > 0 {
output[*written..*written + n]
.copy_from_slice(&self.decoded[self.decoded_idx..self.decoded_idx + n]);
self.decoded_idx += n;
*written += n;
}
if self.decoded_idx == self.decoded.len() {
self.decoded.clear();
self.decoded_idx = 0;
if matches!(self.phase, Phase::DrainDecoded) {
self.phase = Phase::BlockOrEnd;
}
}
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
self.drain(output, &mut written);
loop {
if matches!(self.phase, Phase::Done) && self.decoded_idx == self.decoded.len() {
return Ok(RawProgress {
consumed,
written,
done: true,
});
}
if consumed < input.len() {
self.in_buf.extend_from_slice(&input[consumed..]);
consumed = input.len();
}
let progressed = self.step()?;
self.drain(output, &mut written);
if self.in_committed_bytes > 1 << 20 {
let off = self.in_committed_bytes;
self.in_buf.drain(..off);
self.bit_pos -= off * 8;
self.in_committed_bytes = 0;
}
if matches!(self.phase, Phase::Done) {
continue;
}
if written == output.len() && self.decoded_idx < self.decoded.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
if !progressed {
if consumed >= input.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
if written == output.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let empty: [u8; 0] = [];
let p = self.raw_decode(&empty, output)?;
if matches!(self.phase, Phase::Done) && self.decoded_idx == self.decoded.len() {
Ok(RawProgress {
consumed: 0,
written: p.written,
done: true,
})
} else {
Err(self.poison(Error::UnexpectedEnd))
}
}
fn raw_reset(&mut self) {
self.in_buf.clear();
self.in_committed_bytes = 0;
self.bit_pos = 0;
self.decoded.clear();
self.decoded_idx = 0;
self.combined_crc = 0;
self.level = 0;
self.phase = Phase::Header;
self.poisoned = false;
}
}