use std::io::{self, Error, Read};
use bitstream_io::{BitRead, BitReader, Endianness, LittleEndian};
const MIN_CODE_SIZE: u8 = 9;
const MAX_CODE_SIZE: u8 = 13;
const MAX_CODE: usize = (1 << MAX_CODE_SIZE) - 1;
const CONTROL_CODE: usize = 256;
const INC_CODE_SIZE: u16 = 1;
const PARTIAL_CLEAR: u16 = 2;
const FREE_CODE_QUEUE_SIZE: usize = MAX_CODE - CONTROL_CODE + 1;
const UNKNOWN_LEN: u16 = u16::MAX;
struct CodeQueue {
next_idx: usize,
codes: [Option<u16>; FREE_CODE_QUEUE_SIZE],
}
impl CodeQueue {
fn new() -> Self {
let mut codes = [None; FREE_CODE_QUEUE_SIZE];
for (i, code) in ((CONTROL_CODE as u16 + 1)..=(MAX_CODE as u16 - 1)).enumerate() {
codes[i] = Some(code);
}
Self { next_idx: 0, codes }
}
fn next(&self) -> Option<u16> {
if let Some(Some(next)) = self.codes.get(self.next_idx) {
Some(*next)
} else {
None
}
}
fn remove_next(&mut self) -> Option<u16> {
let res = self.next();
if res.is_some() {
self.next_idx += 1;
}
res
}
}
#[derive(Clone, Debug, Copy, Default)]
struct Codetab {
last_dst_pos: usize,
prefix_code: Option<u16>,
len: u16,
ext_byte: u8,
}
impl Codetab {
pub fn create_new() -> [Self; MAX_CODE + 1] {
let mut codetab = [Codetab::default(); MAX_CODE + 1];
for (i, code) in codetab.iter_mut().enumerate().take(u8::MAX as usize + 1) {
*code = Codetab {
prefix_code: Some(i as u16),
ext_byte: i as u8,
len: 1,
last_dst_pos: 0,
};
}
codetab
}
}
fn unshrink_partial_clear(codetab: &mut [Codetab], queue: &mut CodeQueue) {
let mut is_prefix = [false; MAX_CODE + 1];
for code in codetab.iter().take(MAX_CODE + 1).skip(CONTROL_CODE + 1) {
if let Some(prefix_code) = code.prefix_code {
is_prefix[prefix_code as usize] = true;
}
}
let mut code_queue_size = 0;
for i in (CONTROL_CODE + 1)..MAX_CODE {
if !is_prefix[i] {
codetab[i].prefix_code = None;
queue.codes[code_queue_size] = Some(i as u16);
code_queue_size += 1;
}
}
queue.codes[code_queue_size] = None; queue.next_idx = 0;
}
fn read_code<T: Read, E: Endianness>(
is: &mut BitReader<T, E>,
code_size: &mut u8,
codetab: &mut [Codetab],
queue: &mut CodeQueue,
) -> io::Result<Option<u16>> {
let code = is.read_var::<u16>(u32::from(*code_size))?;
if code != CONTROL_CODE as u16 {
return Ok(Some(code));
}
if let Ok(control_code) = is.read_var::<u16>(u32::from(*code_size)) {
match control_code {
INC_CODE_SIZE => {
if *code_size >= MAX_CODE_SIZE {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"tried to increase code size when already at maximum",
));
}
*code_size += 1;
}
PARTIAL_CLEAR => {
unshrink_partial_clear(codetab, queue);
}
_ => {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
format!("Invalid control code {}", control_code),
));
}
}
return read_code(is, code_size, codetab, queue);
}
Ok(None)
}
fn output_code(
dst: &mut Vec<u8>,
code: u16,
prev_code: u16,
codetab: &mut [Codetab],
queue: &CodeQueue,
first_byte: u8,
) -> io::Result<u8> {
debug_assert!(code <= MAX_CODE as u16 && code != CONTROL_CODE as u16);
if code <= u16::from(u8::MAX) {
dst.push(code as u8);
return Ok(code as u8);
}
if codetab[code as usize].prefix_code.is_none()
|| codetab[code as usize].prefix_code == Some(code)
{
return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid code"));
}
if codetab[code as usize].len != UNKNOWN_LEN {
let ct = &codetab[code as usize];
for i in ct.last_dst_pos..ct.last_dst_pos + ct.len as usize {
dst.push(dst[i]);
}
return Ok(dst[ct.last_dst_pos]);
}
let prefix_code = codetab[code as usize]
.prefix_code
.ok_or_else(|| io::Error::other("Cannot get prefix code"))?;
if cfg!(debug_assertions) {
let tab_entry = codetab[code as usize];
assert!(tab_entry.len == UNKNOWN_LEN);
assert!(
tab_entry
.prefix_code
.ok_or_else(|| io::Error::other("Cannot get prefix code"))? as usize
> CONTROL_CODE
);
}
if Some(prefix_code) == queue.next() {
codetab[prefix_code as usize] = Codetab {
prefix_code: Some(prev_code),
ext_byte: first_byte,
len: codetab[prev_code as usize].len + 1,
last_dst_pos: codetab[prev_code as usize].last_dst_pos,
};
dst.push(first_byte);
} else if codetab[prefix_code as usize].prefix_code.is_none() {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"invalid prefix code",
));
}
let ct = codetab[prefix_code as usize];
let last_dst_pos = dst.len();
let start = ct.last_dst_pos;
let len = ct.len as usize;
let first = dst[start];
dst.extend_from_within(start..start + len);
dst.push(codetab[code as usize].ext_byte);
debug_assert!(prev_code != code);
codetab[code as usize].len = (len + 1) as u16;
codetab[code as usize].last_dst_pos = last_dst_pos;
Ok(first)
}
fn hwunshrink(src: &[u8], uncompressed_size: usize, dst: &mut Vec<u8>) -> io::Result<()> {
dst.reserve(uncompressed_size);
let mut codetab = Codetab::create_new();
let mut queue = CodeQueue::new();
let mut is = BitReader::endian(src, LittleEndian);
let mut code_size = MIN_CODE_SIZE;
let Some(curr_code) = read_code(&mut is, &mut code_size, &mut codetab, &mut queue)? else {
return Ok(());
};
debug_assert!(curr_code != CONTROL_CODE as u16);
if curr_code > u16::from(u8::MAX) {
return Err(io::Error::new(
io::ErrorKind::InvalidData,
"the first code must be a literal",
));
}
let mut first_byte = curr_code as u8; codetab[curr_code as usize].last_dst_pos = dst.len();
dst.push(first_byte);
let mut prev_code = curr_code;
while dst.len() < uncompressed_size {
let curr_opt = match read_code(&mut is, &mut code_size, &mut codetab, &mut queue) {
Ok(c) => c,
Err(_) => break, };
let Some(curr_code) = curr_opt else {
return Err(Error::new(io::ErrorKind::InvalidData, "invalid code"));
};
let dst_pos = dst.len();
if Some(curr_code) == queue.next() {
if codetab[prev_code as usize].prefix_code.is_none() {
return Err(Error::new(
io::ErrorKind::InvalidData,
"previous code no longer valid",
));
}
debug_assert!(curr_code != prev_code);
codetab[curr_code as usize] = Codetab {
prefix_code: Some(prev_code),
ext_byte: first_byte,
len: codetab[prev_code as usize].len + 1,
last_dst_pos: codetab[prev_code as usize].last_dst_pos,
};
}
first_byte = output_code(dst, curr_code, prev_code, &mut codetab, &queue, first_byte)?;
if let Some(new_code) = queue.remove_next() {
let prev_entry = codetab[prev_code as usize];
codetab[new_code as usize] = Codetab {
prefix_code: Some(prev_code),
ext_byte: first_byte,
last_dst_pos: prev_entry.last_dst_pos,
len: if prev_entry.prefix_code.is_none() {
UNKNOWN_LEN
} else {
prev_entry.len + 1
},
};
}
codetab[curr_code as usize].last_dst_pos = dst_pos;
prev_code = curr_code;
}
if dst.len() != uncompressed_size {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"unexpected end of compressed stream",
));
}
Ok(())
}
#[derive(Debug)]
pub struct ShrinkDecoder<R> {
compressed_reader: R,
stream_read: bool,
uncompressed_size: u64,
stream: Vec<u8>,
read_pos: usize,
}
impl<R: Read> ShrinkDecoder<R> {
pub fn new(inner: R, uncompressed_size: u64) -> Self {
Self {
compressed_reader: inner,
uncompressed_size,
stream_read: false,
stream: Vec::new(),
read_pos: 0,
}
}
pub fn into_inner(self) -> R {
self.compressed_reader
}
}
impl<R: Read> Read for ShrinkDecoder<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if !self.stream_read {
self.stream_read = true;
let mut compressed_bytes = Vec::new();
self.compressed_reader.read_to_end(&mut compressed_bytes)?;
self.stream.reserve(self.uncompressed_size as usize);
hwunshrink(
&compressed_bytes,
self.uncompressed_size as usize,
&mut self.stream,
)?;
}
let available = self.stream.len().saturating_sub(self.read_pos);
if available == 0 {
return Ok(0);
}
let n = available.min(buf.len());
buf[..n].copy_from_slice(&self.stream[self.read_pos..self.read_pos + n]);
self.read_pos += n;
Ok(n)
}
}
#[cfg(test)]
mod tests {
use crate::legacy::shrink::hwunshrink;
const LZW_FIG5: &[u8; 17] = b"ababcbababaaaaaaa";
const LZW_FIG5_SHRUNK: [u8; 12] = [
0x61, 0xc4, 0x04, 0x1c, 0x23, 0xb0, 0x60, 0x98, 0x83, 0x08, 0xc3, 0x00,
];
#[test]
fn test_unshrink_lzw_fig5() {
let mut dst = Vec::with_capacity(LZW_FIG5.len());
hwunshrink(&LZW_FIG5_SHRUNK, LZW_FIG5.len(), &mut dst).unwrap();
assert_eq!(dst.as_slice(), LZW_FIG5);
}
}