use std::convert::TryInto;
pub fn inflate(input: &[u8]) -> Result<Vec<u8>, &'static str> {
let mut br = BitReader::new(input);
let mut out = Vec::new();
loop {
let bfinal = br.read_bits(1)?;
let btype = br.read_bits(2)?;
match btype {
0 => inflate_stored(&mut br, &mut out)?,
1 => inflate_huffman(&mut br, &mut out, &fixed_litlen_table(), &fixed_dist_table())?,
2 => {
let (litlen, dist) = read_dynamic_tables(&mut br)?;
inflate_huffman(&mut br, &mut out, &litlen, &dist)?;
}
_ => return Err("invalid DEFLATE block type"),
}
if bfinal == 1 {
break;
}
}
Ok(out)
}
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8, }
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
pub fn read_bits(&mut self, n: u8) -> Result<u32, &'static str> {
debug_assert!(n <= 32);
let mut val = 0u32;
for i in 0..n {
if self.byte_pos >= self.data.len() {
return Err("DEFLATE: unexpected EOF");
}
let bit = (self.data[self.byte_pos] >> self.bit_pos) & 1;
val |= (bit as u32) << i;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
}
Ok(val)
}
pub fn peek_msb_bits(&self, n: usize) -> Option<u32> {
let mut val = 0u32;
let mut byte_pos = self.byte_pos;
let mut bit_pos = self.bit_pos;
for i in 0..n {
if byte_pos >= self.data.len() {
return None;
}
let bit = (self.data[byte_pos] >> bit_pos) & 1;
val |= (bit as u32) << (n - 1 - i);
bit_pos += 1;
if bit_pos == 8 {
bit_pos = 0;
byte_pos += 1;
}
}
Some(val)
}
pub fn consume_bits(&mut self, n: usize) {
let total = self.bit_pos as usize + n;
self.byte_pos += total / 8;
self.bit_pos = (total % 8) as u8;
}
fn align_to_byte(&mut self) {
if self.bit_pos != 0 {
self.bit_pos = 0;
self.byte_pos += 1;
}
}
fn read_bytes(&mut self, n: usize) -> Result<&'a [u8], &'static str> {
debug_assert_eq!(self.bit_pos, 0);
if self.byte_pos + n > self.data.len() {
return Err("DEFLATE: unexpected EOF (bytes)");
}
let s = &self.data[self.byte_pos..self.byte_pos + n];
self.byte_pos += n;
Ok(s)
}
}
fn inflate_stored(br: &mut BitReader, out: &mut Vec<u8>) -> Result<(), &'static str> {
br.align_to_byte();
let header = br.read_bytes(4)?;
let len = u16::from_le_bytes(header[0..2].try_into().unwrap()) as usize;
let nlen = u16::from_le_bytes(header[2..4].try_into().unwrap());
if (len as u16) ^ nlen != 0xFFFF {
return Err("DEFLATE: stored LEN/NLEN mismatch");
}
let data = br.read_bytes(len)?;
out.extend_from_slice(data);
Ok(())
}
pub struct HuffmanTable {
base_code: [u32; 16],
base_sym: [u32; 16],
symbols: Vec<u16>,
fast_path: [u32; 512],
max_len: usize,
}
const FAST_PATH_BITS: usize = 9;
impl HuffmanTable {
pub fn from_lengths(lengths: &[u8]) -> Result<Self, &'static str> {
let max_len = *lengths.iter().max().unwrap_or(&0) as usize;
if max_len == 0 {
return Ok(Self {
base_code: [0; 16],
base_sym: [0; 16],
symbols: Vec::new(),
fast_path: [0; 512],
max_len: 0,
});
}
if max_len > 15 {
return Err("DEFLATE: Huffman code length exceeds 15");
}
let mut count = [0u32; 16];
for &l in lengths {
if l > 0 {
count[l as usize] += 1;
}
}
let mut base_code = [0u32; 16];
let mut code = 0u32;
for bits in 1..=max_len {
code = (code + count[bits - 1]) << 1;
base_code[bits] = code;
}
let mut base_sym = [0u32; 16];
for bits in 1..=max_len {
base_sym[bits] = base_sym[bits - 1] + count[bits - 1];
}
let total: usize = count.iter().sum::<u32>() as usize;
let mut symbols = vec![0u16; total];
let mut next = base_sym;
for (sym, &l) in lengths.iter().enumerate() {
if l > 0 {
symbols[next[l as usize] as usize] = sym as u16;
next[l as usize] += 1;
}
}
let mut fast_path = [0u32; 1 << FAST_PATH_BITS];
for sym_idx in 0..symbols.len() {
let mut len = 1usize;
while len <= max_len {
let next_base = if len < max_len {
base_sym[len + 1]
} else {
total as u32
};
if (sym_idx as u32) < next_base {
break;
}
len += 1;
}
if len > FAST_PATH_BITS {
continue; }
let c = base_code[len] + (sym_idx as u32 - base_sym[len]);
let entry = ((len as u32) << 16) | (symbols[sym_idx] as u32);
let shift = FAST_PATH_BITS - len;
let base_idx = (c << shift) as usize;
let span = 1usize << shift;
for k in 0..span {
fast_path[base_idx + k] = entry;
}
}
Ok(Self {
base_code,
base_sym,
symbols,
fast_path,
max_len,
})
}
pub fn decode(&self, br: &mut BitReader) -> Result<u16, &'static str> {
let peek = br.peek_msb_bits(FAST_PATH_BITS);
if let Some(peek_val) = peek {
let entry = self.fast_path[peek_val as usize];
if entry != 0 {
let len = (entry >> 16) as u8;
br.consume_bits(len as usize);
return Ok((entry & 0xFFFF) as u16);
}
}
let mut code = 0u32;
for len in 1..=self.max_len {
let bit = br.read_bits(1)?;
code = (code << 1) | bit;
let base = self.base_code[len];
let next_base = if len < self.max_len {
self.base_code[len + 1] >> 1
} else {
code + 1
};
if code >= base && code < next_base {
let sym_idx = self.base_sym[len] + (code - base);
if (sym_idx as usize) < self.symbols.len() {
return Ok(self.symbols[sym_idx as usize]);
}
}
}
Err("DEFLATE: no Huffman code matched")
}
}
fn fixed_litlen_table() -> HuffmanTable {
let mut lens = vec![0u8; 288];
for s in 0..=143 {
lens[s] = 8;
}
for s in 144..=255 {
lens[s] = 9;
}
for s in 256..=279 {
lens[s] = 7;
}
for s in 280..=287 {
lens[s] = 8;
}
HuffmanTable::from_lengths(&lens).unwrap()
}
fn fixed_dist_table() -> HuffmanTable {
let lens = vec![5u8; 30];
HuffmanTable::from_lengths(&lens).unwrap()
}
fn length_base_extra(sym: u16) -> Result<(u16, u8), &'static str> {
if !(257..=285).contains(&sym) {
return Err("DEFLATE: bad length symbol");
}
const LENGTH_BASE: [u16; 29] = [
3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, 35, 43, 51, 59, 67, 83, 99, 115,
131, 163, 195, 227, 258,
];
const LENGTH_EXTRA: [u8; 29] = [
0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0,
];
let i = (sym - 257) as usize;
Ok((LENGTH_BASE[i], LENGTH_EXTRA[i]))
}
fn distance_base_extra(sym: u16) -> Result<(u16, u8), &'static str> {
if sym > 29 {
return Err("DEFLATE: bad distance symbol");
}
const DIST_BASE: [u16; 30] = [
1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, 257, 385, 513, 769, 1025, 1537,
2049, 3073, 4097, 6145, 8193, 12289, 16385, 24577,
];
const DIST_EXTRA: [u8; 30] = [
0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12,
13, 13,
];
Ok((DIST_BASE[sym as usize], DIST_EXTRA[sym as usize]))
}
fn inflate_huffman(
br: &mut BitReader,
out: &mut Vec<u8>,
litlen: &HuffmanTable,
dist: &HuffmanTable,
) -> Result<(), &'static str> {
loop {
let sym = litlen.decode(br)?;
if sym < 256 {
out.push(sym as u8);
} else if sym == 256 {
return Ok(());
} else {
let (base_len, extra_len) = length_base_extra(sym)?;
let length = base_len as usize + br.read_bits(extra_len)? as usize;
let dsym = dist.decode(br)?;
let (base_dist, extra_dist) = distance_base_extra(dsym)?;
let distance = base_dist as usize + br.read_bits(extra_dist)? as usize;
if distance > out.len() {
return Err("DEFLATE: back-reference past start");
}
let start = out.len() - distance;
for i in 0..length {
let b = out[start + i];
out.push(b);
}
}
}
}
const CODE_LEN_ORDER: [usize; 19] = [
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
];
fn read_dynamic_tables(br: &mut BitReader) -> Result<(HuffmanTable, HuffmanTable), &'static str> {
let hlit = br.read_bits(5)? as usize + 257;
let hdist = br.read_bits(5)? as usize + 1;
let hclen = br.read_bits(4)? as usize + 4;
let mut code_len_lens = [0u8; 19];
for i in 0..hclen {
code_len_lens[CODE_LEN_ORDER[i]] = br.read_bits(3)? as u8;
}
let code_len_table = HuffmanTable::from_lengths(&code_len_lens)?;
let total = hlit + hdist;
let mut lens = vec![0u8; total];
let mut i = 0;
while i < total {
let sym = code_len_table.decode(br)?;
match sym {
0..=15 => {
lens[i] = sym as u8;
i += 1;
}
16 => {
if i == 0 {
return Err("DEFLATE: repeat-prev at start");
}
let prev = lens[i - 1];
let n = br.read_bits(2)? as usize + 3;
if i + n > total {
return Err("DEFLATE: repeat overflow");
}
for _ in 0..n {
lens[i] = prev;
i += 1;
}
}
17 => {
let n = br.read_bits(3)? as usize + 3;
if i + n > total {
return Err("DEFLATE: zero-repeat overflow");
}
for _ in 0..n {
lens[i] = 0;
i += 1;
}
}
18 => {
let n = br.read_bits(7)? as usize + 11;
if i + n > total {
return Err("DEFLATE: long-zero-repeat overflow");
}
for _ in 0..n {
lens[i] = 0;
i += 1;
}
}
_ => return Err("DEFLATE: bad code-length symbol"),
}
}
let litlen = HuffmanTable::from_lengths(&lens[..hlit])?;
let dist = HuffmanTable::from_lengths(&lens[hlit..])?;
Ok((litlen, dist))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bitreader_lsb_first() {
let mut br = BitReader::new(&[0xA5]);
let bits: Vec<u32> = (0..8).map(|_| br.read_bits(1).unwrap()).collect();
assert_eq!(bits, vec![1, 0, 1, 0, 0, 1, 0, 1]);
}
#[test]
fn bitreader_multi_bit() {
let mut br = BitReader::new(&[0x34, 0x12]);
assert_eq!(br.read_bits(16).unwrap(), 0x1234);
}
#[test]
fn inflate_stored_block() {
let mut data = vec![0b0000_0001u8]; data.extend_from_slice(&5u16.to_le_bytes()); data.extend_from_slice(&(!5u16).to_le_bytes()); data.extend_from_slice(b"Hello");
assert_eq!(inflate(&data).unwrap(), b"Hello");
}
#[test]
fn canonical_huffman_rfc_example() {
let lens = vec![3, 3, 3, 3, 3, 2, 4, 4];
let t = HuffmanTable::from_lengths(&lens).unwrap();
let mut br = BitReader::new(&[0b0000_0000]);
assert_eq!(t.decode(&mut br).unwrap(), 5); }
#[test]
fn inflate_fixed_huffman_real_zlib_payload() {
let payload = [0xF3u8, 0xC8, 0x04, 0x00];
let out = inflate(&payload).unwrap();
assert_eq!(out, b"Hi");
}
#[test]
fn inflate_two_stored_blocks() {
let mut data = Vec::new();
data.push(0b0000_0000u8); data.extend_from_slice(&3u16.to_le_bytes());
data.extend_from_slice(&(!3u16).to_le_bytes());
data.extend_from_slice(b"abc");
data.push(0b0000_0001u8); data.extend_from_slice(&3u16.to_le_bytes());
data.extend_from_slice(&(!3u16).to_le_bytes());
data.extend_from_slice(b"def");
assert_eq!(inflate(&data).unwrap(), b"abcdef");
}
#[test]
fn inflate_lz77_back_reference() {
let payload = [0x4Bu8, 0x4C, 0x4A, 0x4E, 0x04, 0x23, 0x00];
let out = inflate(&payload).unwrap();
assert_eq!(out, b"abcabcabc");
}
#[test]
fn inflate_hello_world() {
let payload = [
0xF3u8, 0x48, 0xCD, 0xC9, 0xC9, 0xD7, 0x51, 0x28, 0xCF, 0x2F, 0xCA, 0x49, 0x51, 0x04,
0x00,
];
let out = inflate(&payload).unwrap();
assert_eq!(out, b"Hello, world!");
}
}