use alloc::boxed::Box;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::bits::BitReader;
use super::filters::apply_e8_filter;
use super::huffman::Huffman;
use super::tables::{
DICT_DEFAULT_SIZE, HUFF_TABLE_SIZE, LENGTH_BASE, LENGTH_EXTRA_BITS, LENGTH_SIZE,
LOW_OFFSET_SIZE, MAIN_SIZE, OFFSET_BASE, OFFSET_EXTRA_BITS, OFFSET_SIZE, PRECODE_SIZE,
SHORT_BASE, SHORT_EXTRA_BITS,
};
pub struct Decoder {
state: State,
out_buf: Vec<u8>,
out_drained: usize,
unpack_size: u64,
e8_enabled: bool,
e8_translate_e9: bool,
poisoned: bool,
}
enum State {
Buffering { input: Vec<u8> },
Draining,
Done,
}
impl Decoder {
pub fn new() -> Self {
Self {
state: State::Buffering { input: Vec::new() },
out_buf: Vec::new(),
out_drained: 0,
unpack_size: u64::MAX,
e8_enabled: false,
e8_translate_e9: false,
poisoned: false,
}
}
pub fn with_unpack_size(n: u64) -> Self {
Self {
state: State::Buffering { input: Vec::new() },
out_buf: Vec::new(),
out_drained: 0,
unpack_size: n,
e8_enabled: false,
e8_translate_e9: false,
poisoned: false,
}
}
pub fn with_e8_filter(mut self, translate_e9: bool) -> Self {
self.e8_enabled = true;
self.e8_translate_e9 = translate_e9;
self
}
fn poison<T>(&mut self, e: Error) -> Result<T, Error> {
self.poisoned = true;
Err(e)
}
fn drain_into(&mut self, output: &mut [u8]) -> usize {
let mut written = 0usize;
while self.out_drained < self.out_buf.len() && written < output.len() {
let n = (self.out_buf.len() - self.out_drained).min(output.len() - written);
output[written..written + n]
.copy_from_slice(&self.out_buf[self.out_drained..self.out_drained + n]);
written += n;
self.out_drained += n;
}
if self.out_drained == self.out_buf.len() {
self.out_buf.clear();
self.out_drained = 0;
self.state = State::Done;
}
written
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
match &mut self.state {
State::Buffering { input: buf } => {
buf.extend_from_slice(input);
consumed = input.len();
}
State::Draining => {
written = self.drain_into(output);
}
State::Done => {}
}
Ok(RawProgress {
consumed,
written,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
if let State::Buffering { input } = &mut self.state {
let input = core::mem::take(input);
match run_decode(
input,
self.unpack_size,
self.e8_enabled,
self.e8_translate_e9,
) {
Ok(out) => {
self.out_buf = out;
self.out_drained = 0;
self.state = State::Draining;
}
Err(e) => return self.poison(e),
}
}
let written = self.drain_into(output);
let done = matches!(self.state, State::Done);
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
self.state = State::Buffering { input: Vec::new() };
self.out_buf.clear();
self.out_drained = 0;
self.poisoned = false;
}
}
fn run_decode(
input: Vec<u8>,
unpack_size: u64,
e8_enabled: bool,
e8_translate_e9: bool,
) -> Result<Vec<u8>, Error> {
if unpack_size == 0 {
return Ok(Vec::new());
}
let mut br = BitReader::new();
br.feed_slice(&input);
let mut ctx = Box::new(RunCtx {
bits: br,
lengths: vec![0u8; HUFF_TABLE_SIZE],
main: None,
offset: None,
low_offset: None,
length: None,
old_offsets: [1u32, 1, 1, 1],
last_offset: 0,
last_length: 0,
last_low_offset: 0,
num_low_offset_repeats: 0,
out: Vec::new(),
window: vec![0u8; DICT_DEFAULT_SIZE],
window_pos: 0,
unpack_size,
});
parse_block_header(&mut ctx)?;
expand(&mut ctx)?;
let mut out = core::mem::take(&mut ctx.out);
if e8_enabled {
apply_e8_filter(&mut out, 0, e8_translate_e9);
}
Ok(out)
}
struct RunCtx {
bits: BitReader,
lengths: Vec<u8>,
main: Option<Box<Huffman>>,
offset: Option<Box<Huffman>>,
low_offset: Option<Box<Huffman>>,
length: Option<Box<Huffman>>,
old_offsets: [u32; 4],
last_offset: u32,
last_length: u32,
last_low_offset: u32,
num_low_offset_repeats: u32,
out: Vec<u8>,
window: Vec<u8>,
window_pos: usize,
unpack_size: u64,
}
impl RunCtx {
fn emit_literal(&mut self, b: u8) {
self.out.push(b);
let wlen = self.window.len();
self.window[self.window_pos] = b;
self.window_pos = (self.window_pos + 1) % wlen;
}
fn emit_match(&mut self, offset: u32, length: u32) -> Result<(), Error> {
if offset == 0 {
return Err(Error::InvalidDistance);
}
let wlen = self.window.len();
let off = offset as usize;
if off > wlen {
return Err(Error::InvalidDistance);
}
for _ in 0..length {
let src = (self.window_pos + wlen - off) % wlen;
let b = self.window[src];
self.out.push(b);
self.window[self.window_pos] = b;
self.window_pos = (self.window_pos + 1) % wlen;
if (self.out.len() as u64) >= self.unpack_size {
break;
}
}
Ok(())
}
fn done(&self) -> bool {
(self.out.len() as u64) >= self.unpack_size
}
}
fn parse_block_header(ctx: &mut RunCtx) -> Result<(), Error> {
ctx.bits.byte_align();
let is_ppmd = ctx.bits.read_bits(1)?;
if is_ppmd != 0 {
return Err(Error::Unsupported);
}
let keep_table = ctx.bits.read_bits(1)? != 0;
if !keep_table {
for slot in ctx.lengths.iter_mut() {
*slot = 0;
}
}
let mut precode = [0u8; PRECODE_SIZE];
let mut i = 0usize;
while i < PRECODE_SIZE {
let v = ctx.bits.read_bits(4)? as u8;
if v == 0x0F {
let runcount = ctx.bits.read_bits(4)? as u8;
if runcount == 0 {
precode[i] = 0x0F;
i += 1;
} else {
let n = (runcount as usize) + 2;
let mut k = 0;
while k < n && i < PRECODE_SIZE {
precode[i] = 0;
i += 1;
k += 1;
}
}
} else {
precode[i] = v;
i += 1;
}
}
let pre_tree = Huffman::from_lengths(&precode)?;
let mut idx = 0usize;
while idx < HUFF_TABLE_SIZE {
let sym = pre_tree.decode(&mut ctx.bits)?;
if sym < 16 {
ctx.lengths[idx] = ((ctx.lengths[idx] as u16 + sym) & 0xF) as u8;
idx += 1;
} else if sym == 16 {
if idx == 0 {
return Err(Error::Corrupt);
}
let n = (ctx.bits.read_bits(3)? as usize) + 3;
let prev = ctx.lengths[idx - 1];
for _ in 0..n {
if idx >= HUFF_TABLE_SIZE {
break;
}
ctx.lengths[idx] = prev;
idx += 1;
}
} else if sym == 17 {
if idx == 0 {
return Err(Error::Corrupt);
}
let n = (ctx.bits.read_bits(7)? as usize) + 11;
let prev = ctx.lengths[idx - 1];
for _ in 0..n {
if idx >= HUFF_TABLE_SIZE {
break;
}
ctx.lengths[idx] = prev;
idx += 1;
}
} else if sym == 18 {
let n = (ctx.bits.read_bits(3)? as usize) + 3;
for _ in 0..n {
if idx >= HUFF_TABLE_SIZE {
break;
}
ctx.lengths[idx] = 0;
idx += 1;
}
} else if sym == 19 {
let n = (ctx.bits.read_bits(7)? as usize) + 11;
for _ in 0..n {
if idx >= HUFF_TABLE_SIZE {
break;
}
ctx.lengths[idx] = 0;
idx += 1;
}
} else {
return Err(Error::Corrupt);
}
}
ctx.main = Some(Box::new(Huffman::from_lengths(&ctx.lengths[..MAIN_SIZE])?));
ctx.offset = Some(Box::new(Huffman::from_lengths(
&ctx.lengths[MAIN_SIZE..MAIN_SIZE + OFFSET_SIZE],
)?));
ctx.low_offset = Some(Box::new(Huffman::from_lengths(
&ctx.lengths[MAIN_SIZE + OFFSET_SIZE..MAIN_SIZE + OFFSET_SIZE + LOW_OFFSET_SIZE],
)?));
ctx.length = Some(Box::new(Huffman::from_lengths(
&ctx.lengths[MAIN_SIZE + OFFSET_SIZE + LOW_OFFSET_SIZE
..MAIN_SIZE + OFFSET_SIZE + LOW_OFFSET_SIZE + LENGTH_SIZE],
)?));
Ok(())
}
fn expand(ctx: &mut RunCtx) -> Result<(), Error> {
loop {
if ctx.done() {
return Ok(());
}
let main_tree = ctx.main.as_ref().ok_or(Error::InvalidHuffmanTree)?;
let sym = match main_tree.decode(&mut ctx.bits) {
Ok(s) => s,
Err(Error::UnexpectedEnd) => {
return Ok(());
}
Err(e) => return Err(e),
};
if sym < 256 {
ctx.emit_literal(sym as u8);
continue;
}
match sym {
256 => {
let new_table = ctx.bits.read_bits(1)? != 0;
if new_table {
parse_block_header(ctx)?;
} else {
return Ok(());
}
}
257 => {
return Err(Error::Unsupported);
}
258 => {
if ctx.last_length == 0 {
return Err(Error::Corrupt);
}
let (o, l) = (ctx.last_offset, ctx.last_length);
ctx.emit_match(o, l)?;
}
259..=262 => {
let idx = (sym - 259) as usize;
let offs = ctx.old_offsets[idx];
let length_tree = ctx.length.as_ref().ok_or(Error::InvalidHuffmanTree)?;
let lensym = length_tree.decode(&mut ctx.bits)? as usize;
if lensym >= LENGTH_BASE.len() {
return Err(Error::Corrupt);
}
let lbase = LENGTH_BASE[lensym] as u32 + 2;
let lbits = LENGTH_EXTRA_BITS[lensym] as u32;
let extra = if lbits > 0 {
ctx.bits.read_bits(lbits)?
} else {
0
};
let length = lbase + extra;
promote_offset(ctx, idx, offs);
ctx.last_offset = offs;
ctx.last_length = length;
ctx.emit_match(offs, length)?;
}
263..=270 => {
let idx = (sym - 263) as usize;
let sbase = SHORT_BASE[idx];
let sbits = SHORT_EXTRA_BITS[idx] as u32;
let extra = if sbits > 0 {
ctx.bits.read_bits(sbits)?
} else {
0
};
let offs = sbase + extra + 1;
let length: u32 = 2;
ctx.old_offsets[3] = ctx.old_offsets[2];
ctx.old_offsets[2] = ctx.old_offsets[1];
ctx.old_offsets[1] = ctx.old_offsets[0];
ctx.old_offsets[0] = offs;
ctx.last_offset = offs;
ctx.last_length = length;
ctx.emit_match(offs, length)?;
}
271..=298 => {
let idx = (sym - 271) as usize;
if idx >= LENGTH_BASE.len() {
return Err(Error::Corrupt);
}
let lbase = LENGTH_BASE[idx] as u32 + 3;
let lbits = LENGTH_EXTRA_BITS[idx] as u32;
let lextra = if lbits > 0 {
ctx.bits.read_bits(lbits)?
} else {
0
};
let mut length = lbase + lextra;
let offset_tree = ctx.offset.as_ref().ok_or(Error::InvalidHuffmanTree)?;
let osym = offset_tree.decode(&mut ctx.bits)? as usize;
if osym >= OFFSET_BASE.len() {
return Err(Error::Corrupt);
}
let mut offs = OFFSET_BASE[osym] + 1;
let obits = OFFSET_EXTRA_BITS[osym] as u32;
if osym > 9 {
if obits > 4 {
let high = ctx.bits.read_bits(obits - 4)?;
offs = offs.wrapping_add(high << 4);
}
if ctx.num_low_offset_repeats > 0 {
ctx.num_low_offset_repeats -= 1;
offs = offs.wrapping_add(ctx.last_low_offset);
} else {
let low_tree = ctx.low_offset.as_ref().ok_or(Error::InvalidHuffmanTree)?;
let lowsym = low_tree.decode(&mut ctx.bits)?;
if lowsym == 16 {
ctx.num_low_offset_repeats = 15;
offs = offs.wrapping_add(ctx.last_low_offset);
} else {
offs = offs.wrapping_add(lowsym as u32);
ctx.last_low_offset = lowsym as u32;
}
}
} else if obits > 0 {
let extra = ctx.bits.read_bits(obits)?;
offs = offs.wrapping_add(extra);
}
if offs >= 0x4_0000 {
length += 1;
}
if offs >= 0x2000 {
length += 1;
}
ctx.old_offsets[3] = ctx.old_offsets[2];
ctx.old_offsets[2] = ctx.old_offsets[1];
ctx.old_offsets[1] = ctx.old_offsets[0];
ctx.old_offsets[0] = offs;
ctx.last_offset = offs;
ctx.last_length = length;
ctx.emit_match(offs, length)?;
}
_ => return Err(Error::Corrupt),
}
}
}
fn promote_offset(ctx: &mut RunCtx, idx: usize, offs: u32) {
let mut i = idx;
while i > 0 {
ctx.old_offsets[i] = ctx.old_offsets[i - 1];
i -= 1;
}
ctx.old_offsets[0] = offs;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::traits::Decoder as _;
extern crate std;
use std::vec;
#[test]
fn unpack_size_zero_is_immediate_done() {
let mut dec = Decoder::with_unpack_size(0);
let mut out = [0u8; 8];
let (p, status) = dec.finish(&mut out).unwrap();
assert_eq!(p.written, 0);
assert!(matches!(status, crate::Status::StreamEnd));
}
#[test]
fn promote_offset_rotates_correctly() {
let mut ctx = RunCtx {
bits: BitReader::new(),
lengths: vec![],
main: None,
offset: None,
low_offset: None,
length: None,
old_offsets: [10, 20, 30, 40],
last_offset: 0,
last_length: 0,
last_low_offset: 0,
num_low_offset_repeats: 0,
out: vec![],
window: vec![0u8; 16],
window_pos: 0,
unpack_size: 0,
};
promote_offset(&mut ctx, 2, 30);
assert_eq!(ctx.old_offsets, [30, 10, 20, 40]);
}
}