use gamut_core::{Error, Result};
const MAX_BITS_PER_OP: u32 = 32;
#[derive(Debug, Clone)]
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
acc: u64,
bits_in_acc: u32,
}
impl<'a> BitReader<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
acc: 0,
bits_in_acc: 0,
}
}
pub fn read_bits(&mut self, n: u32) -> Result<u32> {
if n == 0 {
return Ok(0);
}
if n > MAX_BITS_PER_OP {
return Err(Error::InvalidInput("VP8L: bit read width out of range"));
}
while self.bits_in_acc < n {
let Some(&byte) = self.data.get(self.byte_pos) else {
return Err(Error::InvalidInput("VP8L: unexpected end of bitstream"));
};
self.acc |= u64::from(byte) << self.bits_in_acc;
self.bits_in_acc += 8;
self.byte_pos += 1;
}
let mask = (1u64 << n) - 1;
let value = (self.acc & mask) as u32;
self.acc >>= n;
self.bits_in_acc -= n;
Ok(value)
}
pub fn read_bit(&mut self) -> Result<u32> {
self.read_bits(1)
}
#[must_use]
pub fn bits_consumed(&self) -> usize {
self.byte_pos * 8 - self.bits_in_acc as usize
}
#[must_use]
pub fn is_exhausted(&self) -> bool {
self.byte_pos >= self.data.len() && self.bits_in_acc == 0
}
}
#[derive(Debug, Clone, Default)]
pub struct BitWriter {
buf: Vec<u8>,
acc: u64,
bits_in_acc: u32,
}
impl BitWriter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn write_bits(&mut self, value: u32, n: u32) {
if n == 0 || n > MAX_BITS_PER_OP {
return;
}
let mask = (1u64 << n) - 1;
self.acc |= (u64::from(value) & mask) << self.bits_in_acc;
self.bits_in_acc += n;
while self.bits_in_acc >= 8 {
self.buf.push((self.acc & 0xff) as u8);
self.acc >>= 8;
self.bits_in_acc -= 8;
}
}
#[must_use]
pub fn bit_len(&self) -> usize {
self.buf.len() * 8 + self.bits_in_acc as usize
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
if self.bits_in_acc > 0 {
self.buf.push((self.acc & 0xff) as u8);
}
self.buf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn reads_lsb_first_within_a_byte() {
let mut r = BitReader::new(&[0xB2]);
assert_eq!(r.read_bits(1).unwrap(), 0);
assert_eq!(r.read_bits(1).unwrap(), 1);
assert_eq!(r.read_bits(2).unwrap(), 0b00); assert_eq!(r.read_bits(4).unwrap(), 0b1011); }
#[test]
fn multi_bit_read_equals_single_bit_reads() {
let data = [0x9D, 0x01, 0x2A, 0xFF];
let mut a = BitReader::new(&data);
let mut b = BitReader::new(&data);
for _ in 0..10 {
let combined = a.read_bits(2).unwrap();
let lo = b.read_bits(1).unwrap();
let hi = b.read_bits(1).unwrap();
assert_eq!(combined, lo | (hi << 1));
}
}
#[test]
fn read_zero_bits_is_noop() {
let mut r = BitReader::new(&[0xFF]);
assert_eq!(r.read_bits(0).unwrap(), 0);
assert_eq!(r.bits_consumed(), 0);
assert_eq!(r.read_bits(8).unwrap(), 0xFF);
}
#[test]
fn crosses_byte_boundaries() {
let mut w = BitWriter::new();
w.write_bits(0x2f, 8);
w.write_bits(16383, 14); w.write_bits(0, 14); let bytes = w.finish();
let mut r = BitReader::new(&bytes);
assert_eq!(r.read_bits(8).unwrap(), 0x2f);
assert_eq!(r.read_bits(14).unwrap(), 16383);
assert_eq!(r.read_bits(14).unwrap(), 0);
}
#[test]
fn out_of_data_is_invalid_input_not_panic() {
let mut r = BitReader::new(&[0xAB]);
assert_eq!(r.read_bits(8).unwrap(), 0xAB);
assert!(matches!(r.read_bits(1), Err(Error::InvalidInput(_))));
let mut empty = BitReader::new(&[]);
assert!(matches!(empty.read_bits(1), Err(Error::InvalidInput(_))));
}
#[test]
fn oversized_read_is_rejected() {
let mut r = BitReader::new(&[0; 8]);
assert!(matches!(r.read_bits(33), Err(Error::InvalidInput(_))));
}
#[test]
fn writer_round_trips_varied_widths() {
let fields: &[(u32, u32)] = &[
(0, 1),
(1, 1),
(0b101, 3),
(0, 0),
(0x3FFF, 14),
(0xABCD, 16),
(0x00FF_FF00, 24),
(0xDEAD_BEEF, 32),
(7, 3),
];
let mut w = BitWriter::new();
for &(v, n) in fields {
w.write_bits(v, n);
}
let total_bits: usize = fields.iter().map(|&(_, n)| n as usize).sum();
assert_eq!(w.bit_len(), total_bits);
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
for &(v, n) in fields {
let masked = if n == 0 {
0
} else {
v & ((1u64 << n) - 1) as u32
};
assert_eq!(r.read_bits(n).unwrap(), masked, "field {v:#x}/{n}");
}
}
#[test]
fn partial_final_byte_is_zero_padded() {
let mut w = BitWriter::new();
w.write_bits(0b1, 1); let bytes = w.finish();
assert_eq!(bytes, vec![0b0000_0001]);
}
#[test]
fn consumed_and_exhausted_track_position() {
let mut r = BitReader::new(&[0xFF, 0x0F]);
assert!(!r.is_exhausted());
r.read_bits(12).unwrap();
assert_eq!(r.bits_consumed(), 12);
assert!(!r.is_exhausted());
r.read_bits(4).unwrap();
assert_eq!(r.bits_consumed(), 16);
assert!(r.is_exhausted());
}
}