extern crate alloc;
use alloc::vec::Vec;
use crate::error::Error;
use crate::traits::{RawDecoder, RawProgress};
use super::model::Model;
use super::range_dec::{ByteSource, RangeDec};
const HEADER_LEN: usize = 11;
const RANGE_INIT_LEN: usize = 5;
const UNKNOWN_LEN: u64 = u64::MAX;
#[derive(Clone, Copy, PartialEq, Eq)]
enum Phase {
Header,
RangeInit,
Body,
Done,
}
pub struct Decoder {
in_buf: Vec<u8>,
in_committed: usize,
decoded: Vec<u8>,
decoded_idx: usize,
phase: Phase,
poisoned: bool,
order: u32,
mem_mb: u32,
restoration: u8,
expected_len: u64,
produced_len: u64,
model: Option<Model>,
range_dec: RangeDec,
}
impl Decoder {
pub fn new() -> Self {
Self {
in_buf: Vec::new(),
in_committed: 0,
decoded: Vec::new(),
decoded_idx: 0,
phase: Phase::Header,
poisoned: false,
order: 0,
mem_mb: 0,
restoration: 0,
expected_len: 0,
produced_len: 0,
model: None,
range_dec: RangeDec::new(),
}
}
fn poison(&mut self, e: Error) -> Error {
self.poisoned = true;
e
}
fn step(&mut self) -> Result<bool, Error> {
match self.phase {
Phase::Header => self.try_header(),
Phase::RangeInit => self.try_range_init(),
Phase::Body => self.try_body(),
Phase::Done => Ok(false),
}
}
fn try_header(&mut self) -> Result<bool, Error> {
if self.in_buf.len() < self.in_committed + HEADER_LEN {
return Ok(false);
}
let off = self.in_committed;
let h = &self.in_buf[off..off + HEADER_LEN];
let order = h[0] as u32;
let mem_mb = h[1] as u32;
let restoration = h[2];
if !(2..=16).contains(&order) {
return Err(self.poison(Error::BadHeader));
}
if !(1..=255).contains(&mem_mb) {
return Err(self.poison(Error::BadHeader));
}
if restoration > 2 {
return Err(self.poison(Error::BadHeader));
}
let len = u64::from_le_bytes(h[3..11].try_into().unwrap());
self.order = order;
self.mem_mb = mem_mb;
self.restoration = restoration;
self.expected_len = len;
self.in_committed += HEADER_LEN;
let mem_bytes = (mem_mb as usize).saturating_mul(1024 * 1024);
self.model = Some(Model::new(order, mem_bytes).map_err(|e| self.poison(e))?);
self.phase = Phase::RangeInit;
Ok(true)
}
fn try_range_init(&mut self) -> Result<bool, Error> {
if self.in_buf.len() < self.in_committed + RANGE_INIT_LEN {
return Ok(false);
}
self.range_dec.pos = self.in_committed;
match self.range_dec.init(&self.in_buf) {
Ok(true) => {
self.in_committed = self.range_dec.pos;
self.phase = Phase::Body;
Ok(true)
}
Ok(false) => Ok(false), Err(e) => Err(self.poison(e)),
}
}
fn try_body(&mut self) -> Result<bool, Error> {
let model = match self.model.as_mut() {
Some(m) => m,
None => return Err(self.poison(Error::Corrupt)),
};
if self.expected_len != UNKNOWN_LEN && self.produced_len >= self.expected_len {
if self.range_dec.is_finished_ok() {
self.phase = Phase::Done;
return Ok(true);
}
self.phase = Phase::Done;
return Ok(true);
}
let mut src = ByteSource::new(&self.in_buf, self.range_dec.pos);
let mut progressed = false;
loop {
if self.decoded.len() - self.decoded_idx > 4096 {
break;
}
let rd_pre = self.range_dec.clone();
let pos_pre = src.pos;
match model.decode_symbol(&mut self.range_dec, &mut src) {
Ok(sym) => {
self.decoded.push(sym);
self.produced_len += 1;
progressed = true;
if self.expected_len != UNKNOWN_LEN && self.produced_len >= self.expected_len {
break;
}
}
Err(Error::UnexpectedEnd) => {
self.range_dec = rd_pre;
src.pos = pos_pre;
break;
}
Err(e) => return Err(self.poison(e)),
}
}
self.range_dec.pos = src.pos;
self.in_committed = self.range_dec.pos;
Ok(progressed)
}
fn drain(&mut self, output: &mut [u8], written: &mut usize) {
let avail = self.decoded.len() - self.decoded_idx;
let space = output.len() - *written;
let n = avail.min(space);
if n > 0 {
output[*written..*written + n]
.copy_from_slice(&self.decoded[self.decoded_idx..self.decoded_idx + n]);
self.decoded_idx += n;
*written += n;
}
if self.decoded_idx == self.decoded.len() {
self.decoded.clear();
self.decoded_idx = 0;
}
}
}
impl Default for Decoder {
fn default() -> Self {
Self::new()
}
}
impl RawDecoder for Decoder {
fn raw_decode(&mut self, input: &[u8], output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let mut consumed = 0usize;
let mut written = 0usize;
self.drain(output, &mut written);
if written == output.len() && self.decoded_idx < self.decoded.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
loop {
if matches!(self.phase, Phase::Done) && self.decoded_idx == self.decoded.len() {
return Ok(RawProgress {
consumed,
written,
done: true,
});
}
if consumed < input.len() {
self.in_buf.extend_from_slice(&input[consumed..]);
consumed = input.len();
}
let progressed = self.step()?;
self.drain(output, &mut written);
if self.in_committed > 1 << 20 {
let off = self.in_committed;
self.in_buf.drain(..off);
self.in_committed = 0;
self.range_dec.pos = self.range_dec.pos.saturating_sub(off);
}
if matches!(self.phase, Phase::Done) {
continue;
}
if written == output.len() && self.decoded_idx < self.decoded.len() {
consumed = consumed.saturating_sub(1);
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
if !progressed {
if consumed >= input.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
if written == output.len() {
return Ok(RawProgress {
consumed,
written,
done: false,
});
}
}
}
fn raw_finish(&mut self, output: &mut [u8]) -> Result<RawProgress, Error> {
if self.poisoned {
return Err(Error::Corrupt);
}
let empty: [u8; 0] = [];
let p = self.raw_decode(&empty, output)?;
if matches!(self.phase, Phase::Done) && self.decoded_idx == self.decoded.len() {
Ok(RawProgress {
consumed: 0,
written: p.written,
done: true,
})
} else {
Err(self.poison(Error::UnexpectedEnd))
}
}
fn raw_reset(&mut self) {
self.in_buf.clear();
self.in_committed = 0;
self.decoded.clear();
self.decoded_idx = 0;
self.phase = Phase::Header;
self.poisoned = false;
self.order = 0;
self.mem_mb = 0;
self.restoration = 0;
self.expected_len = 0;
self.produced_len = 0;
self.model = None;
self.range_dec = RangeDec::new();
}
}