const MAX_CODE_LEN: usize = 12;
const UNCOMPRESSED_MARKER: u8 = 0xff;
#[derive(Clone)]
pub struct HuffmanCodec {
code_bits: [u16; 256],
code_lens: [u8; 256],
lookup: Box<[u8; 65536]>,
esc_code: i32,
esc_len: i32,
hist: [u64; 256],
is_big_endian: bool,
}
impl std::fmt::Debug for HuffmanCodec {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("HuffmanCodec")
.field("esc_code", &self.esc_code)
.field("esc_len", &self.esc_len)
.finish()
}
}
impl HuffmanCodec {
pub fn new() -> Self {
Self {
code_bits: [0; 256],
code_lens: [0; 256],
lookup: Box::new([0; 65536]),
esc_code: -1,
esc_len: 0,
hist: [0; 256],
is_big_endian: cfg!(target_endian = "big"),
}
}
pub fn add_to_histogram(&mut self, data: &[u8]) {
for &b in data {
self.hist[b as usize] += 1;
}
}
pub fn merge_histogram(&mut self, other: &HuffmanCodec) {
for i in 0..256 {
self.hist[i] += other.hist[i];
}
}
pub fn build(&mut self, partial: bool) {
let mut codes: Vec<u8> = Vec::new();
let mut esc_byte: Option<u8> = None;
for i in 0u16..256 {
if self.hist[i as usize] > 0 {
codes.push(i as u8);
} else if partial && esc_byte.is_none() {
esc_byte = Some(i as u8);
codes.push(i as u8);
}
}
if codes.len() <= 1 {
self.esc_code = -1;
return;
}
codes.sort_by(|&a, &b| self.hist[a as usize].cmp(&self.hist[b as usize]));
let (lengths, bits) = compute_codes(&codes, &self.hist);
self.code_lens.fill(0);
self.code_bits.fill(0);
for (i, &byte_val) in codes.iter().enumerate() {
self.code_lens[byte_val as usize] = lengths[i] as u8;
self.code_bits[byte_val as usize] = bits[i];
}
if let Some(esc) = esc_byte {
self.esc_code = esc as i32;
self.esc_len = self.code_lens[esc as usize] as i32;
self.code_lens[esc as usize] = 0;
} else {
self.esc_code = -1;
self.esc_len = 0;
}
self.build_lookup_table();
}
pub fn encode(&self, input: &[u8], output: &mut [u8]) -> usize {
let ibits = input.len() * 8;
let mut tbits: usize = 2;
for &x in input {
let n = self.code_lens[x as usize] as usize;
if n == 0 {
if self.esc_code < 0 {
return write_uncompressed(input, output);
}
tbits += self.esc_len as usize + 8;
} else {
tbits += n;
}
if tbits > ibits {
return write_uncompressed(input, output);
}
}
output[..tbits.div_ceil(8)].fill(0);
let mut ocode: u64 = if self.is_big_endian {
0x4000000000000000
} else {
0
};
let mut rem: i32 = 62;
let mut word_pos: usize = 0;
for &x in input {
let n = self.code_lens[x as usize];
if n == 0 {
let c = self.code_bits[self.esc_code as usize] as u64;
ocode_push(
&mut ocode, &mut rem, output, &mut word_pos,
self.esc_len, c,
);
ocode_push(
&mut ocode, &mut rem, output, &mut word_pos,
8, x as u64,
);
} else {
let c = self.code_bits[x as usize] as u64;
ocode_push(
&mut ocode, &mut rem, output, &mut word_pos,
n as i32, c,
);
}
}
let ocode_bytes = ocode.to_ne_bytes();
if self.is_big_endian {
let n = ((71 - rem) >> 3) as usize;
output[word_pos..word_pos + n]
.copy_from_slice(&ocode_bytes[..n]);
} else {
let start = (7 - ((63 - rem) >> 3)) as usize;
for k in (start..=7).rev() {
output[word_pos] = ocode_bytes[k];
word_pos += 1;
}
}
if tbits >= 64 && !self.is_big_endian {
output.swap(0, 7);
}
tbits
}
pub fn decode(&self, n_bits: usize, input: &[u8], output: &mut [u8]) -> usize {
let max_out = output.len();
if input.is_empty() || n_bits == 0 {
return 0;
}
if input[0] == UNCOMPRESSED_MARKER {
let raw_len = (n_bits / 8) - 1;
output[..raw_len].copy_from_slice(&input[1..1 + raw_len]);
return raw_len;
}
let n_bytes = n_bits.div_ceil(8);
let mut buf = vec![0u8; n_bytes.max(8)];
buf[..n_bytes].copy_from_slice(&input[..n_bytes]);
let in_big = buf[0] & 0x40 != 0;
if !in_big && n_bits >= 64 {
buf.swap(0, 7);
}
if in_big != self.is_big_endian {
for chunk in buf.chunks_exact_mut(8) {
chunk.reverse();
}
}
let mut icode: u64;
let mut word_pos: usize;
if n_bits < 64 {
icode = 0;
for (k, &b) in buf.iter().enumerate() {
if k * 8 >= n_bits {
break;
}
icode |= (b as u64) << (56 - k * 8);
}
word_pos = n_bytes;
} else {
icode = u64::from_ne_bytes(buf[..8].try_into().unwrap());
word_pos = 8;
}
icode <<= 2;
let mut ilen = n_bits as i64 - 2;
let mut rem: i64 = ilen.min(62);
let mut ncode: u64 = 0;
let mut nem: i64 = 0;
if ilen > 62 {
if ilen - 62 < 64 {
nem = ilen - 62;
ncode = 0;
for k in 0..(nem as usize).div_ceil(8) {
if word_pos + k < buf.len() {
ncode |= (buf[word_pos + k] as u64) << (56 - k * 8);
}
}
} else if word_pos + 8 <= buf.len() {
ncode =
u64::from_ne_bytes(buf[word_pos..word_pos + 8].try_into().unwrap());
word_pos += 8;
nem = 64;
}
}
let mut out_idx: usize = 0;
while ilen > 0 && out_idx < max_out {
let c = self.lookup[(icode >> 48) as usize];
if c as i32 == self.esc_code && self.esc_code >= 0 {
icode_get(
&mut icode, &mut ilen, &mut rem, &mut ncode,
&mut nem, &buf, &mut word_pos, self.esc_len as i64,
);
output[out_idx] = (icode >> 56) as u8;
icode_get(
&mut icode, &mut ilen, &mut rem, &mut ncode,
&mut nem, &buf, &mut word_pos, 8,
);
} else {
let n = self.code_lens[c as usize] as i64;
if n == 0 {
break;
}
icode_get(
&mut icode, &mut ilen, &mut rem, &mut ncode,
&mut nem, &buf, &mut word_pos, n,
);
output[out_idx] = c;
}
out_idx += 1;
}
out_idx
}
pub fn serialise(&self, out: &mut [u8]) -> usize {
let mut pos: usize = 0;
out[pos] = self.is_big_endian as u8;
pos += 1;
out[pos..pos + 4].copy_from_slice(&self.esc_code.to_ne_bytes());
pos += 4;
out[pos..pos + 4].copy_from_slice(&self.esc_len.to_ne_bytes());
pos += 4;
for i in 0..256 {
let len = if i as i32 == self.esc_code {
self.esc_len as u8
} else {
self.code_lens[i]
};
out[pos] = len;
pos += 1;
if len > 0 || i as i32 == self.esc_code {
out[pos..pos + 2]
.copy_from_slice(&self.code_bits[i].to_ne_bytes());
pos += 2;
}
}
pos
}
pub fn deserialise(input: &[u8]) -> Self {
let mut codec = Self::new();
let machine_big = cfg!(target_endian = "big");
let mut pos: usize = 0;
let stored_big = input[pos] != 0;
codec.is_big_endian = machine_big;
pos += 1;
let need_flip = stored_big != machine_big;
let mut esc_bytes = [0u8; 4];
esc_bytes.copy_from_slice(&input[pos..pos + 4]);
if need_flip { esc_bytes.reverse(); }
codec.esc_code = i32::from_ne_bytes(esc_bytes);
pos += 4;
let mut elen_bytes = [0u8; 4];
elen_bytes.copy_from_slice(&input[pos..pos + 4]);
if need_flip { elen_bytes.reverse(); }
codec.esc_len = i32::from_ne_bytes(elen_bytes);
pos += 4;
for i in 0..256 {
let len = input[pos];
codec.code_lens[i] = len;
pos += 1;
if len > 0 || i as i32 == codec.esc_code {
let mut bits_bytes = [0u8; 2];
bits_bytes.copy_from_slice(&input[pos..pos + 2]);
if need_flip { bits_bytes.reverse(); }
codec.code_bits[i] = u16::from_ne_bytes(bits_bytes);
pos += 2;
}
}
if codec.esc_code >= 0 {
codec.code_lens[codec.esc_code as usize] = codec.esc_len as u8;
}
codec.build_lookup_table();
codec
}
pub const fn max_serial_size() -> usize {
1 + 4 + 4 + 256 + 256 * 2
}
fn build_lookup_table(&mut self) {
self.lookup.fill(0);
for i in 0..256 {
let len = self.code_lens[i] as u32;
if len == 0 { continue; }
let bits = self.code_bits[i] as u32;
let base = bits << (16 - len);
let count = 1u32 << (16 - len);
for j in 0..count {
self.lookup[(base + j) as usize] = i as u8;
}
}
}
}
impl Default for HuffmanCodec {
fn default() -> Self { Self::new() }
}
fn ocode_push(
ocode: &mut u64,
rem: &mut i32,
output: &mut [u8],
word_pos: &mut usize,
len: i32,
code: u64,
) {
*rem -= len;
if *rem <= 0 {
*ocode |= code.checked_shr((-*rem) as u32).unwrap_or(0);
output[*word_pos..*word_pos + 8].copy_from_slice(&ocode.to_ne_bytes());
*word_pos += 8;
if *rem < 0 {
*rem += 64;
*ocode = code << *rem as u32;
} else {
*rem = 64;
*ocode = 0;
}
} else {
*ocode |= code << *rem as u32;
}
}
#[allow(clippy::too_many_arguments)]
fn icode_get(
icode: &mut u64,
ilen: &mut i64,
rem: &mut i64,
ncode: &mut u64,
nem: &mut i64,
buf: &[u8],
word_pos: &mut usize,
n: i64,
) {
*ilen -= n;
*icode <<= n as u32;
*rem -= n;
while *rem < 16 {
let z = 64 - *rem;
*icode |= ncode.checked_shr(*rem as u32).unwrap_or(0);
if *nem > z {
*nem -= z;
*ncode <<= z as u32;
*rem = 64;
break;
} else {
*rem += *nem;
if *rem >= *ilen {
break;
} else if *ilen - *rem < 64 {
*nem = *ilen - *rem;
*ncode = 0;
for k in 0..(*nem as usize).div_ceil(8) {
if *word_pos + k < buf.len() {
*ncode |= (buf[*word_pos + k] as u64) << (56 - k * 8);
}
}
} else {
*ncode = u64::from_ne_bytes(
buf[*word_pos..*word_pos + 8].try_into().unwrap(),
);
*word_pos += 8;
*nem = 64;
}
}
}
}
fn compute_codes(codes: &[u8], hist: &[u64; 256]) -> (Vec<usize>, Vec<u16>) {
let ncode = codes.len();
if ncode <= 1 {
return (vec![1; ncode], vec![1; ncode]);
}
let dcode = 2 * ncode;
let countb: Vec<u64> = codes.iter().map(|&c| hist[c as usize].max(1)).collect();
let mut leng = vec![0usize; ncode];
{
let mut matrix = vec![vec![0u8; dcode]; MAX_CODE_LEN];
let mut lcnt = vec![0u64; dcode];
let mut ccnt = vec![0u64; dcode];
lcnt[..ncode].copy_from_slice(&countb);
let mut llen: usize = ncode - 1;
for level in (1..MAX_CODE_LEN).rev() {
let mut j: usize = 0;
let mut k: usize = 0;
let mut n: usize = 0;
while j < ncode || k < llen {
if k >= llen
|| (j < ncode && countb[j] <= lcnt[k] + lcnt[k + 1])
{
ccnt[n] = countb[j];
matrix[level][n] = 1;
j += 1;
} else {
ccnt[n] = lcnt[k] + lcnt[k + 1];
matrix[level][n] = 0;
k += 2;
}
n += 1;
}
llen = n - 1;
std::mem::swap(&mut lcnt, &mut ccnt);
}
let mut span: usize = 2 * (ncode - 1);
for row in matrix.iter().take(MAX_CODE_LEN).skip(1) {
let mut j: usize = 0;
for &flag in row.iter().take(span) {
if flag != 0 {
leng[j] += 1;
j += 1;
}
}
span = 2 * (span - j);
}
for item in leng.iter_mut().take(span) {
*item += 1;
}
}
let mut bits = vec![0u16; ncode];
{
let mut llen = leng[0] as i32;
let mut lbits: u16 = ((1u32 << leng[0]) - 1) as u16;
bits[0] = lbits;
for n in 1..ncode {
while lbits & 1 == 0 {
lbits >>= 1;
llen -= 1;
}
lbits -= 1;
while llen < leng[n] as i32 {
lbits = (lbits << 1) | 1;
llen += 1;
}
bits[n] = lbits;
}
}
(leng, bits)
}
fn write_uncompressed(input: &[u8], output: &mut [u8]) -> usize {
output[0] = UNCOMPRESSED_MARKER;
output[1..1 + input.len()].copy_from_slice(input);
input.len() * 8 + 8
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_simple() {
let data = b"AAABBBCCDDDDDDEEEE";
let mut codec = HuffmanCodec::new();
codec.add_to_histogram(data);
codec.build(false);
let mut compressed = vec![0u8; data.len() * 2];
let n_bits = codec.encode(data, &mut compressed);
assert!(n_bits > 0);
assert!(n_bits < data.len() * 8, "should compress");
let mut decompressed = vec![0u8; data.len()];
let n_bytes = codec.decode(n_bits, &compressed, &mut decompressed);
assert_eq!(n_bytes, data.len());
assert_eq!(&decompressed[..n_bytes], &data[..]);
}
#[test]
fn round_trip_scaling() {
for n_symbols in [5, 10, 20, 50, 100, 200, 256] {
let mut data = Vec::new();
for i in 0..n_symbols as u16 {
for _ in 0..(i as usize + 1) {
data.push(i as u8);
}
}
let mut codec = HuffmanCodec::new();
codec.add_to_histogram(&data);
codec.build(false);
let mut compressed = vec![0u8; data.len() * 2];
let n_bits = codec.encode(&data, &mut compressed);
let mut decompressed = vec![0u8; data.len()];
let n_bytes = codec.decode(n_bits, &compressed, &mut decompressed);
assert_eq!(
n_bytes,
data.len(),
"FAILED with {n_symbols} symbols: decoded {n_bytes}, expected {}",
data.len()
);
assert_eq!(
&decompressed[..n_bytes],
&data[..],
"data mismatch with {n_symbols} symbols"
);
}
}
#[test]
fn round_trip_partial() {
let training = b"AABBCCDD";
let mut codec = HuffmanCodec::new();
codec.add_to_histogram(training);
codec.build(true);
let test_data = b"AABXCD";
let mut compressed = vec![0u8; test_data.len() * 4];
let n_bits = codec.encode(test_data, &mut compressed);
let mut decompressed = vec![0u8; test_data.len()];
let n_bytes = codec.decode(n_bits, &compressed, &mut decompressed);
assert_eq!(n_bytes, test_data.len());
assert_eq!(&decompressed[..n_bytes], &test_data[..]);
}
#[test]
fn serialise_deserialise_round_trip() {
let data = b"The quick brown fox jumps over the lazy dog";
let mut codec = HuffmanCodec::new();
codec.add_to_histogram(data);
codec.build(true);
let mut serial_buf = vec![0u8; HuffmanCodec::max_serial_size()];
let serial_len = codec.serialise(&mut serial_buf);
assert!(serial_len > 0);
let codec2 = HuffmanCodec::deserialise(&serial_buf[..serial_len]);
let mut compressed = vec![0u8; data.len() * 2];
let n_bits = codec.encode(data, &mut compressed);
let mut decompressed = vec![0u8; data.len()];
let n_bytes = codec2.decode(n_bits, &compressed, &mut decompressed);
assert_eq!(n_bytes, data.len());
assert_eq!(&decompressed[..n_bytes], &data[..]);
}
#[test]
fn uncompressed_fallback() {
let data = b"X";
let mut codec = HuffmanCodec::new();
codec.add_to_histogram(b"AB");
codec.build(false);
let mut compressed = vec![0u8; 16];
let n_bits = codec.encode(data, &mut compressed);
assert_eq!(compressed[0], UNCOMPRESSED_MARKER);
let mut decompressed = vec![0u8; data.len()];
let n_bytes = codec.decode(n_bits, &compressed, &mut decompressed);
assert_eq!(n_bytes, data.len());
assert_eq!(&decompressed[..n_bytes], &data[..]);
}
#[test]
fn empty_input() {
let codec = HuffmanCodec::new();
let mut output = vec![0u8; 16];
let n = codec.decode(0, &[], &mut output);
assert_eq!(n, 0);
}
#[test]
fn histogram_merge() {
let mut c1 = HuffmanCodec::new();
c1.add_to_histogram(b"AAAA");
let mut c2 = HuffmanCodec::new();
c2.add_to_histogram(b"BBBB");
c1.merge_histogram(&c2);
assert_eq!(c1.hist[b'A' as usize], 4);
assert_eq!(c1.hist[b'B' as usize], 4);
}
}