#![cfg_attr(docsrs, doc(cfg(feature = "huffman")))]
use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{Algorithm, RawDecoder, RawEncoder, RawProgress};
const MAX_CODE_LEN: u8 = 15;
#[derive(Debug, Clone, Copy, Default)]
pub struct Huffman;
impl Algorithm for Huffman {
const NAME: &'static str = "huffman";
type Encoder = Encoder;
type Decoder = Decoder;
type EncoderConfig = ();
type DecoderConfig = ();
fn encoder_with(_: ()) -> Encoder {
Encoder::new()
}
fn decoder_with(_: ()) -> Decoder {
Decoder::new()
}
}
fn write_varint(out: &mut Vec<u8>, mut v: u64) {
loop {
let byte = (v & 0x7F) as u8;
v >>= 7;
if v == 0 {
out.push(byte);
break;
}
out.push(byte | 0x80);
}
}
fn read_varint(buf: &[u8]) -> Result<(u64, usize), Error> {
let mut v: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in buf.iter().enumerate() {
if shift >= 64 {
return Err(Error::Corrupt);
}
v |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
return Ok((v, i + 1));
}
shift += 7;
}
Err(Error::Corrupt)
}
fn encode_lengths(lengths: &[u8; 256], out: &mut Vec<u8>) {
let mut i = 0usize;
while i < 256 {
let val = lengths[i];
let mut run = 1usize;
while i + run < 256 && lengths[i + run] == val {
run += 1;
}
i += run;
if val == 0 {
while run >= 19 {
let k = (run - 19).min(255);
out.push(0xF0);
out.push(k as u8);
run -= k + 19;
}
emit_short(out, 0, run);
} else {
while run >= 19 {
let k = (run - 19).min(255);
out.push(0xF1);
out.push(val);
out.push(k as u8);
run -= k + 19;
}
emit_short(out, val, run);
}
}
}
fn emit_short(out: &mut Vec<u8>, val: u8, count: usize) {
let mut left = count;
while left > 0 {
if left >= 3 {
let n = left.min(18);
if val == 0 {
out.push(0xF2);
out.push((n - 3) as u8);
left -= n;
} else if val <= 14 {
out.push((val << 4) | ((n - 3) as u8));
left -= n;
} else {
out.push(val); left -= 1;
}
} else {
out.push(val); left -= 1;
}
}
}
fn decode_lengths(buf: &[u8]) -> Result<([u8; 256], usize), Error> {
let mut lengths = [0u8; 256];
let mut pos = 0usize; let mut i = 0usize;
while pos < 256 {
let c = *buf.get(i).ok_or(Error::Corrupt)?;
i += 1;
match c {
0x00..=0x0F => {
lengths[pos] = c;
pos += 1;
}
0xF0 => {
let k = *buf.get(i).ok_or(Error::Corrupt)? as usize;
i += 1;
let count = k + 19;
if pos + count > 256 {
return Err(Error::Corrupt);
}
pos += count;
}
0xF1 => {
let val = *buf.get(i).ok_or(Error::Corrupt)?;
i += 1;
let k = *buf.get(i).ok_or(Error::Corrupt)? as usize;
i += 1;
if val == 0 || val > MAX_CODE_LEN {
return Err(Error::Corrupt);
}
let count = k + 19;
if pos + count > 256 {
return Err(Error::Corrupt);
}
for slot in &mut lengths[pos..pos + count] {
*slot = val;
}
pos += count;
}
0xF2 => {
let k = *buf.get(i).ok_or(Error::Corrupt)? as usize;
i += 1;
let count = k + 3;
if pos + count > 256 {
return Err(Error::Corrupt);
}
pos += count;
}
_ => {
let val = c >> 4;
let count = (c & 0x0F) as usize + 3;
if pos + count > 256 {
return Err(Error::Corrupt);
}
for slot in &mut lengths[pos..pos + count] {
*slot = val;
}
pos += count;
}
}
}
Ok((lengths, i))
}
fn length_limited_lengths(freqs: &[u32; 256], max_length: u8) -> [u8; 256] {
let mut out = [0u8; 256];
let mut coins: Vec<(u32, u16)> = freqs
.iter()
.enumerate()
.filter_map(|(i, &f)| if f > 0 { Some((f, i as u16)) } else { None })
.collect();
let n = coins.len();
if n == 0 {
return out;
}
if n == 1 {
out[coins[0].1 as usize] = 1;
return out;
}
coins.sort_by_key(|&(f, _)| f);
#[derive(Clone, Copy)]
enum Kind {
Coin(u16),
Pair(u32, u32),
}
struct Elem {
cost: u64,
kind: Kind,
}
let mut pool: Vec<Elem> = Vec::with_capacity(n * (max_length as usize) * 2 + 8);
let mut current: Vec<u32> = Vec::with_capacity(2 * n);
for &(f, sym) in &coins {
pool.push(Elem {
cost: f as u64,
kind: Kind::Coin(sym),
});
current.push((pool.len() - 1) as u32);
}
for _ in 1..max_length {
let mut packages: Vec<u32> = Vec::with_capacity(current.len() / 2);
let mut i = 0;
while i + 1 < current.len() {
let a = current[i];
let b = current[i + 1];
let cost = pool[a as usize].cost + pool[b as usize].cost;
pool.push(Elem {
cost,
kind: Kind::Pair(a, b),
});
packages.push((pool.len() - 1) as u32);
i += 2;
}
let coin_start = pool.len();
for &(f, sym) in &coins {
pool.push(Elem {
cost: f as u64,
kind: Kind::Coin(sym),
});
}
let fresh: Vec<u32> = (coin_start..pool.len()).map(|i| i as u32).collect();
let mut merged: Vec<u32> = Vec::with_capacity(fresh.len() + packages.len());
let (mut ci, mut pi) = (0usize, 0usize);
while ci < fresh.len() && pi < packages.len() {
if pool[fresh[ci] as usize].cost <= pool[packages[pi] as usize].cost {
merged.push(fresh[ci]);
ci += 1;
} else {
merged.push(packages[pi]);
pi += 1;
}
}
merged.extend_from_slice(&fresh[ci..]);
merged.extend_from_slice(&packages[pi..]);
current = merged;
}
let pick = 2 * n - 2;
let mut stack: Vec<u32> = Vec::with_capacity(32);
for &root in ¤t[..pick] {
stack.clear();
stack.push(root);
while let Some(idx) = stack.pop() {
match pool[idx as usize].kind {
Kind::Coin(sym) => out[sym as usize] += 1,
Kind::Pair(a, b) => {
stack.push(a);
stack.push(b);
}
}
}
}
out
}
fn canonical_codes(lengths: &[u8; 256]) -> [u16; 256] {
let mut count = [0u32; 16];
for &len in lengths.iter() {
if len > 0 {
count[len as usize] += 1;
}
}
let mut next_code = [0u32; 16];
let mut code: u32 = 0;
for bits in 1..=15usize {
code = (code + count[bits - 1]) << 1;
next_code[bits] = code;
}
let mut codes = [0u16; 256];
for (i, &len) in lengths.iter().enumerate() {
if len > 0 {
codes[i] = next_code[len as usize] as u16;
next_code[len as usize] += 1;
}
}
codes
}
struct CanonicalTable {
counts: [u16; 16],
first_code: [u32; 16],
first_idx: [u16; 16],
symbols: Vec<u16>,
max_length: u8,
single: Option<u16>,
}
impl CanonicalTable {
fn from_lengths(lengths: &[u8; 256]) -> Result<Self, Error> {
let mut counts = [0u16; 16];
let mut max_length = 0u8;
let mut present = 0usize;
for &len in lengths.iter() {
if len > MAX_CODE_LEN {
return Err(Error::Corrupt);
}
if len > 0 {
counts[len as usize] += 1;
present += 1;
if len > max_length {
max_length = len;
}
}
}
if present == 0 {
return Err(Error::Corrupt);
}
let single = if present == 1 {
if counts[1] != 1 {
return Err(Error::Corrupt);
}
let sym = lengths
.iter()
.position(|&l| l > 0)
.expect("present == 1 guarantees one nonzero length") as u16;
Some(sym)
} else {
None
};
let mut kraft: u32 = 0;
for l in 1..=15u32 {
kraft += (counts[l as usize] as u32) << (15 - l);
}
if single.is_none() && kraft != (1 << 15) {
return Err(Error::Corrupt);
}
let mut first_code = [0u32; 16];
let mut first_idx = [0u16; 16];
let mut code: u32 = 0;
let mut idx: u16 = 0;
for l in 1..=15usize {
code <<= 1;
first_code[l] = code;
first_idx[l] = idx;
code += counts[l] as u32;
idx += counts[l];
}
let mut symbols = vec![0u16; present];
let mut next = first_idx;
for (sym, &len) in lengths.iter().enumerate() {
if len > 0 {
symbols[next[len as usize] as usize] = sym as u16;
next[len as usize] += 1;
}
}
Ok(Self {
counts,
first_code,
first_idx,
symbols,
max_length,
single,
})
}
}
struct BitWriter {
out: Vec<u8>,
cur: u8,
nbits: u8,
}
impl BitWriter {
fn new() -> Self {
Self {
out: Vec::new(),
cur: 0,
nbits: 0,
}
}
fn write(&mut self, code: u16, len: u8) {
let mut i = len;
while i > 0 {
i -= 1;
let bit = ((code >> i) & 1) as u8;
self.cur = (self.cur << 1) | bit;
self.nbits += 1;
if self.nbits == 8 {
self.out.push(self.cur);
self.cur = 0;
self.nbits = 0;
}
}
}
fn finish(mut self) -> Vec<u8> {
if self.nbits > 0 {
self.cur <<= 8 - self.nbits;
self.out.push(self.cur);
}
self.out
}
}
struct BitReader<'a> {
buf: &'a [u8],
byte: usize,
bit: u8, }
impl<'a> BitReader<'a> {
fn new(buf: &'a [u8]) -> Self {
Self {
buf,
byte: 0,
bit: 0,
}
}
fn read_bit(&mut self) -> Option<u8> {
if self.byte >= self.buf.len() {
return None;
}
let b = (self.buf[self.byte] >> (7 - self.bit)) & 1;
self.bit += 1;
if self.bit == 8 {
self.bit = 0;
self.byte += 1;
}
Some(b)
}
}
fn encode_stream(input: &[u8]) -> Vec<u8> {
let mut out = Vec::new();
write_varint(&mut out, input.len() as u64);
if input.is_empty() {
return out;
}
let mut freqs = [0u32; 256];
for &b in input {
freqs[b as usize] += 1;
}
let lengths = length_limited_lengths(&freqs, MAX_CODE_LEN);
encode_lengths(&lengths, &mut out);
let codes = canonical_codes(&lengths);
let mut bw = BitWriter::new();
for &b in input {
let s = b as usize;
bw.write(codes[s], lengths[s]);
}
out.extend_from_slice(&bw.finish());
out
}
fn decode_stream(input: &[u8]) -> Result<Vec<u8>, Error> {
let (orig_len, vlen) = read_varint(input)?;
let orig_len = orig_len as usize;
let mut rest = &input[vlen..];
if orig_len == 0 {
return Ok(Vec::new());
}
let (lengths, consumed) = decode_lengths(rest)?;
rest = &rest[consumed..];
let table = CanonicalTable::from_lengths(&lengths)?;
let mut out = Vec::with_capacity(orig_len);
if let Some(sym) = table.single {
out.resize(orig_len, sym as u8);
return Ok(out);
}
let mut reader = BitReader::new(rest);
let max = table.max_length as u32;
while out.len() < orig_len {
let mut code: u32 = 0;
let mut matched = false;
for length in 1..=max {
let bit = reader.read_bit().ok_or(Error::UnexpectedEnd)? as u32;
code = (code << 1) | bit;
let count = table.counts[length as usize] as u32;
if count > 0 {
let first = table.first_code[length as usize];
if code >= first && code < first + count {
let sym_idx = table.first_idx[length as usize] as u32 + (code - first);
out.push(table.symbols[sym_idx as usize] as u8);
matched = true;
break;
}
}
}
if !matched {
return Err(Error::Corrupt);
}
}
Ok(out)
}
#[derive(Debug)]
pub struct Encoder {
input: Vec<u8>,
output: Vec<u8>,
cursor: usize,
finalized: bool,
}
impl Encoder {
pub const fn new() -> Self {
Self {
input: Vec::new(),
output: Vec::new(),
cursor: 0,
finalized: false,
}
}
}
impl Default for Encoder {
fn default() -> Self {
Self::new()
}
}
impl RawEncoder for Encoder {
fn raw_encode(&mut self, input: &[u8], _output: &mut [u8]) -> Result<RawProgress, Error> {
self.input.extend_from_slice(input);
Ok(RawProgress {
consumed: input.len(),
written: 0,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if !self.finalized {
self.output = encode_stream(&self.input);
self.finalized = true;
}
let remaining = self.output.len() - self.cursor;
let take = remaining.min(output.len());
output[..take].copy_from_slice(&self.output[self.cursor..self.cursor + take]);
self.cursor += take;
Ok(RawProgress {
consumed: 0,
written: take,
done: self.cursor >= self.output.len(),
})
}
fn raw_reset(&mut self) {
self.input.clear();
self.output.clear();
self.cursor = 0;
self.finalized = false;
}
}
#[derive(Debug)]
pub struct Decoder {
input: Vec<u8>,
output: Vec<u8>,
cursor: usize,
decoded: bool,
}
impl Decoder {
pub const fn new() -> Self {
Self {
input: Vec::new(),
output: Vec::new(),
cursor: 0,
decoded: false,
}
}
fn drain(&mut self, output: &mut [u8]) -> RawProgress {
let remaining = self.output.len() - self.cursor;
let take = remaining.min(output.len());
output[..take].copy_from_slice(&self.output[self.cursor..self.cursor + take]);
self.cursor += take;
RawProgress {
consumed: 0,
written: take,
done: self.cursor >= self.output.len(),
}
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if !self.decoded {
self.input.extend_from_slice(input);
return Ok(RawProgress {
consumed: input.len(),
written: 0,
done: false,
});
}
Ok(self.drain(output))
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if !self.decoded {
self.output = decode_stream(&self.input)?;
self.decoded = true;
}
Ok(self.drain(output))
}
fn raw_reset(&mut self) {
self.input.clear();
self.output.clear();
self.cursor = 0;
self.decoded = false;
}
}
#[cfg(test)]
mod tests;