#[rustfmt::skip]
const TRANS_IDX_LPS: [u8; 64] = [
0, 0, 1, 2, 2, 4, 4, 5,
6, 7, 8, 9, 9, 11, 11, 12,
13, 13, 15, 15, 16, 16, 18, 18,
19, 19, 21, 21, 22, 22, 23, 24,
24, 25, 26, 26, 27, 27, 28, 29,
29, 30, 30, 30, 31, 32, 32, 33,
33, 33, 34, 34, 35, 35, 35, 36,
36, 36, 37, 37, 37, 38, 38, 63,
];
#[rustfmt::skip]
const TRANS_IDX_MPS: [u8; 64] = [
1, 2, 3, 4, 5, 6, 7, 8,
9, 10, 11, 12, 13, 14, 15, 16,
17, 18, 19, 20, 21, 22, 23, 24,
25, 26, 27, 28, 29, 30, 31, 32,
33, 34, 35, 36, 37, 38, 39, 40,
41, 42, 43, 44, 45, 46, 47, 48,
49, 50, 51, 52, 53, 54, 55, 56,
57, 58, 59, 60, 61, 62, 62, 63,
];
#[rustfmt::skip]
const RANGE_TAB_LPS: [[u8; 4]; 64] = [
[128, 176, 208, 240],
[128, 167, 197, 227],
[128, 158, 187, 216],
[123, 150, 178, 205],
[116, 142, 169, 195],
[111, 135, 160, 185],
[105, 128, 152, 175],
[100, 122, 144, 166],
[ 95, 116, 137, 158],
[ 90, 110, 130, 150],
[ 85, 104, 123, 142],
[ 81, 99, 117, 135],
[ 77, 94, 111, 128],
[ 73, 89, 105, 122],
[ 69, 85, 100, 116],
[ 66, 80, 95, 110],
[ 62, 76, 90, 104],
[ 59, 72, 86, 99],
[ 56, 69, 81, 94],
[ 53, 65, 77, 89],
[ 51, 62, 73, 85],
[ 48, 59, 69, 80],
[ 46, 56, 66, 76],
[ 43, 53, 63, 72],
[ 41, 50, 59, 69],
[ 39, 48, 56, 65],
[ 37, 45, 54, 62],
[ 35, 43, 51, 59],
[ 33, 41, 48, 56],
[ 32, 39, 46, 53],
[ 30, 37, 43, 50],
[ 29, 35, 41, 48],
[ 27, 33, 39, 45],
[ 26, 31, 37, 43],
[ 24, 30, 35, 41],
[ 23, 28, 33, 39],
[ 22, 27, 32, 37],
[ 21, 26, 30, 35],
[ 20, 24, 29, 33],
[ 19, 23, 27, 31],
[ 18, 22, 26, 30],
[ 17, 21, 25, 28],
[ 16, 20, 23, 27],
[ 15, 19, 22, 25],
[ 14, 18, 21, 24],
[ 14, 17, 20, 23],
[ 13, 16, 19, 22],
[ 12, 15, 18, 21],
[ 12, 14, 17, 20],
[ 11, 14, 16, 19],
[ 11, 13, 15, 18],
[ 10, 12, 15, 17],
[ 10, 12, 14, 16],
[ 9, 11, 13, 15],
[ 9, 11, 12, 14],
[ 8, 10, 12, 14],
[ 8, 9, 11, 13],
[ 7, 9, 11, 12],
[ 7, 9, 10, 12],
[ 7, 8, 10, 11],
[ 6, 8, 9, 11],
[ 6, 7, 9, 10],
[ 6, 7, 8, 9],
[ 2, 2, 2, 2],
];
static CABAC_TRANS_MPS: [u8; 128] = {
let mut t = [0u8; 128];
let mut mps = 0u8;
while mps < 2 {
let mut s = 0u8;
while s < 64 {
let idx = s | (mps << 6);
let new_state = TRANS_IDX_MPS[s as usize];
t[idx as usize] = new_state | (mps << 6);
s += 1;
}
mps += 1;
}
t
};
static CABAC_TRANS_LPS: [u8; 128] = {
let mut t = [0u8; 128];
let mut mps = 0u8;
while mps < 2 {
let mut s = 0u8;
while s < 64 {
let idx = s | (mps << 6);
let new_state = TRANS_IDX_LPS[s as usize];
let new_mps = if s == 0 { 1 - mps } else { mps };
t[idx as usize] = new_state | (new_mps << 6);
s += 1;
}
mps += 1;
}
t
};
#[derive(Debug, Clone)]
pub struct ContextModel {
pub state: u8,
pub mps: u8,
}
impl ContextModel {
pub fn new(init_value: u8) -> Self {
let mut ctx = ContextModel { state: 0, mps: 0 };
ctx.init(26, init_value);
ctx
}
pub fn init(&mut self, slice_qp: i32, init_value: u8) {
let slope = ((init_value >> 4) as i32) * 5 - 45;
let offset = (((init_value & 15) as i32) << 3) - 16;
let init_state = ((slope * (slice_qp.clamp(0, 51) - 16)) >> 4) + offset;
let pre_ctx_state = init_state.clamp(1, 126);
if pre_ctx_state <= 63 {
self.state = (63 - pre_ctx_state) as u8;
self.mps = 0;
} else {
self.state = (pre_ctx_state - 64) as u8;
self.mps = 1;
}
}
#[inline(always)]
pub fn packed(&self) -> u8 {
self.state | (self.mps << 6)
}
#[inline(always)]
pub fn unpack(packed: u8) -> Self {
ContextModel {
state: packed & 63,
mps: (packed >> 6) & 1,
}
}
}
pub struct CabacDecoder<'a> {
data: &'a [u8],
offset: usize,
bit_buf: u32,
bits_left: u32,
range: u32,
value: u32,
}
impl<'a> CabacDecoder<'a> {
pub fn new(data: &'a [u8]) -> Self {
let mut dec = CabacDecoder {
data,
offset: 0,
bit_buf: 0,
bits_left: 0,
range: 510,
value: 0,
};
dec.value = dec.read_bits(9);
dec
}
#[inline(always)]
#[allow(unsafe_code)]
fn refill(&mut self) {
unsafe {
let len = self.data.len();
let ptr = self.data.as_ptr();
while self.bits_left <= 24 && self.offset < len {
self.bit_buf = (self.bit_buf << 8) | *ptr.add(self.offset) as u32;
self.offset += 1;
self.bits_left += 8;
}
}
}
#[inline(always)]
fn read_bit(&mut self) -> u32 {
if self.bits_left == 0 {
self.refill();
if self.bits_left == 0 {
return 0;
}
}
self.bits_left -= 1;
(self.bit_buf >> self.bits_left) & 1
}
#[inline(always)]
fn read_bits(&mut self, n: u32) -> u32 {
if self.bits_left < n {
self.refill();
}
if self.bits_left >= n {
self.bits_left -= n;
(self.bit_buf >> self.bits_left) & ((1u32 << n) - 1)
} else {
let mut val = 0u32;
for _ in 0..n {
val = (val << 1) | self.read_bit();
}
val
}
}
#[inline(always)]
fn renormalize(&mut self) {
if self.range >= 256 {
return;
}
let shift = self.range.leading_zeros() - 23; self.range <<= shift;
self.value = (self.value << shift) | self.read_bits(shift);
}
#[inline(always)]
#[allow(unsafe_code)]
pub fn decode_decision(&mut self, ctx: &mut ContextModel) -> bool {
unsafe {
let state = ctx.state as usize;
let mps = ctx.mps;
let q_range_idx = ((self.range >> 6) & 3) as usize;
let range_lps = *RANGE_TAB_LPS
.get_unchecked(state)
.get_unchecked(q_range_idx) as u32;
self.range -= range_lps;
let is_lps = (self.value >= self.range) as u32;
let lps_mask = 0u32.wrapping_sub(is_lps);
self.value -= self.range & lps_mask;
self.range = (self.range & !lps_mask) | (range_lps & lps_mask);
let packed = (state | ((mps as usize) << 6)) & 127;
let trans_mps = *CABAC_TRANS_MPS.get_unchecked(packed);
let trans_lps = *CABAC_TRANS_LPS.get_unchecked(packed);
let new_packed = trans_mps ^ ((trans_mps ^ trans_lps) & (lps_mask as u8));
ctx.state = new_packed & 63;
ctx.mps = (new_packed >> 6) & 1;
let shift = self.range.leading_zeros().saturating_sub(23);
if shift > 0 {
self.range <<= shift;
self.value = (self.value << shift) | self.read_bits(shift);
}
(mps ^ (is_lps as u8)) != 0
}
}
#[inline(always)]
#[allow(unsafe_code)]
pub fn decode_bypass(&mut self) -> bool {
self.value = (self.value << 1) | self.read_bit();
let is_one = (self.value >= self.range) as u32;
self.value -= self.range & 0u32.wrapping_sub(is_one);
is_one != 0
}
#[inline(always)]
pub fn decode_terminate(&mut self) -> bool {
self.range -= 2;
if self.value >= self.range {
true
} else {
self.renormalize();
false
}
}
pub fn bytes_remaining(&self) -> usize {
let consumed = self.offset;
let partial = if self.bits_left > 0 { 1 } else { 0 };
self.data.len().saturating_sub(consumed) + partial
}
pub fn reinit_at_offset(&mut self, byte_offset: usize) {
self.offset = byte_offset.min(self.data.len());
self.bit_buf = 0;
self.bits_left = 0;
self.range = 510;
self.value = self.read_bits(9);
}
pub fn current_byte_offset(&self) -> usize {
let buffered_bytes = (self.bits_left / 8) as usize;
self.offset.saturating_sub(buffered_bytes)
}
pub fn byte_align(&mut self) {
let discard = self.bits_left % 8;
if discard > 0 {
self.bits_left -= discard;
}
}
#[inline]
pub fn decode_tr(&mut self, ctx: &mut [ContextModel], c_max: u32, c_rice_param: u32) -> u32 {
let prefix_max = c_max >> c_rice_param;
let mut prefix = 0u32;
while prefix < prefix_max {
let ctx_idx = prefix.min((ctx.len() as u32).saturating_sub(1)) as usize;
if self.decode_decision(&mut ctx[ctx_idx]) {
prefix += 1;
} else {
break;
}
}
let suffix = if c_rice_param > 0 {
self.decode_fl_bypass(c_rice_param)
} else {
0
};
let value = (prefix << c_rice_param) + suffix;
value.min(c_max)
}
#[inline(always)]
pub fn decode_fl(&mut self, n_bits: u32) -> u32 {
self.decode_fl_bypass(n_bits)
}
#[inline(always)]
fn decode_fl_bypass(&mut self, n_bits: u32) -> u32 {
let mut val = 0u32;
for _ in 0..n_bits {
val = (val << 1) | u32::from(self.decode_bypass());
}
val
}
#[inline]
pub fn decode_unary(&mut self, ctx: &mut [ContextModel], max: u32) -> u32 {
let mut val = 0u32;
while val < max {
let ctx_idx = val.min((ctx.len() as u32).saturating_sub(1)) as usize;
if self.decode_decision(&mut ctx[ctx_idx]) {
val += 1;
} else {
break;
}
}
val
}
#[inline]
pub fn decode_eg(&mut self, k: u32) -> u32 {
let mut leading = 0u32;
while self.decode_bypass() {
leading += 1;
if leading > 31 {
break;
}
}
let suffix_len = leading + k;
let suffix = self.decode_fl_bypass(suffix_len);
if leading >= 32 {
return suffix;
}
((1u32 << leading) - 1).wrapping_shl(k).wrapping_add(suffix)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn context_init_default_qp() {
let ctx = ContextModel::new(154);
assert_eq!(ctx.state, 0);
assert_eq!(ctx.mps, 1);
}
#[test]
fn context_init_high_state() {
let ctx = ContextModel::new(255);
assert_eq!(ctx.state, 58);
assert_eq!(ctx.mps, 1);
}
#[test]
fn context_init_low_state() {
let ctx = ContextModel::new(0);
assert_eq!(ctx.state, 62);
assert_eq!(ctx.mps, 0);
}
#[test]
fn context_state_transitions() {
let mut ctx = ContextModel { state: 0, mps: 0 };
let old_mps = ctx.mps;
ctx.state = TRANS_IDX_MPS[ctx.state as usize];
assert_eq!(ctx.state, 1);
assert_eq!(ctx.mps, old_mps);
ctx.state = 0;
ctx.mps = 0;
if ctx.state == 0 {
ctx.mps = 1 - ctx.mps;
}
ctx.state = TRANS_IDX_LPS[ctx.state as usize];
assert_eq!(ctx.state, 0);
assert_eq!(ctx.mps, 1);
}
#[test]
fn context_init_reinit() {
let mut ctx = ContextModel::new(154);
let orig_state = ctx.state;
let orig_mps = ctx.mps;
ctx.init(26, 154);
assert_eq!(ctx.state, orig_state);
assert_eq!(ctx.mps, orig_mps);
ctx.init(40, 154);
assert_eq!(ctx.state, 0);
assert_eq!(ctx.mps, 1);
}
#[test]
fn cabac_init() {
let data = [0x00, 0x00, 0x01, 0xFF];
let dec = CabacDecoder::new(&data);
assert_eq!(dec.range, 510);
assert_eq!(dec.value, 0);
}
#[test]
fn cabac_bypass_known_pattern() {
let data = [0x00u8, 0x59, 0x00];
let mut dec = CabacDecoder::new(&data);
assert_eq!(dec.value, 0);
assert_eq!(dec.range, 510);
let mut bits = Vec::new();
for _ in 0..8 {
bits.push(dec.decode_bypass());
}
assert_eq!(bits.len(), 8);
}
#[test]
fn cabac_terminate_no_end() {
let data = [0x00u8; 4];
let mut dec = CabacDecoder::new(&data);
assert!(!dec.decode_terminate());
}
#[test]
fn cabac_terminate_end() {
let data = [0xFE, 0x00];
let mut dec = CabacDecoder::new(&data);
assert_eq!(dec.value, 508);
assert!(dec.decode_terminate());
}
#[test]
fn cabac_decision_basic() {
let data = [0x00u8; 8];
let mut dec = CabacDecoder::new(&data);
let mut ctx = ContextModel::new(154);
let mut results = Vec::new();
for _ in 0..10 {
results.push(dec.decode_decision(&mut ctx));
}
assert_eq!(results.len(), 10);
}
#[test]
fn cabac_fl_decode() {
let data = [0x00u8; 4];
let mut dec = CabacDecoder::new(&data);
let val = dec.decode_fl(3);
assert_eq!(val, 0);
}
#[test]
fn cabac_unary_decode() {
let data = [0x00u8; 8];
let mut dec = CabacDecoder::new(&data);
let mut ctx = [ContextModel::new(154)];
let val = dec.decode_unary(&mut ctx, 5);
assert!(val <= 5);
}
#[test]
fn cabac_eg_decode() {
let data = [0x00u8; 8];
let mut dec = CabacDecoder::new(&data);
let val = dec.decode_eg(0);
assert_eq!(val, 0);
}
#[test]
fn range_tab_lps_sanity() {
assert_eq!(RANGE_TAB_LPS[0][0], 128);
assert_eq!(RANGE_TAB_LPS[0][3], 240);
assert_eq!(RANGE_TAB_LPS[63][0], 2);
assert_eq!(RANGE_TAB_LPS[63][3], 2);
assert_eq!(RANGE_TAB_LPS[12][0], 77);
assert_eq!(RANGE_TAB_LPS[12][3], 128);
}
#[test]
fn trans_tables_sanity() {
for i in 0..63 {
assert!(TRANS_IDX_MPS[i] >= i as u8);
}
for i in 0..63 {
assert!(TRANS_IDX_LPS[i] <= i as u8);
}
assert_eq!(TRANS_IDX_MPS[63], 63);
assert_eq!(TRANS_IDX_LPS[63], 63);
}
#[test]
fn cabac_decode_tr_basic() {
let data = [0x00u8; 8];
let mut dec = CabacDecoder::new(&data);
let mut ctx = [ContextModel::new(154), ContextModel::new(154)];
let val = dec.decode_tr(&mut ctx, 4, 0);
assert!(val <= 4);
}
#[test]
fn cabac_bypass_long_sequence() {
let data: Vec<u8> = (0..32).collect();
let mut dec = CabacDecoder::new(&data);
let mut count_true = 0u32;
for _ in 0..100 {
if dec.decode_bypass() {
count_true += 1;
}
}
assert!(count_true <= 100);
}
}