use crate::error::Error;
#[derive(Debug, Clone, Copy, Default)]
pub struct BitReader {
byte_pos: usize,
acc: u64,
nbits: u32,
}
impl BitReader {
pub const fn new() -> Self {
Self {
byte_pos: 0,
acc: 0,
nbits: 0,
}
}
#[allow(dead_code)]
pub const fn byte_pos(&self) -> usize {
self.byte_pos
}
fn refill(&mut self, want: u32, input: &[u8]) -> Result<(), Error> {
debug_assert!(want <= 64);
while self.nbits < want {
if self.byte_pos >= input.len() {
return Err(Error::UnexpectedEnd);
}
let b = input[self.byte_pos] as u64;
self.byte_pos += 1;
self.acc |= b << (56 - self.nbits);
self.nbits += 8;
}
Ok(())
}
pub fn read_bits(&mut self, n: u32, input: &[u8]) -> Result<u32, Error> {
if n == 0 {
return Ok(0);
}
debug_assert!(n <= 32);
self.refill(n, input)?;
let v = ((self.acc >> (64 - n)) & ((1u64 << n) - 1)) as u32;
self.acc <<= n;
self.nbits -= n;
Ok(v)
}
pub fn peek_up_to(&mut self, n: u32, input: &[u8]) -> (u32, u32) {
debug_assert!(n <= 32);
let _ = self.refill(n, input);
let have = self.nbits.min(n);
if have == 0 {
return (0, 0);
}
let v = (self.acc >> (64 - n)) as u32;
let shift = n - have;
let masked = (v >> shift) << shift;
(masked, have)
}
pub fn drop_bits(&mut self, n: u32) {
debug_assert!(n <= self.nbits);
self.acc <<= n;
self.nbits -= n;
}
#[allow(dead_code)]
pub fn bits_remaining(&self, input: &[u8]) -> u64 {
self.nbits as u64 + (input.len().saturating_sub(self.byte_pos) as u64) * 8
}
#[allow(dead_code)]
pub fn reset(&mut self) {
self.byte_pos = 0;
self.acc = 0;
self.nbits = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_bits_msb_first() {
let input = [0xCA, 0x53];
let mut r = BitReader::new();
assert_eq!(r.read_bits(4, &input).unwrap(), 0xC);
assert_eq!(r.read_bits(4, &input).unwrap(), 0xA);
assert_eq!(r.read_bits(8, &input).unwrap(), 0x53);
assert_eq!(r.read_bits(1, &input), Err(Error::UnexpectedEnd));
}
#[test]
fn read_bits_crossing_byte() {
let input = [0xAB, 0xCD];
let mut r = BitReader::new();
assert_eq!(r.read_bits(12, &input).unwrap(), 0xABC);
assert_eq!(r.read_bits(4, &input).unwrap(), 0xD);
}
#[test]
fn read_zero_bits_is_noop() {
let input = [];
let mut r = BitReader::new();
assert_eq!(r.read_bits(0, &input).unwrap(), 0);
}
#[test]
fn peek_up_to_handles_short_input() {
let input = [0xF0];
let mut r = BitReader::new();
let (v, have) = r.peek_up_to(15, &input);
assert_eq!(have, 8);
assert_eq!(v, 0xF0 << 7);
assert_eq!(r.read_bits(8, &input).unwrap(), 0xF0);
}
#[test]
fn bits_remaining_tracks_buffered_and_unread() {
let input = [0xFF, 0xFF, 0xFF];
let mut r = BitReader::new();
assert_eq!(r.bits_remaining(&input), 24);
r.read_bits(4, &input).unwrap();
assert_eq!(r.bits_remaining(&input), 20);
r.read_bits(12, &input).unwrap();
assert_eq!(r.bits_remaining(&input), 8);
}
}