use std::io::{BufReader, Read};
use crate::error::Result;
const RANGE_MAX: u16 = 0xFFFF;
const MSB_MASK: u16 = 0x8000;
const UNDERFLOW_MASK: u16 = 0x4000;
pub struct ArithmeticDecoder<R: Read> {
input: BufReader<R>,
high: u16,
low: u16,
code: u16,
byte_buffer: u8,
bits_remaining: u8,
}
impl<R: Read> ArithmeticDecoder<R> {
#[inline]
pub fn new(reader: R) -> Result<Self> {
let mut input = BufReader::new(reader);
let mut initial_bytes = [0u8; 2];
input.read_exact(&mut initial_bytes)?;
let code = u16::from_be_bytes(initial_bytes);
Ok(Self {
input,
high: RANGE_MAX,
low: 0,
code,
byte_buffer: 0,
bits_remaining: 0,
})
}
#[inline(always)]
fn read_bit(&mut self) -> u16 {
if self.bits_remaining == 0 {
let mut byte = [0u8; 1];
if self.input.read_exact(&mut byte).is_ok() {
self.byte_buffer = byte[0];
} else {
self.byte_buffer = 0;
}
self.bits_remaining = 8;
}
self.bits_remaining -= 1;
((self.byte_buffer >> self.bits_remaining) & 1) as u16
}
#[inline]
pub fn threshold_val(&self, total: u16) -> u16 {
let range = (self.high - self.low) as u32 + 1;
let offset = (self.code - self.low) as u32 + 1;
((offset * total as u32 - 1) / range) as u16
}
#[inline]
pub fn decode_update(&mut self, cum_low: u16, cum_high: u16, total: u16) -> Result<()> {
let range = (self.high - self.low) as u32 + 1;
let scale = total as u32;
let new_high = self.low.wrapping_add(((range * cum_high as u32 / scale) - 1) as u16);
let new_low = self.low.wrapping_add((range * cum_low as u32 / scale) as u16);
self.high = new_high;
self.low = new_low;
self.renormalize();
Ok(())
}
#[inline(always)]
fn renormalize(&mut self) {
loop {
if (self.high ^ self.low) & MSB_MASK == 0 {
self.shift_out_msb();
} else if (self.low & UNDERFLOW_MASK) != 0 && (self.high & UNDERFLOW_MASK) == 0 {
self.handle_underflow();
} else {
break;
}
}
}
#[inline(always)]
fn shift_out_msb(&mut self) {
self.low <<= 1;
self.high = (self.high << 1) | 1;
self.code = (self.code << 1) | self.read_bit();
}
#[inline(always)]
fn handle_underflow(&mut self) {
self.low = (self.low << 1) & 0x7FFF;
self.high = (self.high << 1) | 0x8001;
self.code = ((self.code << 1) ^ MSB_MASK) | self.read_bit();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_initialization() {
let data = vec![0xAB, 0xCD];
let decoder = ArithmeticDecoder::new(Cursor::new(data)).unwrap();
assert_eq!(decoder.low, 0);
assert_eq!(decoder.high, 0xFFFF);
assert_eq!(decoder.code, 0xABCD);
}
#[test]
fn test_threshold_midpoint() {
let data = vec![0x80, 0x00];
let decoder = ArithmeticDecoder::new(Cursor::new(data)).unwrap();
assert_eq!(decoder.threshold_val(256), 128);
}
#[test]
fn test_threshold_boundaries() {
let data = vec![0x00, 0x00];
let decoder = ArithmeticDecoder::new(Cursor::new(data)).unwrap();
assert_eq!(decoder.threshold_val(100), 0);
let data = vec![0xFF, 0xFF];
let decoder = ArithmeticDecoder::new(Cursor::new(data)).unwrap();
assert_eq!(decoder.threshold_val(100), 99);
}
}