use super::tables::{CODE_LENGTH_ORDER, DISTANCE_TABLE, LENGTH_TABLE};
use super::tokens::{CodeLengths, LZ77Block, LZ77Token};
use crate::bits::{BitRead, BitReader, SliceBitReader};
use crate::error::{Error, Result};
use crate::huffman::HuffmanDecoder;
use std::io::Read;
pub struct DeflateParser<B: BitRead> {
bits: B,
finished: bool,
}
impl<R: Read> DeflateParser<BitReader<R>> {
pub fn new(reader: R) -> Self {
Self { bits: BitReader::new(reader), finished: false }
}
}
impl<'a> DeflateParser<SliceBitReader<'a>> {
pub fn from_slice(data: &'a [u8], offset: usize) -> Self {
let mut bits = SliceBitReader::new(data);
bits.set_position(offset);
Self { bits, finished: false }
}
}
#[allow(clippy::type_complexity)]
fn parse_dynamic_huffman_tables_inner<B: BitRead>(
bits: &mut B,
) -> Result<(Vec<u8>, Vec<u8>, HuffmanDecoder, Option<HuffmanDecoder>)> {
let hlit = bits.read_bits(5)? as usize + 257; let hdist = bits.read_bits(5)? as usize + 1; let hclen = bits.read_bits(4)? as usize + 4;
let mut code_length_lengths = [0u8; 19];
for i in 0..hclen {
code_length_lengths[CODE_LENGTH_ORDER[i]] = bits.read_bits(3)? as u8;
}
let code_length_decoder = HuffmanDecoder::from_code_lengths(&code_length_lengths)?;
let total_codes = hlit + hdist;
let mut all_lengths = Vec::with_capacity(total_codes);
while all_lengths.len() < total_codes {
let sym = code_length_decoder.decode(bits)?;
match sym {
0..=15 => {
all_lengths.push(sym as u8);
}
16 => {
let repeat = bits.read_bits(2)? as usize + 3;
let prev = *all_lengths.last().ok_or(Error::HuffmanIncomplete)?;
if all_lengths.len() + repeat > total_codes {
return Err(Error::HuffmanIncomplete);
}
for _ in 0..repeat {
all_lengths.push(prev);
}
}
17 => {
let repeat = bits.read_bits(3)? as usize + 3;
if all_lengths.len() + repeat > total_codes {
return Err(Error::HuffmanIncomplete);
}
all_lengths.resize(all_lengths.len() + repeat, 0);
}
18 => {
let repeat = bits.read_bits(7)? as usize + 11;
if all_lengths.len() + repeat > total_codes {
return Err(Error::HuffmanIncomplete);
}
all_lengths.resize(all_lengths.len() + repeat, 0);
}
_ => return Err(Error::InvalidHuffmanSymbol(sym)),
}
}
let literal_lengths: Vec<u8> = all_lengths[..hlit].to_vec();
let distance_lengths: Vec<u8> = all_lengths[hlit..].to_vec();
let lit_decoder = HuffmanDecoder::from_code_lengths(&literal_lengths)?;
let dist_decoder = if distance_lengths.iter().all(|&l| l == 0) {
None
} else {
Some(HuffmanDecoder::from_code_lengths(&distance_lengths)?)
};
Ok((literal_lengths, distance_lengths, lit_decoder, dist_decoder))
}
pub fn parse_dynamic_huffman_tables<B: BitRead>(
bits: &mut B,
) -> Result<(HuffmanDecoder, Option<HuffmanDecoder>)> {
let (_, _, lit_decoder, dist_decoder) = parse_dynamic_huffman_tables_inner(bits)?;
Ok((lit_decoder, dist_decoder))
}
impl<B: BitRead> DeflateParser<B> {
pub fn parse_block(&mut self) -> Result<Option<LZ77Block>> {
if self.finished {
return Ok(None);
}
let is_final = self.bits.read_bit()?;
let block_type = self.bits.read_bits(2)? as u8;
let block = match block_type {
0 => self.parse_stored_block(is_final)?,
1 => self.parse_fixed_block(is_final)?,
2 => self.parse_dynamic_block(is_final)?,
_ => return Err(Error::InvalidBlockType(block_type)),
};
if is_final {
self.finished = true;
}
Ok(Some(block))
}
fn parse_stored_block(&mut self, is_final: bool) -> Result<LZ77Block> {
self.bits.align_to_byte();
let len = self.bits.read_u16_le()?;
let nlen = self.bits.read_u16_le()?;
if len != !nlen {
return Err(Error::StoredBlockLengthMismatch { len, nlen });
}
let mut tokens = Vec::with_capacity(len as usize + 1);
for _ in 0..len {
let byte = self.bits.read_bits(8)? as u8;
tokens.push(LZ77Token::Literal(byte));
}
tokens.push(LZ77Token::EndOfBlock);
Ok(LZ77Block::new(tokens, is_final, 0))
}
fn parse_fixed_block(&mut self, is_final: bool) -> Result<LZ77Block> {
let lit_decoder = HuffmanDecoder::fixed_literal_length();
let dist_decoder = HuffmanDecoder::fixed_distance();
let tokens = self.decode_symbols(&lit_decoder, &dist_decoder)?;
Ok(LZ77Block::new(tokens, is_final, 1))
}
fn parse_dynamic_block(&mut self, is_final: bool) -> Result<LZ77Block> {
let (literal_lengths, distance_lengths, lit_decoder, dist_decoder) =
parse_dynamic_huffman_tables_inner(&mut self.bits)?;
let tokens = self.decode_symbols_with_optional_dist(&lit_decoder, dist_decoder.as_ref())?;
let mut block = LZ77Block::new(tokens, is_final, 2);
block.code_lengths = Some(CodeLengths { literal_lengths, distance_lengths });
Ok(block)
}
#[inline(never)] fn decode_symbols(
&mut self,
lit_decoder: &HuffmanDecoder,
dist_decoder: &HuffmanDecoder,
) -> Result<Vec<LZ77Token>> {
let mut tokens = Vec::with_capacity(8192);
loop {
let sym = lit_decoder.decode(&mut self.bits)?;
if sym <= 255 {
tokens.push(LZ77Token::Literal(sym as u8));
continue;
}
if sym == 256 {
tokens.push(LZ77Token::EndOfBlock);
break;
}
if sym > 285 {
return Err(Error::InvalidLengthCode(sym));
}
let len_idx = (sym - 257) as usize;
let (base_len, extra_bits) = unsafe { *LENGTH_TABLE.get_unchecked(len_idx) };
let extra = if extra_bits > 0 { self.bits.read_bits(extra_bits)? } else { 0 };
let length = base_len + extra as u16;
let dist_sym = dist_decoder.decode(&mut self.bits)?;
if dist_sym > 29 {
return Err(Error::InvalidDistanceCode(dist_sym));
}
let (base_dist, dist_extra_bits) =
unsafe { *DISTANCE_TABLE.get_unchecked(dist_sym as usize) };
let dist_extra =
if dist_extra_bits > 0 { self.bits.read_bits(dist_extra_bits)? } else { 0 };
let distance = base_dist + dist_extra as u16;
tokens.push(LZ77Token::Copy { length, distance });
}
Ok(tokens)
}
fn decode_symbols_with_optional_dist(
&mut self,
lit_decoder: &HuffmanDecoder,
dist_decoder: Option<&HuffmanDecoder>,
) -> Result<Vec<LZ77Token>> {
if let Some(dist_dec) = dist_decoder {
return self.decode_symbols(lit_decoder, dist_dec);
}
let mut tokens = Vec::with_capacity(8192);
loop {
let sym = lit_decoder.decode(&mut self.bits)?;
match sym {
0..=255 => tokens.push(LZ77Token::Literal(sym as u8)),
256 => {
tokens.push(LZ77Token::EndOfBlock);
break;
}
257..=285 => {
return Err(Error::InvalidDistanceCode(0));
}
_ => return Err(Error::InvalidLengthCode(sym)),
}
}
Ok(tokens)
}
pub fn bytes_read(&self) -> u64 {
self.bits.bytes_read()
}
pub fn is_finished(&self) -> bool {
self.finished
}
pub fn into_inner(self) -> B {
self.bits
}
pub fn read_trailer_and_check_next(&mut self) -> Result<bool> {
if !self.finished {
return Err(Error::Internal("Cannot read trailer before DEFLATE is finished".into()));
}
self.bits.align_to_byte();
let _crc32 = self.bits.read_u32_le()?;
let _isize = self.bits.read_u32_le()?;
match self.bits.read_bits(8) {
Ok(b1) => {
match self.bits.read_bits(8) {
Ok(b2) => {
if b1 == 0x1f && b2 == 0x8b {
let method = self.bits.read_bits(8)? as u8;
if method != 8 {
return Err(Error::UnsupportedCompressionMethod(method));
}
let flags = self.bits.read_bits(8)? as u8;
let _mtime = self.bits.read_u32_le()?;
let _xfl = self.bits.read_bits(8)?;
let _os = self.bits.read_bits(8)?;
const FEXTRA: u8 = 1 << 2;
const FNAME: u8 = 1 << 3;
const FCOMMENT: u8 = 1 << 4;
const FHCRC: u8 = 1 << 1;
if flags & FEXTRA != 0 {
let xlen = self.bits.read_u16_le()?;
for _ in 0..xlen {
self.bits.read_bits(8)?;
}
}
if flags & FNAME != 0 {
loop {
if self.bits.read_bits(8)? == 0 {
break;
}
}
}
if flags & FCOMMENT != 0 {
loop {
if self.bits.read_bits(8)? == 0 {
break;
}
}
}
if flags & FHCRC != 0 {
let _hcrc = self.bits.read_u16_le()?;
}
self.finished = false;
Ok(true)
} else {
Err(Error::InvalidGzipMagic(((b2 as u16) << 8) | (b1 as u16)))
}
}
Err(Error::UnexpectedEof) => Ok(false), Err(e) => Err(e),
}
}
Err(Error::UnexpectedEof) => Ok(false), Err(e) => Err(e),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_parse_stored_block() {
let data = vec![
0b00000001, 0x05, 0x00, 0xFA, 0xFF, b'H', b'e', b'l', b'l', b'o',
];
let mut parser = DeflateParser::new(Cursor::new(data));
let block = parser.parse_block().unwrap().unwrap();
assert!(block.is_final);
assert_eq!(block.block_type, 0);
assert_eq!(block.tokens.len(), 6);
assert_eq!(block.tokens[0], LZ77Token::Literal(b'H'));
assert_eq!(block.tokens[1], LZ77Token::Literal(b'e'));
assert_eq!(block.tokens[2], LZ77Token::Literal(b'l'));
assert_eq!(block.tokens[3], LZ77Token::Literal(b'l'));
assert_eq!(block.tokens[4], LZ77Token::Literal(b'o'));
assert_eq!(block.tokens[5], LZ77Token::EndOfBlock);
}
#[test]
fn test_parse_real_gzip() {
use std::io::Write;
let mut encoder =
flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default());
encoder.write_all(b"Hello, World!").unwrap();
let compressed = encoder.finish().unwrap();
let mut parser = DeflateParser::new(Cursor::new(compressed));
let mut total_size = 0;
while let Some(block) = parser.parse_block().unwrap() {
total_size += block.uncompressed_size();
if block.is_final {
break;
}
}
assert_eq!(total_size, 13);
}
#[test]
fn test_parse_dynamic_header_only() {
use std::io::Write;
let mut encoder =
flate2::write::DeflateEncoder::new(Vec::new(), flate2::Compression::default());
let input: Vec<u8> = (0u8..=127).cycle().take(512).collect();
encoder.write_all(&input).unwrap();
let compressed = encoder.finish().unwrap();
let mut bits = SliceBitReader::new(&compressed);
let _bfinal = bits.read_bit().unwrap();
let btype = bits.read_bits(2).unwrap();
assert_eq!(btype, 2);
let (lit_decoder, dist_decoder) = parse_dynamic_huffman_tables(&mut bits).unwrap();
assert!(!lit_decoder.is_empty());
let _ = dist_decoder;
}
}