use crate::byte_order::ByteOrder;
use crate::error::BitReadWriteError;
use crate::traits::{BitPeek, BitRead};
use std::io::{BufReader, Read};
pub struct BitReader<R: Read> {
byte_order: ByteOrder,
inner: BufReader<R>,
bits_buffer: u64, bits_in_buffer: usize, }
impl<R: Read> BitReader<R> {
pub fn new(inner: R) -> Self {
Self::with_byte_order(ByteOrder::BigEndian, inner)
}
pub fn with_byte_order(byte_order: ByteOrder, inner: R) -> Self {
Self {
byte_order,
inner: BufReader::new(inner),
bits_buffer: 0,
bits_in_buffer: 0,
}
}
}
impl<R: Read> BitReader<R> {
fn put_into_bits_buffer(&mut self, n: usize) -> std::io::Result<()> {
let bits_needed = n.saturating_sub(self.bits_in_buffer); let mut bytes_needed = (bits_needed + 7) / 8; let max_bytes_needed = (64 - self.bits_in_buffer) / 8;
if bytes_needed > max_bytes_needed {
bytes_needed = max_bytes_needed;
}
if bytes_needed > 0 {
let mut buf = [0u8; 8]; let slice = &mut buf[..bytes_needed];
if self.inner.read(slice)? < bytes_needed {
return Err(BitReadWriteError::UnexpectedEof.into());
};
for &mut b in slice {
let shift = match self.byte_order {
ByteOrder::BigEndian => {
let s = 64u32 - 8u32 - self.bits_in_buffer as u32; s
}
ByteOrder::LittleEndian => {
let s = self.bits_in_buffer as u32;
s
}
};
self.bits_buffer |= u64::from(b).wrapping_shl(shift);
self.bits_in_buffer = (self.bits_in_buffer + 8).min(64);
}
}
Ok(())
}
fn get_from_bits_buffer(&mut self, n: usize, take: bool) -> std::io::Result<u64> {
let bit_value = match self.byte_order {
ByteOrder::BigEndian => {
let value = self.bits_buffer >> (64 - n);
value
}
ByteOrder::LittleEndian => {
let mask = if n == 64 { u64::MAX } else { (1u64 << n) - 1 };
let value = self.bits_buffer & mask;
value
}
};
if take {
if n == 64 {
self.bits_buffer = 0;
} else {
match self.byte_order {
ByteOrder::BigEndian => {
self.bits_buffer <<= n;
}
ByteOrder::LittleEndian => {
self.bits_buffer >>= n;
}
}
}
self.bits_in_buffer -= n;
}
Ok(bit_value)
}
}
impl<R: Read> BitReader<R> {
pub fn is_byte_aligned(&self) -> bool {
self.bits_in_buffer % 8 == 0
}
}
impl<R: Read> BitRead for BitReader<R> {
type Output = u64;
fn read_bits(&mut self, n: usize) -> std::io::Result<Self::Output> {
if n == 0 || n > 64 {
return Err(BitReadWriteError::InvalidBitCount(n).into());
}
self.put_into_bits_buffer(n)?;
self.get_from_bits_buffer(n, true)
}
}
impl<R: Read> Read for BitReader<R> {
fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
let mut written = 0;
if self.bits_in_buffer == 0 {
return self.inner.read(buf);
}
if self.bits_in_buffer % 8 == 0 {
while self.bits_in_buffer >= 8 && written < buf.len() {
let byte = self.get_from_bits_buffer(8, true)? as u8;
buf[written] = byte;
written += 1;
}
if written < buf.len() {
let n = self.inner.read(&mut buf[written..])?;
written += n;
}
return Ok(written);
}
Err(BitReadWriteError::UnalignedAccess.into())
}
}
pub struct PeekableBitReader<R: Read> {
inner: BitReader<R>,
}
impl<R: Read> PeekableBitReader<R> {
pub fn new(inner: R) -> Self {
Self {
inner: BitReader::new(inner),
}
}
pub fn with_byte_order(inner: R) -> Self {
Self {
inner: BitReader::with_byte_order(ByteOrder::LittleEndian, inner),
}
}
}
impl<R: Read> BitRead for PeekableBitReader<R> {
type Output = u64;
fn read_bits(&mut self, n: usize) -> std::io::Result<Self::Output> {
self.inner.read_bits(n)
}
}
impl<R: Read> BitPeek for PeekableBitReader<R> {
type Output = u64;
fn peek_bits(&mut self, n: usize) -> std::io::Result<Self::Output> {
if n == 0 || n > 64 {
return Err(BitReadWriteError::InvalidBitCount(n).into());
}
self.inner.put_into_bits_buffer(n)?;
self.inner.get_from_bits_buffer(n, false)
}
}
pub struct BulkBitReader<R: Read> {
inner: BitReader<R>,
}
impl<R: Read> BulkBitReader<R> {
pub fn new(inner: R) -> Self {
Self {
inner: BitReader::new(inner),
}
}
pub fn with_endianness(endianness: ByteOrder, inner: R) -> Self {
Self {
inner: BitReader::with_byte_order(endianness, inner),
}
}
}
impl<R: Read> BitRead for BulkBitReader<R> {
type Output = Vec<u64>;
fn read_bits(&mut self, n: usize) -> std::io::Result<Self::Output> {
if n == 0 {
return Err(BitReadWriteError::InvalidBitCount(n).into());
}
let mut remaining = n;
let mut chunks = Vec::with_capacity((n + 63) / 64);
while remaining > 0 {
let take = remaining.min(64);
chunks.push(self.inner.read_bits(take)?);
remaining -= take;
}
Ok(chunks)
}
}