use crate::error::Error;
pub struct RevBitReader<'a> {
data: &'a [u8],
available: usize,
consumed: usize,
}
impl<'a> RevBitReader<'a> {
pub fn new(data: &'a [u8]) -> Result<Self, Error> {
if data.is_empty() {
return Err(Error::Corrupt);
}
let last = *data.last().unwrap();
if last == 0 {
return Err(Error::Corrupt);
}
let marker_pos = 7 - last.leading_zeros() as usize;
let available = (data.len() - 1) * 8 + marker_pos;
Ok(Self {
data,
available,
consumed: 0,
})
}
pub fn remaining(&self) -> usize {
self.available - self.consumed
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.consumed >= self.available
}
pub fn unread(&mut self, n: u32) {
let n = n as usize;
debug_assert!(self.consumed >= n);
self.consumed -= n;
}
pub fn read(&mut self, n: u32) -> Result<u64, Error> {
if n == 0 {
return Ok(0);
}
if n > 64 {
return Err(Error::Corrupt);
}
if self.consumed + n as usize > self.available {
return Err(Error::Corrupt);
}
let high_bit = self.available - 1 - self.consumed;
let low_bit = high_bit + 1 - n as usize;
let mut acc: u64 = 0;
let mut bits_read: u32 = 0;
let mut cur_bit = high_bit;
while bits_read < n {
let byte_idx = cur_bit / 8;
let bit_in_byte = cur_bit % 8; let take_from_this_byte = core::cmp::min(bit_in_byte as u32 + 1, n - bits_read);
let byte = self.data[byte_idx] as u64;
let shift_down = bit_in_byte as u32 + 1 - take_from_this_byte;
let mask = (1u64 << take_from_this_byte) - 1;
let chunk = (byte >> shift_down) & mask;
acc = (acc << take_from_this_byte) | chunk;
bits_read += take_from_this_byte;
if bits_read == n {
break;
}
cur_bit = (byte_idx * 8) - 1; }
let _ = low_bit;
self.consumed += n as usize;
Ok(acc)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn marker_in_last_byte() {
let data = [0xAB, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 8);
let v = r.read(8).unwrap();
assert_eq!(v, 0xAB);
assert!(r.is_empty());
}
#[test]
fn read_individual_bits() {
let data = [0b1011_0100, 0b0000_0010];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 9);
assert_eq!(r.read(1).unwrap(), 0); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(1).unwrap(), 0); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(1).unwrap(), 1); assert_eq!(r.read(4).unwrap(), 0b0100); }
#[test]
fn empty_data_corrupt() {
let r = RevBitReader::new(&[]);
assert!(r.is_err());
}
#[test]
fn zero_last_byte_corrupt() {
let r = RevBitReader::new(&[0x01, 0x00]);
assert!(r.is_err());
}
#[test]
fn cross_byte_read() {
let data = [0xFF, 0xA0, 0x01];
let mut r = RevBitReader::new(&data).unwrap();
assert_eq!(r.remaining(), 16);
let v = r.read(12).unwrap();
assert_eq!(v, 0xA0F);
let v = r.read(4).unwrap();
assert_eq!(v, 0xF);
}
}