use crate::codec::h264::cabac::context::CabacContext;
use crate::codec::h264::cabac::tables::RANGE_TAB_LPS;
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum DecodeError {
UnexpectedEof,
}
pub struct CabacDecodeEngine<'a> {
bytes: &'a [u8],
byte_idx: usize,
bit_ptr: u32,
cod_i_offset: u32,
cod_i_range: u32,
bin_counts: u32,
pub trace: Option<Vec<String>>,
pub trace_label: String,
}
impl<'a> CabacDecodeEngine<'a> {
pub fn new(bytes: &'a [u8]) -> Result<Self, DecodeError> {
let mut eng = Self {
bytes,
byte_idx: 0,
bit_ptr: 0,
cod_i_offset: 0,
cod_i_range: 510,
bin_counts: 0,
trace: None,
trace_label: String::new(),
};
for _ in 0..9 {
let b = eng.read_bit()?;
eng.cod_i_offset = (eng.cod_i_offset << 1) | b;
}
Ok(eng)
}
#[inline]
pub fn bin_count(&self) -> u32 {
self.bin_counts
}
#[inline]
pub fn bytes_consumed(&self) -> usize {
self.byte_idx + (if self.bit_ptr > 0 { 1 } else { 0 })
}
#[inline]
pub fn decode_decision(&mut self, ctx: &mut CabacContext) -> Result<u8, DecodeError> {
self.decode_decision_with_ctx_idx(ctx, u32::MAX)
}
pub fn decode_decision_with_ctx_idx(
&mut self,
ctx: &mut CabacContext,
ctx_idx: u32,
) -> Result<u8, DecodeError> {
let pre_range = self.cod_i_range;
let pre_offset = self.cod_i_offset;
let pre_state = ctx.p_state_idx();
let pre_mps = ctx.val_mps();
let p_state = ctx.p_state_idx() as usize;
let q_idx = ((self.cod_i_range >> 6) & 3) as usize;
let range_lps = RANGE_TAB_LPS[p_state][q_idx] as u32;
self.cod_i_range -= range_lps;
let bin = if self.cod_i_offset >= self.cod_i_range {
self.cod_i_offset -= self.cod_i_range;
self.cod_i_range = range_lps;
let b = 1 ^ ctx.val_mps();
ctx.update_lps();
b
} else {
let b = ctx.val_mps();
ctx.update_mps();
b
};
self.renormalize_d()?;
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"DEC {}: ctx={} pre_range=0x{:x} pre_offset=0x{:x} p_state_pre={} val_mps_pre={} \
bin={} post_range=0x{:x} post_offset=0x{:x} post_state={} post_mps={}",
self.trace_label, ctx_idx, pre_range, pre_offset, pre_state, pre_mps, bin,
self.cod_i_range, self.cod_i_offset, ctx.p_state_idx(), ctx.val_mps(),
));
}
Ok(bin)
}
#[inline]
pub fn decode_bypass(&mut self) -> Result<u8, DecodeError> {
let pre_offset = self.cod_i_offset;
self.cod_i_offset = (self.cod_i_offset << 1) | self.read_bit()?;
let bin = if self.cod_i_offset >= self.cod_i_range {
self.cod_i_offset -= self.cod_i_range;
1
} else {
0
};
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"DEC {}: BYPASS pre_offset=0x{:x} bin={} post_offset=0x{:x}",
self.trace_label, pre_offset, bin, self.cod_i_offset,
));
}
Ok(bin)
}
pub fn decode_terminate(&mut self) -> Result<u8, DecodeError> {
let pre_range = self.cod_i_range;
let pre_offset = self.cod_i_offset;
self.cod_i_range -= 2;
let bin = if self.cod_i_offset >= self.cod_i_range {
1
} else {
self.renormalize_d()?;
0
};
self.bin_counts += 1;
if let Some(tr) = self.trace.as_mut() {
tr.push(format!(
"DEC {}: TERMINATE pre_range=0x{:x} pre_offset=0x{:x} bin={} \
post_range=0x{:x} post_offset=0x{:x}",
self.trace_label, pre_range, pre_offset, bin,
self.cod_i_range, self.cod_i_offset,
));
}
Ok(bin)
}
#[inline]
fn read_bit(&mut self) -> Result<u32, DecodeError> {
if self.byte_idx >= self.bytes.len() {
return Ok(0);
}
let byte = self.bytes[self.byte_idx];
let bit = ((byte >> (7 - self.bit_ptr)) & 1) as u32;
self.bit_ptr += 1;
if self.bit_ptr == 8 {
self.bit_ptr = 0;
self.byte_idx += 1;
}
Ok(bit)
}
#[inline]
fn renormalize_d(&mut self) -> Result<(), DecodeError> {
while self.cod_i_range < 256 {
self.cod_i_range <<= 1;
self.cod_i_offset = (self.cod_i_offset << 1) | self.read_bit()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::cabac::engine::CabacEngine;
fn roundtrip_decisions(bins: &[u8], p_state: u8, val_mps: u8) {
let mut enc = CabacEngine::new();
let mut ctx_e = CabacContext::new(p_state, val_mps);
for &b in bins {
enc.encode_decision(b, &mut ctx_e);
}
enc.encode_terminate(1);
let bytes = enc.finish();
let mut dec = CabacDecodeEngine::new(&bytes).expect("init");
let mut ctx_d = CabacContext::new(p_state, val_mps);
let out: Vec<u8> = (0..bins.len())
.map(|_| dec.decode_decision(&mut ctx_d).expect("decode"))
.collect();
assert_eq!(out, bins, "regular-bin roundtrip");
}
#[test]
fn engine_init_reads_9_bits() {
let bytes = [0b0101_0101, 0b0101_0101, 0xFF];
let dec = CabacDecodeEngine::new(&bytes).expect("init");
assert_eq!(dec.cod_i_offset, 0b0_0101_0101_0);
assert_eq!(dec.cod_i_range, 510);
}
#[test]
fn roundtrip_single_bin_zero() {
roundtrip_decisions(&[0], 30, 0);
}
#[test]
fn roundtrip_single_bin_one() {
roundtrip_decisions(&[1], 30, 0);
}
#[test]
fn roundtrip_alternating_bins() {
let bins: Vec<u8> = (0..32).map(|i| (i & 1) as u8).collect();
roundtrip_decisions(&bins, 0, 0);
}
#[test]
fn roundtrip_biased_mps_run() {
let bins = vec![0u8; 100];
roundtrip_decisions(&bins, 60, 0);
}
#[test]
fn roundtrip_random_bins() {
let mut s: u32 = 0x1234_5678;
let bins: Vec<u8> = (0..64)
.map(|_| {
s = s.wrapping_mul(1664525).wrapping_add(1013904223);
(s & 1) as u8
})
.collect();
roundtrip_decisions(&bins, 20, 1);
}
#[test]
fn roundtrip_mixed_decision_and_bypass() {
let regular = [0u8, 1, 1, 0];
let bypass = [1u8, 0, 1, 1, 0, 0, 1, 0];
let regular2 = [1u8, 0, 1, 1];
let mut enc = CabacEngine::new();
let mut ctx = CabacContext::new(20, 0);
for &b in ®ular {
enc.encode_decision(b, &mut ctx);
}
for &b in &bypass {
enc.encode_bypass(b);
}
let mut ctx2 = CabacContext::new(20, 0);
for &b in ®ular2 {
enc.encode_decision(b, &mut ctx);
}
enc.encode_terminate(1);
let bytes = enc.finish();
let mut dec = CabacDecodeEngine::new(&bytes).expect("init");
let mut ctx_d = CabacContext::new(20, 0);
let out_regular: Vec<u8> = (0..regular.len())
.map(|_| dec.decode_decision(&mut ctx_d).unwrap())
.collect();
assert_eq!(out_regular, regular);
let out_bypass: Vec<u8> = (0..bypass.len())
.map(|_| dec.decode_bypass().unwrap())
.collect();
assert_eq!(out_bypass, bypass);
let _ = ctx2;
let out_regular2: Vec<u8> = (0..regular2.len())
.map(|_| dec.decode_decision(&mut ctx_d).unwrap())
.collect();
assert_eq!(out_regular2, regular2);
}
#[test]
fn roundtrip_terminate_zero_then_one() {
let mut enc = CabacEngine::new();
let mut ctx_e = CabacContext::new(20, 0);
for _ in 0..3 {
enc.encode_decision(0, &mut ctx_e);
enc.encode_terminate(0);
}
enc.encode_decision(1, &mut ctx_e);
enc.encode_terminate(1);
let bytes = enc.finish();
let mut dec = CabacDecodeEngine::new(&bytes).expect("init");
let mut ctx_d = CabacContext::new(20, 0);
for _ in 0..3 {
assert_eq!(dec.decode_decision(&mut ctx_d).unwrap(), 0);
assert_eq!(dec.decode_terminate().unwrap(), 0);
}
assert_eq!(dec.decode_decision(&mut ctx_d).unwrap(), 1);
assert_eq!(dec.decode_terminate().unwrap(), 1);
}
#[test]
fn bin_count_matches_encoder() {
let bins = [0u8, 1, 1, 0, 1, 0, 0, 1];
let mut enc = CabacEngine::new();
let mut ctx = CabacContext::new(30, 0);
for &b in &bins {
enc.encode_decision(b, &mut ctx);
}
enc.encode_bypass(1);
enc.encode_terminate(1);
let enc_bin_count = enc.bin_count();
let bytes = enc.finish();
let mut dec = CabacDecodeEngine::new(&bytes).expect("init");
let mut ctx_d = CabacContext::new(30, 0);
for _ in 0..bins.len() {
dec.decode_decision(&mut ctx_d).unwrap();
}
dec.decode_bypass().unwrap();
dec.decode_terminate().unwrap();
assert_eq!(enc_bin_count, 10);
assert_eq!(dec.bin_count(), 10);
}
#[test]
fn unexpected_eof_returns_zero_per_spec() {
let bytes: [u8; 0] = [];
let dec = CabacDecodeEngine::new(&bytes);
let dec = dec.expect("init succeeds with all-zero reads");
assert_eq!(dec.cod_i_offset, 0);
}
#[test]
fn decode_engine_separate_state_from_encoder() {
let bytes_a = {
let mut enc = CabacEngine::new();
let mut ctx = CabacContext::new(20, 0);
for _ in 0..10 {
enc.encode_decision(0, &mut ctx);
}
enc.encode_terminate(1);
enc.finish()
};
let bytes_b = {
let mut enc = CabacEngine::new();
let mut ctx = CabacContext::new(20, 0);
for _ in 0..10 {
enc.encode_decision(0, &mut ctx);
}
enc.encode_terminate(1);
enc.finish()
};
assert_eq!(bytes_a, bytes_b);
let mut d1 = CabacDecodeEngine::new(&bytes_a).unwrap();
let mut d2 = CabacDecodeEngine::new(&bytes_b).unwrap();
let mut c1 = CabacContext::new(20, 0);
let mut c2 = CabacContext::new(20, 0);
for _ in 0..10 {
assert_eq!(
d1.decode_decision(&mut c1).unwrap(),
d2.decode_decision(&mut c2).unwrap(),
);
}
}
}