use crate::bitstream::{BitReader, BitWriter};
use crate::error::CdpError;
use crate::huffman::AdaptiveHuffmanTree;
pub const LENGTH_TABLE: [(u32, u32); 16] = [
(0, 0),
(1, 0),
(2, 0),
(3, 0),
(4, 0),
(5, 1),
(7, 1),
(9, 2),
(13, 2),
(17, 3),
(25, 3),
(33, 4),
(49, 4),
(65, 5),
(97, 5),
(129, 7),
];
const EOF_SYMBOL: u32 = 272;
const NUM_LIT_SYMBOLS: usize = 273;
pub fn mode_params(mode: u8) -> Result<(u32, usize, u32), CdpError> {
match mode {
0 => Ok((8, 16, 4)),
1 => Ok((12, 64, 6)),
2 => Ok((14, 64, 8)),
_ => Err(CdpError::InvalidMode(mode)),
}
}
pub fn decompress(comp_data: &[u8], uncomp_size: usize) -> Result<Vec<u8>, CdpError> {
if comp_data.is_empty() {
return Err(CdpError::TruncatedInput);
}
let first_byte = comp_data[0];
let mode = (first_byte >> 4) & 3;
let level = first_byte & 0xf;
if level == 0 {
let end = (1 + uncomp_size).min(comp_data.len());
return Ok(comp_data[1..end].to_vec());
}
let (dist_bits_total, num_dist_symbols, dist_extra_bits) = mode_params(mode)?;
let dist_mask = (1u32 << dist_bits_total) - 1;
let mut lit_tree = AdaptiveHuffmanTree::new(NUM_LIT_SYMBOLS);
let mut dist_tree = AdaptiveHuffmanTree::new(num_dist_symbols);
let mut br = BitReader::new(&comp_data[1..]);
let mut output: Vec<u8> = Vec::with_capacity(uncomp_size);
while output.len() < uncomp_size {
let sym = lit_tree.decode(&mut br)?;
if sym == EOF_SYMBOL {
break;
}
if sym > EOF_SYMBOL {
return Err(CdpError::DecompressFailed(format!("invalid symbol {sym}")));
}
if sym < 256 {
output.push(sym as u8);
} else {
let (base, extra) = LENGTH_TABLE[(sym - 256) as usize];
let extra_val = if extra > 0 { br.read_bits(extra) } else { 0 };
let match_len = (base + extra_val + 3) as usize;
let dist_sym = dist_tree.decode(&mut br)?;
let dist_extra = br.read_bits(dist_extra_bits);
let distance = ((dist_sym << dist_extra_bits) | dist_extra) & dist_mask;
let distance = if distance == 0 { 1 } else { distance } as usize;
for _ in 0..match_len {
if output.len() >= uncomp_size {
break;
}
let byte = if distance <= output.len() {
output[output.len() - distance]
} else {
0
};
output.push(byte);
}
}
}
output.truncate(uncomp_size);
Ok(output)
}
const MAX_MATCH_LEN: usize = 259;
const MIN_MATCH_LEN: usize = 3;
const HASH_BITS: usize = 13;
const HASH_SIZE: usize = 1 << HASH_BITS;
const MAX_CHAIN: usize = 128;
fn encode_length(match_len: usize) -> (u32, u32, u32) {
let len_minus_3 = (match_len - 3) as u32;
let mut code = 0u32;
for i in (0..16).rev() {
if LENGTH_TABLE[i].0 <= len_minus_3 {
code = i as u32;
break;
}
}
let (base, extra_bits) = LENGTH_TABLE[code as usize];
(code, extra_bits, len_minus_3 - base)
}
fn encode_distance(distance: usize, dist_extra_bits: u32) -> (u32, u32) {
let d = distance as u32;
let dist_sym = d >> dist_extra_bits;
let extra = d & ((1u32 << dist_extra_bits) - 1);
(dist_sym, extra)
}
struct MatchFinder {
head: Vec<u32>,
prev: Vec<u32>,
window_mask: usize,
}
impl MatchFinder {
fn new(window_size: usize) -> Self {
Self {
head: vec![u32::MAX; HASH_SIZE],
prev: vec![u32::MAX; window_size],
window_mask: window_size - 1,
}
}
fn hash3(data: &[u8], pos: usize) -> usize {
if pos + 2 >= data.len() {
return 0;
}
let h = (data[pos] as usize)
^ ((data[pos + 1] as usize) << 4)
^ ((data[pos + 2] as usize) << 8);
h & (HASH_SIZE - 1)
}
fn insert(&mut self, data: &[u8], pos: usize) {
if pos + 2 >= data.len() {
return;
}
let h = Self::hash3(data, pos);
self.prev[pos & self.window_mask] = self.head[h];
self.head[h] = pos as u32;
}
fn find_best(&self, data: &[u8], pos: usize, max_dist: usize) -> Option<(usize, usize)> {
if pos + 2 >= data.len() {
return None;
}
let h = Self::hash3(data, pos);
let mut match_pos = self.head[h];
let mut best_len = MIN_MATCH_LEN - 1;
let mut best_dist = 0;
let max_len = MAX_MATCH_LEN.min(data.len() - pos);
let mut chain_len = 0;
while match_pos != u32::MAX && chain_len < MAX_CHAIN {
let mp = match_pos as usize;
let dist = pos - mp;
if dist == 0 || dist > max_dist {
match_pos = self.prev[mp & self.window_mask];
chain_len += 1;
continue;
}
let mut len = 0;
while len < max_len && data[mp + len] == data[pos + len] {
len += 1;
}
if len > best_len {
best_len = len;
best_dist = dist;
if len == max_len {
break;
}
}
match_pos = self.prev[mp & self.window_mask];
chain_len += 1;
}
if best_len >= MIN_MATCH_LEN {
Some((best_dist, best_len))
} else {
None
}
}
}
pub fn compress(data: &[u8], mode: u8, level: u8) -> Result<Vec<u8>, CdpError> {
if data.is_empty() {
let mut out = vec![(mode << 4) | level];
if level == 0 {
return Ok(out);
}
let mut lit_tree = AdaptiveHuffmanTree::new(NUM_LIT_SYMBOLS);
let mut bw = BitWriter::new();
lit_tree.encode(&mut bw, EOF_SYMBOL);
out.extend_from_slice(&bw.finish());
return Ok(out);
}
if level == 0 {
let mut out = vec![mode << 4];
out.extend_from_slice(data);
return Ok(out);
}
let (dist_bits_total, num_dist_symbols, dist_extra_bits) = mode_params(mode)?;
let max_dist = ((1u32 << dist_bits_total) - 1) as usize;
let window_size = max_dist.next_power_of_two();
let mut lit_tree = AdaptiveHuffmanTree::new(NUM_LIT_SYMBOLS);
let mut dist_tree = AdaptiveHuffmanTree::new(num_dist_symbols);
let mut bw = BitWriter::new();
let mut mf = MatchFinder::new(window_size);
let mut pos = 0;
while pos < data.len() {
let best = mf.find_best(data, pos, max_dist);
mf.insert(data, pos);
if let Some((distance, length)) = best {
let (len_code, extra_bits, extra_val) = encode_length(length);
lit_tree.encode(&mut bw, 256 + len_code);
if extra_bits > 0 {
bw.write_bits(extra_val, extra_bits);
}
let (dist_sym, dist_extra) = encode_distance(distance, dist_extra_bits);
dist_tree.encode(&mut bw, dist_sym);
bw.write_bits(dist_extra, dist_extra_bits);
for i in 1..length {
mf.insert(data, pos + i);
}
pos += length;
} else {
lit_tree.encode(&mut bw, data[pos] as u32);
pos += 1;
}
}
lit_tree.encode(&mut bw, EOF_SYMBOL);
let mut out = vec![(mode << 4) | level];
out.extend_from_slice(&bw.finish());
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn round_trip_hello() {
let data = b"hello world hello hello world";
let compressed = compress(data, 2, 9).unwrap();
let decompressed = decompress(&compressed, data.len()).unwrap();
assert_eq!(&decompressed, data);
}
#[test]
fn round_trip_empty() {
let data = b"";
let compressed = compress(data, 2, 9).unwrap();
let decompressed = decompress(&compressed, 0).unwrap();
assert_eq!(&decompressed, data);
}
#[test]
fn round_trip_single_byte() {
let data = b"\x42";
let compressed = compress(data, 2, 9).unwrap();
let decompressed = decompress(&compressed, 1).unwrap();
assert_eq!(&decompressed, data);
}
#[test]
fn round_trip_all_zeros() {
let data = vec![0u8; 4096];
let compressed = compress(&data, 2, 9).unwrap();
assert!(compressed.len() < data.len() / 2, "should compress well");
let decompressed = decompress(&compressed, data.len()).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn round_trip_incompressible() {
let data: Vec<u8> = (0..1000).map(|i| ((i * 137 + 43) % 256) as u8).collect();
let compressed = compress(&data, 2, 9).unwrap();
let decompressed = decompress(&compressed, data.len()).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn round_trip_across_rebuild() {
let data: Vec<u8> = (0..40_000).map(|i| ((i * 97 + 31) % 256) as u8).collect();
let compressed = compress(&data, 2, 9).unwrap();
let decompressed = decompress(&compressed, data.len()).unwrap();
assert_eq!(decompressed, data);
}
#[test]
fn round_trip_store_mode() {
let data = b"stored uncompressed";
let compressed = compress(data, 2, 0).unwrap();
assert_eq!(compressed[0], 0x20); let decompressed = decompress(&compressed, data.len()).unwrap();
assert_eq!(&decompressed, data);
}
#[test]
fn encode_length_table() {
assert_eq!(encode_length(3), (0, 0, 0)); assert_eq!(encode_length(4), (1, 0, 0));
assert_eq!(encode_length(8), (5, 1, 0)); assert_eq!(encode_length(9), (5, 1, 1)); }
}