use super::bitstream::{EpByteMap, RbspReader};
use super::tables::{decode_coeff_token, decode_vlc, run_before_table, total_zeros_table_for};
use super::H264Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EmbedDomain {
T1Sign,
LevelSuffixMag,
LevelSuffixSign,
MvdLsb,
}
#[derive(Debug, Clone)]
pub struct EmbeddablePosition {
pub raw_byte_offset: usize,
pub bit_offset: u8,
pub domain: EmbedDomain,
pub scan_pos: u8,
pub coeff_value: i32,
pub ep_conflict: bool,
pub block_idx: u32,
pub frame_idx: u16,
pub mb_idx: u32,
}
#[derive(Debug, Clone)]
pub struct CavlcBlock {
pub total_coeffs: u8,
pub trailing_ones: u8,
pub coeffs: [i32; 16],
}
pub fn decode_cavlc_block(
reader: &mut RbspReader<'_>,
nc: i8,
ep_map: &EpByteMap,
raw_data: &[u8],
max_coeffs: u8,
) -> Result<(CavlcBlock, Vec<EmbeddablePosition>), H264Error> {
let mut positions = Vec::new();
let (total_coeffs, trailing_ones) = decode_coeff_token(reader, nc)?;
if total_coeffs == 0 {
return Ok((
CavlcBlock {
total_coeffs: 0,
trailing_ones: 0,
coeffs: [0; 16],
},
positions,
));
}
let mut levels = vec![0i32; total_coeffs as usize];
for i in 0..trailing_ones as usize {
let rbsp_byte = reader.byte_pos();
let rbsp_bit = reader.bit_pos();
let sign_bit = reader.read_bit()?;
levels[i] = if sign_bit { -1 } else { 1 };
if rbsp_byte < ep_map.rbsp_to_raw.len() {
let raw_byte = ep_map.rbsp_to_raw[rbsp_byte];
let ep_conflict = check_ep_conflict(raw_data, raw_byte, rbsp_bit);
positions.push(EmbeddablePosition {
raw_byte_offset: raw_byte,
bit_offset: rbsp_bit,
domain: EmbedDomain::T1Sign,
scan_pos: 0, coeff_value: levels[i],
ep_conflict,
block_idx: 0, frame_idx: 0, mb_idx: 0, });
}
}
let mut suffix_length: u8 = if total_coeffs > 10 && trailing_ones < 3 {
1
} else {
0
};
for i in trailing_ones as usize..total_coeffs as usize {
let mut level_prefix = 0u32;
loop {
if reader.read_bit()? {
break;
}
level_prefix += 1;
if level_prefix > 28 {
return Err(H264Error::CavlcError(
"level_prefix overflow (>28)".into(),
));
}
}
let level_suffix_size = if level_prefix == 14 && suffix_length == 0 {
4
} else if level_prefix >= 15 {
(level_prefix - 3) as u8
} else {
suffix_length
};
let level_suffix = if level_suffix_size > 0 {
let suffix_start_byte = reader.byte_pos();
let suffix_start_bit = reader.bit_pos();
let val = reader.read_bits(level_suffix_size)?;
if suffix_length >= 1 && suffix_start_byte < ep_map.rbsp_to_raw.len() {
let _raw_byte = ep_map.rbsp_to_raw[suffix_start_byte];
let sign_bit_byte_pos = suffix_start_byte + (suffix_start_bit as usize + level_suffix_size as usize - 1) / 8;
let sign_bit_bit_pos = (suffix_start_bit + level_suffix_size - 1) % 8;
if sign_bit_byte_pos < ep_map.rbsp_to_raw.len() {
let sign_raw = ep_map.rbsp_to_raw[sign_bit_byte_pos];
positions.push(EmbeddablePosition {
raw_byte_offset: sign_raw,
bit_offset: sign_bit_bit_pos,
domain: EmbedDomain::LevelSuffixSign,
scan_pos: 0, coeff_value: 0, ep_conflict: check_ep_conflict(raw_data, sign_raw, sign_bit_bit_pos),
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
});
}
if suffix_length >= 2 && level_suffix_size >= 2 {
let mag_bit_byte_pos = suffix_start_byte + (suffix_start_bit as usize + level_suffix_size as usize - 2) / 8;
let mag_bit_bit_pos = (suffix_start_bit + level_suffix_size - 2) % 8;
if mag_bit_byte_pos < ep_map.rbsp_to_raw.len() {
let mag_raw = ep_map.rbsp_to_raw[mag_bit_byte_pos];
positions.push(EmbeddablePosition {
raw_byte_offset: mag_raw,
bit_offset: mag_bit_bit_pos,
domain: EmbedDomain::LevelSuffixMag,
scan_pos: 0,
coeff_value: 0,
ep_conflict: check_ep_conflict(raw_data, mag_raw, mag_bit_bit_pos),
block_idx: 0,
frame_idx: 0,
mb_idx: 0,
});
}
}
}
val
} else {
0
};
let mut level_code =
(level_prefix.min(15) << suffix_length) + level_suffix;
if level_prefix >= 15 && suffix_length == 0 {
level_code += 15;
}
if level_prefix >= 16 {
level_code += (1 << (level_prefix - 3)) - 4096;
}
if i == trailing_ones as usize && trailing_ones < 3 {
level_code += 2;
}
let level = if level_code & 1 == 0 {
(level_code as i32 + 2) / 2
} else {
-((level_code as i32 + 1) / 2)
};
levels[i] = level;
let abs_level = level.unsigned_abs();
let active_sl = if suffix_length == 0 { 1 } else { suffix_length };
let thresholds: [u32; 6] = [3, 6, 12, 24, 48, u32::MAX];
let threshold = thresholds[(active_sl as usize - 1).min(5)];
let orig_exceeds = abs_level > threshold;
let would_cross = {
let plus = abs_level + 1;
let minus = abs_level.saturating_sub(1);
(plus > threshold) != orig_exceeds || (minus > threshold) != orig_exceeds
};
for pos in positions.iter_mut().rev() {
if pos.coeff_value == 0
&& (pos.domain == EmbedDomain::LevelSuffixSign
|| pos.domain == EmbedDomain::LevelSuffixMag)
{
pos.coeff_value = level;
if pos.domain == EmbedDomain::LevelSuffixMag && would_cross {
pos.ep_conflict = true;
}
} else {
break;
}
}
if suffix_length == 0 {
suffix_length = 1;
}
if suffix_length < 6 && abs_level > thresholds[suffix_length as usize - 1] {
suffix_length += 1;
}
}
let total_zeros = if total_coeffs < max_coeffs {
let table = total_zeros_table_for(total_coeffs, max_coeffs);
if table.is_empty() {
0
} else {
decode_vlc(reader, table)?
}
} else {
0 };
let mut runs = vec![0u8; total_coeffs as usize];
let mut zeros_left = total_zeros;
for i in 0..total_coeffs as usize - 1 {
if zeros_left == 0 {
break;
}
let table = run_before_table(zeros_left);
runs[i] = decode_vlc(reader, table)?;
zeros_left = zeros_left.saturating_sub(runs[i]);
}
if total_coeffs > 0 {
runs[total_coeffs as usize - 1] = zeros_left;
}
let mut coeffs = [0i32; 16];
let mut scan_positions = vec![0u8; total_coeffs as usize];
let mut coeff_num: i32 = -1;
for i in (0..total_coeffs as usize).rev() {
coeff_num += runs[i] as i32 + 1;
if coeff_num < 0 || coeff_num >= max_coeffs as i32 {
return Err(H264Error::CavlcError(format!(
"invalid run_before: coeff_num={coeff_num}"
)));
}
scan_positions[i] = coeff_num as u8;
coeffs[coeff_num as usize] = levels[i];
}
let mut t1_idx = 0usize;
let mut level_suffix_idx = trailing_ones as usize;
for epos in positions.iter_mut() {
match epos.domain {
EmbedDomain::T1Sign => {
if t1_idx < trailing_ones as usize {
epos.scan_pos = scan_positions[t1_idx];
t1_idx += 1;
}
}
EmbedDomain::LevelSuffixSign | EmbedDomain::LevelSuffixMag => {
if level_suffix_idx < total_coeffs as usize {
epos.scan_pos = scan_positions[level_suffix_idx];
if epos.domain == EmbedDomain::LevelSuffixSign {
level_suffix_idx += 1;
}
}
}
EmbedDomain::MvdLsb => {
}
}
}
Ok((
CavlcBlock {
total_coeffs,
trailing_ones,
coeffs,
},
positions,
))
}
pub(crate) fn check_ep_conflict(raw_data: &[u8], byte_offset: usize, bit_offset: u8) -> bool {
if byte_offset >= raw_data.len() {
return true; }
let original_byte = raw_data[byte_offset];
let flipped_byte = original_byte ^ (1 << (7 - bit_offset));
for offset in 0..3 {
let start = if byte_offset >= offset {
byte_offset - offset
} else {
continue;
};
if start + 2 >= raw_data.len() {
continue;
}
let b0 = if start == byte_offset { flipped_byte } else { raw_data[start] };
let b1 = if start + 1 == byte_offset {
flipped_byte
} else {
raw_data[start + 1]
};
let b2 = if start + 2 == byte_offset {
flipped_byte
} else {
raw_data[start + 2]
};
let orig_b0 = raw_data[start];
let orig_b1 = raw_data[start + 1];
let orig_b2 = raw_data[start + 2];
let orig_is_ep = orig_b0 == 0 && orig_b1 == 0 && orig_b2 <= 3;
let new_is_ep = b0 == 0 && b1 == 0 && b2 <= 3;
if orig_is_ep != new_is_ep {
return true;
}
}
false
}
#[cfg(test)]
mod tests {
use super::*;
use crate::codec::h264::bitstream::EpByteMap;
struct BitWriter {
data: Vec<u8>,
current: u8,
bit_pos: u8,
}
impl BitWriter {
fn new() -> Self {
Self { data: Vec::new(), current: 0, bit_pos: 0 }
}
fn write_bit(&mut self, val: bool) {
if val { self.current |= 1 << (7 - self.bit_pos); }
self.bit_pos += 1;
if self.bit_pos == 8 { self.data.push(self.current); self.current = 0; self.bit_pos = 0; }
}
fn write_bits(&mut self, val: u32, n: u8) {
for i in (0..n).rev() { self.write_bit((val >> i) & 1 != 0); }
}
fn align(&mut self) {
if self.bit_pos > 0 { self.data.push(self.current); self.current = 0; self.bit_pos = 0; }
}
}
fn identity_ep_map(len: usize) -> EpByteMap {
EpByteMap {
rbsp_to_raw: (0..len).collect(),
}
}
#[test]
fn decode_empty_block() {
let data = [0b1000_0000];
let ep_map = identity_ep_map(data.len());
let mut reader = RbspReader::new(&data);
let (block, positions) = decode_cavlc_block(&mut reader, 0, &ep_map, &data, 16).unwrap();
assert_eq!(block.total_coeffs, 0);
assert_eq!(block.trailing_ones, 0);
assert!(positions.is_empty());
}
#[test]
fn decode_single_trailing_one() {
let mut bits = BitWriter::new();
bits.write_bits(0b01, 2);
bits.write_bit(false);
bits.write_bit(true);
bits.align();
let ep_map = identity_ep_map(bits.data.len());
let mut reader = RbspReader::new(&bits.data);
let (block, positions) =
decode_cavlc_block(&mut reader, 0, &ep_map, &bits.data, 16).unwrap();
assert_eq!(block.total_coeffs, 1);
assert_eq!(block.trailing_ones, 1);
assert_eq!(block.coeffs[0], 1);
let t1_positions: Vec<_> = positions
.iter()
.filter(|p| p.domain == EmbedDomain::T1Sign)
.collect();
assert_eq!(t1_positions.len(), 1);
assert_eq!(t1_positions[0].coeff_value, 1);
}
#[test]
fn decode_two_trailing_ones_with_signs() {
let mut bits = BitWriter::new();
bits.write_bits(0b001, 3); bits.write_bit(true); bits.write_bit(false); bits.write_bits(0b111, 3);
bits.align();
let ep_map = identity_ep_map(bits.data.len());
let mut reader = RbspReader::new(&bits.data);
let (block, positions) =
decode_cavlc_block(&mut reader, 0, &ep_map, &bits.data, 16).unwrap();
assert_eq!(block.total_coeffs, 2);
assert_eq!(block.trailing_ones, 2);
let t1_positions: Vec<_> = positions
.iter()
.filter(|p| p.domain == EmbedDomain::T1Sign)
.collect();
assert_eq!(t1_positions.len(), 2);
assert_eq!(t1_positions[0].coeff_value, -1);
assert_eq!(t1_positions[1].coeff_value, 1);
}
#[test]
fn ep_conflict_detection_safe() {
let data = [0x12, 0x34, 0x56];
assert!(!check_ep_conflict(&data, 1, 0)); }
#[test]
fn ep_conflict_detection_creates_pattern() {
let data = [0x00, 0x00, 0x04];
assert!(check_ep_conflict(&data, 2, 5)); }
#[test]
fn ep_conflict_detection_destroys_ep_byte() {
let data = [0x00, 0x00, 0x03];
assert!(!check_ep_conflict(&data, 2, 6));
assert!(check_ep_conflict(&data, 2, 4)); }
}