#[derive(Debug, thiserror::Error, PartialEq, Eq)]
pub enum BitReaderError {
#[error("bit reader out of input: needed {needed} bits, had {available}")]
OutOfInput {
needed: u32,
available: u32,
},
#[error("bit reader: cannot read {0} bits in a single read (max 32)")]
TooManyBits(u32),
}
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitReader<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
#[must_use]
pub fn bits_remaining(&self) -> u32 {
(self.data.len() as u32)
.saturating_sub(self.byte_pos as u32)
.saturating_mul(8)
.saturating_sub(self.bit_pos as u32)
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.bits_remaining() == 0
}
pub fn read_bits(&mut self, n: u32) -> Result<u32, BitReaderError> {
if n == 0 {
return Ok(0);
}
if n > 32 {
return Err(BitReaderError::TooManyBits(n));
}
if self.bits_remaining() < n {
return Err(BitReaderError::OutOfInput {
needed: n,
available: self.bits_remaining(),
});
}
let mut value: u32 = 0;
for _ in 0..n {
let byte = self.data[self.byte_pos];
let bit = (byte >> (7 - self.bit_pos)) & 1;
value = (value << 1) | u32::from(bit);
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
}
Ok(value)
}
pub fn read_bit(&mut self) -> Result<u32, BitReaderError> {
self.read_bits(1)
}
pub fn count_leading_ones(&mut self) -> Result<u32, BitReaderError> {
let mut count = 0u32;
loop {
let bit = self.read_bit()?;
if bit == 0 {
return Ok(count);
}
count += 1;
if count > 64 {
return Err(BitReaderError::OutOfInput {
needed: count + 1,
available: self.bits_remaining(),
});
}
}
}
pub fn align_to_byte(&mut self) {
if self.bit_pos != 0 {
self.byte_pos += 1;
self.bit_pos = 0;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_single_bits_msb_first() {
let buf = [0xB4u8];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 0);
assert!(r.is_empty());
}
#[test]
fn read_bits_packs_correctly_across_byte_boundary() {
let buf = [0xAB, 0xCD];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(12).unwrap(), 0xABC);
assert_eq!(r.bits_remaining(), 4);
assert_eq!(r.read_bits(4).unwrap(), 0xD);
}
#[test]
fn read_zero_bits_returns_zero() {
let buf = [0xFFu8];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(0).unwrap(), 0);
assert_eq!(r.bits_remaining(), 8);
}
#[test]
fn read_too_many_bits_errors() {
let buf = [0u8; 4];
let mut r = BitReader::new(&buf);
assert_eq!(
r.read_bits(33).unwrap_err(),
BitReaderError::TooManyBits(33)
);
}
#[test]
fn out_of_input_errors_clean() {
let buf = [0xFFu8];
let mut r = BitReader::new(&buf);
r.read_bits(8).unwrap();
assert!(matches!(
r.read_bit().unwrap_err(),
BitReaderError::OutOfInput { .. }
));
}
#[test]
fn count_leading_ones_counts_unary_prefix() {
let buf = [0xE0];
let mut r = BitReader::new(&buf);
assert_eq!(r.count_leading_ones().unwrap(), 3);
assert_eq!(r.read_bits(4).unwrap(), 0);
}
#[test]
fn count_leading_ones_zero_when_first_bit_zero() {
let buf = [0x00u8];
let mut r = BitReader::new(&buf);
assert_eq!(r.count_leading_ones().unwrap(), 0);
}
#[test]
fn align_to_byte_advances_to_boundary() {
let buf = [0xFF, 0x00];
let mut r = BitReader::new(&buf);
r.read_bits(3).unwrap();
r.align_to_byte();
assert_eq!(r.bits_remaining(), 8);
assert_eq!(r.read_bits(8).unwrap(), 0);
}
}