use alloc::vec;
use alloc::vec::Vec;
use crate::arsenic::tables::ModelParams;
use crate::error::Error;
const NUMBITS: u32 = 26;
const ONE: i64 = 1 << (NUMBITS - 1);
const HALF: i64 = 1 << (NUMBITS - 2);
pub(crate) struct BitReader<'a> {
data: &'a [u8],
pos: usize,
underflow: bool,
}
impl<'a> BitReader<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
Self {
data,
pos: 0,
underflow: false,
}
}
#[inline]
fn next_bit(&mut self) -> i64 {
let byte_idx = self.pos >> 3;
if byte_idx >= self.data.len() {
self.underflow = true;
return 0;
}
let bit = (self.data[byte_idx] >> (7 - (self.pos & 7))) & 1;
self.pos += 1;
bit as i64
}
#[inline]
pub(crate) fn underflowed(&self) -> bool {
self.underflow
}
}
pub(crate) struct Model {
first: u16,
increment: u32,
limit: u32,
freq: Vec<u32>,
total: u32,
}
impl Model {
pub(crate) fn new(p: &ModelParams) -> Self {
let n = p.num_symbols();
Self {
first: p.first,
increment: p.increment,
limit: p.limit,
freq: vec![p.increment; n],
total: p.increment * n as u32,
}
}
pub(crate) fn reset(&mut self) {
for f in self.freq.iter_mut() {
*f = self.increment;
}
self.total = self.increment * self.freq.len() as u32;
}
fn adapt(&mut self, n: usize) {
self.freq[n] += self.increment;
self.total += self.increment;
if self.total > self.limit {
let mut new_total = 0u32;
for f in self.freq.iter_mut() {
*f = (*f + 1) >> 1;
new_total += *f;
}
self.total = new_total;
}
}
}
pub(crate) struct RangeDecoder<'a> {
reader: BitReader<'a>,
range: i64,
code: i64,
}
impl<'a> RangeDecoder<'a> {
pub(crate) fn new(data: &'a [u8]) -> Self {
let mut reader = BitReader::new(data);
let mut code: i64 = 0;
for _ in 0..NUMBITS {
code = (code << 1) | reader.next_bit();
}
Self {
reader,
range: ONE,
code,
}
}
#[inline]
pub(crate) fn underflowed(&self) -> bool {
self.reader.underflowed()
}
pub(crate) fn decode_index(&mut self, model: &mut Model) -> Result<usize, Error> {
let total = model.total as i64;
if total < 1 {
return Err(Error::Corrupt);
}
let r = self.range / total;
if r < 1 {
return Err(Error::Corrupt);
}
let mut f = self.code / r;
if f >= total {
f = total - 1;
}
let mut cumulative: i64 = 0;
let mut n = model.freq.len() - 1;
for (i, &fr) in model.freq.iter().enumerate() {
let next = cumulative + fr as i64;
if f < next {
n = i;
break;
}
cumulative = next;
}
let size = model.freq[n] as i64;
let low = cumulative;
let lowincr = r * low;
self.code -= lowincr;
if low + size == total {
self.range -= lowincr;
} else {
self.range = size * r;
}
let mut guard = 0u32;
while self.range <= HALF {
self.range <<= 1;
self.code = (self.code << 1) | self.reader.next_bit();
guard += 1;
if guard > NUMBITS + 2 {
return Err(Error::Corrupt);
}
}
model.adapt(n);
Ok(n)
}
#[inline]
pub(crate) fn decode_value(&mut self, model: &mut Model) -> Result<u16, Error> {
let n = self.decode_index(model)?;
Ok(model.first + n as u16)
}
pub(crate) fn decode_bits(&mut self, model: &mut Model, bits: u32) -> Result<u32, Error> {
let mut value = 0u32;
for i in 0..bits {
let bit = self.decode_index(model)? as u32;
value |= bit << i;
}
Ok(value)
}
}