extern crate alloc;
use crate::error::Error;
use super::arena::{
Arena, STATE_SIZE, UNIT_SIZE, ctx_num_stats, ctx_set_num_stats, ctx_set_stats, ctx_set_suffix,
ctx_set_summ_freq, ctx_summ_freq, state_freq, state_set_freq, state_store, state_symbol,
swap_states,
};
use super::range_dec::{ByteSource, RangeDec};
const MAX_FREQ: u8 = 124;
pub(super) struct Model {
arena: Arena,
ctx_off: u32,
stats_off: u32,
}
impl Model {
const MIN_ARENA: usize = 2048;
pub fn new(order: u32, mem_size_bytes: usize) -> Result<Self, Error> {
if !(2..=16).contains(&order) {
return Err(Error::BadHeader);
}
let _ = mem_size_bytes;
let size = Self::MIN_ARENA;
let _ = order; let mut m = Self {
arena: Arena::new(size),
ctx_off: 0,
stats_off: UNIT_SIZE as u32,
};
m.restart()?;
Ok(m)
}
pub fn restart(&mut self) -> Result<(), Error> {
self.arena.clear();
ctx_set_num_stats(&mut self.arena, self.ctx_off, 256).ok_or(Error::Corrupt)?;
ctx_set_summ_freq(&mut self.arena, self.ctx_off, 256 + 1).ok_or(Error::Corrupt)?;
ctx_set_stats(&mut self.arena, self.ctx_off, self.stats_off).ok_or(Error::Corrupt)?;
ctx_set_suffix(&mut self.arena, self.ctx_off, 0).ok_or(Error::Corrupt)?;
for i in 0..256u32 {
let st = self.stats_off + i * STATE_SIZE as u32;
state_store(&mut self.arena, st, i as u8, 1, 0).ok_or(Error::Corrupt)?;
}
Ok(())
}
pub fn decode_symbol(
&mut self,
rd: &mut RangeDec,
src: &mut ByteSource<'_>,
) -> Result<u8, Error> {
let rd_snap = rd.clone();
let pos_snap = src.pos;
match self.decode_inner(rd, src) {
Ok(sym) => Ok(sym),
Err(Error::UnexpectedEnd) => {
*rd = rd_snap;
src.pos = pos_snap;
Err(Error::UnexpectedEnd)
}
Err(e) => Err(e),
}
}
fn decode_inner(&mut self, rd: &mut RangeDec, src: &mut ByteSource<'_>) -> Result<u8, Error> {
let nstats = ctx_num_stats(&self.arena, self.ctx_off).ok_or(Error::Corrupt)?;
if nstats == 0 {
return Err(Error::Corrupt);
}
let summ_freq = ctx_summ_freq(&self.arena, self.ctx_off).ok_or(Error::Corrupt)?;
let total = summ_freq as u32;
if total == 0 {
return Err(Error::Corrupt);
}
let hi_count = rd.get_threshold(total);
let mut acc: u32 = 0;
for i in 0..nstats as u32 {
let st = self.stats_off + i * STATE_SIZE as u32;
let f = state_freq(&self.arena, st).ok_or(Error::Corrupt)? as u32;
if acc + f > hi_count {
rd.decode(src, acc, f)?;
let sym = state_symbol(&self.arena, st).ok_or(Error::Corrupt)?;
self.bump_freq(i, f, summ_freq)?;
return Ok(sym);
}
acc += f;
}
Err(Error::Corrupt)
}
fn bump_freq(&mut self, i: u32, freq: u32, summ_freq: u16) -> Result<(), Error> {
let new_freq = freq + 4;
let new_summ = summ_freq as u32 + 4;
if new_freq > MAX_FREQ as u32 || new_summ > 0xFFFF {
self.rescale()?;
return Ok(());
}
let st = self.stats_off + i * STATE_SIZE as u32;
state_set_freq(&mut self.arena, st, new_freq as u8).ok_or(Error::Corrupt)?;
ctx_set_summ_freq(&mut self.arena, self.ctx_off, new_summ as u16).ok_or(Error::Corrupt)?;
if i > 0 {
let prev = self.stats_off + (i - 1) * STATE_SIZE as u32;
let prev_f = state_freq(&self.arena, prev).ok_or(Error::Corrupt)? as u32;
if new_freq > prev_f {
swap_states(&mut self.arena, st, prev).ok_or(Error::Corrupt)?;
}
}
Ok(())
}
fn rescale(&mut self) -> Result<(), Error> {
let nstats = ctx_num_stats(&self.arena, self.ctx_off).ok_or(Error::Corrupt)? as u32;
let mut new_summ: u32 = 0;
for i in 0..nstats {
let st = self.stats_off + i * STATE_SIZE as u32;
let f = state_freq(&self.arena, st).ok_or(Error::Corrupt)? as u32;
let new_f = ((f + 1) >> 1).max(1) as u8;
state_set_freq(&mut self.arena, st, new_f).ok_or(Error::Corrupt)?;
new_summ += new_f as u32;
}
ctx_set_summ_freq(&mut self.arena, self.ctx_off, new_summ.min(0xFFFF) as u16)
.ok_or(Error::Corrupt)?;
Ok(())
}
}