use std::io::{self, Read};
use crate::laz::arithmetic_model::{
ArithmeticBitModel,
ArithmeticSymbolModel,
BIT_LENGTH_SHIFT,
SYMBOL_LENGTH_SHIFT,
};
pub const AC_MAX_LENGTH: u32 = 0xFFFF_FFFF;
pub const AC_MIN_LENGTH: u32 = 0x0100_0000;
pub struct ArithmeticDecoder<R: Read> {
input: R,
value: u32,
length: u32,
}
impl<R: Read> ArithmeticDecoder<R> {
pub fn new(input: R) -> Self {
Self {
input,
value: 0,
length: AC_MAX_LENGTH,
}
}
pub fn read_init_bytes(&mut self) -> io::Result<()> {
let mut init = [0u8; 4];
self.input.read_exact(&mut init)?;
self.value = u32::from(init[0]) << 24
| u32::from(init[1]) << 16
| u32::from(init[2]) << 8
| u32::from(init[3]);
Ok(())
}
pub fn decode_bit(&mut self, model: &mut ArithmeticBitModel) -> io::Result<bool> {
let x = model.zero_probability() * (self.length >> BIT_LENGTH_SHIFT);
let one = self.value >= x;
if one {
self.value -= x;
self.length -= x;
} else {
self.length = x;
}
if self.length < AC_MIN_LENGTH {
self.renorm()?;
}
model.observe_bit(one);
Ok(one)
}
pub fn decode_symbol(&mut self, model: &mut ArithmeticSymbolModel) -> io::Result<u32> {
let mut y = self.length;
self.length >>= SYMBOL_LENGTH_SHIFT;
if self.length == 0 {
return Err(io::Error::new(
io::ErrorKind::UnexpectedEof,
"arithmetic decode interval collapsed",
));
}
let mut x = 0u32;
let mut symbol = 0u32;
let mut n = model.symbols();
let mut k = n >> 1;
loop {
let z = self.length * model.cdf_at(k);
if z > self.value {
n = k;
y = z;
} else {
symbol = k;
x = z;
}
k = (symbol + n) >> 1;
if k == symbol {
break;
}
}
self.value -= x;
self.length = y - x;
if self.length < AC_MIN_LENGTH {
self.renorm()?;
}
model.observe_symbol(symbol);
Ok(symbol)
}
pub fn read_bit(&mut self) -> io::Result<u32> {
self.length >>= 1;
let symbol = self.value / self.length;
self.value -= self.length * symbol;
if self.length < AC_MIN_LENGTH {
self.renorm()?;
}
Ok(symbol)
}
pub fn read_bits(&mut self, mut bits: u32) -> io::Result<u32> {
debug_assert!(bits > 0 && bits <= 32);
if bits > 19 {
let low = u32::from(self.read_short()?);
bits -= 16;
let high = self.read_bits(bits)? << 16;
Ok(high | low)
} else {
self.length >>= bits;
let symbol = self.value / self.length;
self.value -= self.length * symbol;
if self.length < AC_MIN_LENGTH {
self.renorm()?;
}
Ok(symbol)
}
}
pub fn read_int(&mut self) -> io::Result<u32> {
self.read_bits(32)
}
pub fn read_int64(&mut self) -> io::Result<u64> {
let lo = u64::from(self.read_int()?);
let hi = u64::from(self.read_int()?);
Ok((hi << 32) | lo)
}
fn read_short(&mut self) -> io::Result<u16> {
self.length >>= 16;
let symbol = self.value / self.length;
self.value -= self.length * symbol;
if self.length < AC_MIN_LENGTH {
self.renorm()?;
}
Ok(symbol as u16)
}
fn renorm(&mut self) -> io::Result<()> {
while self.length < AC_MIN_LENGTH {
let mut b = [0u8; 1];
self.input.read_exact(&mut b)?;
self.value = (self.value << 8) | u32::from(b[0]);
self.length <<= 8;
}
Ok(())
}
pub fn get_ref(&self) -> &R {
&self.input
}
pub fn get_mut(&mut self) -> &mut R {
&mut self.input
}
pub fn into_inner(self) -> R {
self.input
}
}