use super::rc_qs_model::RCQsModel;
pub struct RangeDecoder<'a> {
input: &'a [u8],
pos: usize,
low: u32,
range: u32,
code: u32,
error: bool,
}
impl<'a> RangeDecoder<'a> {
pub fn new(input: &'a [u8]) -> Self {
Self {
input,
pos: 0,
low: 0,
range: 0xFFFFFFFF,
code: 0,
error: false,
}
}
pub fn error(&self) -> bool {
self.error
}
pub fn bytes_read(&self) -> usize {
self.pos
}
pub fn init(&mut self) {
self.error = false;
self.get(4);
}
#[inline]
pub fn decode_bit(&mut self) -> bool {
self.range >>= 1;
let s = self.code >= self.low.wrapping_add(self.range);
if s {
self.low = self.low.wrapping_add(self.range);
}
self.normalize();
s
}
#[inline]
pub fn decode_with_model(&mut self, model: &mut RCQsModel) -> u32 {
model.normalize(&mut self.range);
let mut l = self.code.wrapping_sub(self.low) / self.range;
let mut r = 0u32;
let s = model.decode(&mut l, &mut r);
self.low = self.low.wrapping_add(self.range.wrapping_mul(l));
self.range = self.range.wrapping_mul(r);
self.normalize();
s
}
#[inline]
pub fn decode_uint(&mut self, n: i32) -> u32 {
if n <= 0 {
return 0;
}
let mut s = 0u32;
let mut m = 0;
let mut n = n;
while n > 16 {
s += self.decode_shift(16) << m;
m += 16;
n -= 16;
}
(self.decode_shift(n) << m) + s
}
pub fn decode_ulong(&mut self, n: i32) -> u64 {
if n <= 0 {
return 0;
}
let mut s = 0u64;
let mut m = 0;
let mut n = n;
while n > 16 {
s += (self.decode_shift(16) as u64) << m;
m += 16;
n -= 16;
}
((self.decode_shift(n) as u64) << m) + s
}
#[inline]
fn decode_shift(&mut self, n: i32) -> u32 {
self.range >>= n as u32;
let s = self.code.wrapping_sub(self.low) / self.range;
self.low = self.low.wrapping_add(self.range.wrapping_mul(s));
self.normalize();
s
}
#[inline]
fn normalize(&mut self) {
while ((self.low ^ self.low.wrapping_add(self.range)) >> 24) == 0 {
self.get(1);
self.range <<= 8;
}
if (self.range >> 16) == 0 {
self.get(2);
self.range = 0u32.wrapping_sub(self.low);
}
}
#[inline]
fn get(&mut self, n: i32) {
for _ in 0..n {
self.code <<= 8;
self.code |= self.get_byte() as u32;
self.low <<= 8;
}
}
#[inline]
fn get_byte(&mut self) -> u8 {
if self.pos >= self.input.len() {
self.error = true;
return 0;
}
let b = self.input[self.pos];
self.pos += 1;
b
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn eof_sets_error() {
let data = [0u8; 4]; let mut dec = RangeDecoder::new(&data);
dec.init();
assert!(!dec.error());
for _ in 0..100 {
dec.decode_bit();
}
assert!(dec.error());
}
}