use alloc::boxed::Box;
use alloc::collections::VecDeque;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::bits::BitBuf;
use super::filters::{Filter, FilterKind};
use super::huffman::Huffman;
pub(crate) const HUFF_BC: usize = 20;
pub(crate) const HUFF_NC: usize = 306;
pub(crate) const HUFF_DC: usize = 64;
pub(crate) const HUFF_LDC: usize = 16;
pub(crate) const HUFF_RC: usize = 44;
pub(crate) const HUFF_TABLE_SIZE: usize = HUFF_NC + HUFF_DC + HUFF_LDC + HUFF_RC;
const MIN_WINDOW_SIZE: usize = 0x20000; const MAX_WINDOW_SIZE: usize = 0x4000_0000; const DEFAULT_WINDOW_SIZE: usize = 0x10_0000;
pub struct Decoder {
state: State,
poisoned: bool,
input: VecDeque<u8>,
unpack_total: u64,
unpack_so_far: u64,
window_size: usize,
window: Vec<u8>,
window_pos: usize,
dist_cache: [u32; 4],
last_len: u32,
tables: Option<Box<Tables>>,
out_queue: VecDeque<u8>,
pending_filters: Vec<Filter>,
ready: VecDeque<u8>,
out_queue_start: u64,
}
#[derive(Debug)]
struct Tables {
nc: Huffman, dc: Huffman, ldc: Huffman, rc: Huffman, }
enum State {
BlockHeader,
InBlock { bits: Box<BitBuf>, last_block: bool },
Done,
}
impl Decoder {
pub fn new() -> Self {
Self::with_unpack_size_and_window(u64::MAX, DEFAULT_WINDOW_SIZE)
}
pub fn with_unpack_size(n: u64) -> Self {
Self::with_unpack_size_and_window(n, DEFAULT_WINDOW_SIZE)
}
pub fn with_window_size(window_size: usize) -> Self {
Self::with_unpack_size_and_window(u64::MAX, window_size)
}
pub fn with_unpack_size_and_window(unpack: u64, window_size: usize) -> Self {
let ws = window_size
.clamp(MIN_WINDOW_SIZE, MAX_WINDOW_SIZE)
.next_power_of_two();
Self {
state: State::BlockHeader,
poisoned: false,
input: VecDeque::new(),
unpack_total: unpack,
unpack_so_far: 0,
window_size: ws,
window: vec![0u8; ws],
window_pos: 0,
dist_cache: [0; 4],
last_len: 0,
tables: None,
out_queue: VecDeque::new(),
pending_filters: Vec::new(),
ready: VecDeque::new(),
out_queue_start: 0,
}
}
fn poison(&mut self, e: Error) -> Error {
self.poisoned = true;
e
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
loop {
while written < output.len()
&& let Some(b) = self.ready.pop_front()
{
output[written] = b;
written += 1;
}
if self.unpack_so_far == self.unpack_total
&& self.ready.is_empty()
&& self.out_queue.is_empty()
{
self.state = State::Done;
}
if matches!(self.state, State::Done) {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
if written == output.len() && !self.ready.is_empty() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
while consumed < input.len() {
self.input.push_back(input[consumed]);
consumed += 1;
}
let progressed = match self.step() {
Ok(p) => p,
Err(e) => return Err(self.poison(e)),
};
if !progressed {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut written = 0usize;
while written < output.len()
&& let Some(b) = self.ready.pop_front()
{
output[written] = b;
written += 1;
}
if self.unpack_so_far == self.unpack_total
&& self.ready.is_empty()
&& self.out_queue.is_empty()
{
self.state = State::Done;
}
match &self.state {
State::Done => Ok(RawProgress {
consumed: 0,
written,
done: true,
}),
State::BlockHeader if self.input.is_empty() && self.unpack_total == 0 => {
self.state = State::Done;
Ok(RawProgress {
consumed: 0,
written,
done: true,
})
}
_ => {
if written == 0 {
Err(self.poison(Error::UnexpectedEnd))
} else {
Ok(RawProgress {
consumed: 0,
written,
done: false,
})
}
}
}
}
fn raw_reset(&mut self) {
self.state = State::BlockHeader;
self.poisoned = false;
self.input.clear();
self.unpack_so_far = 0;
for b in &mut self.window {
*b = 0;
}
self.window_pos = 0;
self.dist_cache = [0; 4];
self.last_len = 0;
self.tables = None;
self.out_queue.clear();
self.pending_filters.clear();
self.ready.clear();
self.out_queue_start = 0;
}
}
impl Decoder {
fn step(&mut self) -> Result<bool, Error> {
match &mut self.state {
State::Done => Ok(false),
State::BlockHeader => {
if self.input.len() < 2 {
return Ok(false);
}
let flags = self.input[0];
let byte_count = ((flags >> 3) & 7) as usize;
if byte_count > 2 {
return Err(Error::Corrupt);
}
let header_len = 2 + byte_count + 1;
if self.input.len() < header_len {
return Ok(false);
}
let bit_size = (flags & 7) + 1; let last_block = (flags & 0x40) != 0;
let table_present = (flags & 0x80) != 0;
let cksum_byte = self.input[1];
let mut size_bytes = [0u8; 3];
for (i, sb) in size_bytes.iter_mut().enumerate().take(byte_count + 1) {
*sb = self.input[2 + i];
}
let computed_cksum = 0x5A ^ flags ^ size_bytes[0] ^ size_bytes[1] ^ size_bytes[2];
if computed_cksum != cksum_byte {
return Err(Error::BadHeader);
}
let block_size =
u32::from_le_bytes([size_bytes[0], size_bytes[1], size_bytes[2], 0]) as usize;
if block_size == 0 {
return Err(Error::Corrupt);
}
if self.input.len() < header_len + block_size {
return Ok(false);
}
for _ in 0..header_len {
self.input.pop_front();
}
let mut block_bytes = Vec::with_capacity(block_size);
for _ in 0..block_size {
if let Some(b) = self.input.pop_front() {
block_bytes.push(b);
}
}
let mut bits = Box::new(BitBuf::new());
bits.reset(&block_bytes, bit_size);
if table_present {
let tables = self.read_tables(&mut bits)?;
self.tables = Some(Box::new(tables));
}
if self.tables.is_none() {
return Err(Error::Corrupt);
}
self.state = State::InBlock { bits, last_block };
Ok(true)
}
State::InBlock { bits, last_block } => {
let last_block = *last_block;
let mut bits_owned = core::mem::replace(bits, Box::new(BitBuf::new()));
let res = self.decode_in_block(&mut bits_owned);
if let State::InBlock { bits: slot, .. } = &mut self.state {
*slot = bits_owned;
}
let made_progress = res?;
let at_end = match &self.state {
State::InBlock { bits, .. } => bits.at_end(),
_ => true,
};
if at_end {
if last_block {
self.state = State::BlockHeader;
if self.unpack_total == u64::MAX {
self.unpack_total = self.unpack_so_far + self.out_queue.len() as u64;
}
} else {
self.state = State::BlockHeader;
}
self.flush_ready_through_filters();
return Ok(true);
}
Ok(made_progress)
}
}
}
fn decode_in_block(&mut self, bits: &mut BitBuf) -> Result<bool, Error> {
let mut progressed = false;
loop {
if bits.at_end() {
return Ok(progressed);
}
if self.unpack_so_far + self.out_queue.len() as u64 >= self.unpack_total {
return Ok(progressed);
}
if self.out_queue.len() > 1 << 20 {
self.flush_ready_through_filters();
}
let t = self.tables.as_ref().ok_or(Error::Corrupt)?;
let num = t.nc.decode(bits)?;
progressed = true;
match num {
0..=255 => {
self.emit_literal(num as u8);
}
256 => {
let filter = read_filter(bits, self.window_size, self.cur_out_pos())?;
self.pending_filters.push(filter);
}
257 => {
if self.last_len == 0 {
return Err(Error::Corrupt);
}
let dist = self.dist_cache[0];
self.emit_match(self.last_len, dist)?;
}
258..=261 => {
let idx = (num - 258) as usize;
let dist = self.dist_cache[idx];
for j in (1..=idx).rev() {
self.dist_cache[j] = self.dist_cache[j - 1];
}
self.dist_cache[0] = dist;
let len_sym = t.rc.decode(bits)?;
let length = decode_length(bits, len_sym)?;
self.emit_match(length, dist)?;
}
262..=305 => {
let len_sym = (num - 262) as u32;
let length = decode_length(bits, len_sym as u16)?;
let dist_slot = t.dc.decode(bits)?;
let dist = decode_distance(bits, dist_slot, &t.ldc)?;
let adj_len = adjust_length(length, dist);
self.dist_cache[3] = self.dist_cache[2];
self.dist_cache[2] = self.dist_cache[1];
self.dist_cache[1] = self.dist_cache[0];
self.dist_cache[0] = dist;
self.emit_match(adj_len, dist)?;
}
_ => return Err(Error::Corrupt),
}
if self.out_queue.len() >= 4096 {
self.flush_ready_through_filters();
}
}
}
fn emit_literal(&mut self, b: u8) {
self.window[self.window_pos] = b;
self.window_pos = (self.window_pos + 1) % self.window_size;
self.out_queue.push_back(b);
}
fn emit_match(&mut self, length: u32, dist: u32) -> Result<(), Error> {
if dist == 0 || dist as usize > self.window_size {
return Err(Error::InvalidDistance);
}
if length < 2 {
return Err(Error::Corrupt);
}
let ws = self.window_size;
for _ in 0..length {
let src = (self.window_pos + ws - dist as usize) % ws;
let b = self.window[src];
self.window[self.window_pos] = b;
self.window_pos = (self.window_pos + 1) % ws;
self.out_queue.push_back(b);
if self.unpack_so_far + self.out_queue.len() as u64 >= self.unpack_total {
break;
}
}
self.last_len = length;
Ok(())
}
fn cur_out_pos(&self) -> u64 {
self.out_queue_start + self.out_queue.len() as u64
}
fn flush_ready_through_filters(&mut self) {
let mut pos = self.out_queue_start;
self.pending_filters.sort_by_key(|f| f.start);
loop {
let next_filter_start = self.pending_filters.first().map(|f| f.start);
match next_filter_start {
None => {
let drained = self.out_queue.len() as u64;
while let Some(b) = self.out_queue.pop_front() {
self.ready.push_back(b);
}
self.out_queue_start += drained;
self.unpack_so_far += drained;
return;
}
Some(s) if s > pos => {
let n = (s - pos) as usize;
let avail = self.out_queue.len();
let take = n.min(avail);
for _ in 0..take {
if let Some(b) = self.out_queue.pop_front() {
self.ready.push_back(b);
}
}
pos += take as u64;
self.out_queue_start = pos;
self.unpack_so_far += take as u64;
if (take as u64) < n as u64 {
return;
}
}
Some(_) => {
let f = self.pending_filters[0];
let end = f.start + f.length as u64;
let buf_end = pos + self.out_queue.len() as u64;
if end > buf_end {
return;
}
if f.start < pos {
return;
}
let leading = (f.start - pos) as usize;
for _ in 0..leading {
if let Some(b) = self.out_queue.pop_front() {
self.ready.push_back(b);
}
}
pos += leading as u64;
self.unpack_so_far += leading as u64;
self.out_queue_start = pos;
let length = f.length as usize;
let mut region = Vec::with_capacity(length);
for _ in 0..length {
if let Some(b) = self.out_queue.pop_front() {
region.push(b);
}
}
let ok = super::filters::apply(&f, &mut region);
if ok.is_err() {
}
for &b in ®ion {
self.ready.push_back(b);
}
pos += length as u64;
self.unpack_so_far += length as u64;
self.out_queue_start = pos;
self.pending_filters.remove(0);
}
}
}
}
fn read_tables(&mut self, bits: &mut BitBuf) -> Result<Tables, Error> {
let mut bc_lens = [0u8; HUFF_BC];
let mut i = 0;
while i < HUFF_BC {
let n = bits.read(4)? as u8;
if n < 15 {
bc_lens[i] = n;
i += 1;
} else {
let m = bits.read(4)? as u8;
if m == 0 {
bc_lens[i] = 15;
i += 1;
} else {
let run = (m as usize) + 2;
let end = (i + run).min(HUFF_BC);
while i < end {
bc_lens[i] = 0;
i += 1;
}
}
}
}
let bc = Huffman::from_lengths(&bc_lens)?;
if bc.is_empty() {
return Err(Error::InvalidHuffmanTree);
}
let mut table = vec![0u8; HUFF_TABLE_SIZE];
let mut idx = 0;
while idx < HUFF_TABLE_SIZE {
let sym = bc.decode(bits)?;
match sym {
0..=15 => {
table[idx] = sym as u8;
idx += 1;
}
16 => {
if idx == 0 {
return Err(Error::Corrupt);
}
let n = bits.read(3)? as usize + 3;
let prev = table[idx - 1];
let end = (idx + n).min(HUFF_TABLE_SIZE);
while idx < end {
table[idx] = prev;
idx += 1;
}
}
17 => {
if idx == 0 {
return Err(Error::Corrupt);
}
let n = bits.read(7)? as usize + 11;
let prev = table[idx - 1];
let end = (idx + n).min(HUFF_TABLE_SIZE);
while idx < end {
table[idx] = prev;
idx += 1;
}
}
18 => {
let n = bits.read(3)? as usize + 3;
let end = (idx + n).min(HUFF_TABLE_SIZE);
while idx < end {
table[idx] = 0;
idx += 1;
}
}
_ => {
let n = bits.read(7)? as usize + 11;
let end = (idx + n).min(HUFF_TABLE_SIZE);
while idx < end {
table[idx] = 0;
idx += 1;
}
}
}
}
let nc = Huffman::from_lengths(&table[0..HUFF_NC])?;
let dc = Huffman::from_lengths(&table[HUFF_NC..HUFF_NC + HUFF_DC])?;
let ldc = Huffman::from_lengths(&table[HUFF_NC + HUFF_DC..HUFF_NC + HUFF_DC + HUFF_LDC])?;
let rc = Huffman::from_lengths(&table[HUFF_NC + HUFF_DC + HUFF_LDC..HUFF_TABLE_SIZE])?;
if nc.is_empty() {
return Err(Error::InvalidHuffmanTree);
}
Ok(Tables { nc, dc, ldc, rc })
}
}
fn decode_length(bits: &mut BitBuf, code: u16) -> Result<u32, Error> {
let mut length: u32 = 2;
let lbits: u32;
if code < 8 {
lbits = 0;
length += code as u32;
} else {
lbits = (code as u32 / 4) - 1;
length += (4 | (code as u32 & 3)) << lbits;
}
if lbits > 0 {
length += bits.read(lbits)?;
}
Ok(length)
}
fn adjust_length(length: u32, dist: u32) -> u32 {
let mut len = length;
if dist > 0x100 {
len += 1;
}
if dist > 0x2000 {
len += 1;
}
if dist > 0x40000 {
len += 1;
}
len
}
fn decode_distance(bits: &mut BitBuf, dist_slot: u16, ldc: &Huffman) -> Result<u32, Error> {
let mut dist: u32;
let dbits: u32;
if dist_slot < 4 {
dbits = 0;
dist = 1 + dist_slot as u32;
} else {
dbits = (dist_slot as u32 / 2) - 1;
dist = 1 + ((2 | (dist_slot as u32 & 1)) << dbits);
}
if dbits > 0 {
if dbits >= 4 {
let high_extra = dbits - 4;
if high_extra > 0 {
let high = bits.read(high_extra)?;
dist += high << 4;
}
let low = ldc.decode(bits)? as u32;
dist += low;
} else {
let extra = bits.read(dbits)?;
dist += extra;
}
}
Ok(dist)
}
fn read_filter(bits: &mut BitBuf, window_size: usize, cur_pos: u64) -> Result<Filter, Error> {
let block_start = read_filter_uint(bits)?;
let block_length = read_filter_uint(bits)?;
if !(4..=0x40_0000).contains(&block_length) {
return Err(Error::Corrupt);
}
if block_length as usize > window_size / 2 {
return Err(Error::Corrupt);
}
let ftype_raw = bits.read(3)?;
let kind = match ftype_raw {
1 => FilterKind::X86Call,
2 => FilterKind::X86CallJmp,
0 => {
let channels = bits.read(5)? as u8 + 1;
FilterKind::Delta { channels }
}
3 => FilterKind::Arm,
_ => return Err(Error::Unsupported),
};
Ok(Filter {
start: cur_pos + block_start as u64,
length: block_length,
kind,
})
}
fn read_filter_uint(bits: &mut BitBuf) -> Result<u32, Error> {
let bc = bits.read(2)?;
let mut v: u32 = 0;
for i in 0..=bc {
let b = bits.read(8)?;
v |= b << (i * 8);
}
Ok(v)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn length_decoding_small_codes() {
let mut br = BitBuf::new();
br.reset(&[0; 4], 8);
assert_eq!(decode_length(&mut br, 0).unwrap(), 2);
assert_eq!(decode_length(&mut br, 7).unwrap(), 9);
}
#[test]
fn length_decoding_large_codes_consume_extra_bits() {
let mut br = BitBuf::new();
br.reset(&[0b1000_0000], 8);
assert_eq!(decode_length(&mut br, 8).unwrap(), 11);
}
#[test]
fn distance_adjustment_rules() {
assert_eq!(adjust_length(3, 0x10), 3);
assert_eq!(adjust_length(3, 0x101), 4);
assert_eq!(adjust_length(3, 0x2001), 5);
assert_eq!(adjust_length(3, 0x4_0001), 6);
}
}