use crate::Error;
#[derive(Debug)]
pub struct RangeDecoder<'a> {
buf: &'a [u8],
fwd: usize,
back: usize,
back_bits_avail: u32,
back_window: u32,
rem: u32,
rng: u32,
val: u32,
nbits_total: u32,
nbits_raw: u32,
error: bool,
}
impl<'a> RangeDecoder<'a> {
const RNG_MIN: u32 = 1 << 23;
pub fn new(buf: &'a [u8]) -> Self {
let b0 = buf.first().copied().unwrap_or(0) as u32;
let mut dec = Self {
buf,
fwd: if buf.is_empty() { 0 } else { 1 },
back: 0,
back_bits_avail: 0,
back_window: 0,
rem: b0 & 1,
rng: 128,
val: 127 - (b0 >> 1),
nbits_total: 9,
nbits_raw: 0,
error: false,
};
dec.normalize();
dec
}
pub fn has_error(&self) -> bool {
self.error
}
pub fn tell(&self) -> u32 {
let lg = 32 - self.rng.leading_zeros();
self.nbits_total
.saturating_sub(lg)
.saturating_add(self.nbits_raw)
}
pub fn tell_frac(&self) -> u32 {
let lg0 = 32 - self.rng.leading_zeros();
let mut r_q15 = self.rng >> (lg0 - 16);
let mut lg_frac = lg0;
for _ in 0..3 {
r_q15 = (r_q15 * r_q15) >> 15;
let bit = r_q15 >> 16;
lg_frac = 2 * lg_frac + bit;
if bit == 1 {
r_q15 >>= 1;
}
}
self.nbits_total
.saturating_mul(8)
.saturating_sub(lg_frac)
.saturating_add(self.nbits_raw.saturating_mul(8))
}
pub fn dec_bit_logp(&mut self, logp: u32) -> u32 {
let r = self.rng;
let d = self.val;
let s = r >> logp;
let bit = if d < s { 1 } else { 0 };
if bit == 1 {
self.rng = s;
} else {
self.val = d - s;
self.rng = r - s;
}
self.normalize();
bit
}
pub fn dec_bits(&mut self, bits: u32) -> u32 {
if bits == 0 {
return 0;
}
if bits > 32 {
self.error = true;
return 0;
}
let mut window = self.back_window;
let mut avail = self.back_bits_avail;
while avail < bits {
let byte = if self.back < self.buf.len() {
self.buf[self.buf.len() - 1 - self.back]
} else {
0
};
self.back = self.back.saturating_add(1);
window |= (byte as u32) << avail;
avail += 8;
}
let mask: u32 = if bits == 32 { !0 } else { (1u32 << bits) - 1 };
let result = window & mask;
self.back_window = window >> bits;
self.back_bits_avail = avail - bits;
self.nbits_raw += bits;
result
}
pub fn dec_uint(&mut self, ft: u32) -> Result<u32, Error> {
if ft <= 1 {
return Ok(0);
}
let ftb = 32 - (ft - 1).leading_zeros();
if ftb <= 8 {
let t = self.decode(ft);
self.dec_update(t, t + 1, ft);
Ok(t)
} else {
let split_bits = ftb - 8;
let top_ft = ((ft - 1) >> split_bits) + 1;
let t_hi = self.decode(top_ft);
self.dec_update(t_hi, t_hi + 1, top_ft);
let t_lo = self.dec_bits(split_bits);
let t = (t_hi << split_bits) | t_lo;
if t >= ft {
self.error = true;
Ok(ft - 1)
} else {
Ok(t)
}
}
}
pub fn decode_bin(&mut self, ftb: u32) -> u32 {
let s = self.rng >> ftb;
if s == 0 {
return 0;
}
let ft = 1u32 << ftb;
let approx = (self.val / s).saturating_add(1);
ft - approx.min(ft)
}
pub fn dec_icdf(&mut self, icdf: &[u8], ftb: u32) -> u32 {
let s = self.rng >> ftb;
let mut t = self.rng;
for (k, &cell) in icdf.iter().enumerate() {
let next = s.saturating_mul(cell as u32);
if self.val >= next {
self.val -= next;
self.rng = t - next;
self.normalize();
return k as u32;
}
t = next;
}
self.error = true;
0
}
fn decode(&mut self, ft: u32) -> u32 {
let s = self.rng / ft;
let approx = self.val / s + 1;
ft - approx.min(ft)
}
fn dec_update(&mut self, fl: u32, fh: u32, ft: u32) {
let s = self.rng / ft;
self.val -= s * (ft - fh);
if fl > 0 {
self.rng = s * (fh - fl);
} else {
self.rng -= s * (ft - fh);
}
self.normalize();
}
fn normalize(&mut self) {
while self.rng <= Self::RNG_MIN {
let byte = if self.fwd < self.buf.len() {
let b = self.buf[self.fwd];
self.fwd += 1;
b as u32
} else {
0
};
let sym = (self.rem << 7) | (byte >> 1);
self.rem = byte & 1;
self.rng <<= 8;
self.val = ((self.val << 8) + (255 - sym)) & 0x7FFF_FFFF;
self.nbits_total = self.nbits_total.saturating_add(8);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn init_empty_buffer_satisfies_invariant() {
let dec = RangeDecoder::new(&[]);
assert!(dec.rng > RangeDecoder::RNG_MIN);
assert!(!dec.has_error());
assert_eq!(dec.tell(), 1);
}
#[test]
fn init_nonempty_buffer_holds_invariant() {
let dec = RangeDecoder::new(&[0xAB, 0xCD, 0xEF, 0x12]);
assert!(dec.rng > RangeDecoder::RNG_MIN);
assert!(!dec.has_error());
assert!(dec.tell() >= 1);
}
#[test]
fn dec_bit_logp_bias_with_extreme_inputs() {
let mut dec0 = RangeDecoder::new(&[0u8; 16]);
let mut zero_count = 0;
for _ in 0..32 {
if dec0.dec_bit_logp(1) == 0 {
zero_count += 1;
}
}
assert!(!dec0.has_error());
assert!(
zero_count > 16,
"all-zero stream should be biased toward 0: zero_count={}",
zero_count
);
let mut dec1 = RangeDecoder::new(&[0xFFu8; 16]);
let mut one_count = 0;
for _ in 0..32 {
if dec1.dec_bit_logp(1) == 1 {
one_count += 1;
}
}
assert!(!dec1.has_error());
assert!(
one_count > 16,
"all-ones stream should be biased toward 1: one_count={}",
one_count
);
}
#[test]
fn dec_bits_lsb_first_from_end() {
let mut dec = RangeDecoder::new(&[0x00, 0x00, 0xA6]);
let lo = dec.dec_bits(4);
let hi = dec.dec_bits(4);
assert_eq!(lo, 0x6);
assert_eq!(hi, 0xA);
assert!(!dec.has_error());
}
#[test]
fn dec_bits_zero_past_end_of_frame() {
let mut dec = RangeDecoder::new(&[0xFF, 0xFF]);
for _ in 0..4 {
let v = dec.dec_bits(4);
assert_eq!(v, 0xF);
}
let pad = dec.dec_bits(8);
let _ = pad;
assert!(!dec.has_error());
}
#[test]
fn dec_uint_ft_one_is_zero_no_consumption() {
let mut dec = RangeDecoder::new(&[0x12, 0x34, 0x56]);
let before = dec.tell();
let v = dec.dec_uint(1).expect("ft=1 must succeed");
let after = dec.tell();
assert_eq!(v, 0);
assert_eq!(after, before);
}
#[test]
fn dec_uint_small_ft_in_range() {
let mut dec = RangeDecoder::new(&[0x42, 0x18, 0xC3, 0x7F]);
for _ in 0..8 {
let v = dec.dec_uint(200).expect("ft=200 must succeed");
assert!(v < 200, "v={} out of range", v);
}
assert!(!dec.has_error());
}
#[test]
fn dec_uint_large_ft_in_range() {
let buf: Vec<u8> = (0..64).collect();
let mut dec = RangeDecoder::new(&buf);
for _ in 0..8 {
let v = dec.dec_uint(1_000_000).expect("ft=1_000_000 must succeed");
assert!(v < 1_000_000, "v={} out of range", v);
}
}
#[test]
fn dec_uint_ft_zero_returns_zero() {
let mut dec = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
let before = dec.tell();
let v = dec.dec_uint(0).expect("ft=0 must succeed");
assert_eq!(v, 0);
assert_eq!(dec.tell(), before);
}
#[test]
fn tell_is_monotonic_across_decodes() {
let mut dec = RangeDecoder::new(&[0x55; 8]);
let mut prev = dec.tell();
for _ in 0..16 {
let _ = dec.dec_bit_logp(2);
let now = dec.tell();
assert!(now >= prev, "tell() went backwards: {} -> {}", prev, now);
prev = now;
}
}
#[test]
fn decode_bin_matches_generic_decode() {
for &ftb in &[1u32, 4, 8, 12, 15] {
let buf = [0x37u8, 0x91, 0xC4, 0x18, 0xA2, 0x5D, 0x6E, 0xFF];
let mut a = RangeDecoder::new(&buf);
let mut b = RangeDecoder::new(&buf);
let from_bin = a.decode_bin(ftb);
let from_generic = b.decode(1u32 << ftb);
assert_eq!(
from_bin, from_generic,
"decode_bin({ftb}) != decode(1<<{ftb})"
);
assert!(from_bin < (1u32 << ftb), "fs={from_bin} out of range");
}
}
#[test]
fn tell_frac_consistent_with_tell() {
let mut dec = RangeDecoder::new(&[0xA3, 0x7F, 0x10, 0x5C, 0xE8, 0x91, 0x42, 0xB7]);
assert_eq!(dec.tell(), 1);
for _ in 0..12 {
let whole = dec.tell();
let frac = dec.tell_frac();
let ceil_eighths = frac.div_ceil(8);
assert_eq!(
ceil_eighths, whole,
"tell()={whole} != ceil(tell_frac()={frac} / 8)={ceil_eighths}"
);
let _ = dec.dec_bit_logp(1);
let _ = dec.dec_bits(2);
}
}
#[test]
fn tell_frac_initial_within_one_bit() {
let dec = RangeDecoder::new(&[0xCC, 0xDD, 0xEE, 0xFF]);
let frac = dec.tell_frac();
assert!(
(1..=8).contains(&frac),
"tell_frac initial out of [1,8]: {frac}"
);
assert!(frac.div_ceil(8) == dec.tell());
}
#[test]
fn dec_icdf_matches_dec_bit_logp_for_binary() {
let buf = [0xDE, 0xAD, 0xBE, 0xEF, 0x10, 0x32, 0x54, 0x76];
let logp = 3u32;
let icdf = [1u8, 0];
let mut a = RangeDecoder::new(&buf);
let mut b = RangeDecoder::new(&buf);
for _ in 0..16 {
let via_logp = a.dec_bit_logp(logp);
let via_icdf = b.dec_icdf(&icdf, logp);
assert_eq!(
via_logp, via_icdf,
"dec_bit_logp({logp}) != dec_icdf({icdf:?}, {logp})"
);
}
assert!(!a.has_error() && !b.has_error());
}
#[test]
fn dec_icdf_uniform_returns_in_range() {
let icdf = [7u8, 6, 5, 4, 3, 2, 1, 0];
let mut dec = RangeDecoder::new(&[0x42, 0x18, 0xC3, 0x7F, 0x55, 0xAA, 0x33, 0xCC]);
for _ in 0..16 {
let k = dec.dec_icdf(&icdf, 3);
assert!(k < 8, "icdf uniform returned {k} out of [0, 8)");
}
assert!(!dec.has_error());
}
#[test]
fn dec_icdf_single_symbol_always_zero() {
let icdf = [0u8];
let mut dec = RangeDecoder::new(&[0x77, 0x33, 0x11, 0xAA]);
let before_tell = dec.tell();
for _ in 0..4 {
let k = dec.dec_icdf(&icdf, 3);
assert_eq!(k, 0);
}
assert!(dec.tell() >= before_tell);
assert!(!dec.has_error());
}
#[test]
fn tell_frac_is_monotonic() {
let mut dec = RangeDecoder::new(&[0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88]);
let icdf = [7u8, 6, 5, 4, 3, 2, 1, 0];
let mut prev = dec.tell_frac();
for i in 0..24 {
match i % 3 {
0 => {
let _ = dec.dec_bit_logp(2);
}
1 => {
let _ = dec.dec_icdf(&icdf, 3);
}
_ => {
let _ = dec.dec_bits(2);
}
}
let now = dec.tell_frac();
assert!(
now >= prev,
"tell_frac() went backwards: {} -> {}",
prev,
now
);
prev = now;
}
}
#[test]
fn dec_bits_zero_width_is_noop() {
let mut dec = RangeDecoder::new(&[0x12, 0x34, 0x56]);
let before = dec.tell();
let v = dec.dec_bits(0);
assert_eq!(v, 0);
assert_eq!(dec.tell(), before);
assert!(!dec.has_error());
}
#[test]
fn dec_bits_oversize_latches_error() {
let mut dec = RangeDecoder::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
let v = dec.dec_bits(33);
assert_eq!(v, 0);
assert!(dec.has_error());
}
}