use alloc::vec;
use alloc::vec::Vec;
use crate::error::Error;
use crate::quantum::bits::BitReader;
use crate::quantum::model::{ArithDecoder, Model};
use crate::quantum::tables::{EXTRA_BITS, LENGTH_BASE, LENGTH_EXTRA, POSITION_BASE};
use crate::traits::{RawDecoder, RawProgress};
const FRAME_SIZE: u32 = 32_768;
pub(crate) const DEFAULT_WINDOW_BITS: u32 = 15;
#[derive(Debug, Clone, Copy)]
struct PendingOutput {
start: usize,
remaining: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum TrailerState {
None,
SeekingFF,
}
#[derive(Debug, Clone, Copy)]
struct PendingMatch {
match_offset: usize,
remaining: usize,
}
pub struct Decoder {
window_bits: u32,
window_size: usize,
window_mask: usize,
input_buf: Vec<u8>,
window: Vec<u8>,
window_posn: usize,
initialised: bool,
bit_reader: BitReader,
arith: ArithDecoder,
header_read: bool,
frame_todo: u32,
trailer_state: TrailerState,
pending_match: Option<PendingMatch>,
pending_output: Option<PendingOutput>,
model0: Model,
model1: Model,
model2: Model,
model3: Model,
model4: Model,
model5: Model,
model6: Model,
model6len: Model,
model7: Model,
}
impl Decoder {
pub fn new() -> Self {
Self::with_window_bits(DEFAULT_WINDOW_BITS).expect("default window bits are valid")
}
pub fn with_window_bits(window_bits: u32) -> Result<Self, Error> {
if !(10..=21).contains(&window_bits) {
return Err(Error::BadHeader);
}
let window_size = 1usize << window_bits;
let i = (window_bits * 2) as usize;
let model4_size = i.min(24);
let model5_size = i.min(36);
let model6_size = i; debug_assert!(model6_size <= 42);
Ok(Self {
window_bits,
window_size,
window_mask: window_size - 1,
input_buf: Vec::new(),
window: Vec::new(),
window_posn: 0,
initialised: false,
bit_reader: BitReader::new(),
arith: ArithDecoder::new(),
header_read: false,
frame_todo: FRAME_SIZE,
trailer_state: TrailerState::None,
pending_match: None,
pending_output: None,
model0: Model::new(0, 64),
model1: Model::new(64, 64),
model2: Model::new(128, 64),
model3: Model::new(192, 64),
model4: Model::new(0, model4_size),
model5: Model::new(0, model5_size),
model6: Model::new(0, model6_size),
model6len: Model::new(0, 27),
model7: Model::new(0, 7),
})
}
fn ensure_window(&mut self) {
if !self.initialised {
self.window = vec![0u8; self.window_size];
self.initialised = true;
}
}
fn snapshot(&self) -> Snapshot {
Snapshot {
bit_reader: self.bit_reader,
arith: self.arith,
header_read: self.header_read,
model0: self.model0.clone(),
model1: self.model1.clone(),
model2: self.model2.clone(),
model3: self.model3.clone(),
model4: self.model4.clone(),
model5: self.model5.clone(),
model6: self.model6.clone(),
model6len: self.model6len.clone(),
model7: self.model7.clone(),
trailer_state: self.trailer_state,
}
}
fn restore(&mut self, s: Snapshot) {
self.bit_reader = s.bit_reader;
self.arith = s.arith;
self.header_read = s.header_read;
self.model0 = s.model0;
self.model1 = s.model1;
self.model2 = s.model2;
self.model3 = s.model3;
self.model4 = s.model4;
self.model5 = s.model5;
self.model6 = s.model6;
self.model6len = s.model6len;
self.model7 = s.model7;
self.trailer_state = s.trailer_state;
}
fn drain_pending_output(&mut self, output: &mut [u8], written: &mut usize) -> bool {
if let Some(po) = self.pending_output {
if po.remaining == 0 {
self.pending_output = None;
return false;
}
let n = po.remaining.min(output.len() - *written);
let first = (self.window_size - po.start).min(n);
output[*written..*written + first]
.copy_from_slice(&self.window[po.start..po.start + first]);
if first < n {
let rem = n - first;
output[*written + first..*written + n].copy_from_slice(&self.window[..rem]);
}
*written += n;
let new_rem = po.remaining - n;
if new_rem == 0 {
self.pending_output = None;
} else {
self.pending_output = Some(PendingOutput {
start: (po.start + n) & self.window_mask,
remaining: new_rem,
});
}
return *written == output.len();
}
false
}
fn enqueue_output(&mut self, start: usize, n: usize) {
if n == 0 {
return;
}
match self.pending_output {
Some(po) => {
let expected_end = (po.start + po.remaining) & self.window_mask;
if expected_end == start {
self.pending_output = Some(PendingOutput {
start: po.start,
remaining: po.remaining + n,
});
} else {
self.pending_output = Some(PendingOutput {
start,
remaining: n,
});
}
}
None => {
self.pending_output = Some(PendingOutput {
start,
remaining: n,
});
}
}
}
fn process_trailer(&mut self) -> Result<bool, Error> {
if self.trailer_state == TrailerState::None {
return Ok(true);
}
loop {
let b = self.bit_reader.read_bits(8, &self.input_buf)?;
if b == 0xFF {
self.trailer_state = TrailerState::None;
self.header_read = false;
self.frame_todo = FRAME_SIZE;
return Ok(true);
}
if b != 0 {
return Err(Error::Corrupt);
}
}
}
fn decode_one_packet(&mut self) -> Result<(), Error> {
let selector =
self.arith
.get_symbol(&mut self.model7, &mut self.bit_reader, &self.input_buf)?;
if selector >= 7 {
return Err(Error::Corrupt);
}
if selector < 4 {
let sym = match selector {
0 => self.arith.get_symbol(
&mut self.model0,
&mut self.bit_reader,
&self.input_buf,
)?,
1 => self.arith.get_symbol(
&mut self.model1,
&mut self.bit_reader,
&self.input_buf,
)?,
2 => self.arith.get_symbol(
&mut self.model2,
&mut self.bit_reader,
&self.input_buf,
)?,
_ => self.arith.get_symbol(
&mut self.model3,
&mut self.bit_reader,
&self.input_buf,
)?,
};
let byte = (sym & 0xFF) as u8;
self.window[self.window_posn & self.window_mask] = byte;
let start = self.window_posn & self.window_mask;
self.window_posn += 1;
self.frame_todo = self.frame_todo.wrapping_sub(1);
self.enqueue_output(start, 1);
return Ok(());
}
let (match_length, match_offset) = match selector {
4 => {
let sym = self.arith.get_symbol(
&mut self.model4,
&mut self.bit_reader,
&self.input_buf,
)? as usize;
if sym >= EXTRA_BITS.len() {
return Err(Error::Corrupt);
}
let eb = EXTRA_BITS[sym] as u32;
let extra = if eb == 0 {
0
} else {
self.bit_reader.read_many_bits(eb, &self.input_buf)?
} as usize;
let off = POSITION_BASE[sym] as usize + extra + 1;
(3usize, off)
}
5 => {
let sym = self.arith.get_symbol(
&mut self.model5,
&mut self.bit_reader,
&self.input_buf,
)? as usize;
if sym >= EXTRA_BITS.len() {
return Err(Error::Corrupt);
}
let eb = EXTRA_BITS[sym] as u32;
let extra = if eb == 0 {
0
} else {
self.bit_reader.read_many_bits(eb, &self.input_buf)?
} as usize;
let off = POSITION_BASE[sym] as usize + extra + 1;
(4usize, off)
}
6 => {
let sym_len = self.arith.get_symbol(
&mut self.model6len,
&mut self.bit_reader,
&self.input_buf,
)? as usize;
if sym_len >= LENGTH_EXTRA.len() {
return Err(Error::Corrupt);
}
let le = LENGTH_EXTRA[sym_len] as u32;
let extra_len = if le == 0 {
0
} else {
self.bit_reader.read_many_bits(le, &self.input_buf)?
} as usize;
let match_length = LENGTH_BASE[sym_len] as usize + extra_len + 5;
let sym_off = self.arith.get_symbol(
&mut self.model6,
&mut self.bit_reader,
&self.input_buf,
)? as usize;
if sym_off >= EXTRA_BITS.len() {
return Err(Error::Corrupt);
}
let eb = EXTRA_BITS[sym_off] as u32;
let extra_off = if eb == 0 {
0
} else {
self.bit_reader.read_many_bits(eb, &self.input_buf)?
} as usize;
let match_offset = POSITION_BASE[sym_off] as usize + extra_off + 1;
(match_length, match_offset)
}
_ => return Err(Error::Corrupt),
};
if match_offset == 0 || match_offset > self.window_size {
return Err(Error::InvalidDistance);
}
self.pending_match = Some(PendingMatch {
match_offset,
remaining: match_length,
});
self.frame_todo = self.frame_todo.wrapping_sub(match_length as u32);
Ok(())
}
fn copy_pending_match(&mut self) {
let Some(pm) = self.pending_match.take() else {
return;
};
let match_offset = pm.match_offset;
let mut remaining = pm.remaining;
let mut emit_start = self.window_posn & self.window_mask;
let mut emit_count = 0usize;
while remaining > 0 {
let dest = self.window_posn & self.window_mask;
let src = (self.window_posn.wrapping_sub(match_offset)) & self.window_mask;
let dst_room = self.window_size - dest;
let src_room = self.window_size - src;
let max_chunk = remaining.min(dst_room).min(src_room);
if max_chunk > 1 && match_offset >= max_chunk {
self.window.copy_within(src..src + max_chunk, dest);
self.window_posn += max_chunk;
remaining -= max_chunk;
emit_count += max_chunk;
let next = self.window_posn & self.window_mask;
if next == 0 && remaining > 0 {
self.enqueue_output(emit_start, emit_count);
emit_start = 0;
emit_count = 0;
}
continue;
}
self.window[dest] = self.window[src];
self.window_posn += 1;
remaining -= 1;
emit_count += 1;
let next = self.window_posn & self.window_mask;
if next == 0 && remaining > 0 {
self.enqueue_output(emit_start, emit_count);
emit_start = 0;
emit_count = 0;
}
}
if emit_count > 0 {
self.enqueue_output(emit_start, emit_count);
}
}
fn drain(&mut self, output: &mut [u8], written: &mut usize) -> Result<bool, Error> {
if self.drain_pending_output(output, written) {
return Ok(false);
}
if self.pending_match.is_some() {
self.copy_pending_match();
if self.drain_pending_output(output, written) {
return Ok(false);
}
}
loop {
if *written == output.len() {
return Ok(false);
}
if self.trailer_state != TrailerState::None {
let snap = self.snapshot();
match self.process_trailer() {
Ok(true) => {}
Ok(false) => unreachable!(),
Err(Error::UnexpectedEnd) => {
self.restore(snap);
return Ok(false);
}
Err(e) => return Err(e),
}
}
if !self.header_read {
let snap = self.snapshot();
match self.arith.init_frame(&mut self.bit_reader, &self.input_buf) {
Ok(()) => {
self.header_read = true;
}
Err(Error::UnexpectedEnd) => {
self.restore(snap);
return Ok(false);
}
Err(e) => return Err(e),
}
}
if self.frame_todo == 0 {
let leftover = self.bit_reader.bits_left() & 7;
self.bit_reader.remove_bits(leftover);
self.trailer_state = TrailerState::SeekingFF;
continue;
}
let snap = self.snapshot();
let saved_window_posn = self.window_posn;
let saved_frame_todo = self.frame_todo;
let saved_pending_match = self.pending_match;
match self.decode_one_packet() {
Ok(()) => {
if self.pending_match.is_some() {
self.copy_pending_match();
}
if self.drain_pending_output(output, written) {
return Ok(false);
}
}
Err(Error::UnexpectedEnd) => {
self.restore(snap);
self.window_posn = saved_window_posn;
self.frame_todo = saved_frame_todo;
self.pending_match = saved_pending_match;
return Ok(false);
}
Err(e) => return Err(e),
}
}
}
fn compact_input(&mut self) {
let bp = self.bit_reader.byte_pos();
if bp == 0 {
return;
}
self.input_buf.drain(0..bp);
self.bit_reader.rebase(bp);
}
}
#[derive(Clone)]
struct Snapshot {
bit_reader: BitReader,
arith: ArithDecoder,
header_read: bool,
model0: Model,
model1: Model,
model2: Model,
model3: Model,
model4: Model,
model5: Model,
model6: Model,
model6len: Model,
model7: Model,
trailer_state: TrailerState,
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
self.ensure_window();
self.input_buf.extend_from_slice(input);
let mut written = 0usize;
self.drain(output, &mut written)?;
self.compact_input();
Ok(RawProgress {
consumed: input.len(),
written,
done: false,
})
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
self.ensure_window();
self.bit_reader.set_eof(true);
let mut written = 0usize;
self.drain(output, &mut written)?;
self.compact_input();
let done =
self.pending_output.is_none() && self.pending_match.is_none() && written < output.len();
Ok(RawProgress {
consumed: 0,
written,
done,
})
}
fn raw_reset(&mut self) {
self.input_buf.clear();
if self.initialised {
for b in self.window.iter_mut() {
*b = 0;
}
}
self.window_posn = 0;
self.bit_reader = BitReader::new();
self.arith = ArithDecoder::new();
self.header_read = false;
self.frame_todo = FRAME_SIZE;
self.trailer_state = TrailerState::None;
self.pending_match = None;
self.pending_output = None;
let i = (self.window_bits * 2) as usize;
self.model0 = Model::new(0, 64);
self.model1 = Model::new(64, 64);
self.model2 = Model::new(128, 64);
self.model3 = Model::new(192, 64);
self.model4 = Model::new(0, i.min(24));
self.model5 = Model::new(0, i.min(36));
self.model6 = Model::new(0, i);
self.model6len = Model::new(0, 27);
self.model7 = Model::new(0, 7);
}
}