use gamut_bitstream::BitWriter;
use gamut_core::{Error, Result};
const CLEAR: u32 = 256;
const EOI: u32 = 257;
const FIRST: u32 = 258;
const MAX_WIDTH: u32 = 12;
const RESET_AT: u32 = 4094;
#[must_use]
pub fn encode(data: &[u8]) -> Vec<u8> {
use std::collections::HashMap;
let mut out = BitWriter::new();
let mut width = 9u32;
out.put_bits(CLEAR, width);
let Some((&first, rest)) = data.split_first() else {
out.put_bits(EOI, width);
return out.into_bytes();
};
let mut table: HashMap<(u32, u8), u32> = HashMap::new();
let mut next_code = FIRST;
let mut omega = u32::from(first);
for &k in rest {
if let Some(&code) = table.get(&(omega, k)) {
omega = code;
} else {
out.put_bits(omega, width);
table.insert((omega, k), next_code);
next_code += 1;
if next_code == (1 << width) && width < MAX_WIDTH {
width += 1;
}
if next_code == RESET_AT {
out.put_bits(CLEAR, width);
table.clear();
next_code = FIRST;
width = 9;
}
omega = u32::from(k);
}
}
out.put_bits(omega, width);
out.put_bits(EOI, width);
out.into_bytes()
}
struct BitReader<'a> {
data: &'a [u8],
pos: usize,
}
impl BitReader<'_> {
fn read(&mut self, n: u32) -> Option<u32> {
let mut value = 0u32;
for _ in 0..n {
let byte = *self.data.get(self.pos / 8)?;
value = (value << 1) | u32::from((byte >> (7 - (self.pos % 8))) & 1);
self.pos += 1;
}
Some(value)
}
}
fn init_table() -> Vec<Vec<u8>> {
let mut table: Vec<Vec<u8>> = (0..=255u32).map(|b| vec![b as u8]).collect();
table.push(Vec::new()); table.push(Vec::new()); table
}
pub fn decode(data: &[u8], expected: usize) -> Result<Vec<u8>> {
let mut reader = BitReader { data, pos: 0 };
let mut out = Vec::with_capacity(expected.min(1 << 16));
let mut table = init_table();
let mut width = 9u32;
let mut prev: Option<u32> = None;
while let Some(code) = reader.read(width) {
if code == EOI {
break;
}
if code == CLEAR {
table = init_table();
width = 9;
prev = None;
continue;
}
let entry = if (code as usize) < table.len() {
table[code as usize].clone()
} else if code as usize == table.len() {
let p = prev.ok_or(Error::InvalidInput("LZW: code before ClearCode"))? as usize;
let mut s = table[p].clone();
s.push(table[p][0]);
s
} else {
return Err(Error::InvalidInput("LZW: code out of range"));
};
out.extend_from_slice(&entry);
if let Some(p) = prev {
let mut s = table[p as usize].clone();
s.push(entry[0]);
table.push(s);
if table.len() == ((1 << width) - 1) as usize && width < MAX_WIDTH {
width += 1;
}
}
prev = Some(code);
}
if out.len() < expected {
return Err(Error::InvalidInput(
"LZW: decoded fewer bytes than expected",
));
}
out.truncate(expected);
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
fn roundtrip(data: &[u8]) {
let enc = encode(data);
let dec = decode(&enc, data.len()).expect("decode");
assert_eq!(dec, data);
}
#[test]
fn roundtrips_varied_inputs() {
roundtrip(&[]);
roundtrip(&[42]);
roundtrip(&[1, 2, 3, 4, 5]);
roundtrip(&[7, 7, 7, 7, 7, 7, 7, 7]);
roundtrip(b"TOBEORNOTTOBEORTOBEORNOT");
let big: Vec<u8> = (0..20000u32)
.map(|i| (i.wrapping_mul(2654435761) >> 13) as u8)
.collect();
roundtrip(&big);
roundtrip(&vec![0xABu8; 10000]);
}
}