use crate::bits::traits::BitRead;
use crate::error::{Error, Result};
pub struct SliceBitReader<'a> {
data: &'a [u8],
pos: usize,
buffer: u64,
bits_available: u8,
}
impl<'a> SliceBitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self { data, pos: 0, buffer: 0, bits_available: 0 }
}
#[inline(always)]
fn refill(&mut self) {
if self.bits_available > 56 {
return;
}
let bytes_can_consume = ((63 - self.bits_available) / 8) as usize;
let bytes_remaining = self.data.len().saturating_sub(self.pos);
let bytes_to_read = bytes_can_consume.min(bytes_remaining);
if bytes_to_read == 0 {
return;
}
if bytes_remaining >= 8 {
let raw = unsafe { (self.data.as_ptr().add(self.pos) as *const u64).read_unaligned() };
let raw = u64::from_le(raw);
let mask = if bytes_to_read < 8 { (1u64 << (bytes_to_read * 8)) - 1 } else { u64::MAX };
self.buffer |= (raw & mask) << self.bits_available;
self.pos += bytes_to_read;
self.bits_available += (bytes_to_read * 8) as u8;
} else {
for _ in 0..bytes_to_read {
self.buffer |= (self.data[self.pos] as u64) << self.bits_available;
self.pos += 1;
self.bits_available += 8;
}
}
}
pub fn position(&self) -> usize {
self.pos
}
pub fn bit_position(&self) -> (usize, u8) {
let buffered_bytes = (self.bits_available / 8) as usize;
let remaining_bits = self.bits_available % 8;
let effective_byte = self.pos - buffered_bytes;
if remaining_bits > 0 {
(effective_byte - 1, 8 - remaining_bits)
} else {
(effective_byte, 0)
}
}
pub fn set_position(&mut self, pos: usize) {
self.pos = pos;
self.buffer = 0;
self.bits_available = 0;
}
pub fn set_bit_position(&mut self, byte_pos: usize, bit_offset: u8) {
debug_assert!(bit_offset < 8);
self.pos = byte_pos;
self.buffer = 0;
self.bits_available = 0;
if bit_offset > 0 {
self.refill();
if self.bits_available >= bit_offset {
self.buffer >>= bit_offset;
self.bits_available -= bit_offset;
}
}
}
pub fn past_position(&self, byte_pos: usize) -> bool {
let buffered_bytes = ((self.bits_available + 7) / 8) as usize;
self.pos.saturating_sub(buffered_bytes) >= byte_pos
}
}
impl<'a> BitRead for SliceBitReader<'a> {
#[inline(always)]
fn fill_buffer(&mut self, n: u8) -> Result<()> {
if self.bits_available >= n {
return Ok(());
}
self.refill();
if self.bits_available >= n {
Ok(())
} else {
Err(Error::UnexpectedEof)
}
}
#[inline(always)]
fn read_bits(&mut self, n: u8) -> Result<u32> {
debug_assert!(n <= 32, "Cannot read more than 32 bits at once");
if n == 0 {
return Ok(0);
}
self.fill_buffer(n)?;
let mask = (1u64 << n) - 1;
let result = (self.buffer & mask) as u32;
self.buffer >>= n;
self.bits_available -= n;
Ok(result)
}
#[inline(always)]
fn peek_bits(&mut self, n: u8) -> Result<u32> {
debug_assert!(n <= 32, "Cannot peek more than 32 bits at once");
if n == 0 {
return Ok(0);
}
self.fill_buffer(n)?;
let mask = (1u64 << n) - 1;
Ok((self.buffer & mask) as u32)
}
#[inline(always)]
fn consume_bits(&mut self, n: u8) {
debug_assert!(n <= self.bits_available, "Cannot consume more bits than available");
self.buffer >>= n;
self.bits_available -= n;
}
#[inline]
fn align_to_byte(&mut self) {
let discard = self.bits_available % 8;
if discard > 0 {
self.buffer >>= discard;
self.bits_available -= discard;
}
}
fn bytes_read(&self) -> u64 {
self.pos as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_bits() {
let data = vec![0xD3, 0xAA, 0, 0, 0, 0, 0, 0]; let mut reader = SliceBitReader::new(&data);
assert_eq!(reader.read_bits(3).unwrap(), 0b011);
assert_eq!(reader.read_bits(5).unwrap(), 0b11010);
assert_eq!(reader.read_bits(8).unwrap(), 0xAA);
}
#[test]
fn test_read_bit() {
let data = vec![0b10110001, 0, 0, 0, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); assert!(!reader.read_bit().unwrap()); assert!(reader.read_bit().unwrap()); }
#[test]
fn test_peek_consume() {
let data = vec![0xFF, 0x00, 0, 0, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
let peeked = reader.peek_bits(10).unwrap();
assert_eq!(peeked, 0x0FF); reader.consume_bits(4);
let next = reader.read_bits(8).unwrap();
assert_eq!(next, 0x0F); }
#[test]
fn test_align_to_byte() {
let data = vec![0xFF, 0xAB, 0, 0, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
reader.read_bits(3).unwrap();
reader.align_to_byte();
assert_eq!(reader.read_bits(8).unwrap(), 0xAB);
}
#[test]
fn test_cross_byte_boundary() {
let data = vec![0xFF, 0x00, 0, 0, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
assert_eq!(reader.read_bits(12).unwrap(), 0x0FF);
}
#[test]
fn test_u16_le() {
let data = vec![0x34, 0x12, 0, 0, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
assert_eq!(reader.read_u16_le().unwrap(), 0x1234);
}
#[test]
fn test_u32_le() {
let data = vec![0x78, 0x56, 0x34, 0x12, 0, 0, 0, 0];
let mut reader = SliceBitReader::new(&data);
assert_eq!(reader.read_u32_le().unwrap(), 0x12345678);
}
#[test]
fn test_past_position_partial_byte() {
let data = vec![0xFF; 16];
let mut reader = SliceBitReader::new(&data);
reader.read_bits(3).unwrap(); let (byte, bit) = reader.bit_position();
assert_eq!((byte, bit), (0, 3));
assert!(!reader.past_position(1));
}
#[test]
fn test_past_position_full_bytes() {
let data = vec![0xFF; 16];
let mut reader = SliceBitReader::new(&data);
reader.read_bits(16).unwrap(); assert!(reader.past_position(2));
assert!(!reader.past_position(3));
}
}