extern crate alloc;
use alloc::boxed::Box;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::huffman::{DecodeTable, NUM_SYMBOLS, build_decode_table, unpack_lengths};
const BLOCK_OUTPUT_BYTES: usize = 65536;
const MAX_DISTANCE: usize = 65536;
const TABLE_BYTES: usize = 256;
const MIN_BLOCK_BYTES: usize = TABLE_BYTES + 4;
#[derive(Clone, Copy, PartialEq, Eq, Debug)]
enum Phase {
Header,
Decoding,
DrainTail,
Done,
}
pub struct Decoder {
in_buf: Vec<u8>,
in_pos: usize,
decoded: Vec<u8>,
decoded_idx: usize,
out_history: Vec<u8>,
phase: Phase,
poisoned: bool,
total_output: u64,
output_emitted: u64,
table: Option<Box<DecodeTable>>,
lengths: [u8; NUM_SYMBOLS],
next_bits: u32,
extra_bit_count: i32,
block_end_emitted: u64,
finishing: bool,
}
impl Decoder {
pub fn new() -> Self {
Self {
in_buf: Vec::new(),
in_pos: 0,
decoded: Vec::new(),
decoded_idx: 0,
out_history: Vec::new(),
phase: Phase::Header,
poisoned: false,
total_output: 0,
output_emitted: 0,
table: None,
lengths: [0u8; NUM_SYMBOLS],
next_bits: 0,
extra_bit_count: 0,
block_end_emitted: 0,
finishing: false,
}
}
fn poison(&mut self, e: Error) -> Error {
self.poisoned = true;
e
}
fn drain_decoded_into(&mut self, output: &mut [u8]) -> usize {
let avail = self.decoded.len() - self.decoded_idx;
let n = avail.min(output.len());
output[..n].copy_from_slice(&self.decoded[self.decoded_idx..self.decoded_idx + n]);
self.decoded_idx += n;
if self.decoded_idx == self.decoded.len() {
self.decoded.clear();
self.decoded_idx = 0;
}
n
}
fn emit_byte(&mut self, b: u8) {
self.decoded.push(b);
self.out_history.push(b);
if self.out_history.len() > MAX_DISTANCE {
let drop = self.out_history.len() - MAX_DISTANCE;
self.out_history.drain(0..drop);
}
self.output_emitted += 1;
}
fn try_start_block(&mut self) -> Result<bool, Error> {
let available = self.in_buf.len() - self.in_pos;
if available < MIN_BLOCK_BYTES {
return Ok(false);
}
let mut packed = [0u8; TABLE_BYTES];
packed.copy_from_slice(&self.in_buf[self.in_pos..self.in_pos + TABLE_BYTES]);
self.in_pos += TABLE_BYTES;
self.lengths = unpack_lengths(&packed);
self.table = Some(build_decode_table(&self.lengths)?);
let lo_a = self.in_buf[self.in_pos] as u32;
let hi_a = self.in_buf[self.in_pos + 1] as u32;
let lo_b = self.in_buf[self.in_pos + 2] as u32;
let hi_b = self.in_buf[self.in_pos + 3] as u32;
self.in_pos += 4;
let word_a = (hi_a << 8) | lo_a;
let word_b = (hi_b << 8) | lo_b;
self.next_bits = (word_a << 16) | word_b;
self.extra_bit_count = 16;
self.block_end_emitted = self.output_emitted + BLOCK_OUTPUT_BYTES as u64;
Ok(true)
}
fn read_word(&mut self) -> Option<u16> {
if self.in_pos + 2 > self.in_buf.len() {
return None;
}
let lo = self.in_buf[self.in_pos] as u16;
let hi = self.in_buf[self.in_pos + 1] as u16;
self.in_pos += 2;
Some((hi << 8) | lo)
}
fn read_byte(&mut self) -> Option<u8> {
if self.in_pos >= self.in_buf.len() {
return None;
}
let b = self.in_buf[self.in_pos];
self.in_pos += 1;
Some(b)
}
fn decode_loop(&mut self, cap: usize) -> Result<(), Error> {
loop {
if self.output_emitted == self.total_output {
self.phase = Phase::DrainTail;
return Ok(());
}
if self.decoded.len() - self.decoded_idx >= cap {
return Ok(());
}
if self.output_emitted >= self.block_end_emitted {
self.table = None;
if !self.try_start_block()? {
return Ok(());
}
continue;
}
let table = self
.table
.as_ref()
.expect("decode_loop entered without a primed block");
let idx = (self.next_bits >> (32 - 15)) as usize;
let (symbol, len) = table[idx];
let len = len as u32;
self.next_bits = self.next_bits.wrapping_shl(len);
self.extra_bit_count -= len as i32;
if self.extra_bit_count < 0 {
let w: u32 = match self.read_word() {
Some(w) => w as u32,
None if self.finishing || self.output_emitted == self.total_output => {
0
}
None => return Err(Error::UnexpectedEnd),
};
self.next_bits |= w << (-self.extra_bit_count);
self.extra_bit_count += 16;
}
if (symbol as usize) < 256 {
if self.output_emitted >= self.total_output {
continue;
}
self.emit_byte(symbol as u8);
continue;
}
let match_sym = symbol as usize - 256;
let length_class = (match_sym & 15) as u32;
let dist_hi = (match_sym >> 4) as u32;
let length_short_or_escape = length_class;
let mut match_length: u32 = length_short_or_escape;
if length_short_or_escape == 15 {
let byte = self.read_byte().ok_or(Error::UnexpectedEnd)?;
let mut len_extra = byte as u32;
if len_extra == 255 {
let word = self.read_word().ok_or(Error::UnexpectedEnd)?;
let word = word as u32;
if word < 15 {
return Err(self.poison(Error::Corrupt));
}
len_extra = word - 15;
}
match_length = len_extra + 15;
}
match_length += 3;
if symbol == 256 && self.output_emitted == self.total_output {
continue;
}
let dist_low: u32 = if dist_hi == 0 {
0
} else {
let v = self.next_bits >> (32 - dist_hi);
self.next_bits = self.next_bits.wrapping_shl(dist_hi);
self.extra_bit_count -= dist_hi as i32;
if self.extra_bit_count < 0 {
let w = self.read_word().ok_or(Error::UnexpectedEnd)? as u32;
self.next_bits |= w << (-self.extra_bit_count);
self.extra_bit_count += 16;
}
v
};
let match_offset = dist_low + (1u32 << dist_hi);
if (match_offset as usize) > self.out_history.len() {
return Err(self.poison(Error::InvalidDistance));
}
for _ in 0..match_length {
let src = self.out_history.len() - match_offset as usize;
let b = self.out_history[src];
self.emit_byte(b);
if self.output_emitted >= self.block_end_emitted {
break;
}
if self.output_emitted == self.total_output {
break;
}
}
}
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed_in = 0usize;
let mut written = 0usize;
written += self.drain_decoded_into(&mut output[written..]);
if written == output.len() && self.phase != Phase::Done {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
if !input.is_empty() {
self.in_buf.extend_from_slice(input);
consumed_in = input.len();
}
loop {
match self.phase {
Phase::Header => {
if self.in_buf.len() - self.in_pos < 4 {
break;
}
let mut buf = [0u8; 4];
buf.copy_from_slice(&self.in_buf[self.in_pos..self.in_pos + 4]);
self.in_pos += 4;
self.total_output = u32::from_le_bytes(buf) as u64;
if self.total_output == 0 {
self.phase = Phase::Done;
} else {
if !self.try_start_block()? {
self.phase = Phase::Decoding;
break;
}
self.phase = Phase::Decoding;
}
}
Phase::Decoding => {
if self.table.is_none() && !self.try_start_block()? {
break;
}
let snap_in_pos = self.in_pos;
let snap_next_bits = self.next_bits;
let snap_extra = self.extra_bit_count;
let snap_decoded_len = self.decoded.len();
let snap_output_emitted = self.output_emitted;
let snap_block_end = self.block_end_emitted;
let snap_phase = self.phase;
let snap_table = self.table.clone();
let snap_lengths = self.lengths;
let snap_out_history = self.out_history.clone();
let cap = (output.len() - written).saturating_add(BLOCK_OUTPUT_BYTES);
match self.decode_loop(cap) {
Ok(()) => {}
Err(Error::UnexpectedEnd) => {
self.in_pos = snap_in_pos;
self.next_bits = snap_next_bits;
self.extra_bit_count = snap_extra;
self.decoded.truncate(snap_decoded_len);
self.output_emitted = snap_output_emitted;
self.block_end_emitted = snap_block_end;
self.phase = snap_phase;
self.table = snap_table;
self.lengths = snap_lengths;
self.out_history = snap_out_history;
break;
}
Err(e) => return Err(self.poison(e)),
}
written += self.drain_decoded_into(&mut output[written..]);
if written == output.len() && self.phase != Phase::Done {
break;
}
if self.phase == Phase::Decoding {
if snap_in_pos == self.in_pos && snap_decoded_len == self.decoded.len() {
break;
}
}
}
Phase::DrainTail => {
written += self.drain_decoded_into(&mut output[written..]);
if self.decoded.is_empty() {
self.phase = Phase::Done;
}
if written == output.len() {
break;
}
if self.phase == Phase::Done {
break;
}
}
Phase::Done => break,
}
}
if self.in_pos > 4096 {
self.in_buf.drain(..self.in_pos);
self.in_pos = 0;
}
let done = matches!(self.phase, Phase::Done);
Ok(RawProgress {
consumed: consumed_in,
written,
done,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
self.finishing = true;
let mut written = self.drain_decoded_into(output);
if !self.decoded.is_empty() {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
self.raw_decode(&[], &mut output[written..]).map(|p| {
written += p.written;
})?;
let done = matches!(self.phase, Phase::Done);
if !done && self.decoded.is_empty() && self.in_buf.len() == self.in_pos {
return Err(self.poison(Error::UnexpectedEnd));
}
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
self.in_buf.clear();
self.in_pos = 0;
self.decoded.clear();
self.decoded_idx = 0;
self.out_history.clear();
self.phase = Phase::Header;
self.poisoned = false;
self.total_output = 0;
self.output_emitted = 0;
self.table = None;
self.lengths = [0u8; NUM_SYMBOLS];
self.next_bits = 0;
self.extra_bit_count = 0;
self.block_end_emitted = 0;
self.finishing = false;
}
}