use std::io::Read;
use byteorder::ReadBytesExt;
use crate::models;
use crate::models::DM_LENGTH_SHIFT;
pub const AC_MAX_LENGTH: u32 = 0xFFFF_FFFF;
pub const AC_MIN_LENGTH: u32 = 0x0100_0000;
pub struct ArithmeticDecoder<T: Read> {
in_stream: T,
value: u32,
length: u32,
}
impl<T: Read> ArithmeticDecoder<T> {
pub fn new(in_stream: T) -> Self {
Self {
in_stream,
value: 0,
length: AC_MAX_LENGTH,
}
}
pub fn reset(&mut self) {
self.value = 0;
self.length = AC_MAX_LENGTH;
}
pub fn read_init_bytes(&mut self) -> std::io::Result<()> {
let mut v = [0u8; 4];
self.in_stream.read_exact(&mut v)?;
self.value =
u32::from(v[0]) << 24 | u32::from(v[1]) << 16 | u32::from(v[2]) << 8 | u32::from(v[3]);
Ok(())
}
pub fn decode_bit(&mut self, model: &mut models::ArithmeticBitModel) -> std::io::Result<u32> {
let x = model.bit_0_prob * (self.length >> models::BM_LENGTH_SHIFT);
let sym = self.value >= x;
if !sym {
self.length = x;
model.bit_0_count += 1;
} else {
self.value -= x;
self.length -= x;
}
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
model.bits_until_update -= 1;
if model.bits_until_update == 0 {
model.update();
}
Ok(sym as u32)
}
pub fn decode_symbol(&mut self, model: &mut models::ArithmeticModel) -> std::io::Result<u32> {
let mut sym;
let mut n;
let mut x;
let mut y = self.length;
if !model.decoder_table.is_empty() {
self.length >>= DM_LENGTH_SHIFT;
if self.length == 0 {
return Err(std::io::Error::new(
std::io::ErrorKind::UnexpectedEof,
"ArithmeticDecoder length is zero",
));
}
let dv = self.value / self.length;
let t = dv >> model.table_shift;
sym = model.decoder_table[t as usize]; n = model.decoder_table[t as usize + 1] + 1;
while n > sym + 1 {
let k = (sym + n) >> 1;
if model.distribution[k as usize] > dv {
n = k;
} else {
sym = k;
}
}
x = model.distribution[sym as usize] * self.length;
if sym != model.last_symbol {
y = model.distribution[sym as usize + 1] * self.length;
}
} else {
x = 0;
sym = 0;
self.length >>= DM_LENGTH_SHIFT;
n = model.symbols;
let mut k = n >> 1;
loop {
let z = self.length * model.distribution[k as usize];
if z > self.value {
n = k;
y = z; } else {
sym = k;
x = z; }
k = (sym + n) >> 1;
if k == sym {
break;
}
}
}
self.value -= x;
self.length = y - x;
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
model.symbol_count[sym as usize] += 1;
model.symbols_until_update -= 1;
if model.symbols_until_update == 0 {
model.update();
}
Ok(sym)
}
pub fn read_bit(&mut self) -> std::io::Result<u32> {
self.length >>= 1;
let sym = self.value / self.length;
self.value -= self.length * sym;
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
Ok(sym)
}
pub fn read_bits(&mut self, mut bits: u32) -> std::io::Result<u32> {
debug_assert!(bits > 0 && (bits <= 32));
if bits > 19 {
let tmp = u32::from(self.read_short()?);
bits -= 16;
let tmpl = self.read_bits(bits)? << 16;
Ok(tmpl | tmp)
} else {
self.length >>= bits;
let sym = self.value / self.length;
self.value -= self.length * sym;
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
Ok(sym)
}
}
#[allow(dead_code)]
fn read_byte(&mut self) -> std::io::Result<u8> {
self.length >>= 8;
let sym = self.value / self.length;
self.value -= self.length * sym;
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
debug_assert!(sym < (1 << 8));
Ok(sym as u8)
}
fn read_short(&mut self) -> std::io::Result<u16> {
self.length >>= 16;
let sym = self.value / self.length;
self.value -= self.length * sym;
if self.length < AC_MIN_LENGTH {
self.renorm_dec_interval()?;
}
debug_assert!(sym < (1 << 16));
Ok(sym as u16)
}
pub fn read_int(&mut self) -> std::io::Result<u32> {
let lower_int = u32::from(self.read_short()?);
let upper_int = u32::from(self.read_short()?);
Ok(upper_int << 16 | lower_int)
}
pub fn read_int_64(&mut self) -> std::io::Result<u64> {
let lower_int = u64::from(self.read_int()?);
let upper_int = u64::from(self.read_int()?);
Ok((upper_int << 32) | lower_int)
}
fn renorm_dec_interval(&mut self) -> std::io::Result<()> {
loop {
self.value = (self.value << 8) | u32::from(self.in_stream.read_u8()?);
self.length <<= 8;
if self.length >= AC_MIN_LENGTH {
break;
}
}
Ok(())
}
pub fn get_mut(&mut self) -> &mut T {
&mut self.in_stream
}
pub fn get_ref(&self) -> &T {
&self.in_stream
}
pub fn into_inner(self) -> T {
self.in_stream
}
}