use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::MAX_DISTANCE;
const SANITY_MATCH_LEN: u32 = 1 << 28;
#[derive(Clone, Copy, PartialEq, Eq)]
enum HeaderPhase {
Reading { idx: u8 },
Active { target: u64 },
Done,
}
pub struct Decoder {
header_buf: [u8; 8],
header: HeaderPhase,
buf: Vec<u8>,
pos: usize,
out_history: Vec<u8>,
produced: u64,
flags: u32,
flag_bits_left: u8,
pending_match: Option<PendingMatch>,
pending_literal: Option<u8>,
half_byte: Option<u8>,
poisoned: bool,
}
#[derive(Debug, Clone, Copy)]
struct PendingMatch {
distance: u32,
remaining: u32,
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl Decoder {
pub const fn new() -> Self {
Self {
header_buf: [0; 8],
header: HeaderPhase::Reading { idx: 0 },
buf: Vec::new(),
pos: 0,
out_history: Vec::new(),
produced: 0,
flags: 0,
flag_bits_left: 0,
pending_match: None,
pending_literal: None,
half_byte: None,
poisoned: false,
}
}
fn compact_buf(&mut self) {
const THRESHOLD: usize = 64 * 1024;
if self.pos >= THRESHOLD {
self.buf.drain(0..self.pos);
self.pos = 0;
}
}
fn trim_history(&mut self) {
if self.out_history.len() > 2 * MAX_DISTANCE {
let drop = self.out_history.len() - MAX_DISTANCE;
self.out_history.drain(0..drop);
}
}
fn emit_byte(&mut self, byte: u8, output: &mut [u8], written: &mut usize) {
self.out_history.push(byte);
output[*written] = byte;
*written += 1;
self.produced += 1;
}
#[inline]
fn buf_avail(&self) -> usize {
self.buf.len() - self.pos
}
#[inline]
fn peek_u16(&self, offset: usize) -> Option<u16> {
if self.pos + offset + 2 > self.buf.len() {
return None;
}
Some(u16::from_le_bytes([
self.buf[self.pos + offset],
self.buf[self.pos + offset + 1],
]))
}
#[inline]
fn peek_u32(&self, offset: usize) -> Option<u32> {
if self.pos + offset + 4 > self.buf.len() {
return None;
}
Some(u32::from_le_bytes([
self.buf[self.pos + offset],
self.buf[self.pos + offset + 1],
self.buf[self.pos + offset + 2],
self.buf[self.pos + offset + 3],
]))
}
}
enum HalfByteOp {
Unchanged,
Clear,
Set(u8),
}
impl Decoder {
fn apply_half_byte_op(&mut self, op: HalfByteOp) {
match op {
HalfByteOp::Unchanged => {}
HalfByteOp::Clear => self.half_byte = None,
HalfByteOp::Set(high) => self.half_byte = Some(high),
}
}
fn drain(&mut self, output: &mut [u8], written: &mut usize, at_eof: bool) -> Result<(), Error> {
let target = match self.header {
HeaderPhase::Active { target } => target,
HeaderPhase::Done => return Ok(()),
HeaderPhase::Reading { .. } => return Ok(()),
};
loop {
if self.produced >= target {
self.header = HeaderPhase::Done;
return Ok(());
}
if let Some(b) = self.pending_literal.take() {
if *written == output.len() {
self.pending_literal = Some(b);
return Ok(());
}
self.emit_byte(b, output, written);
self.trim_history();
if self.produced >= target {
self.header = HeaderPhase::Done;
return Ok(());
}
continue;
}
if let Some(mut pm) = self.pending_match.take() {
while pm.remaining > 0 && *written < output.len() {
if (pm.distance as usize) > self.out_history.len() {
return Err(Error::InvalidDistance);
}
let src = self.out_history.len() - pm.distance as usize;
let b = self.out_history[src];
self.emit_byte(b, output, written);
pm.remaining -= 1;
if self.produced >= target {
self.header = HeaderPhase::Done;
return Ok(());
}
}
self.trim_history();
if pm.remaining > 0 {
self.pending_match = Some(pm);
return Ok(());
}
continue;
}
if self.flag_bits_left == 0 {
if self.buf_avail() < 4 {
if at_eof {
return Err(Error::UnexpectedEnd);
}
return Ok(());
}
let f = self.peek_u32(0).expect("peek bounded by buf_avail check");
self.flags = f;
self.flag_bits_left = 32;
self.pos += 4;
self.compact_buf();
}
let bit = self.flags & 0x8000_0000 != 0;
if !bit {
if self.buf_avail() < 1 {
if at_eof {
return Err(Error::UnexpectedEnd);
}
return Ok(());
}
let b = self.buf[self.pos];
self.pos += 1;
self.compact_buf();
self.flags <<= 1;
self.flag_bits_left -= 1;
if *written == output.len() {
self.pending_literal = Some(b);
return Ok(());
}
self.emit_byte(b, output, written);
self.trim_history();
if self.produced >= target {
self.header = HeaderPhase::Done;
return Ok(());
}
continue;
}
if self.buf_avail() < 2 {
if at_eof {
return Err(Error::UnexpectedEnd);
}
return Ok(());
}
let sym = self.peek_u16(0).expect("buf_avail >= 2 checked");
let distance = ((u32::from(sym) >> 3) + 1) as usize;
let lc = u32::from(sym & 0x7);
let sym_consumed = 2usize;
let (length, len_consumed, hb_op) =
match try_read_length_at(self, self.pos + sym_consumed, lc)? {
Some(v) => v,
None => {
if at_eof {
return Err(Error::UnexpectedEnd);
}
return Ok(());
}
};
if length < 3 {
return Err(Error::Corrupt);
}
if distance > MAX_DISTANCE {
return Err(Error::InvalidDistance);
}
if (distance as u64) > self.produced {
return Err(Error::InvalidDistance);
}
self.pos += sym_consumed + len_consumed;
self.compact_buf();
self.apply_half_byte_op(hb_op);
self.flags <<= 1;
self.flag_bits_left -= 1;
let mut pm = PendingMatch {
distance: distance as u32,
remaining: length,
};
while pm.remaining > 0 && *written < output.len() {
if (pm.distance as usize) > self.out_history.len() {
return Err(Error::InvalidDistance);
}
let src = self.out_history.len() - pm.distance as usize;
let b = self.out_history[src];
self.emit_byte(b, output, written);
pm.remaining -= 1;
if self.produced >= target {
self.header = HeaderPhase::Done;
return Ok(());
}
}
self.trim_history();
if pm.remaining > 0 {
self.pending_match = Some(pm);
return Ok(());
}
}
}
}
fn try_read_length_at(
dec: &Decoder,
start: usize,
base_lc: u32,
) -> Result<Option<(u32, usize, HalfByteOp)>, Error> {
if base_lc < 7 {
return Ok(Some((base_lc + 3, 0, HalfByteOp::Unchanged)));
}
let avail = dec.buf.len().saturating_sub(start);
let (hb, hb_consumed, hb_op) = match dec.half_byte {
Some(b) => (u32::from(b), 0, HalfByteOp::Clear),
None => {
if avail < 1 {
return Ok(None);
}
let b = dec.buf[start];
let low = u32::from(b & 0x0F);
let high = b >> 4;
(low, 1, HalfByteOp::Set(high))
}
};
if hb < 15 {
return Ok(Some((hb + 10, hb_consumed, hb_op)));
}
if avail < hb_consumed + 1 {
return Ok(None);
}
let b8 = dec.buf[start + hb_consumed];
if b8 < 255 {
return Ok(Some((u32::from(b8) + 25, hb_consumed + 1, hb_op)));
}
if avail < hb_consumed + 1 + 2 {
return Ok(None);
}
let w = u32::from(u16::from_le_bytes([
dec.buf[start + hb_consumed + 1],
dec.buf[start + hb_consumed + 2],
]));
if w != 0 {
if w < 22 {
return Err(Error::Corrupt);
}
return Ok(Some((w + 3, hb_consumed + 1 + 2, hb_op)));
}
if avail < hb_consumed + 1 + 2 + 4 {
return Ok(None);
}
let dw = u32::from_le_bytes([
dec.buf[start + hb_consumed + 1 + 2],
dec.buf[start + hb_consumed + 1 + 2 + 1],
dec.buf[start + hb_consumed + 1 + 2 + 2],
dec.buf[start + hb_consumed + 1 + 2 + 3],
]);
if dw < 22 {
return Err(Error::Corrupt);
}
if dw > SANITY_MATCH_LEN {
return Err(Error::Corrupt);
}
Ok(Some((dw + 3, hb_consumed + 1 + 2 + 4, hb_op)))
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut written = 0usize;
let mut consumed = 0usize;
if let HeaderPhase::Reading { mut idx } = self.header {
while idx < 8 && consumed < input.len() {
self.header_buf[idx as usize] = input[consumed];
idx += 1;
consumed += 1;
}
if idx < 8 {
self.header = HeaderPhase::Reading { idx };
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
let target = u64::from_le_bytes(self.header_buf);
self.header = if target == 0 {
HeaderPhase::Done
} else {
HeaderPhase::Active { target }
};
}
self.buf.extend_from_slice(&input[consumed..]);
consumed = input.len();
if let HeaderPhase::Active { .. } = self.header
&& let Err(e) = self.drain(output, &mut written, false)
{
self.poisoned = true;
return Err(e);
}
let done = matches!(self.header, HeaderPhase::Done);
Ok(RawProgress {
consumed,
written,
done,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut written = 0usize;
if let HeaderPhase::Active { .. } = self.header
&& let Err(e) = self.drain(output, &mut written, true)
{
self.poisoned = true;
return Err(e);
}
let done = matches!(self.header, HeaderPhase::Done);
if !done && self.pending_match.is_none() && self.pending_literal.is_none() {
if written == 0 && !output.is_empty() {
self.poisoned = true;
return Err(Error::UnexpectedEnd);
}
}
if matches!(self.header, HeaderPhase::Reading { idx } if idx > 0) {
self.poisoned = true;
return Err(Error::UnexpectedEnd);
}
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
self.header_buf = [0; 8];
self.header = HeaderPhase::Reading { idx: 0 };
self.buf.clear();
self.pos = 0;
self.out_history.clear();
self.produced = 0;
self.flags = 0;
self.flag_bits_left = 0;
self.pending_match = None;
self.pending_literal = None;
self.half_byte = None;
self.poisoned = false;
}
}