use alloc::vec;
use alloc::vec::Vec;
use core::fmt;
use super::{CODE_BOT, CODE_SHIFT, CODE_TOP, SYM_BITS, SYM_MAX, WINDOW_SIZE, ilog};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct RangeEncoderError;
impl fmt::Display for RangeEncoderError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str("range encoder overflowed its frame buffer")
}
}
#[cfg(feature = "std")]
impl std::error::Error for RangeEncoderError {}
#[derive(Debug, Clone)]
pub struct RangeEncoder {
buf: Vec<u8>,
offs: usize,
end_offs: usize,
end_window: u32,
nend_bits: u32,
val: u32,
rng: u32,
rem: Option<u8>,
ext: u32,
nbits_total: u32,
error: bool,
}
impl RangeEncoder {
#[must_use]
pub fn new(size: usize) -> Self {
RangeEncoder {
buf: vec![0; size],
offs: 0,
end_offs: 0,
end_window: 0,
nend_bits: 0,
val: 0,
rng: CODE_TOP,
rem: None,
ext: 0,
nbits_total: 33,
error: false,
}
}
fn write_byte(&mut self, b: u8) {
if self.offs + self.end_offs >= self.buf.len() {
self.error = true;
} else {
self.buf[self.offs] = b;
self.offs += 1;
}
}
fn write_byte_at_end(&mut self, b: u8) {
if self.offs + self.end_offs >= self.buf.len() {
self.error = true;
} else {
self.end_offs += 1;
let at = self.buf.len() - self.end_offs;
self.buf[at] = b;
}
}
fn carry_out(&mut self, c: u32) {
if c == SYM_MAX {
self.ext += 1;
} else {
let carry = (c >> SYM_BITS) as u8;
if let Some(rem) = self.rem.take() {
self.write_byte(rem + carry);
}
if self.ext > 0 {
let sym = (SYM_MAX + u32::from(carry)) as u8;
for _ in 0..self.ext {
self.write_byte(sym);
}
self.ext = 0;
}
self.rem = Some((c & SYM_MAX) as u8);
}
}
fn normalize(&mut self) {
while self.rng <= CODE_BOT {
self.carry_out(self.val >> CODE_SHIFT);
self.val = (self.val << SYM_BITS) & 0x7FFF_FFFF;
self.rng <<= SYM_BITS;
self.nbits_total += SYM_BITS;
}
}
pub fn encode(&mut self, fl: u32, fh: u32, ft: u32) {
debug_assert!(fl < fh && fh <= ft && ft <= u32::from(u16::MAX));
let r = self.rng / ft;
if fl > 0 {
self.val += self.rng - r * (ft - fl);
self.rng = r * (fh - fl);
} else {
self.rng -= r * (ft - fh);
}
self.normalize();
}
pub fn encode_bin(&mut self, fl: u32, fh: u32, ftb: u32) {
debug_assert!(ftb <= 16);
let ft = 1u32 << ftb;
debug_assert!(fl < fh && fh <= ft);
let r = self.rng >> ftb;
if fl > 0 {
self.val += self.rng - r * (ft - fl);
self.rng = r * (fh - fl);
} else {
self.rng -= r * (ft - fh);
}
self.normalize();
}
pub fn encode_bit_logp(&mut self, bit: bool, logp: u32) {
let r = self.rng;
let s = r >> logp;
let r = r - s;
if bit {
self.val += r;
}
self.rng = if bit { s } else { r };
self.normalize();
}
pub fn encode_icdf(&mut self, k: usize, icdf: &[u8], ftb: u32) {
let r = self.rng >> ftb;
if k > 0 {
self.val += self.rng - r * u32::from(icdf[k - 1]);
self.rng = r * u32::from(icdf[k - 1] - icdf[k]);
} else {
self.rng -= r * u32::from(icdf[k]);
}
self.normalize();
}
pub fn encode_raw_bits(&mut self, value: u32, bits: u32) {
debug_assert!(bits > 0 && bits <= WINDOW_SIZE - SYM_BITS);
debug_assert!(value >> bits == 0 || bits == 32);
if self.nend_bits + bits > WINDOW_SIZE {
while self.nend_bits >= SYM_BITS {
self.write_byte_at_end((self.end_window & SYM_MAX) as u8);
self.end_window >>= SYM_BITS;
self.nend_bits -= SYM_BITS;
}
}
self.end_window |= value << self.nend_bits;
self.nend_bits += bits;
self.nbits_total += bits;
}
pub fn encode_uint(&mut self, t: u32, ft: u32) {
debug_assert!(ft > 1);
debug_assert!(t < ft);
let ftb = ilog(ft - 1);
if ftb <= 8 {
self.encode(t, t + 1, ft);
} else {
let ft_hi = ((ft - 1) >> (ftb - 8)) + 1;
let t_hi = t >> (ftb - 8);
self.encode(t_hi, t_hi + 1, ft_hi);
self.encode_raw_bits(t & ((1 << (ftb - 8)) - 1), ftb - 8);
}
}
#[inline]
#[must_use]
pub fn tell(&self) -> u32 {
self.nbits_total - ilog(self.rng)
}
#[must_use]
pub fn tell_frac(&self) -> u32 {
super::decoder::tell_frac(self.nbits_total, self.rng)
}
#[inline]
#[must_use]
pub fn range_size(&self) -> u32 {
self.rng
}
pub fn shrink(&mut self, new_size: usize) {
assert!(self.offs + self.end_offs <= new_size, "shrink below written data");
let old_len = self.buf.len();
assert!(new_size <= old_len, "shrink cannot grow the buffer");
if new_size != old_len {
for k in 0..self.end_offs {
self.buf[new_size - self.end_offs + k] = self.buf[old_len - self.end_offs + k];
}
self.buf.truncate(new_size);
}
}
pub fn finalize(mut self) -> Result<Vec<u8>, RangeEncoderError> {
let mut l: i32 = (super::CODE_BITS - ilog(self.rng)) as i32;
let mut msk = (CODE_TOP - 1) >> l;
let mut end = self.val.wrapping_add(msk) & !msk;
if (end | msk) >= self.val + self.rng {
l += 1;
msk >>= 1;
end = self.val.wrapping_add(msk) & !msk;
}
while l > 0 {
self.carry_out(end >> CODE_SHIFT);
end = (end << SYM_BITS) & (CODE_TOP - 1);
l -= SYM_BITS as i32;
}
if self.rem.is_some() || self.ext > 0 {
self.carry_out(0);
}
while self.nend_bits >= SYM_BITS {
self.write_byte_at_end((self.end_window & SYM_MAX) as u8);
self.end_window >>= SYM_BITS;
self.nend_bits -= SYM_BITS;
}
if !self.error {
let gap = self.offs..self.buf.len() - self.end_offs;
self.buf[gap].fill(0);
if self.nend_bits > 0 {
if self.end_offs >= self.buf.len() {
self.error = true;
} else {
let spare = (-l) as u32;
if self.offs + self.end_offs >= self.buf.len() && spare < self.nend_bits {
self.end_window &= (1u32 << spare) - 1;
self.error = true;
}
let at = self.buf.len() - self.end_offs - 1;
self.buf[at] |= self.end_window as u8;
}
}
}
if self.error {
Err(RangeEncoderError)
} else {
Ok(self.buf)
}
}
}