use crate::codec::h264::cabac::context::{initialize_contexts, CabacContext, CabacInitSlot};
use crate::codec::h264::cabac::neighbor::CabacNeighborContext;
use super::engine::{CabacDecodeEngine, DecodeError};
pub struct CabacDecoder<'a> {
pub engine: CabacDecodeEngine<'a>,
pub contexts: Box<[CabacContext; 1024]>,
pub neighbors: CabacNeighborContext,
}
impl<'a> CabacDecoder<'a> {
pub fn new_slice(
bytes: &'a [u8],
slot: CabacInitSlot,
slice_qp_y: i32,
mb_width: usize,
) -> Result<Self, DecodeError> {
let contexts = Box::new(initialize_contexts(slot, slice_qp_y));
let neighbors = CabacNeighborContext::new(mb_width, slot);
let engine = CabacDecodeEngine::new(bytes)?;
Ok(Self { engine, contexts, neighbors })
}
#[inline]
pub(crate) fn decode_dec(&mut self, ctx_idx: u32) -> Result<u8, DecodeError> {
let ctx = &mut self.contexts[ctx_idx as usize];
self.engine.decode_decision_with_ctx_idx(ctx, ctx_idx)
}
#[inline]
pub fn decode_bypass(&mut self) -> Result<u8, DecodeError> {
self.engine.decode_bypass()
}
#[inline]
pub fn decode_terminate(&mut self) -> Result<u8, DecodeError> {
self.engine.decode_terminate()
}
#[inline]
pub fn bin_count(&self) -> u32 {
self.engine.bin_count()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::cabac::CabacEncoder;
#[test]
fn new_slice_initializes_1024_contexts() {
let mut enc = CabacEncoder::new_slice(CabacInitSlot::ISI, 26, 4);
enc.engine.encode_terminate(1);
let bytes = enc.finish();
let dec = CabacDecoder::new_slice(&bytes, CabacInitSlot::ISI, 26, 4)
.expect("init");
assert_eq!(dec.contexts.len(), 1024);
assert_eq!(dec.contexts[276].p_state_idx(), 63);
assert_eq!(dec.contexts[276].val_mps(), 0);
}
#[test]
fn p_slice_initializes_with_correct_slot() {
let mut enc = CabacEncoder::new_slice(CabacInitSlot::PIdc1, 30, 8);
let enc_ctx_snapshot: Vec<(u8, u8)> = enc
.contexts
.iter()
.map(|c| (c.p_state_idx(), c.val_mps()))
.collect();
enc.engine.encode_terminate(1);
let bytes = enc.finish();
let dec = CabacDecoder::new_slice(&bytes, CabacInitSlot::PIdc1, 30, 8)
.expect("init");
for i in 0..1024 {
assert_eq!(
(dec.contexts[i].p_state_idx(), dec.contexts[i].val_mps()),
enc_ctx_snapshot[i],
"ctxIdx {i} initial state mismatch",
);
}
}
}