use crate::error::CodecError;
pub const CDF_PROB_TOP: u16 = 32768;
pub const CDF_PROB_BITS: u32 = 15;
pub type CdfTable<const N: usize, const CTX: usize> = [[u16; N]; CTX];
pub const DC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
0, 20000, 32768, ]];
pub const AC_COEFF_SKIP_CDF: CdfTable<3, 1> = [[
0, 14000, 32768, ]];
pub const TRANSFORM_TYPE_CDF: CdfTable<17, 1> = [[
0, 26200, 27340, 28000, 28600, 29100, 29550, 29950, 30310, 30640, 30950, 31240, 31520, 31790, 32060, 32400, 32768, ]];
pub const PARTITION_TYPE_CDF: CdfTable<5, 1> = [[
0, 16000, 21000, 26000, 32768, ]];
#[derive(Debug, Clone)]
pub struct RangeCoder {
range: u32,
low: u32,
output: Vec<u8>,
input: Vec<u8>,
read_pos: usize,
code: u32,
decode_mode: bool,
}
impl RangeCoder {
const BOT: u32 = 1 << 16;
#[must_use]
pub fn new() -> Self {
Self {
range: u32::MAX,
low: 0,
output: Vec::new(),
input: Vec::new(),
read_pos: 0,
code: 0,
decode_mode: false,
}
}
pub fn init_from_slice(&mut self, data: &[u8]) -> Result<(), CodecError> {
if data.is_empty() {
return Err(CodecError::InvalidBitstream(
"RangeCoder: empty bitstream".into(),
));
}
self.decode_mode = true;
self.input = data.to_vec();
self.read_pos = 0;
self.range = u32::MAX;
self.code = 0;
for _ in 0..4 {
let b = self.read_byte_internal();
self.code = (self.code << 8) | u32::from(b);
}
Ok(())
}
#[must_use]
pub fn flush(mut self) -> Vec<u8> {
if !self.decode_mode {
for _ in 0..4 {
self.output.push((self.low >> 24) as u8);
self.low = self.low.wrapping_shl(8);
}
}
self.output
}
fn read_byte_internal(&mut self) -> u8 {
if self.read_pos < self.input.len() {
let b = self.input[self.read_pos];
self.read_pos += 1;
b
} else {
0x00 }
}
fn renormalize_encoder(&mut self) {
while self.range < Self::BOT {
self.output.push((self.low >> 24) as u8);
self.low = self.low.wrapping_shl(8);
self.range <<= 8;
}
}
fn renormalize_decoder(&mut self) {
while self.range < Self::BOT {
let b = self.read_byte_internal();
self.code = (self.code << 8) | u32::from(b);
self.range <<= 8;
}
}
fn encode_symbol_with_cdf(&mut self, sym: usize, cdf: &[u16]) -> Result<(), CodecError> {
let n_syms = cdf.len().saturating_sub(1);
if n_syms == 0 {
return Err(CodecError::InvalidParameter(
"CDF must have at least 2 entries".into(),
));
}
if sym >= n_syms {
return Err(CodecError::InvalidParameter(format!(
"symbol {sym} out of range for {n_syms}-symbol CDF"
)));
}
let total = u32::from(CDF_PROB_TOP);
let cum_lo = u32::from(cdf[sym]);
let cum_hi = u32::from(cdf[sym + 1]);
let step = self.range / total;
self.low = self.low.wrapping_add(step * cum_lo);
if sym + 1 < n_syms {
self.range = step * (cum_hi - cum_lo);
} else {
self.range -= step * cum_lo;
}
self.renormalize_encoder();
Ok(())
}
fn decode_symbol_with_cdf(&mut self, cdf: &[u16]) -> Result<u8, CodecError> {
let n_syms = cdf.len().saturating_sub(1);
if n_syms == 0 {
return Err(CodecError::InvalidBitstream(
"CDF must have at least 2 entries".into(),
));
}
let total = u32::from(CDF_PROB_TOP);
let step = self.range / total;
let mut sym = n_syms - 1;
for i in 0..n_syms {
if self.code < step * u32::from(cdf[i + 1]) {
sym = i;
break;
}
}
let cum_lo = u32::from(cdf[sym]);
self.code = self.code.wrapping_sub(step * cum_lo);
if sym + 1 < n_syms {
let cum_hi = u32::from(cdf[sym + 1]);
self.range = step * (cum_hi - cum_lo);
} else {
self.range -= step * cum_lo;
}
self.renormalize_decoder();
Ok(sym as u8)
}
}
pub fn encode_symbol_table<const N: usize, const CTX: usize>(
rc: &mut RangeCoder,
sym: u8,
ctx: usize,
table: &CdfTable<N, CTX>,
) -> Result<(), CodecError> {
if ctx >= CTX {
return Err(CodecError::InvalidParameter(format!(
"context {ctx} out of range (table has {CTX} contexts)"
)));
}
rc.encode_symbol_with_cdf(sym as usize, &table[ctx])
}
pub fn decode_symbol_table<const N: usize, const CTX: usize>(
rc: &mut RangeCoder,
ctx: usize,
table: &CdfTable<N, CTX>,
) -> Result<u8, CodecError> {
if ctx >= CTX {
return Err(CodecError::InvalidParameter(format!(
"context {ctx} out of range (table has {CTX} contexts)"
)));
}
rc.decode_symbol_with_cdf(&table[ctx])
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn dc_coeff_skip_cdf_valid() {
let row = &DC_COEFF_SKIP_CDF[0];
assert_eq!(row[0], 0, "first CDF entry must be 0");
assert_eq!(
*row.last().expect("non-empty row"),
CDF_PROB_TOP,
"last entry must be CDF_PROB_TOP"
);
for w in row.windows(2) {
assert!(w[0] <= w[1], "CDF must be monotonically non-decreasing");
}
}
#[test]
fn ac_coeff_skip_cdf_valid() {
let row = &AC_COEFF_SKIP_CDF[0];
assert_eq!(row[0], 0);
assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
for w in row.windows(2) {
assert!(w[0] <= w[1]);
}
}
#[test]
fn transform_type_cdf_valid() {
let row = &TRANSFORM_TYPE_CDF[0];
assert_eq!(row[0], 0);
assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
assert_eq!(row.len(), 17, "16 symbols + 1 sentinel");
for w in row.windows(2) {
assert!(w[0] <= w[1]);
}
}
#[test]
fn partition_type_cdf_valid() {
let row = &PARTITION_TYPE_CDF[0];
assert_eq!(row[0], 0);
assert_eq!(*row.last().expect("non-empty"), CDF_PROB_TOP);
assert_eq!(row.len(), 5, "4 symbols + 1 sentinel");
for w in row.windows(2) {
assert!(w[0] <= w[1]);
}
}
#[test]
fn range_coder_dc_skip_roundtrip_zero() {
let mut rc = RangeCoder::new();
encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 0");
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
assert_eq!(sym, 0, "should decode symbol 0");
}
#[test]
fn range_coder_dc_skip_roundtrip_one() {
let mut rc = RangeCoder::new();
encode_symbol_table(&mut rc, 1, 0, &DC_COEFF_SKIP_CDF).expect("encode sym 1");
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
let sym = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
assert_eq!(sym, 1, "should decode symbol 1");
}
#[test]
fn range_coder_partition_type_all_symbols() {
for sym_in in 0u8..4 {
let mut rc = RangeCoder::new();
encode_symbol_table(&mut rc, sym_in, 0, &PARTITION_TYPE_CDF).expect("encode partition");
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
let sym_out = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
assert_eq!(
sym_out, sym_in,
"partition type {sym_in} must survive round-trip"
);
}
}
#[test]
fn range_coder_transform_type_all_symbols() {
for sym_in in 0u8..16 {
let mut rc = RangeCoder::new();
encode_symbol_table(&mut rc, sym_in, 0, &TRANSFORM_TYPE_CDF).expect("encode tx type");
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
let sym_out = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
assert_eq!(
sym_out, sym_in,
"transform type {sym_in} must survive round-trip"
);
}
}
#[test]
fn range_coder_ac_skip_roundtrip() {
let symbols = [0u8, 1, 0, 0, 1, 1, 0, 1];
let mut rc = RangeCoder::new();
for &s in &symbols {
encode_symbol_table(&mut rc, s, 0, &AC_COEFF_SKIP_CDF).expect("encode");
}
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
for &expected in &symbols {
let got = decode_symbol_table(&mut dec, 0, &AC_COEFF_SKIP_CDF).expect("decode");
assert_eq!(got, expected);
}
}
#[test]
fn range_coder_sequence_mixed_tables() {
let dc_syms = [0u8, 1, 0];
let tx_syms = [0u8, 5, 15];
let pt_syms = [3u8, 0, 2];
let mut rc = RangeCoder::new();
for i in 0..3 {
encode_symbol_table(&mut rc, dc_syms[i], 0, &DC_COEFF_SKIP_CDF).expect("encode dc");
encode_symbol_table(&mut rc, tx_syms[i], 0, &TRANSFORM_TYPE_CDF).expect("encode tx");
encode_symbol_table(&mut rc, pt_syms[i], 0, &PARTITION_TYPE_CDF).expect("encode pt");
}
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
for i in 0..3 {
let dc = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode dc");
let tx = decode_symbol_table(&mut dec, 0, &TRANSFORM_TYPE_CDF).expect("decode tx");
let pt = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode pt");
assert_eq!(dc, dc_syms[i]);
assert_eq!(tx, tx_syms[i]);
assert_eq!(pt, pt_syms[i]);
}
}
#[test]
fn range_coder_long_sequence_dc_skip() {
let symbols: Vec<u8> = (0u8..100).map(|i| i % 2).collect();
let mut rc = RangeCoder::new();
for &s in &symbols {
encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
}
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
for (i, &expected) in symbols.iter().enumerate() {
let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
assert_eq!(got, expected, "mismatch at symbol {i}");
}
}
#[test]
fn range_coder_all_same_symbol_zero() {
let n = 50;
let mut rc = RangeCoder::new();
for _ in 0..n {
encode_symbol_table(&mut rc, 0, 0, &PARTITION_TYPE_CDF).expect("encode");
}
let bs = rc.flush();
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
for i in 0..n {
let got = decode_symbol_table(&mut dec, 0, &PARTITION_TYPE_CDF).expect("decode");
assert_eq!(got, 0u8, "all-zero sequence failed at index {i}");
}
}
#[test]
fn range_coder_context_out_of_range_error() {
let mut rc = RangeCoder::new();
let result = encode_symbol_table(&mut rc, 0, 1, &DC_COEFF_SKIP_CDF);
assert!(result.is_err(), "context 1 should be out of range");
}
#[test]
fn range_coder_symbol_out_of_range_error() {
let mut rc = RangeCoder::new();
let result = encode_symbol_table(&mut rc, 2, 0, &DC_COEFF_SKIP_CDF);
assert!(result.is_err(), "symbol 2 should be out of range");
}
#[test]
fn range_coder_empty_bitstream_error() {
let mut dec = RangeCoder::new();
let result = dec.init_from_slice(&[]);
assert!(result.is_err(), "empty bitstream must return error");
}
#[test]
fn range_coder_new_is_in_encode_mode() {
let rc = RangeCoder::new();
assert!(!rc.decode_mode, "new coder should be in encode mode");
assert_eq!(rc.output.len(), 0, "no output yet");
}
#[test]
fn range_coder_flush_produces_bytes() {
let mut rc = RangeCoder::new();
encode_symbol_table(&mut rc, 0, 0, &DC_COEFF_SKIP_CDF).expect("encode");
let bs = rc.flush();
assert!(!bs.is_empty(), "flush must produce at least one byte");
}
#[test]
fn benchmark_table_vs_scalar_estimate() {
let symbols: Vec<u8> = (0u8..200).cycle().take(10_000).map(|x| x % 2).collect();
let mut rc = RangeCoder::new();
for &s in &symbols {
encode_symbol_table(&mut rc, s, 0, &DC_COEFF_SKIP_CDF).expect("encode");
}
let bs = rc.flush();
assert!(
bs.len() <= 2500,
"compressed size {} should be ≤ 2500 bytes for {}-symbol DC skip stream",
bs.len(),
symbols.len()
);
let mut dec = RangeCoder::new();
dec.init_from_slice(&bs).expect("init");
for (i, &expected) in symbols.iter().enumerate() {
let got = decode_symbol_table(&mut dec, 0, &DC_COEFF_SKIP_CDF).expect("decode");
assert_eq!(got, expected, "bulk decode mismatch at index {i}");
}
}
}