use crate::error::{CodecError, CodecResult};
use std::collections::HashMap;
const MAX_CODE_SIZE: u8 = 12;
const MAX_CODES: usize = 1 << MAX_CODE_SIZE;
pub struct LzwDecoder {
min_code_size: u8,
code_size: u8,
clear_code: u16,
eoi_code: u16,
next_code: u16,
table: Vec<Vec<u8>>,
prev_code: Option<u16>,
output: Vec<u8>,
}
impl LzwDecoder {
pub fn new(min_code_size: u8) -> CodecResult<Self> {
if !(2..=8).contains(&min_code_size) {
return Err(CodecError::InvalidParameter(format!(
"Invalid LZW min code size: {}",
min_code_size
)));
}
let clear_code = 1 << min_code_size;
let eoi_code = clear_code + 1;
let code_size = min_code_size + 1;
Ok(Self {
min_code_size,
code_size,
clear_code,
eoi_code,
next_code: eoi_code + 1,
table: Vec::with_capacity(MAX_CODES),
prev_code: None,
output: Vec::new(),
})
}
fn init_table(&mut self) {
self.table.clear();
let table_size = 1 << self.min_code_size;
for i in 0..table_size {
self.table.push(vec![i as u8]);
}
self.table.push(Vec::new()); self.table.push(Vec::new());
self.next_code = self.eoi_code + 1;
self.code_size = self.min_code_size + 1;
self.prev_code = None;
}
pub fn decompress(&mut self, data: &[u8], output_size: usize) -> CodecResult<Vec<u8>> {
self.output.clear();
self.output.reserve(output_size);
self.init_table();
let mut bit_reader = BitReader::new(data);
loop {
let code = match bit_reader.read_bits(self.code_size) {
Some(c) => c,
None => break, };
if code == self.clear_code {
self.init_table();
continue;
}
if code == self.eoi_code {
break;
}
if let Some(sequence) = self.get_sequence(code)? {
self.output.extend_from_slice(&sequence);
if let Some(prev) = self.prev_code {
if self.next_code < MAX_CODES as u16 {
let mut new_entry = self.get_sequence(prev)?.ok_or_else(|| {
CodecError::InvalidData(format!("LZW prev_code {} not in table", prev))
})?;
new_entry.push(sequence[0]);
self.table.push(new_entry);
self.next_code += 1;
if self.next_code >= (1 << self.code_size) && self.code_size < MAX_CODE_SIZE
{
self.code_size += 1;
}
}
}
self.prev_code = Some(code);
} else {
return Err(CodecError::InvalidData(format!(
"Invalid LZW code: {}",
code
)));
}
}
Ok(self.output.clone())
}
fn get_sequence(&self, code: u16) -> CodecResult<Option<Vec<u8>>> {
let code_idx = code as usize;
if code_idx < self.table.len() {
Ok(Some(self.table[code_idx].clone()))
} else if code == self.next_code {
if let Some(prev) = self.prev_code {
let mut sequence = self.table[prev as usize].clone();
sequence.push(sequence[0]);
Ok(Some(sequence))
} else {
Ok(None)
}
} else {
Ok(None)
}
}
}
pub struct LzwEncoder {
min_code_size: u8,
code_size: u8,
clear_code: u16,
eoi_code: u16,
next_code: u16,
table: HashMap<Vec<u8>, u16>,
output: Vec<u8>,
}
impl LzwEncoder {
pub fn new(min_code_size: u8) -> CodecResult<Self> {
if !(2..=8).contains(&min_code_size) {
return Err(CodecError::InvalidParameter(format!(
"Invalid LZW min code size: {}",
min_code_size
)));
}
let clear_code = 1 << min_code_size;
let eoi_code = clear_code + 1;
let code_size = min_code_size + 1;
Ok(Self {
min_code_size,
code_size,
clear_code,
eoi_code,
next_code: eoi_code + 1,
table: HashMap::new(),
output: Vec::new(),
})
}
fn init_table(&mut self) {
self.table.clear();
let table_size = 1 << self.min_code_size;
for i in 0..table_size {
self.table.insert(vec![i as u8], i);
}
self.next_code = self.eoi_code + 1;
self.code_size = self.min_code_size + 1;
}
pub fn compress(&mut self, data: &[u8]) -> CodecResult<Vec<u8>> {
self.output.clear();
self.init_table();
let mut bit_writer = BitWriter::new();
bit_writer.write_bits(self.clear_code.into(), self.code_size);
let mut current_sequence = Vec::new();
for &byte in data {
let mut new_sequence = current_sequence.clone();
new_sequence.push(byte);
if self.table.contains_key(&new_sequence) {
current_sequence = new_sequence;
} else {
if let Some(&code) = self.table.get(¤t_sequence) {
bit_writer.write_bits(code.into(), self.code_size);
}
if self.next_code < MAX_CODES as u16 {
self.table.insert(new_sequence.clone(), self.next_code);
self.next_code += 1;
if self.next_code > (1 << self.code_size) && self.code_size < MAX_CODE_SIZE {
self.code_size += 1;
}
}
current_sequence = vec![byte];
if self.next_code >= MAX_CODES as u16 {
bit_writer.write_bits(self.clear_code.into(), self.code_size);
self.init_table();
}
}
}
if !current_sequence.is_empty() {
if let Some(&code) = self.table.get(¤t_sequence) {
bit_writer.write_bits(code.into(), self.code_size);
}
}
bit_writer.write_bits(self.eoi_code.into(), self.code_size);
self.output = bit_writer.finish();
Ok(self.output.clone())
}
}
struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitReader<'a> {
fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
fn read_bits(&mut self, n: u8) -> Option<u16> {
if n > 16 {
return None;
}
let mut result: u16 = 0;
let mut bits_read = 0;
while bits_read < n {
if self.byte_pos >= self.data.len() {
return None;
}
let byte = self.data[self.byte_pos];
let bits_available = 8 - self.bit_pos;
let bits_needed = n - bits_read;
let bits_to_read = bits_available.min(bits_needed);
let mask = ((1u32 << bits_to_read) - 1) as u8;
let bits = (byte >> self.bit_pos) & mask;
result |= u16::from(bits) << bits_read;
bits_read += bits_to_read;
self.bit_pos += bits_to_read;
if self.bit_pos >= 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
}
Some(result)
}
}
struct BitWriter {
data: Vec<u8>,
current_byte: u8,
bit_pos: u8,
}
impl BitWriter {
fn new() -> Self {
Self {
data: Vec::new(),
current_byte: 0,
bit_pos: 0,
}
}
fn write_bits(&mut self, value: u32, n: u8) {
let mut value = value;
let mut bits_written = 0;
while bits_written < n {
let bits_available = 8 - self.bit_pos;
let bits_remaining = n - bits_written;
let bits_to_write = bits_available.min(bits_remaining);
let mask = (1 << bits_to_write) - 1;
let bits = (value & mask) as u8;
self.current_byte |= bits << self.bit_pos;
value >>= bits_to_write;
bits_written += bits_to_write;
self.bit_pos += bits_to_write;
if self.bit_pos >= 8 {
self.data.push(self.current_byte);
self.current_byte = 0;
self.bit_pos = 0;
}
}
}
fn finish(mut self) -> Vec<u8> {
if self.bit_pos > 0 {
self.data.push(self.current_byte);
}
self.data
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_lzw_roundtrip() {
let min_code_size = 8;
let mut encoder = LzwEncoder::new(min_code_size).expect("should succeed");
let mut decoder = LzwDecoder::new(min_code_size).expect("should succeed");
let original = b"TOBEORNOTTOBEORTOBEORNOT";
let compressed = encoder.compress(original).expect("should succeed");
let decompressed = decoder
.decompress(&compressed, original.len())
.expect("should succeed");
assert_eq!(original, decompressed.as_slice());
}
#[test]
fn test_lzw_simple() {
let min_code_size = 2;
let mut encoder = LzwEncoder::new(min_code_size).expect("should succeed");
let mut decoder = LzwDecoder::new(min_code_size).expect("should succeed");
let original = vec![0, 1, 2, 3, 0, 1, 2, 3];
let compressed = encoder.compress(&original).expect("should succeed");
let decompressed = decoder
.decompress(&compressed, original.len())
.expect("should succeed");
assert_eq!(original, decompressed);
}
#[test]
fn test_lzw_repeated_pattern() {
let min_code_size = 8;
let mut encoder = LzwEncoder::new(min_code_size).expect("should succeed");
let mut decoder = LzwDecoder::new(min_code_size).expect("should succeed");
let original = vec![1; 1000];
let compressed = encoder.compress(&original).expect("should succeed");
let decompressed = decoder
.decompress(&compressed, original.len())
.expect("should succeed");
assert_eq!(original, decompressed);
assert!(compressed.len() < original.len());
}
#[test]
fn test_bit_reader_writer() {
let mut writer = BitWriter::new();
writer.write_bits(0b101, 3);
writer.write_bits(0b110, 3);
writer.write_bits(0b1111, 4);
let data = writer.finish();
let mut reader = BitReader::new(&data);
assert_eq!(reader.read_bits(3), Some(0b101));
assert_eq!(reader.read_bits(3), Some(0b110));
assert_eq!(reader.read_bits(4), Some(0b1111));
}
#[test]
fn test_invalid_min_code_size() {
assert!(LzwEncoder::new(1).is_err());
assert!(LzwEncoder::new(9).is_err());
assert!(LzwDecoder::new(1).is_err());
assert!(LzwDecoder::new(9).is_err());
}
}