use crate::bits::traits::BitRead;
use crate::error::{Error, Result};
use std::io::Read;
pub struct BitReader<R: Read> {
reader: R,
buffer: u64,
bits_available: u8,
bytes_read: u64,
}
impl<R: Read> BitReader<R> {
pub fn new(reader: R) -> Self {
Self { reader, buffer: 0, bits_available: 0, bytes_read: 0 }
}
fn fill_buffer(&mut self, n: u8) -> Result<()> {
debug_assert!(n <= 57, "Cannot request more than 57 bits at once");
if self.bits_available >= n {
return Ok(());
}
if self.bits_available <= 56 {
let bytes_to_read = ((64 - self.bits_available) / 8) as usize;
let mut bulk_buf = [0u8; 8];
match self.reader.read(&mut bulk_buf[..bytes_to_read]) {
Ok(0) => {
}
Ok(bytes_read) => {
for &byte in &bulk_buf[..bytes_read] {
self.buffer |= (byte as u64) << self.bits_available;
self.bits_available += 8;
}
self.bytes_read += bytes_read as u64;
if self.bits_available >= n {
return Ok(());
}
}
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => {
}
Err(e) => return Err(Error::Io(e)),
}
}
while self.bits_available < n {
let mut byte = [0u8; 1];
match self.reader.read_exact(&mut byte) {
Ok(()) => {
self.buffer |= (byte[0] as u64) << self.bits_available;
self.bits_available += 8;
self.bytes_read += 1;
}
Err(e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(Error::UnexpectedEof);
}
Err(e) => return Err(Error::Io(e)),
}
}
Ok(())
}
pub fn read_bits(&mut self, n: u8) -> Result<u32> {
debug_assert!(n <= 32, "Cannot read more than 32 bits at once");
if n == 0 {
return Ok(0);
}
self.fill_buffer(n)?;
let mask = (1u64 << n) - 1;
let result = (self.buffer & mask) as u32;
self.buffer >>= n;
self.bits_available -= n;
Ok(result)
}
#[inline]
pub fn peek_bits(&mut self, n: u8) -> Result<u32> {
debug_assert!(n <= 32, "Cannot peek more than 32 bits at once");
if n == 0 {
return Ok(0);
}
self.fill_buffer(n)?;
let mask = (1u64 << n) - 1;
Ok((self.buffer & mask) as u32)
}
#[inline]
pub fn consume_bits(&mut self, n: u8) {
debug_assert!(n <= self.bits_available, "Cannot consume more bits than available");
self.buffer >>= n;
self.bits_available -= n;
}
#[inline]
pub fn read_bit(&mut self) -> Result<bool> {
Ok(self.read_bits(1)? != 0)
}
pub fn align_to_byte(&mut self) {
let discard = self.bits_available % 8;
if discard > 0 {
self.buffer >>= discard;
self.bits_available -= discard;
}
}
pub fn read_byte(&mut self) -> Result<u8> {
self.align_to_byte();
self.read_bits(8).map(|v| v as u8)
}
pub fn read_u16_le(&mut self) -> Result<u16> {
self.align_to_byte();
let lo = self.read_bits(8)? as u16;
let hi = self.read_bits(8)? as u16;
Ok(lo | (hi << 8))
}
pub fn read_u32_le(&mut self) -> Result<u32> {
self.align_to_byte();
let b0 = self.read_bits(8)?;
let b1 = self.read_bits(8)?;
let b2 = self.read_bits(8)?;
let b3 = self.read_bits(8)?;
Ok(b0 | (b1 << 8) | (b2 << 16) | (b3 << 24))
}
pub fn read_bytes(&mut self, buf: &mut [u8]) -> Result<()> {
self.align_to_byte();
for b in buf.iter_mut() {
*b = self.read_bits(8)? as u8;
}
Ok(())
}
pub fn bytes_read(&self) -> u64 {
self.bytes_read
}
pub fn bits_available(&self) -> u8 {
self.bits_available
}
pub fn into_inner(self) -> R {
self.reader
}
}
impl<R: Read> BitRead for BitReader<R> {
#[inline]
fn fill_buffer(&mut self, n: u8) -> Result<()> {
self.fill_buffer(n)
}
#[inline]
fn read_bits(&mut self, n: u8) -> Result<u32> {
self.read_bits(n)
}
#[inline]
fn peek_bits(&mut self, n: u8) -> Result<u32> {
self.peek_bits(n)
}
#[inline]
fn consume_bits(&mut self, n: u8) {
self.consume_bits(n)
}
fn align_to_byte(&mut self) {
self.align_to_byte()
}
fn bytes_read(&self) -> u64 {
self.bytes_read
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_bits() {
let data = vec![0xD3, 0xAA];
let mut reader = BitReader::new(data.as_slice());
assert_eq!(reader.read_bits(3).unwrap(), 0b011);
assert_eq!(reader.read_bits(5).unwrap(), 0b11010);
assert_eq!(reader.read_bits(8).unwrap(), 0xAA);
}
#[test]
fn test_read_bit() {
let data = vec![0b10110001];
let mut reader = BitReader::new(data.as_slice());
assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); }
#[test]
fn test_align_to_byte() {
let data = vec![0xFF, 0xAB];
let mut reader = BitReader::new(data.as_slice());
reader.read_bits(3).unwrap();
reader.align_to_byte();
assert_eq!(reader.read_bits(8).unwrap(), 0xAB);
}
#[test]
fn test_read_u16_le() {
let data = vec![0x34, 0x12]; let mut reader = BitReader::new(data.as_slice());
assert_eq!(reader.read_u16_le().unwrap(), 0x1234);
}
#[test]
fn test_read_u32_le() {
let data = vec![0x78, 0x56, 0x34, 0x12]; let mut reader = BitReader::new(data.as_slice());
assert_eq!(reader.read_u32_le().unwrap(), 0x12345678);
}
#[test]
fn test_cross_byte_boundary() {
let data = vec![0xFF, 0x00];
let mut reader = BitReader::new(data.as_slice());
assert_eq!(reader.read_bits(12).unwrap(), 0x0FF);
}
}