use crate::cabac_tables::{HEVC_CONTEXTS, INIT_VALUES, LPS_RANGE, MLPS_STATE, NORM_SHIFT};
use crate::slice::SliceType;
const CABAC_BITS: u32 = 16;
const CABAC_MASK: u32 = (1 << CABAC_BITS) - 1;
#[inline]
fn norm_shift(range: u32) -> u32 {
NORM_SHIFT[range as usize] as u32
}
#[inline]
fn lps_range_lookup(range: u32, state: u8) -> u32 {
LPS_RANGE[2 * (range & 0xC0) as usize + state as usize] as u32
}
pub struct CabacReader<'a> {
low: u32,
range: u32,
data: &'a [u8],
pos: usize,
}
impl<'a> CabacReader<'a> {
pub fn new(data: &'a [u8], byte_offset: usize) -> Self {
assert!(
byte_offset + 2 <= data.len(),
"CABAC init needs at least 2 bytes"
);
let mut low: u32 = (data[byte_offset] as u32) << 18;
low = low.wrapping_add((data[byte_offset + 1] as u32) << 10);
low = low.wrapping_add(1 << 9);
CabacReader {
low,
range: 0x1FE,
data,
pos: byte_offset + 2,
}
}
#[inline]
fn refill2(&mut self) {
let i = self.low.trailing_zeros().wrapping_sub(CABAC_BITS);
let b0 = if self.pos < self.data.len() {
self.data[self.pos]
} else {
0
};
let b1 = if self.pos + 1 < self.data.len() {
self.data[self.pos + 1]
} else {
0
};
let x = (b0 as u32) << 9 | (b1 as u32) << 1;
let x = x.wrapping_sub(CABAC_MASK);
self.low = self.low.wrapping_add(x << i);
self.pos += 2;
}
#[inline]
fn refill(&mut self) {
let b0 = if self.pos < self.data.len() {
self.data[self.pos]
} else {
0
};
let b1 = if self.pos + 1 < self.data.len() {
self.data[self.pos + 1]
} else {
0
};
self.low = self.low.wrapping_add((b0 as u32) << 9);
self.low = self.low.wrapping_add((b1 as u32) << 1);
self.low = self.low.wrapping_sub(CABAC_MASK);
self.pos += 2;
}
#[inline]
pub fn decode_bin(&mut self, state: &mut u8) -> u32 {
let s = *state;
let range_lps = lps_range_lookup(self.range, s);
self.range -= range_lps;
let lps_mask =
(((self.range << (CABAC_BITS + 1)).wrapping_sub(self.low)) as i32 >> 31) as u32;
self.low = self
.low
.wrapping_sub((self.range << (CABAC_BITS + 1)) & lps_mask);
self.range = self
.range
.wrapping_add(range_lps.wrapping_sub(self.range) & lps_mask);
let s_signed = (s as i32) ^ (lps_mask as i32);
*state = MLPS_STATE[(128 + s_signed) as usize];
let bit = (s_signed & 1) as u32;
let shift = norm_shift(self.range);
self.range <<= shift;
self.low = self.low.wrapping_shl(shift);
if self.low & CABAC_MASK == 0 {
self.refill2();
}
bit
}
#[inline]
pub fn decode_bypass(&mut self) -> u32 {
self.low = self.low.wrapping_add(self.low);
if self.low & CABAC_MASK == 0 {
self.refill();
}
let range = self.range << (CABAC_BITS + 1);
if self.low < range {
0
} else {
self.low = self.low.wrapping_sub(range);
1
}
}
pub fn decode_bypass_bits(&mut self, n: u8) -> u32 {
let mut val = 0u32;
for _ in 0..n {
val = (val << 1) | self.decode_bypass();
}
val
}
pub fn decode_terminate(&mut self) -> u32 {
self.range -= 2;
if self.low < self.range << (CABAC_BITS + 1) {
let shift = (self.range.wrapping_sub(0x100)) >> 31;
self.range <<= shift;
self.low = self.low.wrapping_shl(shift);
if self.low & CABAC_MASK == 0 {
self.refill();
}
0
} else {
1
}
}
pub fn position(&self) -> usize {
self.pos
}
pub fn rbsp(&self) -> &[u8] {
self.data
}
pub fn pcm_byte_position(&self) -> usize {
let mut ptr = self.pos;
if self.low & 0x1 != 0 {
ptr -= 1;
}
if self.low & 0x1FF != 0 {
ptr -= 1;
}
ptr
}
pub fn reinit_at(&mut self, byte_offset: usize) {
assert!(
byte_offset + 2 <= self.data.len(),
"CABAC reinit needs at least 2 bytes"
);
let mut low: u32 = (self.data[byte_offset] as u32) << 18;
low = low.wrapping_add((self.data[byte_offset + 1] as u32) << 10);
low = low.wrapping_add(1 << 9);
self.low = low;
self.range = 0x1FE;
self.pos = byte_offset + 2;
}
}
pub fn init_state(init_value: u8, slice_qp: i32) -> u8 {
let slope_idx = (init_value >> 4) as i32;
let offset_idx = (init_value & 0x0F) as i32;
let m = slope_idx * 5 - 45;
let n = (offset_idx << 3) - 16;
let qp = slice_qp.clamp(0, 51);
let mut pre = 2 * (((m * qp) >> 4) + n) - 127;
pre ^= pre >> 31;
if pre > 124 {
pre = 124 + (pre & 1);
}
pre as u8
}
pub fn init_states(init_values: &[u8], slice_qp: i32, state: &mut [u8]) {
debug_assert_eq!(init_values.len(), state.len());
for (i, &iv) in init_values.iter().enumerate() {
state[i] = init_state(iv, slice_qp);
}
}
pub fn init_type_for_slice(slice_type: SliceType, cabac_init_flag: bool) -> usize {
let base = 2 - slice_type as i32;
let it = if cabac_init_flag && slice_type != SliceType::I {
base ^ 3
} else {
base
};
it as usize
}
pub struct CabacContexts {
pub state: [u8; HEVC_CONTEXTS],
}
impl CabacContexts {
pub fn init(slice_qp: i32, slice_type: SliceType, cabac_init_flag: bool) -> Self {
let init_type = init_type_for_slice(slice_type, cabac_init_flag);
let row = &INIT_VALUES[init_type];
let mut state = [0u8; HEVC_CONTEXTS];
for i in 0..HEVC_CONTEXTS {
state[i] = init_state(row[i], slice_qp);
}
Self { state }
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_decode_bypass_round_trip() {
let data = [0xB4, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
let mut cabac = CabacReader::new(&data, 0);
let expected = [1, 0, 1, 1, 0, 1, 0, 0];
for &b in &expected {
assert_eq!(cabac.decode_bypass(), b, "bypass bin mismatch");
}
}
#[test]
fn test_decode_bypass_bits_matches_loop() {
let data = [0xCA, 0xFE, 0x12, 0x34, 0x56, 0x78, 0x9A, 0xBC];
let mut a = CabacReader::new(&data, 0);
let mut b = CabacReader::new(&data, 0);
let n = 12;
let v1 = a.decode_bypass_bits(n);
let mut v2 = 0u32;
for _ in 0..n {
v2 = (v2 << 1) | b.decode_bypass();
}
assert_eq!(v1, v2);
assert_eq!(v1 >> (32 - n), 0); }
#[test]
fn test_init_state_in_range() {
for &iv in &[0u8, 0x10, 0x77, 0x88, 0xC9, 0xFE, 0xFF] {
for &qp in &[-10i32, 0, 1, 26, 51, 60] {
let s = init_state(iv, qp);
assert!(s <= 125, "state {} out of range for iv={} qp={}", s, iv, qp);
}
}
}
#[test]
fn test_init_known_states_at_qp25() {
use crate::cabac_tables::ctx;
let ctxs = CabacContexts::init(25, SliceType::I, false);
assert_eq!(ctxs.state[ctx::SPLIT_CODING_UNIT_FLAG], 1);
assert_eq!(ctxs.state[ctx::PREV_INTRA_LUMA_PRED_FLAG], 0);
assert_eq!(ctxs.state[ctx::CBF_LUMA], 33);
assert_eq!(ctxs.state[ctx::INTRA_CHROMA_PRED_MODE], 12);
}
#[test]
fn test_init_type_for_slice() {
assert_eq!(init_type_for_slice(SliceType::I, false), 0);
assert_eq!(init_type_for_slice(SliceType::I, true), 0);
assert_eq!(init_type_for_slice(SliceType::P, false), 1);
assert_eq!(init_type_for_slice(SliceType::P, true), 2);
assert_eq!(init_type_for_slice(SliceType::B, false), 2);
assert_eq!(init_type_for_slice(SliceType::B, true), 1);
}
#[test]
fn test_init_state_matches_spec_formula() {
for iv in 0u8..=255 {
for &qp in &[0, 13, 26, 37, 51] {
let slope_idx = (iv >> 4) as i32;
let offset_idx = (iv & 0x0F) as i32;
let m = slope_idx * 5 - 45;
let n = (offset_idx << 3) - 16;
let pre = (((m * qp) >> 4) + n).clamp(1, 126);
let (p_state_idx, val_mps): (i32, i32) = if pre <= 63 {
(63 - pre, 0)
} else {
(pre - 64, 1)
};
let expected_packed = (2 * p_state_idx + val_mps) as u8;
let actual = init_state(iv, qp);
assert_eq!(
actual, expected_packed,
"init_state(iv={iv:#04x}, qp={qp}) packed mismatch: got {actual}, want {expected_packed}",
);
}
}
}
#[test]
fn test_reinit_at_matches_fresh_new() {
let data = [0xDE, 0xAD, 0xBE, 0xEF, 0x01, 0x23, 0x45, 0x67];
let mut a = CabacReader::new(&data, 0);
let _ = a.decode_bypass_bits(4);
a.reinit_at(4);
let b = CabacReader::new(&data, 4);
assert_eq!(a.low, b.low);
assert_eq!(a.range, b.range);
assert_eq!(a.pos, b.pos);
}
}