use crate::error::Error;
pub struct RevBitReader<'a> {
data: &'a [u8],
available: usize,
consumed: usize,
acc: u64,
bits_in_acc: u32,
bytes_left: usize,
}
impl<'a> RevBitReader<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, Error> {
if data.is_empty() {
return Err(Error::Corrupt);
}
let last = *data.last().unwrap();
if last == 0 {
return Err(Error::Corrupt);
}
let marker_pos = 7 - last.leading_zeros() as usize;
let available = (data.len() - 1) * 8 + marker_pos;
let mut acc: u64 = 0;
let mut bits_in_acc: u32 = 0;
if marker_pos > 0 {
let payload = (last as u64) & ((1u64 << marker_pos) - 1);
acc = payload << (64 - marker_pos as u32);
bits_in_acc = marker_pos as u32;
}
Ok(Self {
data,
available,
consumed: 0,
acc,
bits_in_acc,
bytes_left: data.len() - 1,
})
}
pub fn remaining(&self) -> usize {
self.available - self.consumed
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.consumed >= self.available
}
pub fn unread(&mut self, n: u32) {
let n_usize = n as usize;
debug_assert!(self.consumed >= n_usize);
self.consumed -= n_usize;
self.reseed_from_consumed();
}
fn reseed_from_consumed(&mut self) {
let next_bit = self.available - 1 - self.consumed;
let next_byte = next_bit / 8;
let bit_in_byte = (next_bit % 8) as u32; let high_byte_val = self.data[next_byte] as u64;
let take = bit_in_byte + 1;
let payload = high_byte_val & ((1u64 << take) - 1);
self.acc = payload << (64 - take);
self.bits_in_acc = take;
self.bytes_left = next_byte;
}
#[inline]
fn refill(&mut self) {
while self.bits_in_acc <= 56 && self.bytes_left > 0 {
let byte = self.data[self.bytes_left - 1] as u64;
self.acc |= byte << (56 - self.bits_in_acc);
self.bits_in_acc += 8;
self.bytes_left -= 1;
}
}
pub fn read(&mut self, n: u32) -> Result<u64, Error> {
if n == 0 {
return Ok(0);
}
if n > 64 {
return Err(Error::Corrupt);
}
if self.consumed + n as usize > self.available {
return Err(Error::Corrupt);
}
if n <= 56 {
if self.bits_in_acc < n {
self.refill();
}
let result = self.acc >> (64 - n);
self.acc <<= n;
self.bits_in_acc -= n;
self.consumed += n as usize;
Ok(result)
} else {
let high_n = 56u32;
let low_n = n - 56;
if self.bits_in_acc < high_n {
self.refill();
}
let high = self.acc >> (64 - high_n);
self.acc <<= high_n;
self.bits_in_acc -= high_n;
if self.bits_in_acc < low_n {
self.refill();
}
let low = self.acc >> (64 - low_n);
self.acc <<= low_n;
self.bits_in_acc -= low_n;
self.consumed += n as usize;
Ok((high << low_n) | low)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn marker_in_last_byte() {
let data = [0xAB, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 8);
let v = r.read(8).unwrap();
assert_eq!(v, 0xAB);
assert!(r.is_empty());
}
#[test]
fn read_individual_bits() {
let data = [0b1011_0100, 0b0000_0010];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 9);
assert_eq!(r.read(1).unwrap(), 0); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(1).unwrap(), 0); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(4).unwrap(), 0b0100); }
#[test]
fn empty_data_corrupt() {
let r = RevBitReader::new(&[]);
assert!(r.is_err());
}
#[test]
fn zero_last_byte_corrupt() {
let r = RevBitReader::new(&[0x01, 0x00]);
assert!(r.is_err());
}
#[test]
fn cross_byte_read() {
let data = [0xFF, 0xA0, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 16);
let v = r.read(12).unwrap();
assert_eq!(v, 0xA0F);
let v = r.read(4).unwrap();
assert_eq!(v, 0xF);
}
#[test]
fn wide_read_64_bits() {
let data = [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 64);
let v = r.read(64).unwrap();
assert_eq!(v, 0xEFCD_AB89_6745_2301);
assert!(r.is_empty());
}
#[test]
fn wide_read_60_bits_then_4() {
let data = [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
let high = r.read(60).unwrap();
let low = r.read(4).unwrap();
let combined = (high << 4) | low;
assert_eq!(combined, 0xEFCD_AB89_6745_2301);
}
#[test]
fn unread_round_trip() {
let data = [0x01, 0x23, 0x45, 0x67, 0x89, 0xAB, 0xCD, 0xEF, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
let first12 = r.read(12).unwrap();
assert_eq!(first12, 0xEFC);
r.unread(4);
let nibble = r.read(4).unwrap();
assert_eq!(nibble, 0xC);
let next8 = r.read(8).unwrap();
assert_eq!(next8, 0xDA);
}
}