extern crate alloc;
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
const MAX_CODE_LEN: usize = 16;
const WINDOW_SIZE: usize = 8 * 1024;
const HEADER_LEN: usize = 5;
const LIT_SYMBOLS: usize = 256;
const LEN_SYMBOLS: usize = 64;
const DIST_SYMBOLS: usize = 64;
#[derive(Debug, Clone)]
struct ShannonFano {
lookup: Vec<(u16, u8)>,
num_symbols: u16,
}
impl ShannonFano {
fn from_lengths(lens: &[u8]) -> Result<Self, Error> {
let mut len_count = [0u32; MAX_CODE_LEN + 1];
for &l in lens {
if l == 0 || l as usize > MAX_CODE_LEN {
return Err(Error::InvalidHuffmanTree);
}
len_count[l as usize] += 1;
}
let mut avail: i64 = 2;
for &c in len_count.iter().take(MAX_CODE_LEN + 1).skip(1) {
avail -= c as i64;
if avail < 0 {
return Err(Error::InvalidHuffmanTree);
}
avail *= 2;
}
if avail != 0 {
return Err(Error::InvalidHuffmanTree);
}
let mut next_code = [0u32; MAX_CODE_LEN + 2];
let mut code = 0u32;
for bits in 1..=MAX_CODE_LEN {
code = (code + len_count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut lookup: Vec<(u16, u8)> = vec![(0, 0); 1 << MAX_CODE_LEN];
for (sym, &l) in lens.iter().enumerate() {
let l = l as usize;
let canonical = next_code[l];
next_code[l] += 1;
let lsb = reverse_bits(canonical, l as u32);
let step = 1usize << l;
let mut idx = lsb as usize;
while idx < (1 << MAX_CODE_LEN) {
lookup[idx] = (sym as u16, l as u8);
idx += step;
}
}
Ok(Self {
lookup,
num_symbols: lens.len() as u16,
})
}
#[inline]
fn decode(&self, bits: u16) -> Result<(u16, u8), Error> {
let idx = !bits as usize & 0xFFFF;
let (sym, l) = self.lookup[idx];
if l == 0 || sym >= self.num_symbols {
return Err(Error::InvalidHuffmanTree);
}
Ok((sym, l))
}
}
const fn reverse_bits(mut v: u32, n: u32) -> u32 {
let mut out = 0u32;
let mut i = 0;
while i < n {
out = (out << 1) | (v & 1);
v >>= 1;
i += 1;
}
out
}
#[derive(Debug, Clone, Copy)]
struct Header {
large_window: bool,
lit_tree: bool,
uncompressed_len: u32,
}
impl Header {
fn parse(buf: &[u8; HEADER_LEN]) -> Result<Self, Error> {
let f = buf[0];
if f & 0b1111_1100 != 0 {
return Err(Error::BadHeader);
}
let large_window = (f & 0b0000_0001) != 0;
let lit_tree = (f & 0b0000_0010) != 0;
let uncompressed_len = u32::from_le_bytes([buf[1], buf[2], buf[3], buf[4]]);
Ok(Self {
large_window,
lit_tree,
uncompressed_len,
})
}
fn dist_low_bits(self) -> u8 {
if self.large_window { 7 } else { 6 }
}
fn min_len(self) -> u16 {
if self.lit_tree { 3 } else { 2 }
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum Phase {
AwaitHeader,
AwaitLitTree,
AwaitLenTree,
AwaitDistTree,
Decode,
Done,
}
pub struct Decoder {
phase: Phase,
header_buf: [u8; HEADER_LEN],
header_pos: u8,
header: Option<Header>,
in_buf: Vec<u8>,
bit_pos: usize,
lit_tree: Option<ShannonFano>,
len_tree: Option<ShannonFano>,
dist_tree: Option<ShannonFano>,
window: Vec<u8>,
window_pos: usize,
pending_start: usize,
pending_len: usize,
output_left: u32,
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl Decoder {
pub fn new() -> Self {
Self {
phase: Phase::AwaitHeader,
header_buf: [0u8; HEADER_LEN],
header_pos: 0,
header: None,
in_buf: Vec::new(),
bit_pos: 0,
lit_tree: None,
len_tree: None,
dist_tree: None,
window: vec![0u8; WINDOW_SIZE],
window_pos: 0,
pending_start: 0,
pending_len: 0,
output_left: 0,
}
}
fn ingest(&mut self, input: &[u8]) {
self.in_buf.extend_from_slice(input);
let consumed_bytes = self.bit_pos / 8;
if consumed_bytes >= 1024 {
self.in_buf.drain(..consumed_bytes);
self.bit_pos -= consumed_bytes * 8;
}
}
fn peek_bits(&self, n: u32) -> Option<u32> {
let total_bits = self.in_buf.len() * 8;
if self.bit_pos + n as usize > total_bits {
return None;
}
let mut acc: u64 = 0;
let byte_idx = self.bit_pos / 8;
let off = (self.bit_pos % 8) as u32;
let take = (n + off).div_ceil(8);
for i in 0..take {
let b = self.in_buf[byte_idx + i as usize];
acc |= (b as u64) << (i * 8);
}
acc >>= off;
if n < 32 {
acc &= (1u64 << n) - 1;
}
Some(acc as u32)
}
fn peek16(&self) -> u16 {
let total_bits = self.in_buf.len() * 8;
let avail = total_bits.saturating_sub(self.bit_pos);
let need = 16.min(avail);
if need == 0 {
return 0;
}
let byte_idx = self.bit_pos / 8;
let off = (self.bit_pos % 8) as u32;
let mut acc: u32 = 0;
let max = (self.in_buf.len() - byte_idx).min(3);
for i in 0..max {
acc |= (self.in_buf[byte_idx + i] as u32) << (i * 8);
}
acc >>= off;
acc as u16
}
fn bits_available(&self) -> usize {
let total_bits = self.in_buf.len() * 8;
total_bits.saturating_sub(self.bit_pos)
}
fn advance(&mut self, n: u32) {
self.bit_pos += n as usize;
}
fn snapshot(&self) -> usize {
self.bit_pos
}
fn rollback(&mut self, snap: usize) {
self.bit_pos = snap;
}
fn emit_byte(&mut self, b: u8) {
self.window[self.window_pos] = b;
self.window_pos = (self.window_pos + 1) & (WINDOW_SIZE - 1);
self.pending_len += 1;
self.output_left -= 1;
}
fn drain(&mut self, out: &mut [u8]) -> usize {
if self.pending_len == 0 || out.is_empty() {
return 0;
}
let n = self.pending_len.min(out.len());
let mut wrote = 0usize;
while wrote < n {
let chunk = (WINDOW_SIZE - self.pending_start).min(n - wrote);
out[wrote..wrote + chunk]
.copy_from_slice(&self.window[self.pending_start..self.pending_start + chunk]);
self.pending_start = (self.pending_start + chunk) & (WINDOW_SIZE - 1);
wrote += chunk;
}
self.pending_len -= n;
n
}
fn try_read_tree(&mut self, num_symbols: usize) -> Result<Option<ShannonFano>, Error> {
debug_assert_eq!(self.bit_pos & 7, 0);
let byte_idx = self.bit_pos / 8;
let avail = self.in_buf.len() - byte_idx;
if avail < 1 {
return Ok(None);
}
let count = self.in_buf[byte_idx] as usize + 1;
if avail < 1 + count {
return Ok(None);
}
let mut lens = vec![0u8; num_symbols];
let mut sym = 0usize;
for i in 0..count {
let pair = self.in_buf[byte_idx + 1 + i];
let bits = (pair & 0x0F) + 1;
let run = ((pair >> 4) & 0x0F) as usize + 1;
if sym + run > num_symbols {
return Err(Error::Corrupt);
}
for _ in 0..run {
lens[sym] = bits;
sym += 1;
}
}
if sym != num_symbols {
return Err(Error::Corrupt);
}
let tree = ShannonFano::from_lengths(&lens)?;
self.bit_pos += 8 * (1 + count);
Ok(Some(tree))
}
fn try_step(&mut self) -> Result<bool, Error> {
let snap = self.snapshot();
let hdr = self.header.expect("header must be set in Decode phase");
if self.bits_available() < 1 {
return Ok(false);
}
let sel = self.peek_bits(1).unwrap();
self.advance(1);
if sel == 1 {
if let Some(ref tree) = self.lit_tree {
if self.bits_available() < 1 {
self.rollback(snap);
return Ok(false);
}
let bits = self.peek16();
let (sym, used) = tree.decode(bits)?;
if (used as usize) > self.bits_available() {
self.rollback(snap);
return Ok(false);
}
self.advance(used as u32);
self.emit_byte(sym as u8);
} else {
if self.bits_available() < 8 {
self.rollback(snap);
return Ok(false);
}
let b = self.peek_bits(8).unwrap() as u8;
self.advance(8);
self.emit_byte(b);
}
return Ok(true);
}
let bdl = hdr.dist_low_bits() as u32;
if self.bits_available() < bdl as usize {
self.rollback(snap);
return Ok(false);
}
let dist_low = self.peek_bits(bdl).unwrap();
self.advance(bdl);
let dist_tree = self
.dist_tree
.as_ref()
.expect("dist_tree set before Decode phase");
let bits = self.peek16();
let (dist_hi, dist_used) = dist_tree.decode(bits)?;
if (dist_used as usize) > self.bits_available() {
self.rollback(snap);
return Ok(false);
}
self.advance(dist_used as u32);
let len_tree = self
.len_tree
.as_ref()
.expect("len_tree set before Decode phase");
let bits = self.peek16();
let (len_sym, len_used) = len_tree.decode(bits)?;
if (len_used as usize) > self.bits_available() {
self.rollback(snap);
return Ok(false);
}
self.advance(len_used as u32);
let mut len = len_sym as u32;
if len_sym == 63 {
if self.bits_available() < 8 {
self.rollback(snap);
return Ok(false);
}
let extra = self.peek_bits(8).unwrap();
self.advance(8);
len += extra;
}
len += hdr.min_len() as u32;
let dist = ((dist_hi as u32) << bdl) | dist_low;
let dist = (dist + 1) as usize;
if len > self.output_left {
return Err(Error::Corrupt);
}
for _ in 0..len {
let src = (self.window_pos + WINDOW_SIZE - dist) & (WINDOW_SIZE - 1);
let b = self.window[src];
self.emit_byte(b);
}
Ok(true)
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
let in_len = input.len();
let mut written = 0usize;
self.ingest(input);
loop {
if self.pending_len > 0 {
written += self.drain(&mut output[written..]);
if written == output.len() && self.pending_len > 0 {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
}
match self.phase {
Phase::AwaitHeader => {
let avail_bytes = self.in_buf.len() - self.bit_pos / 8;
let need = HEADER_LEN - self.header_pos as usize;
let take = need.min(avail_bytes);
let byte_idx = self.bit_pos / 8;
for i in 0..take {
self.header_buf[self.header_pos as usize] = self.in_buf[byte_idx + i];
self.header_pos += 1;
}
self.bit_pos += take * 8;
if (self.header_pos as usize) < HEADER_LEN {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
let hdr = Header::parse(&self.header_buf)?;
self.output_left = hdr.uncompressed_len;
self.header = Some(hdr);
self.phase = if hdr.lit_tree {
Phase::AwaitLitTree
} else {
Phase::AwaitLenTree
};
}
Phase::AwaitLitTree => match self.try_read_tree(LIT_SYMBOLS)? {
Some(t) => {
self.lit_tree = Some(t);
self.phase = Phase::AwaitLenTree;
}
None => {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
},
Phase::AwaitLenTree => match self.try_read_tree(LEN_SYMBOLS)? {
Some(t) => {
self.len_tree = Some(t);
self.phase = Phase::AwaitDistTree;
}
None => {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
},
Phase::AwaitDistTree => match self.try_read_tree(DIST_SYMBOLS)? {
Some(t) => {
self.dist_tree = Some(t);
self.phase = if self.output_left == 0 {
Phase::Done
} else {
Phase::Decode
};
}
None => {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
},
Phase::Decode => {
if self.output_left == 0 {
self.phase = Phase::Done;
continue;
}
if !self.try_step()? {
return Ok(RawProgress {
consumed: in_len,
written,
done: false,
});
}
}
Phase::Done => {
return Ok(RawProgress {
consumed: in_len,
written,
done: self.pending_len == 0,
});
}
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
let mut written = 0usize;
if self.pending_len > 0 {
written += self.drain(output);
}
loop {
if self.pending_len > 0 {
written += self.drain(&mut output[written..]);
if written == output.len() && self.pending_len > 0 {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
}
match self.phase {
Phase::Done => {
return Ok(RawProgress {
consumed: 0,
written,
done: self.pending_len == 0,
});
}
Phase::Decode => {
if self.output_left == 0 {
self.phase = Phase::Done;
continue;
}
match self.try_step() {
Ok(true) => continue,
Ok(false) => return Err(Error::UnexpectedEnd),
Err(e) => return Err(e),
}
}
_ => return Err(Error::UnexpectedEnd),
}
}
}
fn raw_reset(&mut self) {
self.phase = Phase::AwaitHeader;
self.header_buf = [0u8; HEADER_LEN];
self.header_pos = 0;
self.header = None;
self.in_buf.clear();
self.bit_pos = 0;
self.lit_tree = None;
self.len_tree = None;
self.dist_tree = None;
for b in self.window.iter_mut() {
*b = 0;
}
self.window_pos = 0;
self.pending_start = 0;
self.pending_len = 0;
self.output_left = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn header_round_trip() {
let buf = [0b11u8, 0xAA, 0xBB, 0xCC, 0xDD];
let h = Header::parse(&buf).unwrap();
assert!(h.large_window);
assert!(h.lit_tree);
assert_eq!(h.uncompressed_len, 0xDDCC_BBAA);
assert_eq!(h.dist_low_bits(), 7);
assert_eq!(h.min_len(), 3);
}
#[test]
fn header_rejects_reserved_bits() {
let buf = [0b1000u8, 0, 0, 0, 0];
assert!(matches!(Header::parse(&buf), Err(Error::BadHeader)));
}
#[test]
fn shannon_fano_single_symbol_tree_rejected() {
let lens = [1u8];
assert!(ShannonFano::from_lengths(&lens).is_err());
}
#[test]
fn shannon_fano_two_symbols_length_one() {
let lens = [1u8, 1u8];
let t = ShannonFano::from_lengths(&lens).unwrap();
let (s0, u0) = t.decode(0).unwrap();
let (s1, u1) = t.decode(1).unwrap();
assert_eq!(u0, 1);
assert_eq!(u1, 1);
assert_ne!(s0, s1);
}
}