use crate::CodecError;
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitReader<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
pub fn read_bit(&mut self) -> Result<bool, CodecError> {
if self.byte_pos >= self.data.len() {
return Err(CodecError::InvalidData(
"Unexpected end of bitstream".into(),
));
}
let byte = self.data[self.byte_pos];
let bit = (byte >> (7 - self.bit_pos)) & 1;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
Ok(bit != 0)
}
pub fn read_bits(&mut self, n: u8) -> Result<u32, CodecError> {
if n == 0 || n > 32 {
return Err(CodecError::InvalidData(format!("Invalid bit count: {n}")));
}
let mut value = 0u32;
for _ in 0..n {
value = (value << 1) | u32::from(self.read_bit()?);
}
Ok(value)
}
pub fn read_signed_bits(&mut self, n: u8) -> Result<i32, CodecError> {
let value = self.read_bits(n)?;
let sign_bit = 1u32 << (n - 1);
if (value & sign_bit) != 0 {
let mask = (1u32 << n) - 1;
Ok((value | !mask) as i32)
} else {
Ok(value as i32)
}
}
pub fn peek_bits(&self, n: u8) -> Result<u32, CodecError> {
if n == 0 || n > 32 {
return Err(CodecError::InvalidData(format!("Invalid bit count: {n}")));
}
let mut value = 0u32;
let mut byte_pos = self.byte_pos;
let mut bit_pos = self.bit_pos;
for _ in 0..n {
if byte_pos >= self.data.len() {
return Err(CodecError::InvalidData(
"Unexpected end of bitstream".into(),
));
}
let byte = self.data[byte_pos];
let bit = (byte >> (7 - bit_pos)) & 1;
value = (value << 1) | u32::from(bit);
bit_pos += 1;
if bit_pos == 8 {
bit_pos = 0;
byte_pos += 1;
}
}
Ok(value)
}
pub fn skip_bits(&mut self, n: usize) -> Result<(), CodecError> {
for _ in 0..n {
self.read_bit()?;
}
Ok(())
}
pub fn byte_align(&mut self) {
if self.bit_pos != 0 {
self.bit_pos = 0;
self.byte_pos += 1;
}
}
#[must_use]
pub fn bit_position(&self) -> usize {
self.byte_pos * 8 + self.bit_pos as usize
}
#[must_use]
pub fn bits_remaining(&self) -> usize {
(self.data.len() - self.byte_pos) * 8 - self.bit_pos as usize
}
#[must_use]
pub fn is_eof(&self) -> bool {
self.byte_pos >= self.data.len()
}
pub fn find_start_code(&mut self) -> Option<usize> {
self.byte_align();
while self.byte_pos + 2 < self.data.len() {
if self.data[self.byte_pos] == 0x00
&& self.data[self.byte_pos + 1] == 0x00
&& (self.data[self.byte_pos + 2] & 0x80) == 0x80
{
return Some(self.byte_pos);
}
self.byte_pos += 1;
}
None
}
pub fn read_vlc(&mut self, max_bits: u8) -> Result<(u32, u8), CodecError> {
for bits in 1..=max_bits {
let code = self.peek_bits(bits)?;
self.skip_bits(bits as usize)?;
return Ok((code, bits));
}
Err(CodecError::InvalidData("VLC code too long".into()))
}
}
pub struct BitWriter {
data: Vec<u8>,
bit_pos: u8,
}
impl BitWriter {
#[must_use]
pub fn new() -> Self {
Self {
data: Vec::new(),
bit_pos: 0,
}
}
#[must_use]
pub fn with_capacity(capacity: usize) -> Self {
Self {
data: Vec::with_capacity(capacity),
bit_pos: 0,
}
}
pub fn write_bit(&mut self, bit: bool) {
if self.bit_pos == 0 {
self.data.push(0);
}
if bit {
let last_idx = self.data.len() - 1;
self.data[last_idx] |= 1 << (7 - self.bit_pos);
}
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
}
}
pub fn write_bits(&mut self, value: u32, n: u8) {
if n == 0 || n > 32 {
return;
}
for i in (0..n).rev() {
let bit = (value >> i) & 1;
self.write_bit(bit != 0);
}
}
pub fn write_signed_bits(&mut self, value: i32, n: u8) {
let mask = (1u32 << n) - 1;
let unsigned = (value as u32) & mask;
self.write_bits(unsigned, n);
}
pub fn byte_align(&mut self) {
while self.bit_pos != 0 {
self.write_bit(false);
}
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.data
}
#[must_use]
pub fn into_vec(mut self) -> Vec<u8> {
if self.bit_pos != 0 {
self.byte_align();
}
self.data
}
#[must_use]
pub fn bit_position(&self) -> usize {
self.data.len() * 8 + self.bit_pos as usize
}
pub fn write_vlc(&mut self, code: u32, bits: u8) {
self.write_bits(code, bits);
}
pub fn write_stuffing(&mut self) {
self.byte_align();
}
pub fn write_start_code(&mut self, code: u8) {
self.byte_align();
self.data.extend_from_slice(&[0x00, 0x00, code]);
}
}
impl Default for BitWriter {
fn default() -> Self {
Self::new()
}
}
pub struct ExpGolomb;
impl ExpGolomb {
pub fn read_ue(reader: &mut BitReader<'_>) -> Result<u32, CodecError> {
let mut leading_zeros = 0;
while !reader.read_bit()? {
leading_zeros += 1;
if leading_zeros > 31 {
return Err(CodecError::InvalidData("Invalid Exp-Golomb code".into()));
}
}
if leading_zeros == 0 {
return Ok(0);
}
let value = reader.read_bits(leading_zeros)?;
Ok((1 << leading_zeros) - 1 + value)
}
pub fn read_se(reader: &mut BitReader<'_>) -> Result<i32, CodecError> {
let value = Self::read_ue(reader)?;
if value == 0 {
return Ok(0);
}
let sign = if (value & 1) != 0 { 1 } else { -1 };
Ok(sign * ((value + 1) / 2) as i32)
}
pub fn write_ue(writer: &mut BitWriter, value: u32) {
if value == 0 {
writer.write_bit(true);
return;
}
let bits = 32 - (value + 1).leading_zeros();
let leading_zeros = bits - 1;
for _ in 0..leading_zeros {
writer.write_bit(false);
}
writer.write_bit(true);
if leading_zeros > 0 {
let remainder = value + 1 - (1 << leading_zeros);
writer.write_bits(remainder, leading_zeros as u8);
}
}
pub fn write_se(writer: &mut BitWriter, value: i32) {
if value == 0 {
Self::write_ue(writer, 0);
return;
}
let abs_value = value.unsigned_abs();
let code = if value > 0 {
2 * abs_value - 1
} else {
2 * abs_value
};
Self::write_ue(writer, code);
}
}
pub struct StuffingHelper;
impl StuffingHelper {
#[must_use]
pub fn needs_emulation_prevention(data: &[u8], pos: usize) -> bool {
if pos < 2 {
return false;
}
data[pos - 2] == 0x00 && data[pos - 1] == 0x00 && data[pos] <= 0x03
}
#[must_use]
pub fn add_emulation_prevention(data: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len() + data.len() / 100);
let mut zero_count = 0;
for &byte in data {
if zero_count == 2 && byte <= 0x03 {
result.push(0x03); zero_count = 0;
}
result.push(byte);
if byte == 0x00 {
zero_count += 1;
} else {
zero_count = 0;
}
}
result
}
#[must_use]
pub fn remove_emulation_prevention(data: &[u8]) -> Vec<u8> {
let mut result = Vec::with_capacity(data.len());
let mut i = 0;
while i < data.len() {
if i + 2 < data.len() && data[i] == 0x00 && data[i + 1] == 0x00 && data[i + 2] == 0x03 {
result.push(0x00);
result.push(0x00);
i += 3; } else {
result.push(data[i]);
i += 1;
}
}
result
}
}