use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{Algorithm, RawDecoder, RawEncoder, RawProgress};
#[derive(Debug, Clone, Copy, Default)]
pub struct Lzw;
impl Algorithm for Lzw {
const NAME: &'static str = "lzw";
type Encoder = Encoder;
type Decoder = Decoder;
type EncoderConfig = ();
type DecoderConfig = ();
fn encoder_with(_: ()) -> Encoder {
Encoder::new()
}
fn decoder_with(_: ()) -> Decoder {
Decoder::new()
}
}
const MAGIC_1: u8 = 0x1F;
const MAGIC_2: u8 = 0x9D;
const INIT_BITS: u8 = 9;
const MAX_BITS: u8 = 16;
const HEADER_BYTE: u8 = 0x80 | MAX_BITS;
const CLEAR: u32 = 256;
const FIRST: u32 = 257;
const HASH_SIZE: usize = 1 << 17;
const HASH_MASK: u32 = (HASH_SIZE as u32) - 1;
const CHECK_GAP: u64 = 10_000;
#[inline]
fn hash(prefix: u32, byte: u8) -> u32 {
let key = (prefix << 8) | byte as u32;
key.wrapping_mul(2_654_435_761) & HASH_MASK
}
#[derive(Debug, Default)]
struct ByteQueue {
buf: Vec<u8>,
head: usize,
}
impl ByteQueue {
fn new() -> Self {
Self {
buf: Vec::new(),
head: 0,
}
}
fn push(&mut self, b: u8) {
self.buf.push(b);
}
fn len(&self) -> usize {
self.buf.len() - self.head
}
fn is_empty(&self) -> bool {
self.head == self.buf.len()
}
fn drain_into(&mut self, out: &mut [u8]) -> usize {
let n = self.len().min(out.len());
out[..n].copy_from_slice(&self.buf[self.head..self.head + n]);
self.head += n;
if self.head == self.buf.len() {
self.buf.clear();
self.head = 0;
}
n
}
fn clear(&mut self) {
self.buf.clear();
self.head = 0;
}
}
#[derive(Debug)]
pub struct Encoder {
ht_key: Vec<u32>,
ht_code: Vec<u32>,
next_code: u32,
nbits: u8,
bit_acc: u64,
bit_count: u8,
w_code: u32,
codes_in_group: u8,
header_remaining: u8,
pending: ByteQueue,
completed: bool,
bytes_in: u64,
bits_out: u64,
next_check: u64,
best_ratio: u64,
}
impl Encoder {
pub fn new() -> Self {
Self {
ht_key: vec![0u32; HASH_SIZE],
ht_code: vec![0u32; HASH_SIZE],
next_code: FIRST,
nbits: INIT_BITS,
bit_acc: 0,
bit_count: 0,
w_code: u32::MAX,
codes_in_group: 0,
header_remaining: 3,
pending: ByteQueue::new(),
completed: false,
bytes_in: 0,
bits_out: 0,
next_check: u64::MAX,
best_ratio: 0,
}
}
fn reset_dict(&mut self) {
for slot in self.ht_key.iter_mut() {
*slot = 0;
}
for slot in self.ht_code.iter_mut() {
*slot = 0;
}
self.next_code = FIRST;
self.nbits = INIT_BITS;
self.bytes_in = 0;
self.bits_out = 0;
self.next_check = u64::MAX;
self.best_ratio = 0;
}
fn emit_code(&mut self, code: u32) {
let n = self.nbits as u32;
self.bit_acc |= (code as u64) << self.bit_count;
self.bit_count += n as u8;
while self.bit_count >= 8 {
self.pending.push(self.bit_acc as u8);
self.bit_acc >>= 8;
self.bit_count -= 8;
}
self.codes_in_group = (self.codes_in_group + 1) & 7;
self.bits_out = self.bits_out.saturating_add(n as u64);
}
fn pad_to_group_boundary(&mut self) {
while self.codes_in_group != 0 {
self.emit_code(0);
}
debug_assert_eq!(self.bit_count, 0);
debug_assert_eq!(self.bit_acc, 0);
}
fn lookup(&self, prefix: u32, byte: u8) -> Result<u32, usize> {
let key = (prefix << 8) | byte as u32;
let mut idx = hash(prefix, byte) as usize;
loop {
let slot_code = self.ht_code[idx];
if slot_code == 0 {
return Err(idx);
}
if self.ht_key[idx] == key {
return Ok(slot_code);
}
idx = (idx + 1) & (HASH_SIZE - 1);
}
}
fn insert(&mut self, slot: usize, prefix: u32, byte: u8, code: u32) {
self.ht_key[slot] = (prefix << 8) | byte as u32;
self.ht_code[slot] = code;
}
fn ensure_header(&mut self) {
while self.header_remaining > 0 {
let b = match self.header_remaining {
3 => MAGIC_1,
2 => MAGIC_2,
1 => HEADER_BYTE,
_ => unreachable!(),
};
self.pending.push(b);
self.header_remaining -= 1;
}
}
fn maybe_widen(&mut self) {
if self.nbits < MAX_BITS {
let threshold = (1u32 << self.nbits) + 1;
if self.next_code >= threshold {
self.pad_to_group_boundary();
self.nbits += 1;
}
} else if self.next_code >= (1u32 << MAX_BITS) && self.next_check == u64::MAX {
self.next_check = self.bytes_in.saturating_add(CHECK_GAP);
}
}
fn check_ratio(&mut self) {
let ratio = (self.bytes_in << 8)
.checked_div(self.bits_out)
.unwrap_or(u64::MAX);
if ratio >= self.best_ratio {
self.best_ratio = ratio;
self.next_check = self.bytes_in.saturating_add(CHECK_GAP);
} else {
self.emit_code(CLEAR);
self.pad_to_group_boundary();
self.reset_dict();
}
}
}
impl Default for Encoder {
fn default() -> Self {
Self::new()
}
}
impl RawEncoder for Encoder {
fn raw_encode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
self.ensure_header();
let mut consumed = 0usize;
let mut written = 0usize;
if !self.pending.is_empty() {
written += self.pending.drain_into(&mut output[written..]);
}
while consumed < input.len() {
if self.pending.len() >= output.len().saturating_sub(written) + 64 {
break;
}
let b = input[consumed];
if self.w_code == u32::MAX {
self.w_code = b as u32;
consumed += 1;
self.bytes_in = self.bytes_in.saturating_add(1);
continue;
}
match self.lookup(self.w_code, b) {
Ok(existing) => {
self.w_code = existing;
consumed += 1;
self.bytes_in = self.bytes_in.saturating_add(1);
}
Err(slot) => {
let prefix = self.w_code;
self.emit_code(prefix);
if self.next_code < (1u32 << MAX_BITS) {
self.insert(slot, prefix, b, self.next_code);
self.next_code += 1;
}
self.maybe_widen();
self.w_code = b as u32;
consumed += 1;
self.bytes_in = self.bytes_in.saturating_add(1);
if self.bytes_in >= self.next_check {
self.check_ratio();
}
}
}
if !self.pending.is_empty() && written < output.len() {
written += self.pending.drain_into(&mut output[written..]);
}
}
if !self.pending.is_empty() && written < output.len() {
written += self.pending.drain_into(&mut output[written..]);
}
Ok(RawProgress {
consumed,
written,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.completed {
return Ok(RawProgress {
consumed: 0,
written: 0,
done: true,
});
}
self.ensure_header();
if self.w_code != u32::MAX {
let c = self.w_code;
self.emit_code(c);
self.w_code = u32::MAX;
}
if self.bit_count > 0 {
self.pending.push(self.bit_acc as u8);
self.bit_acc = 0;
self.bit_count = 0;
self.codes_in_group = 0;
}
let mut written = 0usize;
if !self.pending.is_empty() {
written += self.pending.drain_into(&mut output[written..]);
}
let done = self.pending.is_empty();
if done {
self.completed = true;
}
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
for slot in self.ht_key.iter_mut() {
*slot = 0;
}
for slot in self.ht_code.iter_mut() {
*slot = 0;
}
self.next_code = FIRST;
self.nbits = INIT_BITS;
self.bit_acc = 0;
self.bit_count = 0;
self.w_code = u32::MAX;
self.codes_in_group = 0;
self.header_remaining = 3;
self.pending.clear();
self.completed = false;
self.bytes_in = 0;
self.bits_out = 0;
self.next_check = u64::MAX;
self.best_ratio = 0;
}
}
#[derive(Debug)]
pub struct Decoder {
header_pos: u8,
block_mode: bool,
maxbits: u8,
prefix: Vec<u16>,
suffix: Vec<u8>,
next_code: u32,
nbits: u8,
bit_acc: u64,
bit_count: u8,
prev_code: u32,
finchar: u8,
codes_in_group: u8,
emit_buf: Vec<u8>,
emit_head: usize,
stack: Vec<u8>,
completed: bool,
}
impl Decoder {
pub fn new() -> Self {
let max_size = 1usize << MAX_BITS;
Self {
header_pos: 0,
block_mode: true,
maxbits: MAX_BITS,
prefix: vec![0u16; max_size],
suffix: vec![0u8; max_size],
next_code: FIRST,
nbits: INIT_BITS,
bit_acc: 0,
bit_count: 0,
prev_code: u32::MAX,
finchar: 0,
codes_in_group: 0,
emit_buf: Vec::new(),
emit_head: 0,
stack: Vec::with_capacity(max_size),
completed: false,
}
}
fn reset_dict(&mut self) {
self.next_code = if self.block_mode { FIRST } else { 256 };
self.nbits = INIT_BITS;
self.prev_code = u32::MAX;
self.codes_in_group = 0;
}
fn try_read_code(&mut self, input: &[u8], in_cursor: &mut usize) -> Option<u32> {
let need = self.nbits as u32;
while self.bit_count < need as u8 {
if *in_cursor >= input.len() {
return None;
}
self.bit_acc |= (input[*in_cursor] as u64) << self.bit_count;
self.bit_count += 8;
*in_cursor += 1;
}
let mask = if need == 64 {
u64::MAX
} else {
(1u64 << need) - 1
};
let code = (self.bit_acc & mask) as u32;
self.bit_acc >>= need;
self.bit_count -= need as u8;
self.codes_in_group = (self.codes_in_group + 1) & 7;
Some(code)
}
fn skip_to_group_boundary(&mut self, input: &[u8], in_cursor: &mut usize) -> bool {
while self.codes_in_group != 0 {
if self.try_read_code(input, in_cursor).is_none() {
return false;
}
}
true
}
fn ensure_header(&mut self, input: &[u8], in_cursor: &mut usize) -> Result<bool, Error> {
while self.header_pos < 3 && *in_cursor < input.len() {
let b = input[*in_cursor];
*in_cursor += 1;
match self.header_pos {
0 => {
if b != MAGIC_1 {
return Err(Error::BadHeader);
}
}
1 => {
if b != MAGIC_2 {
return Err(Error::BadHeader);
}
}
2 => {
let mb = b & 0x1F;
self.block_mode = (b & 0x80) != 0;
if !(INIT_BITS..=MAX_BITS).contains(&mb) {
return Err(Error::Unsupported);
}
self.maxbits = mb;
self.next_code = if self.block_mode { FIRST } else { 256 };
}
_ => unreachable!(),
}
self.header_pos += 1;
}
Ok(self.header_pos >= 3)
}
fn decode_string_to_emit_buf(&mut self, mut code: u32) {
self.stack.clear();
while code >= 256 {
self.stack.push(self.suffix[code as usize]);
code = self.prefix[code as usize] as u32;
}
let first = code as u8;
self.finchar = first;
self.emit_buf.push(first);
while let Some(b) = self.stack.pop() {
self.emit_buf.push(b);
}
}
fn drain_emit(&mut self, out: &mut [u8]) -> usize {
let available = self.emit_buf.len() - self.emit_head;
let n = available.min(out.len());
out[..n].copy_from_slice(&self.emit_buf[self.emit_head..self.emit_head + n]);
self.emit_head += n;
if self.emit_head == self.emit_buf.len() {
self.emit_buf.clear();
self.emit_head = 0;
}
n
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
let mut in_cursor = 0usize;
let mut written = 0usize;
if self.header_pos < 3 {
self.ensure_header(input, &mut in_cursor)?;
}
loop {
if self.emit_head < self.emit_buf.len() {
written += self.drain_emit(&mut output[written..]);
if self.emit_head < self.emit_buf.len() {
return Ok(RawProgress {
consumed: in_cursor,
written,
done: false,
});
}
}
if self.header_pos < 3 {
return Ok(RawProgress {
consumed: in_cursor,
written,
done: false,
});
}
let bump_threshold = if self.nbits < self.maxbits {
(1u32 << self.nbits) - 1
} else {
u32::MAX
};
if self.next_code > bump_threshold && self.nbits < self.maxbits {
if !self.skip_to_group_boundary(input, &mut in_cursor) {
return Ok(RawProgress {
consumed: in_cursor,
written,
done: false,
});
}
self.nbits += 1;
}
let code = match self.try_read_code(input, &mut in_cursor) {
Some(c) => c,
None => {
return Ok(RawProgress {
consumed: in_cursor,
written,
done: false,
});
}
};
if self.block_mode && code == CLEAR {
if !self.skip_to_group_boundary(input, &mut in_cursor) {
return Ok(RawProgress {
consumed: in_cursor,
written,
done: false,
});
}
self.reset_dict();
continue;
}
if self.prev_code == u32::MAX {
if code >= 256 {
return Err(Error::Corrupt);
}
self.finchar = code as u8;
self.emit_buf.push(code as u8);
self.prev_code = code;
continue;
}
if code > self.next_code {
return Err(Error::Corrupt);
}
if code == self.next_code {
let prev = self.prev_code;
self.decode_string_to_emit_buf(prev);
self.emit_buf.push(self.finchar);
} else {
self.decode_string_to_emit_buf(code);
}
if self.next_code < (1u32 << self.maxbits) {
let nc = self.next_code as usize;
self.prefix[nc] = self.prev_code as u16;
self.suffix[nc] = self.finchar;
self.next_code += 1;
}
self.prev_code = code;
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.completed {
return Ok(RawProgress {
consumed: 0,
written: 0,
done: true,
});
}
let mut written = 0usize;
if self.emit_head < self.emit_buf.len() {
written += self.drain_emit(&mut output[written..]);
if self.emit_head < self.emit_buf.len() {
return Ok(RawProgress {
consumed: 0,
written,
done: false,
});
}
}
if self.header_pos < 3 && self.header_pos != 0 {
return Err(Error::UnexpectedEnd);
}
self.completed = true;
Ok(RawProgress {
consumed: 0,
written,
done: true,
})
}
fn raw_reset(&mut self) {
self.header_pos = 0;
self.block_mode = true;
self.maxbits = MAX_BITS;
self.next_code = FIRST;
self.nbits = INIT_BITS;
self.bit_acc = 0;
self.bit_count = 0;
self.prev_code = u32::MAX;
self.finchar = 0;
self.codes_in_group = 0;
self.emit_buf.clear();
self.emit_head = 0;
self.stack.clear();
self.completed = false;
}
}