use super::deflate::MAX_MATCH_OFFSET;
use super::dict_decoder::DictDecoder;
use super::huffman_bit_writer::END_BLOCK_MARKER;
use crate::compat;
use crate::errors;
use crate::io as ggio;
use crate::math::bits;
use std::sync::OnceLock;
const MAX_CODE_LEN: usize = 16;
pub(super) const MAX_NUM_LIT: usize = 286;
const MAX_NUM_DIST: usize = 30;
const NUM_CODES: usize = 19;
pub(super) fn get_fixed_huffman_decoder() -> &'static HuffmanDecoder {
static ENCODER: OnceLock<HuffmanDecoder> = OnceLock::new();
ENCODER.get_or_init(fixed_huffman_decoder_init)
}
const HUFFMAN_CHUNK_BITS: u32 = 9;
const HUFFMAN_NUM_CHUNKS: u32 = 1 << HUFFMAN_CHUNK_BITS;
const HUFFMAN_COUNT_MASK: u32 = 15;
const HUFFMAN_VALUE_SHIFT: u32 = 4;
pub(super) struct HuffmanDecoder {
min: u32, chunks: Vec<u32>, links: Vec<Vec<u32>>, link_mask: u32, }
impl HuffmanDecoder {
pub(super) fn new() -> Self {
Self {
min: 0,
chunks: vec![0; HUFFMAN_NUM_CHUNKS as usize],
links: Vec::new(),
link_mask: 0,
}
}
pub(super) fn init(&mut self, lengths: &[u32]) -> bool {
const SANITY: bool = false;
if self.min != 0 {
self.min = 0;
self.chunks = vec![0; HUFFMAN_NUM_CHUNKS as usize];
self.links = Vec::new();
self.link_mask = 0;
}
let mut count = [0_u32; MAX_CODE_LEN];
let mut min: u32 = 0;
let mut max: u32 = 0;
for n in lengths {
let n = *n;
if n == 0 {
continue;
}
if min == 0 || n < min {
min = n;
}
if n > max {
max = n;
}
count[n as usize] += 1;
}
if max == 0 {
return true;
}
let mut code = 0_u32;
let mut nextcode = [0_u32; MAX_CODE_LEN];
for i in min..=max {
code <<= 1;
nextcode[i as usize] = code;
code += count[i as usize];
}
if code != 1 << (max as usize) && !(code == 1 && max == 1) {
return false;
}
self.min = min;
if max > HUFFMAN_CHUNK_BITS {
let num_links = 1_u32 << (max - HUFFMAN_CHUNK_BITS);
self.link_mask = num_links - 1;
let link = nextcode[HUFFMAN_CHUNK_BITS as usize + 1] >> 1;
self.links = vec![Vec::new(); (HUFFMAN_NUM_CHUNKS - link) as usize];
for j in link..HUFFMAN_NUM_CHUNKS {
let mut reverse = bits::reverse16(j as u16);
reverse >>= 16 - HUFFMAN_CHUNK_BITS;
let off = j - link;
if SANITY && self.chunks[reverse as usize] != 0 {
panic!("impossible: overwriting existing chunk");
}
self.chunks[reverse as usize] =
(off << HUFFMAN_VALUE_SHIFT) | (HUFFMAN_CHUNK_BITS + 1);
self.links[off as usize] = vec![0; num_links as usize];
}
}
for (i, n) in lengths.iter().enumerate() {
let n = *n;
if n == 0 {
continue;
}
let code = nextcode[n as usize];
nextcode[n as usize] += 1;
let chunk = ((i as u32) << HUFFMAN_VALUE_SHIFT) | n;
let mut reverse = bits::reverse16(code as u16) as u32;
reverse >>= 16 - n;
if n <= HUFFMAN_CHUNK_BITS {
let mut off = reverse;
while off < self.chunks.len() as u32 {
if SANITY && self.chunks[off as usize] != 0 {
panic!("impossible: overwriting existing chunk")
}
self.chunks[off as usize] = chunk;
off += 1 << n;
}
} else {
let j = reverse & (HUFFMAN_NUM_CHUNKS - 1);
if SANITY
&& (self.chunks[j as usize] & HUFFMAN_COUNT_MASK) != HUFFMAN_CHUNK_BITS + 1
{
panic!("impossible: not an indirect chunk");
}
let value = self.chunks[j as usize] >> HUFFMAN_VALUE_SHIFT;
let linktab = &mut self.links[value as usize];
reverse >>= HUFFMAN_CHUNK_BITS;
let mut off = reverse;
while off < linktab.len() as u32 {
if SANITY && linktab[off as usize] != 0 {
panic!("impossible: overwriting existing chunk");
}
linktab[off as usize] = chunk;
off += 1 << (n - HUFFMAN_CHUNK_BITS);
}
}
}
if SANITY {
for (i, chunk) in self.chunks.iter().enumerate() {
let chunk = *chunk;
if chunk == 0 {
if code == 1 && i % 2 == 1 {
continue;
}
panic!("impossible: missing chunk")
}
}
for linktab in self.links.iter() {
for chunk in linktab {
if *chunk == 0 {
panic!("impossible: missing chunk");
}
}
}
}
true
}
}
#[derive(PartialEq)]
enum HDDecoder {
None,
H2,
}
#[derive(PartialEq)]
enum HLDecoder {
Fixed,
H1,
}
pub struct Reader<Input: std::io::BufRead> {
r: Input,
roffset: u64,
b: u32,
nb: usize,
pub(super) h1: HuffmanDecoder,
h2: HuffmanDecoder,
bits: Vec<u32>, codebits: [u32; NUM_CODES],
pub(super) dict: DictDecoder,
buf: [u8; 4],
step: StepFunc,
step_state: StepState,
final_: bool,
end_of_stream: bool,
err: Option<std::io::Error>,
hl: HLDecoder,
hd: HDDecoder,
copy_len: usize,
copy_dist: usize,
}
enum StepFunc {
NextBlock,
HuffmanBlock,
CopyData,
}
#[derive(PartialEq, Copy, Clone)]
enum StepState {
StateInit,
StateDict,
}
pub(super) enum DecoderToUse {
H1,
HL,
H2,
}
const CODE_ORDER: [usize; 19] = [
16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15,
];
impl<Input: std::io::BufRead> std::io::Read for Reader<Input> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
errors::iores_to_result(crate::io::Reader::read(self, buf))
}
}
impl<'a, Input: std::io::BufRead> Reader<Input> {
pub fn new(r: Input) -> Reader<Input> {
Self::new_dict(r, &[])
}
pub fn new_dict(r: Input, dict: &'a [u8]) -> Reader<Input> {
Reader {
r,
roffset: 0,
b: 0,
nb: 0,
h1: HuffmanDecoder::new(),
h2: HuffmanDecoder::new(),
bits: vec![0; MAX_NUM_LIT + MAX_NUM_DIST],
codebits: [0; NUM_CODES],
dict: DictDecoder::new(MAX_MATCH_OFFSET, dict),
buf: [0; 4],
step: StepFunc::NextBlock,
step_state: StepState::StateInit,
final_: false,
end_of_stream: false,
err: None,
hl: HLDecoder::Fixed,
hd: HDDecoder::None,
copy_len: 0,
copy_dist: 0,
}
}
fn next_block(&mut self) {
while self.nb < 1 + 2 {
let res = self.more_bits();
if res.is_err() {
self.err = Some(res.err().unwrap());
return;
}
}
self.final_ = self.b & 1 == 1;
self.b >>= 1;
let typ = self.b & 3;
self.b >>= 2;
self.nb -= 1 + 2;
match typ {
0 => self.data_block(),
1 => {
self.hl = HLDecoder::Fixed;
self.hd = HDDecoder::None;
self.huffman_block();
}
2 => {
let res = self.read_huffman();
if res.is_ok() {
self.hl = HLDecoder::H1;
self.hd = HDDecoder::H2;
self.huffman_block();
}
}
_ => {
self.err = Some(new_corrupted_input_error(self.roffset));
}
}
}
pub fn close(&mut self) -> std::io::Result<()> {
Ok(())
}
fn read_huffman(&mut self) -> std::io::Result<()> {
while self.nb < 5 + 5 + 4 {
self.more_bits()?
}
let nlit = (self.b & 0x1F) as usize + 257;
if nlit > MAX_NUM_LIT {
return Err(new_corrupted_input_error(self.roffset));
}
self.b >>= 5;
let ndist = (self.b & 0x1F) as usize + 1;
if ndist > MAX_NUM_DIST {
return Err(new_corrupted_input_error(self.roffset));
}
self.b >>= 5;
let nclen = (self.b & 0xF) as usize + 4;
self.b >>= 4;
self.nb -= 5 + 5 + 4;
#[allow(clippy::needless_range_loop)]
for i in 0..nclen {
while self.nb < 3 {
self.more_bits()?
}
self.codebits[CODE_ORDER[i]] = self.b & 0x7;
self.b >>= 3;
self.nb -= 3;
}
#[allow(clippy::needless_range_loop)]
for i in nclen..CODE_ORDER.len() {
self.codebits[CODE_ORDER[i]] = 0;
}
if !self.h1.init(&self.codebits[0..]) {
return Err(new_corrupted_input_error(self.roffset));
}
let mut i = 0;
let n = nlit + ndist;
while i < n {
let x = self.huff_sym(DecoderToUse::H1)?;
if x < 16 {
self.bits[i] = x as u32;
i += 1;
continue;
}
let mut rep: usize;
let nb: usize;
let b: u32;
match x {
16 => {
rep = 3;
nb = 2;
if i == 0 {
return Err(new_corrupted_input_error(self.roffset));
}
b = self.bits[i - 1];
}
17 => {
rep = 3;
nb = 3;
b = 0;
}
18 => {
rep = 11;
nb = 7;
b = 0;
}
_ => {
return Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"unexpected length code",
))
}
};
while self.nb < nb {
self.more_bits()?
}
rep += (self.b & ((1 << nb) - 1) as u32) as usize;
self.b >>= nb;
self.nb -= nb;
if i + rep > n {
return Err(new_corrupted_input_error(self.roffset));
}
for _j in 0..rep {
self.bits[i] = b;
i += 1;
}
}
if !self.h1.init(&self.bits[0..nlit]) || !self.h2.init(&self.bits[nlit..nlit + ndist]) {
return Err(new_corrupted_input_error(self.roffset));
}
if self.h1.min < self.bits[END_BLOCK_MARKER] {
self.h1.min = self.bits[END_BLOCK_MARKER];
}
Ok(())
}
fn huffman_block(&mut self) {
enum StateMachine {
ReadLiteral,
CopyHistory,
}
let mut next_step = match self.step_state {
StepState::StateInit => StateMachine::ReadLiteral,
StepState::StateDict => StateMachine::CopyHistory,
};
loop {
match next_step {
StateMachine::ReadLiteral => {
let v = match self.huff_sym(DecoderToUse::HL) {
Ok(v) => v,
Err(err) => {
self.err = Some(err);
return;
}
};
let n: usize; let length: usize;
match v {
0..=255 => {
self.dict.write_byte(v as u8);
if self.dict.avail_write() == 0 {
self.dict.stash_flush();
self.step = StepFunc::HuffmanBlock;
self.step_state = StepState::StateInit;
return;
}
continue;
}
256 => {
self.finish_block();
return;
}
257..=264 => {
length = v - (257 - 3);
n = 0;
}
265..=268 => {
length = v*2 - (265*2 - 11);
n = 1;
}
269..=272 => {
length = v*4 - (269*4 - 19);
n = 2;
}
273..=276 => {
length = v*8 - (273*8 - 35);
n = 3;
}
277..=280 => {
length = v*16 - (277*16 - 67);
n = 4;
}
281..=284 => {
length = v*32 - (281*32 - 131);
n = 5;
}
285..=285 => {
length = 258;
n = 0;
}
_ => {
self.err = Some(new_corrupted_input_error(self.roffset));
return;
}
};
let mut length = length;
if n > 0 {
while self.nb < n {
if let Err(err) = self.more_bits() {
self.err = Some(err);
return;
}
}
length += (self.b & ((1 << n) - 1) as u32) as usize;
self.b >>= n;
self.nb -= n;
}
let mut dist: usize;
match self.hd {
HDDecoder::None => {
while self.nb < 5 {
if let Err(err) = self.more_bits() {
self.err = Some(err);
return;
}
}
dist = bits::reverse8(((self.b & 0x1F) << 3) as u8) as usize;
self.b >>= 5;
self.nb -= 5;
}
HDDecoder::H2 => {
let res = self.huff_sym(DecoderToUse::H2);
dist = match res {
Ok(v) => v,
Err(err) => {
self.err = Some(err);
return;
}
};
}
}
if (0..4).contains(&dist) {
dist += 1;
} else if (4..MAX_NUM_DIST).contains(&dist) {
let nb = (dist - 2) >> 1;
let mut extra = (dist & 1) << nb;
while self.nb < nb {
if let Err(err) = self.more_bits() {
self.err = Some(err);
return;
}
}
extra |= (self.b & ((1 << nb) - 1) as u32) as usize;
self.b >>= nb;
self.nb -= nb;
dist = 1_usize.overflowing_shl((nb + 1) as u32).0 + 1 + extra;
} else {
self.err = Some(new_corrupted_input_error(self.roffset));
return;
}
if dist > self.dict.hist_size() {
self.err = Some(new_corrupted_input_error(self.roffset));
return;
}
self.copy_len = length;
self.copy_dist = dist;
next_step = StateMachine::CopyHistory;
}
StateMachine::CopyHistory => {
let mut cnt = self.dict.try_write_copy(self.copy_dist, self.copy_len);
if cnt == 0 {
cnt = self.dict.write_copy(self.copy_dist, self.copy_len);
}
self.copy_len -= cnt;
if self.dict.avail_write() == 0 || self.copy_len > 0 {
self.dict.stash_flush();
self.step = StepFunc::HuffmanBlock; self.step_state = StepState::StateDict;
return;
}
next_step = StateMachine::ReadLiteral;
}
}
}
}
fn data_block(&mut self) {
self.nb = 0;
self.b = 0;
let (n, err) = ggio::read_full(&mut self.r, &mut self.buf[0..4]);
self.roffset += n as u64;
if let Some(err) = err {
self.err = Some(err);
return;
}
let n = (self.buf[0]) as usize | ((self.buf[1] as usize) << 8);
let nn = (self.buf[2]) as usize | ((self.buf[3] as usize) << 8);
if (nn as u16) != (!n as u16) {
self.err = Some(new_corrupted_input_error(self.roffset));
return;
}
if n == 0 {
self.dict.stash_flush();
self.finish_block();
return;
}
self.copy_len = n;
self.copy_data();
}
fn copy_data(&mut self) {
let buf = self.dict.write_slice(self.copy_len);
let (n, err) = ggio::read_full(&mut self.r, buf);
self.roffset += n as u64;
self.copy_len -= n;
self.dict.write_mark(n);
if let Some(err) = err {
self.err = Some(err);
return;
}
if self.dict.avail_write() == 0 || self.copy_len > 0 {
self.dict.stash_flush();
self.step = StepFunc::CopyData;
return;
}
self.finish_block();
}
fn finish_block(&mut self) {
if self.final_ {
if self.dict.avail_read() > 0 {
self.dict.stash_flush();
}
self.end_of_stream = true;
}
self.step = StepFunc::NextBlock;
}
fn more_bits(&mut self) -> std::io::Result<()> {
let c = compat::readers::read_byte(&mut self.r)?;
self.roffset += 1;
self.b |= (c as u32) << self.nb;
self.nb += 8;
Ok(())
}
pub(super) fn huff_sym(&mut self, decoder: DecoderToUse) -> Result<usize, std::io::Error> {
let h = match decoder {
DecoderToUse::H1 => &self.h1,
DecoderToUse::HL => match self.hl {
HLDecoder::Fixed => get_fixed_huffman_decoder(),
HLDecoder::H1 => &self.h1,
},
DecoderToUse::H2 => &self.h2,
};
let mut n = h.min as usize;
let (mut nb, mut b) = (self.nb, self.b);
loop {
while nb < n {
match compat::readers::read_byte(&mut self.r) {
Err(err) => {
self.b = b;
self.nb = nb;
return Err(err);
}
Ok(c) => {
self.roffset += 1;
b |= (c as u32) << (nb & 31);
nb = nb.wrapping_add(8);
}
}
}
let mut chunk = h.chunks[(b & (HUFFMAN_NUM_CHUNKS - 1)) as usize];
n = (chunk & HUFFMAN_COUNT_MASK) as usize;
if n > HUFFMAN_CHUNK_BITS as usize {
chunk = h.links[(chunk >> HUFFMAN_VALUE_SHIFT) as usize]
[((b >> HUFFMAN_CHUNK_BITS) & h.link_mask) as usize];
n = (chunk & HUFFMAN_COUNT_MASK) as usize;
}
if n <= nb {
if n == 0 {
self.b = b;
self.nb = nb;
self.err = Some(new_corrupted_input_error(self.roffset));
return Err(errors::copy_stdio_error(self.err.as_ref().unwrap()));
}
self.b = b >> (n & 31);
self.nb = nb - n;
return Ok((chunk >> HUFFMAN_VALUE_SHIFT) as usize);
}
}
}
pub fn reset(&mut self, r: Input, dict: &[u8]) {
self.r = r;
self.reset_state(dict);
}
pub fn reset_state(&mut self, dict: &[u8]) {
self.dict = DictDecoder::new(MAX_MATCH_OFFSET, dict);
self.roffset = 0;
self.b = 0;
self.nb = 0;
self.h1 = HuffmanDecoder::new();
self.h2 = HuffmanDecoder::new();
self.buf = [0; 4];
self.step = StepFunc::NextBlock;
self.step_state = StepState::StateInit;
self.final_ = false;
self.end_of_stream = false;
self.err = None;
self.hl = HLDecoder::Fixed;
self.hd = HDDecoder::None;
self.copy_len = 0;
self.copy_dist = 0;
}
pub fn input_reader(&mut self) -> &mut Input {
&mut self.r
}
}
impl<Input: std::io::BufRead> crate::io::Reader for Reader<Input> {
fn read(&mut self, p: &mut [u8]) -> ggio::IoRes {
loop {
if self.dict.stash_len() > 0 {
let n = self.dict.stash_read(p);
if n > 0 {
return (n, None);
}
}
if self.err.is_some() {
return (0, errors::copy_stdio_option_error(&self.err));
}
if self.end_of_stream {
return ggio::EOF;
}
match self.step {
StepFunc::NextBlock => self.next_block(),
StepFunc::HuffmanBlock => self.huffman_block(),
StepFunc::CopyData => self.copy_data(),
}
if self.err.is_some() && self.dict.stash_len() == 0 {
self.dict.stash_flush(); }
}
}
}
fn fixed_huffman_decoder_init() -> HuffmanDecoder {
let mut h = HuffmanDecoder::new();
let mut bits = [0; 288];
#[allow(clippy::needless_range_loop)]
for i in 0..144 {
bits[i] = 8;
}
#[allow(clippy::needless_range_loop)]
for i in 144..256 {
bits[i] = 9;
}
#[allow(clippy::needless_range_loop)]
for i in 256..280 {
bits[i] = 7;
}
#[allow(clippy::needless_range_loop)]
for i in 280..288 {
bits[i] = 8;
}
h.init(&bits);
h
}
fn new_corrupted_input_error(offset: u64) -> std::io::Error {
let msg = format!("flate: corrupt input before offset {}", offset);
std::io::Error::new(std::io::ErrorKind::InvalidData, msg)
}