use alloc::vec::Vec;
const WINDOW_BITS: usize = 12;
const WINDOW_SIZE: usize = 1 << WINDOW_BITS;
const LENGTH_BITS: usize = 4;
const MIN_MATCH: usize = 3;
const MAX_MATCH: usize = MIN_MATCH + (1 << LENGTH_BITS) - 1;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LzssError {
Truncated,
InvalidReference { at: usize },
OutputOverrun,
}
#[must_use]
pub fn compress(input: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(input.len() / 2 + 4);
let orig_len = u32::try_from(input.len()).unwrap_or(u32::MAX);
out.extend_from_slice(&orig_len.to_le_bytes());
if input.is_empty() {
return out;
}
let mut i: usize = 0;
let mut flag_byte: u8 = 0;
let mut frame: Vec<u8> = Vec::with_capacity(8 * 3);
let mut tokens_in_frame: usize = 0;
while i < input.len() {
let (match_off, match_len) = find_longest_match(input, i);
if match_len >= MIN_MATCH {
let packed: u16 = u16::try_from(match_off).expect("offset < 4096") << LENGTH_BITS
| u16::try_from(match_len - MIN_MATCH).expect("length < 16");
frame.extend_from_slice(&packed.to_le_bytes());
i += match_len;
} else {
flag_byte |= 1u8 << tokens_in_frame;
frame.push(input[i]);
i += 1;
}
tokens_in_frame += 1;
if tokens_in_frame == 8 {
out.push(flag_byte);
out.extend_from_slice(&frame);
flag_byte = 0;
frame.clear();
tokens_in_frame = 0;
}
}
if tokens_in_frame > 0 {
out.push(flag_byte);
out.extend_from_slice(&frame);
}
out
}
pub fn decompress(input: &[u8]) -> Result<Vec<u8>, LzssError> {
if input.len() < 4 {
return Err(LzssError::Truncated);
}
let orig_len = u32::from_le_bytes([input[0], input[1], input[2], input[3]]) as usize;
let mut out: Vec<u8> = Vec::with_capacity(orig_len);
let mut p: usize = 4;
while out.len() < orig_len {
if p >= input.len() {
return Err(LzssError::Truncated);
}
let flag = input[p];
p += 1;
for bit in 0..8 {
if out.len() >= orig_len {
break;
}
if (flag >> bit) & 1 == 1 {
if p >= input.len() {
return Err(LzssError::Truncated);
}
out.push(input[p]);
p += 1;
} else {
if p + 2 > input.len() {
return Err(LzssError::Truncated);
}
let packed = u16::from_le_bytes([input[p], input[p + 1]]);
p += 2;
let offset = (packed >> LENGTH_BITS) as usize;
let length = (packed as usize & ((1 << LENGTH_BITS) - 1)) + MIN_MATCH;
if offset == 0 || offset > out.len() {
return Err(LzssError::InvalidReference { at: p - 2 });
}
let start = out.len() - offset;
for j in 0..length {
if out.len() >= orig_len {
return Err(LzssError::OutputOverrun);
}
let byte = out[start + j];
out.push(byte);
}
}
}
}
out.truncate(orig_len);
Ok(out)
}
fn find_longest_match(input: &[u8], pos: usize) -> (usize, usize) {
if pos == 0 || input.len() - pos < MIN_MATCH {
return (0, 0);
}
let max_lookahead = MAX_MATCH.min(input.len() - pos);
let window_start = pos.saturating_sub(WINDOW_SIZE - 1);
let mut best_off = 0;
let mut best_len = 0;
let mut probe = window_start;
while probe < pos {
let mut len = 0;
while len < max_lookahead && input[probe + len] == input[pos + len] {
len += 1;
}
if len > best_len {
best_len = len;
best_off = pos - probe;
if best_len == max_lookahead {
break;
}
}
probe += 1;
}
(best_off, best_len)
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::string::ToString;
use alloc::vec;
#[test]
fn empty_input_round_trips() {
let c = compress(b"");
let d = decompress(&c).unwrap();
assert_eq!(d, b"");
}
#[test]
fn single_byte_round_trips() {
let c = compress(b"A");
let d = decompress(&c).unwrap();
assert_eq!(d, b"A");
}
#[test]
fn all_distinct_64_bytes_no_compression() {
let input: Vec<u8> = (0..64u8).collect();
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
assert!(c.len() > input.len(), "all-distinct → no compression");
}
#[test]
fn repeated_byte_run_compresses_dramatically() {
let input = vec![b'X'; 4096];
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
assert!(
c.len() * 8 < input.len(),
"4 KiB of X must compress > 8×; got {} → {}",
input.len(),
c.len()
);
}
#[test]
fn repeated_substring_compresses() {
let input: Vec<u8> = b"the quick brown fox jumps "
.iter()
.cycle()
.take(4096)
.copied()
.collect();
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
assert!(
c.len() * 2 < input.len(),
"repeated substring must compress ≥ 2×; got {} → {}",
input.len(),
c.len()
);
}
#[test]
fn max_match_18_bytes_at_boundary() {
let mut input = Vec::with_capacity(40);
input.extend_from_slice(b"abcdefghij");
input.extend_from_slice(b"abcdefghij");
input.extend_from_slice(b"abcdefghij");
input.extend_from_slice(b"abcdefghij");
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
}
#[test]
fn window_wrap_around_at_4_kib_correct() {
let mut input = vec![b'A'; 4096];
input.push(b'B');
input.extend_from_slice(&[b'A'; 18]);
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
}
#[test]
fn decode_errors_on_truncated_header() {
let err = decompress(b"xyz").unwrap_err();
assert_eq!(err, LzssError::Truncated);
}
#[test]
fn decode_errors_on_truncated_payload() {
let mut buf = Vec::new();
buf.extend_from_slice(&64u32.to_le_bytes());
let err = decompress(&buf).unwrap_err();
assert_eq!(err, LzssError::Truncated);
}
#[test]
fn decode_errors_on_back_reference_into_void() {
let mut buf = Vec::new();
buf.extend_from_slice(&16u32.to_le_bytes());
buf.push(0b00000000); buf.extend_from_slice(&5u16.to_le_bytes()); let err = decompress(&buf).unwrap_err();
assert!(matches!(err, LzssError::InvalidReference { .. }));
}
#[test]
fn round_trip_random_1_kib() {
let mut input = Vec::with_capacity(1024);
let mut state: u64 = 0xdead_beef_dead_beef;
for _ in 0..1024 {
state = state.wrapping_mul(6364136223846793005).wrapping_add(1);
input.push((state >> 32) as u8);
}
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
}
#[test]
fn round_trip_canonical_sql_corpus() {
let mut input = Vec::with_capacity(64 * 1024);
for i in 0..1024_u32 {
input.extend_from_slice(b"INSERT INTO t VALUES (");
input.extend_from_slice(i.to_string().as_bytes());
input.extend_from_slice(b", 'short text');\n");
}
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
assert!(
c.len() * 2 < input.len(),
"canonical SQL must compress ≥ 2×; got {} → {}",
input.len(),
c.len()
);
}
#[test]
fn rle_overlap_grow_into_source_region() {
let input: Vec<u8> = b"AB".iter().cycle().take(256).copied().collect();
let c = compress(&input);
let d = decompress(&c).unwrap();
assert_eq!(d, input);
}
}