use crate::error::{WemError, WemResult};
use std::io::Read;
pub trait BitRead {
fn read_bit(&mut self) -> WemResult<bool>;
fn total_bits_read(&self) -> u64;
fn read_bits(&mut self, count: u8) -> WemResult<u32> {
if count > 32 {
return Err(WemError::parse("Cannot read more than 32 bits at once"));
}
let mut result = 0u32;
for i in 0..count {
if self.read_bit()? {
result |= 1u32 << i;
}
}
Ok(result)
}
}
pub struct BitReader<R: Read> {
reader: R,
current_byte: u8,
bit_pos: u8,
bytes_read: u64,
}
impl<R: Read> BitReader<R> {
pub fn new(reader: R) -> Self {
Self {
reader,
current_byte: 0,
bit_pos: 0,
bytes_read: 0,
}
}
pub fn total_bits_read(&self) -> u64 {
if self.bit_pos == 0 {
self.bytes_read * 8
} else {
(self.bytes_read - 1) * 8 + self.bit_pos as u64
}
}
pub fn read_bit(&mut self) -> WemResult<bool> {
if self.bit_pos == 0 {
let mut buf = [0u8; 1];
self.reader
.read_exact(&mut buf)
.map_err(|_| WemError::end_of_stream("Out of bits"))?;
self.current_byte = buf[0];
self.bytes_read += 1;
}
let bit = (self.current_byte & (1 << self.bit_pos)) != 0;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
}
Ok(bit)
}
pub fn read_bits(&mut self, count: u8) -> WemResult<u32> {
BitRead::read_bits(self, count)
}
}
impl<R: Read> BitRead for BitReader<R> {
fn read_bit(&mut self) -> WemResult<bool> {
BitReader::read_bit(self)
}
fn total_bits_read(&self) -> u64 {
BitReader::total_bits_read(self)
}
}
pub struct BitSliceReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitSliceReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 0,
}
}
pub fn total_bits_read(&self) -> u64 {
self.byte_pos as u64 * 8 + self.bit_pos as u64
}
pub fn read_bit(&mut self) -> WemResult<bool> {
if self.byte_pos >= self.data.len() {
return Err(WemError::end_of_stream("Out of bits"));
}
let bit = (self.data[self.byte_pos] & (1 << self.bit_pos)) != 0;
self.bit_pos += 1;
if self.bit_pos == 8 {
self.bit_pos = 0;
self.byte_pos += 1;
}
Ok(bit)
}
}
impl BitRead for BitSliceReader<'_> {
fn read_bit(&mut self) -> WemResult<bool> {
BitSliceReader::read_bit(self)
}
fn total_bits_read(&self) -> u64 {
BitSliceReader::total_bits_read(self)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Cursor;
#[test]
fn test_read_bits_lsb_first() {
let data = [0b10110100u8, 0b11001010u8];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(reader.read_bits(4).unwrap(), 0b0100);
assert_eq!(reader.read_bits(4).unwrap(), 0b1011);
assert_eq!(reader.read_bits(8).unwrap(), 0b11001010);
}
#[test]
fn test_slice_reader() {
let data = [0b10110100u8];
let mut reader = BitSliceReader::new(&data);
assert!(!reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); }
#[test]
fn test_read_zero_bits() {
let data = [0xFF];
let mut reader = BitReader::new(Cursor::new(data));
assert_eq!(reader.read_bits(0).unwrap(), 0);
assert_eq!(reader.total_bits_read(), 0);
}
#[test]
fn test_read_single_bit() {
let data = [0b00000001u8];
let mut reader = BitReader::new(Cursor::new(data));
assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); }
#[test]
fn test_total_bits_read() {
let data = [0xFF, 0xFF, 0xFF];
let mut reader = BitReader::new(Cursor::new(data));
reader.read_bits(5).unwrap();
assert_eq!(reader.total_bits_read(), 5);
reader.read_bits(7).unwrap();
assert_eq!(reader.total_bits_read(), 12);
reader.read_bit().unwrap();
assert_eq!(reader.total_bits_read(), 13);
}
#[test]
fn test_read_across_byte_boundary() {
let data = [0xAB, 0xCD];
let mut reader = BitReader::new(Cursor::new(data));
let value = reader.read_bits(12).unwrap();
assert_eq!(value, 0xDAB);
}
#[test]
fn test_read_full_32_bits() {
let data = [0x78, 0x56, 0x34, 0x12];
let mut reader = BitReader::new(Cursor::new(data));
let value = reader.read_bits(32).unwrap();
assert_eq!(value, 0x12345678);
}
#[test]
fn test_slice_reader_total_bits() {
let data = [0xFF, 0xFF];
let mut reader = BitSliceReader::new(&data);
reader.read_bits(3).unwrap();
assert_eq!(reader.total_bits_read(), 3);
reader.read_bits(10).unwrap();
assert_eq!(reader.total_bits_read(), 13);
}
#[test]
fn test_slice_reader_read_bits() {
let data = [0x34, 0x12]; let mut reader = BitSliceReader::new(&data);
let value = reader.read_bits(16).unwrap();
assert_eq!(value, 0x1234);
}
#[test]
fn test_read_past_end_of_stream() {
let data = [0xFF];
let mut reader = BitReader::new(Cursor::new(data));
reader.read_bits(8).unwrap();
let result = reader.read_bit();
assert!(result.is_err());
}
#[test]
fn test_slice_reader_past_end() {
let data = [0xFF];
let mut reader = BitSliceReader::new(&data);
for _ in 0..8 {
reader.read_bit().unwrap();
}
let result = reader.read_bit();
assert!(result.is_err());
}
}