use crate::error::CodecError;
pub struct BitReader<'a> {
bytes: &'a [u8],
pos: usize,
}
impl<'a> BitReader<'a> {
pub fn new(bytes: &'a [u8]) -> Self {
Self { bytes, pos: 0 }
}
#[inline]
pub fn bits_read(&self) -> usize {
self.pos
}
#[inline]
pub fn bits_remaining(&self) -> usize {
self.bytes.len().saturating_mul(8).saturating_sub(self.pos)
}
pub fn read_bit(&mut self) -> Result<u8, CodecError> {
if self.bits_remaining() == 0 {
return Err(CodecError::EndOfStream {
needed: 1,
remaining: 0,
});
}
let byte = self.bytes[self.pos / 8];
let shift = 7 - (self.pos % 8);
self.pos += 1;
Ok((byte >> shift) & 1)
}
pub fn read_bits(&mut self, n: u8) -> Result<u32, CodecError> {
if n == 0 {
return Ok(0);
}
if n > 32 {
return Err(CodecError::GolombOverflow);
}
if self.bits_remaining() < n as usize {
return Err(CodecError::EndOfStream {
needed: n as usize,
remaining: self.bits_remaining(),
});
}
let mut value: u32 = 0;
for _ in 0..n {
let byte = self.bytes[self.pos / 8];
let shift = 7 - (self.pos % 8);
self.pos += 1;
value = (value << 1) | ((byte >> shift) as u32 & 1);
}
Ok(value)
}
pub fn skip_bits(&mut self, n: usize) -> Result<(), CodecError> {
if self.bits_remaining() < n {
return Err(CodecError::EndOfStream {
needed: n,
remaining: self.bits_remaining(),
});
}
self.pos += n;
Ok(())
}
pub fn read_ue_v(&mut self) -> Result<u32, CodecError> {
let mut leading_zeros: u32 = 0;
while self.read_bit()? == 0 {
leading_zeros += 1;
if leading_zeros > 32 {
return Err(CodecError::GolombOverflow);
}
}
if leading_zeros == 0 {
return Ok(0);
}
let suffix = self.read_bits(leading_zeros as u8)?;
let base: u64 = (1u64 << leading_zeros) - 1;
let total = base + suffix as u64;
if total > u32::MAX as u64 {
return Err(CodecError::GolombOverflow);
}
Ok(total as u32)
}
pub fn read_se_v(&mut self) -> Result<i32, CodecError> {
let code = self.read_ue_v()?;
if code == 0 {
return Ok(0);
}
let magnitude = (code / 2 + code % 2) as i64;
if code & 1 == 1 {
Ok(magnitude as i32)
} else {
Ok(-(magnitude as i32))
}
}
}
pub fn rbsp_from_ebsp(ebsp: &[u8]) -> Vec<u8> {
let mut out = Vec::with_capacity(ebsp.len());
let mut i = 0;
while i < ebsp.len() {
if i + 2 < ebsp.len() && ebsp[i] == 0x00 && ebsp[i + 1] == 0x00 && ebsp[i + 2] == 0x03 {
out.push(0x00);
out.push(0x00);
i += 3;
} else {
out.push(ebsp[i]);
i += 1;
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_single_bits() {
let mut r = BitReader::new(&[0b1010_1100]);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 0);
assert!(matches!(r.read_bit(), Err(CodecError::EndOfStream { .. })));
}
#[test]
fn read_multi_bits() {
let mut r = BitReader::new(&[0xAC, 0x0F]);
assert_eq!(r.read_bits(4).unwrap(), 0b1010);
assert_eq!(r.read_bits(4).unwrap(), 0b1100);
assert_eq!(r.read_bits(8).unwrap(), 0x0F);
}
#[test]
fn read_bits_spanning_byte_boundary() {
let mut r = BitReader::new(&[0xAB, 0xCD]);
assert_eq!(r.read_bits(12).unwrap(), 0xABC);
assert_eq!(r.read_bits(4).unwrap(), 0xD);
}
#[test]
fn ue_v_decodes_known_values() {
let mut r = BitReader::new(&[0xA6, 0x40]);
assert_eq!(r.read_ue_v().unwrap(), 0);
assert_eq!(r.read_ue_v().unwrap(), 1);
assert_eq!(r.read_ue_v().unwrap(), 2);
assert_eq!(r.read_ue_v().unwrap(), 3);
}
#[test]
fn se_v_mapping() {
let mut r = BitReader::new(&[0xA6, 0x42, 0x80]);
assert_eq!(r.read_se_v().unwrap(), 0);
assert_eq!(r.read_se_v().unwrap(), 1);
assert_eq!(r.read_se_v().unwrap(), -1);
assert_eq!(r.read_se_v().unwrap(), 2);
assert_eq!(r.read_se_v().unwrap(), -2);
}
#[test]
fn rbsp_strips_emulation_byte() {
assert_eq!(rbsp_from_ebsp(&[0x00, 0x00, 0x03, 0x01]), vec![0x00, 0x00, 0x01]);
assert_eq!(rbsp_from_ebsp(&[0x01, 0x02, 0x03]), vec![0x01, 0x02, 0x03]);
assert_eq!(
rbsp_from_ebsp(&[0x00, 0x00, 0x03, 0x00, 0x00, 0x03, 0xFF]),
vec![0x00, 0x00, 0x00, 0x00, 0xFF]
);
}
#[test]
fn ue_v_overflow_guard() {
let bytes = [0u8; 8]; let mut r = BitReader::new(&bytes);
assert!(matches!(r.read_ue_v(), Err(CodecError::GolombOverflow)));
}
}