pub struct BitReader<'a> {
pub(crate) data: &'a [u8],
pub(crate) byte_pos: usize,
pub(crate) bit_pos: u8, }
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
pub fn bits_remaining(&self) -> usize {
if self.byte_pos >= self.data.len() {
return 0;
}
(self.data.len() - self.byte_pos) * 8 - self.bit_pos as usize
}
pub fn read_bits(&mut self, n: u8) -> Option<u32> {
if n == 0 {
return Some(0);
}
if self.bits_remaining() < n as usize {
return None;
}
let mut value = 0u32;
let mut remaining = n;
if self.bit_pos > 0 && self.byte_pos < self.data.len() {
let avail = 8 - self.bit_pos;
let take = remaining.min(avail);
let shift = avail - take;
let mask = (1u8 << take) - 1;
value = ((self.data[self.byte_pos] >> shift) & mask) as u32;
self.bit_pos += take;
if self.bit_pos >= 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
remaining -= take;
}
while remaining >= 8 && self.byte_pos < self.data.len() {
value = (value << 8) | self.data[self.byte_pos] as u32;
self.byte_pos += 1;
remaining -= 8;
}
while remaining > 0 && self.byte_pos < self.data.len() {
let bit = (self.data[self.byte_pos] >> (7 - self.bit_pos)) & 1;
value = (value << 1) | bit as u32;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
remaining -= 1;
}
Some(value)
}
pub fn peek_bits(&self, n: u8) -> Option<u32> {
let mut tmp = BitReader {
data: self.data,
byte_pos: self.byte_pos,
bit_pos: self.bit_pos,
};
tmp.read_bits(n)
}
pub fn consume(&mut self, n: u8) {
let total = self.byte_pos * 8 + self.bit_pos as usize + n as usize;
self.byte_pos = total / 8;
self.bit_pos = (total % 8) as u8;
}
pub fn read_ue(&mut self) -> Option<u32> {
let mut leading_zeros = 0u32;
loop {
let bit = self.read_bits(1)?;
if bit == 1 {
break;
}
leading_zeros += 1;
if leading_zeros > 31 {
return None;
}
}
if leading_zeros == 0 {
return Some(0);
}
let suffix = self.read_bits(leading_zeros as u8)?;
Some((1 << leading_zeros) - 1 + suffix)
}
pub fn read_se(&mut self) -> Option<i32> {
let code = self.read_ue()?;
let value = code.div_ceil(2) as i32;
if code % 2 == 0 {
Some(-value)
} else {
Some(value)
}
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CavlcResult {
pub total_coeffs: usize,
pub trailing_ones: usize,
pub levels: [i32; 16],
pub total_zeros: usize,
pub runs: [usize; 16],
}
struct CoeffTokenEntry {
pattern: u32,
length: u8,
total_coeffs: u8,
trailing_ones: u8,
}
macro_rules! ct {
($pat:expr, $len:expr, $tc:expr, $t1:expr) => {
CoeffTokenEntry {
pattern: $pat,
length: $len,
total_coeffs: $tc,
trailing_ones: $t1,
}
};
}
const COEFF_TOKEN_NC_0_1: &[CoeffTokenEntry] = &[
ct!(0b1, 1, 0, 0), ct!(0b000101, 6, 1, 0), ct!(0b01, 2, 1, 1), ct!(0b00000111, 8, 2, 0), ct!(0b000100, 6, 2, 1), ct!(0b001, 3, 2, 2), ct!(0b000000111, 9, 3, 0), ct!(0b00000110, 8, 3, 1), ct!(0b0000101, 7, 3, 2), ct!(0b00011, 5, 3, 3), ct!(0b0000000111, 10, 4, 0), ct!(0b000000110, 9, 4, 1), ct!(0b00000101, 8, 4, 2), ct!(0b000011, 6, 4, 3), ];
const COEFF_TOKEN_NC_2_3: &[CoeffTokenEntry] = &[
ct!(0b11, 2, 0, 0), ct!(0b001011, 6, 1, 0), ct!(0b10, 2, 1, 1), ct!(0b000111, 6, 2, 0), ct!(0b00111, 5, 2, 1), ct!(0b011, 3, 2, 2), ct!(0b0000111, 7, 3, 0), ct!(0b001010, 6, 3, 1), ct!(0b001001, 6, 3, 2), ct!(0b00101, 5, 3, 3), ct!(0b00000111, 8, 4, 0), ct!(0b0000110, 7, 4, 1), ct!(0b000110, 6, 4, 2), ct!(0b00100, 5, 4, 3), ];
const COEFF_TOKEN_NC_4_7: &[CoeffTokenEntry] = &[
ct!(0b1111, 4, 0, 0), ct!(0b001111, 6, 1, 0), ct!(0b1110, 4, 1, 1), ct!(0b001011, 6, 2, 0), ct!(0b01111, 5, 2, 1), ct!(0b1101, 4, 2, 2), ct!(0b001000, 6, 3, 0), ct!(0b01110, 5, 3, 1), ct!(0b01101, 5, 3, 2), ct!(0b1100, 4, 3, 3), ct!(0b0000111, 7, 4, 0), ct!(0b001110, 6, 4, 1), ct!(0b001010, 6, 4, 2), ct!(0b1011, 4, 4, 3), ];
fn coeff_token_nc_8plus(reader: &mut BitReader) -> Option<(u8, u8)> {
let code = reader.read_bits(6)?;
if code == 3 {
return Some((0, 0));
}
let trailing_ones = (code & 0x03) as u8;
let total_coeffs = ((code >> 2) + 1) as u8;
let trailing_ones = trailing_ones.min(total_coeffs).min(3);
Some((total_coeffs, trailing_ones))
}
fn select_coeff_token_table(nc: i32) -> &'static [CoeffTokenEntry] {
match nc {
0..=1 => COEFF_TOKEN_NC_0_1,
2..=3 => COEFF_TOKEN_NC_2_3,
4..=7 => COEFF_TOKEN_NC_4_7,
_ => &[], }
}
fn read_coeff_token(reader: &mut BitReader, nc: i32) -> Option<(u8, u8)> {
if nc >= 8 {
return coeff_token_nc_8plus(reader);
}
let table = select_coeff_token_table(nc);
let max_len = table.iter().map(|e| e.length).max().unwrap_or(1);
let avail = reader.bits_remaining() as u8;
let peek_len = max_len.min(avail);
let peeked = reader.peek_bits(peek_len)?;
for entry in table {
let len = entry.length;
if len > peek_len {
continue;
}
let shift = peek_len - len;
let masked = peeked >> shift;
if masked == entry.pattern {
reader.consume(len);
return Some((entry.total_coeffs, entry.trailing_ones));
}
}
None
}
struct VlcEntry {
pattern: u32,
length: u8,
value: u8,
}
macro_rules! vlc {
($pat:expr, $len:expr, $val:expr) => {
VlcEntry {
pattern: $pat,
length: $len,
value: $val,
}
};
}
const TOTAL_ZEROS_TC1: &[VlcEntry] = &[
vlc!(0b1, 1, 0),
vlc!(0b011, 3, 1),
vlc!(0b010, 3, 2),
vlc!(0b0011, 4, 3),
vlc!(0b0010, 4, 4),
vlc!(0b00011, 5, 5),
vlc!(0b00010, 5, 6),
vlc!(0b00001, 5, 7),
vlc!(0b000001, 6, 8),
vlc!(0b0000001, 7, 9),
vlc!(0b00000001, 8, 10),
vlc!(0b000000001, 9, 11),
vlc!(0b0000000001, 10, 12),
vlc!(0b00000000011, 11, 13),
vlc!(0b00000000010, 11, 14),
vlc!(0b00000000001, 11, 15),
];
const TOTAL_ZEROS_TC2: &[VlcEntry] = &[
vlc!(0b111, 3, 0),
vlc!(0b110, 3, 1),
vlc!(0b101, 3, 2),
vlc!(0b100, 3, 3),
vlc!(0b011, 3, 4),
vlc!(0b0101, 4, 5),
vlc!(0b0100, 4, 6),
vlc!(0b0011, 4, 7),
vlc!(0b0010, 4, 8),
vlc!(0b00011, 5, 9),
vlc!(0b00010, 5, 10),
vlc!(0b000011, 6, 11),
vlc!(0b000010, 6, 12),
vlc!(0b000001, 6, 13),
vlc!(0b000000, 6, 14),
];
const TOTAL_ZEROS_TC3: &[VlcEntry] = &[
vlc!(0b0101, 4, 0),
vlc!(0b111, 3, 1),
vlc!(0b110, 3, 2),
vlc!(0b101, 3, 3),
vlc!(0b0100, 4, 4),
vlc!(0b0011, 4, 5),
vlc!(0b100, 3, 6),
vlc!(0b011, 3, 7),
vlc!(0b0010, 4, 8),
vlc!(0b00011, 5, 9),
vlc!(0b00010, 5, 10),
vlc!(0b000001, 6, 11),
vlc!(0b00001, 5, 12),
vlc!(0b000000, 6, 13),
];
const TOTAL_ZEROS_TC4: &[VlcEntry] = &[
vlc!(0b00011, 5, 0),
vlc!(0b111, 3, 1),
vlc!(0b0101, 4, 2),
vlc!(0b0100, 4, 3),
vlc!(0b110, 3, 4),
vlc!(0b101, 3, 5),
vlc!(0b100, 3, 6),
vlc!(0b0011, 4, 7),
vlc!(0b011, 3, 8),
vlc!(0b00010, 5, 9),
vlc!(0b00001, 5, 10),
vlc!(0b00000, 5, 11),
vlc!(0b0010, 4, 12),
];
fn total_zeros_table(total_coeffs: usize) -> &'static [VlcEntry] {
match total_coeffs {
1 => TOTAL_ZEROS_TC1,
2 => TOTAL_ZEROS_TC2,
3 => TOTAL_ZEROS_TC3,
4 => TOTAL_ZEROS_TC4,
_ => &[],
}
}
fn read_vlc(reader: &mut BitReader, table: &[VlcEntry]) -> Option<u8> {
let max_len = table.iter().map(|e| e.length).max().unwrap_or(1);
let avail = reader.bits_remaining() as u8;
let peek_len = max_len.min(avail);
let peeked = reader.peek_bits(peek_len)?;
for entry in table {
let len = entry.length;
if len > peek_len {
continue;
}
let shift = peek_len - len;
let masked = peeked >> shift;
if masked == entry.pattern {
reader.consume(len);
return Some(entry.value);
}
}
None
}
fn read_run_before(reader: &mut BitReader, zeros_left: usize) -> Option<usize> {
match zeros_left {
0 => Some(0),
1 => {
let bit = reader.read_bits(1)?;
Some(if bit == 1 { 0 } else { 1 })
}
2 => {
let bit = reader.read_bits(1)?;
if bit == 1 {
return Some(0);
}
let bit2 = reader.read_bits(1)?;
Some(if bit2 == 1 { 1 } else { 2 })
}
3 => {
let bits = reader.read_bits(2)?;
match bits {
0b11 => Some(0),
0b10 => Some(1),
0b01 => Some(2),
0b00 => Some(3),
_ => None,
}
}
4 => {
let bits = reader.read_bits(2)?;
if bits != 0 {
return Some(match bits {
0b11 => 0,
0b10 => 1,
0b01 => 2,
_ => unreachable!(),
});
}
let bit = reader.read_bits(1)?;
Some(if bit == 1 { 3 } else { 4 })
}
5 => {
let bits = reader.read_bits(2)?;
if bits != 0 {
return Some(match bits {
0b11 => 0,
0b10 => 1,
0b01 => 2,
_ => unreachable!(),
});
}
let bits2 = reader.read_bits(1)?;
if bits2 == 1 {
return Some(3);
}
let bits3 = reader.read_bits(1)?;
Some(if bits3 == 1 { 4 } else { 5 })
}
6 => {
let bits = reader.read_bits(2)?;
if bits != 0 {
return Some(match bits {
0b11 => 0,
0b10 => 1,
0b01 => 2,
_ => unreachable!(),
});
}
let bits2 = reader.read_bits(1)?;
if bits2 == 1 {
return Some(3);
}
let bits3 = reader.read_bits(1)?;
if bits3 == 1 {
return Some(4);
}
let bits4 = reader.read_bits(1)?;
Some(if bits4 == 1 { 5 } else { 6 })
}
_ => {
let mut run = 0usize;
loop {
let bit = reader.read_bits(1)?;
if bit == 1 {
break;
}
run += 1;
if run > 15 {
return None;
}
}
Some(run)
}
}
}
pub fn decode_cavlc_block(reader: &mut BitReader, nc: i32) -> Option<CavlcResult> {
let (total_coeffs_u8, trailing_ones_u8) = read_coeff_token(reader, nc)?;
let total_coeffs = total_coeffs_u8 as usize;
let trailing_ones = trailing_ones_u8 as usize;
if total_coeffs == 0 {
return Some(CavlcResult {
total_coeffs: 0,
trailing_ones: 0,
levels: [0; 16],
total_zeros: 0,
runs: [0; 16],
});
}
let mut levels = [0i32; 16];
let mut level_count = 0usize;
for _ in 0..trailing_ones {
let sign_bit = reader.read_bits(1)?;
if level_count < 16 {
levels[level_count] = if sign_bit == 0 { 1 } else { -1 };
level_count += 1;
}
}
let mut suffix_length: u8 = if total_coeffs > 10 && trailing_ones < 3 {
1
} else {
0
};
for i in trailing_ones..total_coeffs {
let mut level_prefix = 0u32;
loop {
let bit = reader.read_bits(1)?;
if bit == 1 {
break;
}
level_prefix += 1;
if level_prefix > 20 {
return None;
}
}
let mut level_code = level_prefix as i32;
let suffix_len = if level_prefix == 14 && suffix_length == 0 {
4 } else if level_prefix >= 15 {
(level_prefix as u8).saturating_sub(3)
} else {
suffix_length
};
if suffix_len > 0 {
let level_suffix = reader.read_bits(suffix_len)? as i32;
level_code = (level_code << suffix_len) + level_suffix;
}
if i == trailing_ones && trailing_ones < 3 {
level_code += 2;
}
let level = if level_code % 2 == 0 {
(level_code + 2) >> 1
} else {
(-level_code - 1) >> 1
};
if level_count < 16 {
levels[level_count] = level;
level_count += 1;
}
if suffix_length == 0 {
suffix_length = 1;
}
if level.unsigned_abs() > (3 << (suffix_length - 1)) {
suffix_length += 1;
}
}
let max_zeros = 16 - total_coeffs;
let total_zeros = if total_coeffs < 16 && max_zeros > 0 {
let table = total_zeros_table(total_coeffs);
if table.is_empty() {
reader.read_ue()? as usize
} else {
read_vlc(reader, table)? as usize
}
} else {
0
};
let mut runs = [0usize; 16];
let mut zeros_left = total_zeros;
for i in 0..total_coeffs - 1 {
if zeros_left == 0 {
break;
}
let run = read_run_before(reader, zeros_left)?;
runs[i] = run;
zeros_left = zeros_left.saturating_sub(run);
}
if total_coeffs > 0 {
runs[total_coeffs - 1] = zeros_left;
}
Some(CavlcResult {
total_coeffs,
trailing_ones,
levels,
total_zeros,
runs,
})
}
pub fn expand_cavlc_to_coefficients(result: &CavlcResult, block_size: usize) -> Vec<i32> {
let mut coeffs = vec![0i32; block_size];
expand_cavlc_to_coefficients_into(result, &mut coeffs);
coeffs
}
pub fn expand_cavlc_to_coefficients_into(result: &CavlcResult, coeffs: &mut [i32]) {
if result.total_coeffs == 0 {
return;
}
let n = result.total_coeffs;
let block_size = coeffs.len();
let mut pos = block_size;
for i in 0..n {
let run = if i < result.runs.len() {
result.runs[i]
} else {
0
};
if pos < run {
break;
}
pos -= run;
if pos == 0 {
break;
}
pos -= 1;
if n > i && (n - 1 - i) < result.levels.len() {
coeffs[pos] = result.levels[n - 1 - i];
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bit_reader_basic() {
let data = [0xB4];
let mut r = BitReader::new(&data);
assert_eq!(r.read_bits(1), Some(1)); assert_eq!(r.read_bits(1), Some(0)); assert_eq!(r.read_bits(1), Some(1)); assert_eq!(r.read_bits(1), Some(1)); assert_eq!(r.read_bits(1), Some(0)); assert_eq!(r.read_bits(1), Some(1)); assert_eq!(r.read_bits(1), Some(0)); assert_eq!(r.read_bits(1), Some(0)); assert_eq!(r.read_bits(1), None); }
#[test]
fn test_bit_reader_exp_golomb_unsigned() {
let data = [0xA6, 0x42, 0x80];
let mut r = BitReader::new(&data);
assert_eq!(r.read_ue(), Some(0));
assert_eq!(r.read_ue(), Some(1));
assert_eq!(r.read_ue(), Some(2));
assert_eq!(r.read_ue(), Some(3));
assert_eq!(r.read_ue(), Some(4));
}
#[test]
fn test_bit_reader_exp_golomb_signed() {
let data = [0xA6, 0x42, 0x80];
let mut r = BitReader::new(&data);
assert_eq!(r.read_se(), Some(0));
assert_eq!(r.read_se(), Some(1));
assert_eq!(r.read_se(), Some(-1));
assert_eq!(r.read_se(), Some(2));
assert_eq!(r.read_se(), Some(-2));
}
#[test]
fn test_bit_reader_multi_byte() {
let data = [0xFF, 0xF0];
let mut r = BitReader::new(&data);
assert_eq!(r.read_bits(12), Some(0xFFF)); assert_eq!(r.read_bits(4), Some(0x0)); }
#[test]
fn test_cavlc_all_zeros() {
let data = [0x80]; let mut r = BitReader::new(&data);
let result = decode_cavlc_block(&mut r, 0).unwrap();
assert_eq!(result.total_coeffs, 0);
assert_eq!(result.trailing_ones, 0);
assert_eq!(result.levels, [0; 16]);
assert_eq!(result.total_zeros, 0);
assert_eq!(result.runs, [0; 16]);
}
#[test]
fn test_expand_coefficients() {
let result = CavlcResult {
total_coeffs: 3,
trailing_ones: 0,
levels: {
let mut l = [0i32; 16];
l[0] = 5;
l[1] = -3;
l[2] = 2;
l
},
total_zeros: 2,
runs: {
let mut r = [0usize; 16];
r[0] = 1;
r[1] = 0;
r[2] = 1;
r
},
};
let coeffs = expand_cavlc_to_coefficients(&result, 8);
assert_eq!(coeffs, [0, 0, 0, 5, 0, -3, 2, 0]);
}
#[test]
fn test_bits_remaining() {
let data = [0xAB, 0xCD];
let mut r = BitReader::new(&data);
assert_eq!(r.bits_remaining(), 16);
let _ = r.read_bits(5);
assert_eq!(r.bits_remaining(), 11);
let _ = r.read_bits(11);
assert_eq!(r.bits_remaining(), 0);
}
#[test]
fn test_expand_empty_block() {
let result = CavlcResult {
total_coeffs: 0,
trailing_ones: 0,
levels: [0; 16],
total_zeros: 0,
runs: [0; 16],
};
let coeffs = expand_cavlc_to_coefficients(&result, 16);
assert_eq!(coeffs, vec![0; 16]);
}
#[test]
fn test_cavlc_nc2_all_zeros() {
let data = [0xC0]; let mut r = BitReader::new(&data);
let result = decode_cavlc_block(&mut r, 2).unwrap();
assert_eq!(result.total_coeffs, 0);
assert_eq!(result.trailing_ones, 0);
}
}