extern crate alloc;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::bits::BitReader;
use super::huffman::Huffman;
use super::tables;
use super::window;
const DEFAULT_OUTPUT_CAP: usize = 256 * 1024 * 1024;
const LITLEN_SYMBOLS: usize = tables::LITLEN_SYMBOLS;
const EOS_SYMBOL: u32 = 0x140;
pub(crate) fn decode_payload(input: &[u8], expected_len: Option<usize>) -> Result<Vec<u8>, Error> {
if let Some(0) = expected_len {
return Ok(Vec::new());
}
if input.is_empty() {
return Err(Error::UnexpectedEnd);
}
let cap = expected_len.unwrap_or(DEFAULT_OUTPUT_CAP);
let mut out: Vec<u8> = Vec::new();
if let Some(n) = expected_len {
out.reserve(n.min(1 << 20));
}
let mut reader = BitReader::new(input);
let control = reader.read_bits(8)? as u8;
let high = control >> 4;
let (code_a, code_b, offset_code) = if high == 0 {
let alias = control & 0x08 != 0;
let offset_syms = (control & 0x07) as usize + 10;
let meta = Huffman::from_codes(&tables::META_CODE_VALUES, &tables::META_CODE_LENGTHS)?;
let a_lengths = read_code_lengths(&mut reader, &meta, LITLEN_SYMBOLS)?;
let code_a = Huffman::from_lengths(&a_lengths)?;
let code_b = if alias {
Huffman::from_lengths(&a_lengths)?
} else {
let b_lengths = read_code_lengths(&mut reader, &meta, LITLEN_SYMBOLS)?;
Huffman::from_lengths(&b_lengths)?
};
let off_lengths = read_code_lengths(&mut reader, &meta, offset_syms)?;
let offset_code = Huffman::from_lengths(&off_lengths)?;
(code_a, code_b, offset_code)
} else if (1..=5).contains(&high) {
let idx = (high - 1) as usize;
let code_a = Huffman::from_lengths(tables::PREDEFINED_FIRST[idx])?;
let code_b = Huffman::from_lengths(tables::PREDEFINED_SECOND[idx])?;
let offset_code = Huffman::from_lengths(tables::PREDEFINED_OFFSET[idx])?;
(code_a, code_b, offset_code)
} else {
return Err(Error::Corrupt);
};
let mut use_a = true;
loop {
if out.len() >= cap {
break;
}
let litlen = if use_a {
code_a.decode(&mut reader)?
} else {
code_b.decode(&mut reader)?
};
if litlen <= 0xFF {
window::emit_literal(&mut out, litlen as u8);
use_a = true;
continue;
}
if litlen == EOS_SYMBOL {
break;
}
let length: usize = if litlen <= 0x13D {
(litlen as usize - 0x100) + 3
} else if litlen == 0x13E {
reader.read_bits(10)? as usize + 65
} else if litlen == 0x13F {
reader.read_bits(15)? as usize + 65
} else {
return Err(Error::Corrupt);
};
let b = offset_code.decode(&mut reader)?;
let distance: usize = if b == 0 {
1
} else if b == 1 {
2
} else {
let extra = reader.read_bits(b - 1)? as usize;
(1usize << (b - 1))
.checked_add(extra)
.and_then(|v| v.checked_add(1))
.ok_or(Error::Corrupt)?
};
if out
.len()
.checked_add(length)
.map(|t| t > cap)
.unwrap_or(true)
{
return Err(Error::Corrupt);
}
window::emit_match(&mut out, distance, length)?;
use_a = false;
}
if let Some(n) = expected_len
&& out.len() != n
{
return Err(Error::Corrupt);
}
Ok(out)
}
fn read_code_lengths(
reader: &mut BitReader<'_>,
meta: &Huffman,
count: usize,
) -> Result<Vec<u8>, Error> {
let mut lengths: Vec<u8> = Vec::with_capacity(count);
let mut acc: i32 = 0;
while lengths.len() < count {
let v = meta.decode(reader)?;
let mut extra: usize = 0;
match v {
0..=30 => acc = v as i32 + 1,
31 => acc = -1,
32 => acc = acc.checked_add(1).ok_or(Error::Corrupt)?,
33 => acc = acc.checked_sub(1).ok_or(Error::Corrupt)?,
34 => {
if reader.read_bit()? == 1 {
extra = 1;
}
}
35 => extra = reader.read_bits(3)? as usize + 2,
36 => extra = reader.read_bits(6)? as usize + 10,
_ => return Err(Error::InvalidHuffmanTree),
}
let len_byte: u8 = if acc >= 1 { acc as u8 } else { 0 };
for _ in 0..=extra {
if lengths.len() == count {
break;
}
lengths.push(len_byte);
}
}
Ok(lengths)
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum State {
Buffering,
Draining,
Done,
}
#[derive(Debug)]
pub struct Decoder {
expected_len: Option<usize>,
input: Vec<u8>,
output: Vec<u8>,
out_cursor: usize,
state: State,
poisoned: bool,
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl Decoder {
pub const fn new() -> Self {
Self {
expected_len: None,
input: Vec::new(),
output: Vec::new(),
out_cursor: 0,
state: State::Buffering,
poisoned: false,
}
}
pub const fn with_len(len: usize) -> Self {
Self {
expected_len: Some(len),
input: Vec::new(),
output: Vec::new(),
out_cursor: 0,
state: State::Buffering,
poisoned: false,
}
}
fn decode_now(&mut self) -> Result<(), Error> {
if self.state != State::Buffering {
return Ok(());
}
if self.expected_len == Some(0) {
self.output = Vec::new();
} else {
self.output = decode_payload(&self.input, self.expected_len)?;
}
self.state = State::Draining;
Ok(())
}
fn drain(&mut self, output: &mut [u8]) -> RawProgress {
let remaining = self.output.len() - self.out_cursor;
let take = remaining.min(output.len());
output[..take].copy_from_slice(&self.output[self.out_cursor..self.out_cursor + take]);
self.out_cursor += take;
let done = self.out_cursor >= self.output.len();
if done {
self.state = State::Done;
}
RawProgress {
consumed: 0,
written: take,
done,
}
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
match self.state {
State::Buffering => {
self.input.extend_from_slice(input);
Ok(RawProgress {
consumed: input.len(),
written: 0,
done: false,
})
}
State::Draining => Ok(self.drain(output)),
State::Done => Ok(RawProgress {
consumed: 0,
written: 0,
done: true,
}),
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
if self.state == State::Buffering
&& let Err(e) = self.decode_now()
{
self.poisoned = true;
return Err(e);
}
if self.state == State::Done {
return Ok(RawProgress {
consumed: 0,
written: 0,
done: true,
});
}
Ok(self.drain(output))
}
fn raw_reset(&mut self) {
self.input.clear();
self.output.clear();
self.out_cursor = 0;
self.state = State::Buffering;
self.poisoned = false;
}
}