use std::io::{self, Read, Seek};
use byteorder::ReadBytesExt;
use crate::bits::Bitset;
#[derive(Debug)]
pub struct BitReader<R: Read> {
inner: R,
buf: u8,
bits_left: u8,
}
impl<R: Read> BitReader<R> {
pub const fn new(inner: R) -> Self {
Self {
inner,
buf: 0,
bits_left: 0,
}
}
#[inline]
pub const fn is_byte_aligned(&self) -> bool {
self.bits_left == 0
}
#[inline]
pub fn read_bit(&mut self) -> io::Result<bool> {
if self.bits_left == 0 {
self.buf = self.inner.read_u8()?;
self.bits_left = 8;
}
let bit = self.buf.get_bit_msb(8 - self.bits_left as u32); self.bits_left -= 1;
Ok(bit)
}
pub fn read_bits(&mut self, mut n: u32) -> io::Result<u64> {
if n == 0 {
return Ok(0);
}
assert!(n <= 64);
let mut value: u64 = 0;
let mut bit_count = 0;
while n > 0 {
if self.bits_left == 0 {
match self.inner.read_u8() {
Ok(byte) => {
self.buf = byte;
self.bits_left = 8;
}
Err(e) => {
self.buf = value as u8;
self.bits_left = bit_count as u8;
return Err(e);
}
}
}
let take = n.min(self.bits_left as u32);
let shift = self.bits_left as u32 - take;
let chunk = (self.buf as u64) >> shift;
let mask = (1u64 << take) - 1;
value = (value << take) | (chunk & mask);
self.bits_left -= take as u8;
bit_count += take;
n -= take;
}
Ok(value)
}
#[inline]
pub fn read_signed(&mut self, n: u8) -> io::Result<i64> {
debug_assert!((0..=64).contains(&n));
if n == 0 {
return Ok(0);
}
let raw = self.read_bits(n as u32)? as i64;
let sh = 64 - n;
Ok((raw << sh) >> sh)
}
#[inline]
pub fn take_leftover(&mut self) -> Option<(u8, u8)> {
if self.bits_left != 0 {
Some((
std::mem::take(&mut self.buf).get_bit_range_lsb(0, self.bits_left.into()),
std::mem::take(&mut self.bits_left),
))
} else {
None
}
}
#[inline]
pub fn discard_leftover(&mut self) {
self.bits_left = 0;
}
#[inline]
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R: Read> Read for BitReader<R> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if self.bits_left == 0 {
return self.inner.read(buf);
}
let mut bytes_read = 0;
for byte in buf.iter_mut() {
match self.read_bits(8) {
Ok(val) => {
*byte = val as u8;
bytes_read += 1;
}
Err(e) => {
if bytes_read > 0 && e.kind() == io::ErrorKind::UnexpectedEof {
break;
}
return Err(e);
}
}
}
Ok(bytes_read)
}
}
impl<R: Read + Seek> Seek for BitReader<R> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
self.bits_left = 0;
self.inner.seek(pos)
}
fn seek_relative(&mut self, offset: i64) -> io::Result<()> {
self.bits_left = 0;
self.inner.seek_relative(offset)
}
}
#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Cursor, ErrorKind, Read};
fn create_reader(data: &[u8]) -> BitReader<Cursor<&[u8]>> {
BitReader::new(Cursor::new(data))
}
#[test]
fn read_single_bits() {
let mut reader = create_reader(&[0b1011_0010]);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
}
#[test]
fn read_bits_within_byte() {
let mut reader = create_reader(&[0b1101_0101]);
assert_eq!(reader.read_bits(4).unwrap(), 0b1101);
assert_eq!(reader.read_bits(3).unwrap(), 0b010);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
}
#[test]
fn read_bits_across_byte_boundary() {
let mut reader = create_reader(&[0xF0, 0xAC]);
assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
assert_eq!(reader.read_bits(8).unwrap(), 0b0000_1010);
assert_eq!(reader.read_bits(4).unwrap(), 0b1100);
}
#[test]
fn read_bits_64_bits() {
let data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
let mut reader = create_reader(&data);
assert_eq!(reader.read_bits(64).unwrap(), u64::MAX);
assert_eq!(reader.read_bits(8).unwrap(), 0x00);
}
#[test]
fn discard_leftover() {
let mut reader = create_reader(&[0b1111_1111, 0b0101_0101]);
assert_eq!(reader.read_bits(3).unwrap(), 0b111);
reader.discard_leftover();
assert_eq!(reader.read_bits(8).unwrap(), 0b0101_0101);
}
#[test]
fn read_trait_passthrough() {
let data = [0x01, 0x02, 0x03, 0x04];
let mut reader = create_reader(&data);
let mut buf = [0u8; 3];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 3);
assert_eq!(buf, [0x01, 0x02, 0x03]);
let cursor = reader.into_inner();
assert_eq!(cursor.position(), 3);
}
#[test]
fn read_trait_misaligned() {
let data = [0b1111_0000, 0b1010_1111, 0b0000_1111];
let mut reader = create_reader(&data);
assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
let mut buf = [0u8; 2];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 2);
assert_eq!(buf[0], 0b0000_1010);
assert_eq!(buf[1], 0b1111_0000);
}
#[test]
fn read_bit_eof() {
let mut reader = create_reader(&[0xAA]);
assert_eq!(reader.read_bits(8).unwrap(), 0xAA);
assert_eq!(
reader.read_bit().unwrap_err().kind(),
ErrorKind::UnexpectedEof
);
assert_eq!(reader.take_leftover(), None);
}
#[test]
fn read_trait_partial_eof() {
let data = [0b1111_0000, 0b1010_1111];
let mut reader = create_reader(&data);
reader.read_bits(4).unwrap();
let mut buf = [0u8; 2];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 1);
assert_eq!(buf[0], 0b0000_1010);
assert_eq!(reader.take_leftover().unwrap(), (0b1111, 4));
}
#[test]
fn seek() {
let mut reader = create_reader(&[0b1011_0010, 0b0101_0101]);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), true);
reader.seek(io::SeekFrom::Start(1)).unwrap();
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
}
}