pub struct BitReader {
ptr: *const u8,
safe_end: *const u8,
end: *const u8,
buf: u64,
bits: u32,
}
unsafe impl Send for BitReader {}
impl BitReader {
#[inline]
pub fn new(input: &[u8]) -> Self {
let ptr = input.as_ptr();
let end = unsafe { ptr.add(input.len()) };
let safe_end = if input.len() >= 8 {
unsafe { end.sub(7) }
} else {
ptr
};
Self {
ptr,
safe_end,
end,
buf: 0,
bits: 0,
}
}
#[inline(always)]
pub unsafe fn refill(&mut self) {
debug_assert!(self.bits <= 63);
if self.ptr < self.safe_end {
unsafe {
let raw = core::ptr::read_unaligned(self.ptr as *const u64);
self.buf |= u64::from_le(raw) << (self.bits as u8);
let advance = ((63 ^ self.bits) >> 3) as usize;
self.ptr = self.ptr.add(advance);
}
self.bits |= 56;
} else {
self.refill_slow();
}
}
#[cold]
#[inline(never)]
fn refill_slow(&mut self) {
while self.bits < 56 && self.ptr < self.end {
self.buf |= (unsafe { *self.ptr } as u64) << self.bits;
self.ptr = unsafe { self.ptr.add(1) };
self.bits += 8;
}
}
#[inline(always)]
pub fn peek(&self, n: u32) -> u32 {
debug_assert!(n <= 32 && n <= self.bits);
(self.buf as u32) & ((1u32 << n) - 1)
}
#[inline(always)]
pub fn peek64(&self, n: u32) -> u64 {
debug_assert!(n <= 56 && n <= self.bits);
self.buf & ((1u64 << n) - 1)
}
#[inline(always)]
pub fn peek_at(&self, skip: u32) -> u32 {
(self.buf >> skip) as u32
}
#[inline(always)]
pub fn consume(&mut self, n: u32) {
debug_assert!(n <= 64);
let n = n.min(self.bits);
self.buf >>= n;
self.bits -= n;
}
#[inline(always)]
pub unsafe fn consume_unchecked(&mut self, n: u32) {
debug_assert!(n <= self.bits);
self.buf >>= n;
self.bits -= n;
}
#[inline(always)]
pub fn take(&mut self, n: u32) -> u32 {
let v = self.peek(n);
self.consume(n);
v
}
#[inline(always)]
pub fn extract_var(value: u64, n: u32) -> u64 {
#[cfg(target_arch = "x86_64")]
{
if cfg!(target_feature = "bmi2") {
return unsafe { core::arch::x86_64::_bzhi_u64(value, n) };
}
}
value & ((1u64 << n) - 1)
}
#[inline(always)]
pub fn bits_remaining(&self) -> u32 {
self.bits
}
#[inline(always)]
pub fn is_empty(&self) -> bool {
self.ptr >= self.end && self.bits == 0
}
#[inline(always)]
pub fn input_ptr(&self) -> *const u8 {
self.ptr
}
#[inline(always)]
pub fn input_end(&self) -> *const u8 {
self.end
}
#[inline(always)]
pub fn align_to_byte(&mut self) {
let discard = self.bits & 7;
self.consume(discard);
}
#[inline(always)]
pub fn take_u16(&mut self) -> u16 {
debug_assert!(self.bits >= 16);
let v = (self.buf as u16).to_le();
self.consume(16);
v
}
#[inline(always)]
pub fn raw_buf(&self) -> u64 {
self.buf
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn basic_read() {
let data = [0b10110100u8, 0b01101001u8, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
assert!(br.bits_remaining() >= 56);
assert_eq!(br.take(4), 0b0100); assert_eq!(br.take(4), 0b1011); assert_eq!(br.take(8), 0b01101001); }
#[test]
fn refill_guarantees_56_bits() {
let data = vec![0xAAu8; 64];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
assert!(br.bits_remaining() >= 56);
br.consume(48);
unsafe { br.refill() };
assert!(br.bits_remaining() >= 56);
}
#[test]
fn small_input() {
let data = [0x42u8, 0x37];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
assert!(br.bits_remaining() >= 16);
assert_eq!(br.take(8), 0x42);
assert_eq!(br.take(8), 0x37);
}
#[test]
fn align_to_byte() {
let data = [0xFF; 8];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
br.consume(3); br.align_to_byte(); assert_eq!(br.bits_remaining() % 8, 0);
}
#[test]
fn extract_var_matches_mask() {
assert_eq!(BitReader::extract_var(0xDEADBEEF, 8), 0xEF);
assert_eq!(BitReader::extract_var(0xDEADBEEF, 16), 0xBEEF);
assert_eq!(BitReader::extract_var(0xDEADBEEF, 32), 0xDEADBEEF);
}
#[test]
fn peek_does_not_consume() {
let data = [0xAB; 8];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
let a = br.peek(8);
let b = br.peek(8);
assert_eq!(a, b);
assert_eq!(a, 0xAB);
}
#[test]
fn empty_input() {
let data: &[u8] = &[];
let br = BitReader::new(data);
assert!(br.is_empty());
}
#[test]
fn refill_slow_at_56_bits_keeps_invariant() {
let data = [0x11u8, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88, 0x99, 0xAA];
let mut br = BitReader::new(&data);
unsafe { br.refill() };
assert_eq!(br.bits_remaining(), 56);
unsafe { br.refill() }; assert!(br.bits_remaining() <= 63);
assert_eq!(br.take(8), 0x11);
assert_eq!(br.take(8), 0x22);
unsafe { br.refill() };
assert!(br.bits_remaining() <= 63);
assert_eq!(br.take(8), 0x33);
}
}