use super::context::CabacContext;
use super::tables::RANGE_TAB_LPS;
pub struct CabacEngine {
low: u32,
range: u32,
bits_left: i32,
buffered_byte: u8,
num_buffered_bytes: u32,
output: Vec<u8>,
bin_counts: u32,
pub trace: Option<Vec<String>>,
pub trace_label: String,
}
impl CabacEngine {
pub fn new() -> Self {
Self {
low: 0,
range: 510,
bits_left: 23,
buffered_byte: 0xFF,
num_buffered_bytes: 0,
output: Vec::new(),
bin_counts: 0,
trace: None,
trace_label: String::new(),
}
}
pub fn bin_count(&self) -> u32 {
self.bin_counts
}
#[inline]
pub fn encode_decision(&mut self, bin: u8, ctx: &mut CabacContext) {
self.encode_decision_with_ctx_idx(bin, ctx, u32::MAX);
}
#[inline]
pub fn encode_decision_with_ctx_idx(
&mut self,
bin: u8,
ctx: &mut CabacContext,
ctx_idx: u32,
) {
let (pre_range, pre_low, pre_state, pre_mps) = (
self.range, self.low, ctx.p_state_idx(), ctx.val_mps(),
);
let p_state = ctx.p_state_idx() as usize;
let q_idx = ((self.range >> 6) & 3) as usize;
let range_lps = RANGE_TAB_LPS[p_state][q_idx] as u32;
self.range -= range_lps;
if bin != ctx.val_mps() {
self.low += self.range;
self.range = range_lps;
ctx.update_lps();
} else {
ctx.update_mps();
}
self.renormalize();
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"ENC {}: ctx={} pre_range=0x{:x} pre_low=0x{:x} p_state_pre={} val_mps_pre={} \
bin={} post_range=0x{:x} post_low=0x{:x} post_state={} post_mps={}",
self.trace_label, ctx_idx, pre_range, pre_low, pre_state, pre_mps, bin,
self.range, self.low, ctx.p_state_idx(), ctx.val_mps(),
));
let _ = (pre_range, pre_low, pre_state, pre_mps);
}
}
#[inline]
pub fn encode_bypass(&mut self, bin: u8) {
let pre_low = self.low;
self.low <<= 1;
if bin != 0 {
self.low = self.low.wrapping_add(self.range);
}
self.bits_left -= 1;
if self.bits_left < 12 {
self.write_out();
}
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"ENC {}: BYPASS pre_low=0x{:x} bin={} post_low=0x{:x}",
self.trace_label, pre_low, bin, self.low,
));
}
}
pub fn encode_terminate(&mut self, bin: u8) {
let pre_range = self.range;
let pre_low = self.low;
self.range -= 2;
if bin != 0 {
self.low = self.low.wrapping_add(self.range);
self.low <<= 7;
self.range = 2 << 7;
self.bits_left -= 7;
if self.bits_left < 12 {
self.write_out();
}
} else if self.range < 256 {
self.low <<= 1;
self.range <<= 1;
self.bits_left -= 1;
if self.bits_left < 12 {
self.write_out();
}
}
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"ENC {}: TERMINATE pre_range=0x{:x} pre_low=0x{:x} bin={} \
post_range=0x{:x} post_low=0x{:x}",
self.trace_label, pre_range, pre_low, bin, self.range, self.low,
));
}
}
pub fn finish(mut self) -> Vec<u8> {
let carry = if self.bits_left < 32 {
self.low >> (32 - self.bits_left as u32)
} else {
0
};
if carry != 0 {
if self.num_buffered_bytes > 0 {
self.output.push(self.buffered_byte.wrapping_add(1));
for _ in 1..self.num_buffered_bytes {
self.output.push(0x00);
}
}
self.low -= 1u32 << (32 - self.bits_left as u32);
} else {
if self.num_buffered_bytes > 0 {
self.output.push(self.buffered_byte);
}
for _ in 1..self.num_buffered_bytes {
self.output.push(0xFF);
}
}
let cabac_bits = (24i32 - self.bits_left).max(0) as u32;
let value = self.low >> 8;
let mut acc: u32 = 0;
let mut acc_bits: u32 = 0;
for i in (0..cabac_bits).rev() {
acc = (acc << 1) | ((value >> i) & 1);
acc_bits += 1;
if acc_bits == 8 {
self.output.push(acc as u8);
acc = 0;
acc_bits = 0;
}
}
acc = (acc << 1) | 1;
acc_bits += 1;
if acc_bits == 8 {
self.output.push(acc as u8);
acc = 0;
acc_bits = 0;
}
if acc_bits > 0 {
acc <<= 8 - acc_bits;
self.output.push(acc as u8);
}
self.output
}
#[inline]
fn renormalize(&mut self) {
while self.range < 256 {
self.low <<= 1;
self.range <<= 1;
self.bits_left -= 1;
}
if self.bits_left < 12 {
self.write_out();
}
}
fn write_out(&mut self) {
let lead_byte = (self.low >> (24 - self.bits_left as u32)) & 0x1FF;
self.bits_left += 8;
self.low &= 0xFFFFFFFFu32 >> self.bits_left as u32;
if lead_byte == 0xFF {
self.num_buffered_bytes += 1;
} else {
if self.num_buffered_bytes > 0 {
let carry = (lead_byte >> 8) as u8;
self.output.push(self.buffered_byte.wrapping_add(carry));
let fill = if carry != 0 { 0x00u8 } else { 0xFFu8 };
for _ in 1..self.num_buffered_bytes {
self.output.push(fill);
}
}
self.num_buffered_bytes = 1;
self.buffered_byte = (lead_byte & 0xFF) as u8;
}
}
}
impl Default for CabacEngine {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn initial_state_matches_spec() {
let eng = CabacEngine::new();
assert_eq!(eng.range, 510);
assert_eq!(eng.low, 0);
assert_eq!(eng.bin_counts, 0);
assert_eq!(eng.bits_left, 23);
}
#[test]
fn bin_count_advances_per_encode() {
let mut eng = CabacEngine::new();
let mut ctx = CabacContext::new(30, 0);
eng.encode_decision(0, &mut ctx);
eng.encode_bypass(1);
eng.encode_terminate(0);
assert_eq!(eng.bin_count(), 3);
}
#[test]
fn encode_mps_run_then_terminate_produces_bytes() {
let mut eng = CabacEngine::new();
let mut ctx = CabacContext::new(60, 0);
for _ in 0..100 {
eng.encode_decision(0, &mut ctx);
}
eng.encode_terminate(1);
let bytes = eng.finish();
assert!(
bytes.len() < 12,
"expected compression below 12 bytes, got {}",
bytes.len()
);
}
#[test]
fn encode_random_bins_then_terminate_produces_bytes() {
let mut eng = CabacEngine::new();
let mut ctx = CabacContext::new(0, 0);
for i in 0..64 {
eng.encode_decision((i & 1) as u8, &mut ctx);
}
eng.encode_terminate(1);
let bytes = eng.finish();
assert!(!bytes.is_empty());
assert_ne!(bytes.last().copied(), Some(0));
}
#[test]
fn bypass_bins_extend_output_linearly() {
let mut eng = CabacEngine::new();
for i in 0..16 {
eng.encode_bypass((i & 1) as u8);
}
eng.encode_terminate(1);
let bytes = eng.finish();
assert!(bytes.len() >= 2);
}
#[test]
fn terminate_zero_then_terminate_one_flushes_properly() {
let mut eng = CabacEngine::new();
let mut ctx = CabacContext::new(20, 0);
for _ in 0..3 {
eng.encode_decision(0, &mut ctx);
eng.encode_terminate(0);
}
eng.encode_decision(1, &mut ctx);
eng.encode_terminate(1);
let bytes = eng.finish();
assert!(!bytes.is_empty());
assert_ne!(*bytes.last().unwrap(), 0);
}
#[test]
fn single_bin_roundtrip_via_spec_decoder_pseudocode() {
let bins_in = [0u8, 1, 1, 0, 1];
let mut eng = CabacEngine::new();
let mut ctx = CabacContext::new(20, 0);
for &b in &bins_in {
eng.encode_decision(b, &mut ctx);
}
eng.encode_terminate(1);
let bytes = eng.finish();
let mut dec = TestDecoder::new(&bytes);
let mut ctx = CabacContext::new(20, 0);
let bins_out: Vec<u8> = (0..5).map(|_| dec.decode_decision(&mut ctx)).collect();
assert_eq!(bins_out, bins_in);
}
struct TestDecoder<'a> {
bytes: &'a [u8],
cod_i_offset: u32,
cod_i_range: u32,
byte_idx: usize,
bits_in_byte: u32,
bit_ptr: u32,
}
impl<'a> TestDecoder<'a> {
fn new(bytes: &'a [u8]) -> Self {
let mut d = Self {
bytes,
cod_i_offset: 0,
cod_i_range: 510,
byte_idx: 0,
bits_in_byte: 0,
bit_ptr: 0,
};
for _ in 0..9 {
let b = d.read_bit();
d.cod_i_offset = (d.cod_i_offset << 1) | b;
}
d
}
fn read_bit(&mut self) -> u32 {
if self.byte_idx >= self.bytes.len() {
return 0;
}
let byte = self.bytes[self.byte_idx];
let bit = (byte >> (7 - self.bit_ptr)) & 1;
self.bit_ptr += 1;
if self.bit_ptr == 8 {
self.bit_ptr = 0;
self.byte_idx += 1;
}
bit as u32
}
fn decode_decision(&mut self, ctx: &mut CabacContext) -> u8 {
let p_state = ctx.p_state_idx() as usize;
let q_idx = ((self.cod_i_range >> 6) & 3) as usize;
let range_lps = RANGE_TAB_LPS[p_state][q_idx] as u32;
self.cod_i_range -= range_lps;
let bin = if self.cod_i_offset >= self.cod_i_range {
self.cod_i_offset -= self.cod_i_range;
self.cod_i_range = range_lps;
let b = 1 ^ ctx.val_mps();
ctx.update_lps();
b
} else {
let b = ctx.val_mps();
ctx.update_mps();
b
};
while self.cod_i_range < 256 {
self.cod_i_range <<= 1;
self.cod_i_offset = (self.cod_i_offset << 1) | self.read_bit();
}
self.bits_in_byte = self.bits_in_byte.wrapping_add(0); bin
}
}
}