use std::io::{self, Read, Seek, Write};
use byteorder::ReadBytesExt;
use crate::bits::Bitset;
pub trait BitRead: Read {
fn is_byte_aligned(&self) -> bool;
fn take_leftover(&mut self) -> Option<(u8, u8)>;
fn discard_leftover(&mut self);
fn read_bit(&mut self) -> io::Result<bool>;
fn read_bits(&mut self, n: u32) -> io::Result<u64>;
fn read_signed(&mut self, n: u8) -> io::Result<i64>;
}
#[derive(Debug)]
pub struct BitReader<R: Read> {
inner: R,
buf: u8,
bits_left: u8,
}
impl<R: Read> BitReader<R> {
pub const fn new(inner: R) -> Self {
Self {
inner,
buf: 0,
bits_left: 0,
}
}
#[inline]
pub fn into_inner(self) -> R {
self.inner
}
}
impl<R: Read> Read for BitReader<R> {
#[inline]
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if buf.is_empty() {
return Ok(0);
}
if self.bits_left == 0 {
return self.inner.read(buf);
}
let mut bytes_read = 0;
for byte in buf.iter_mut() {
match self.read_bits(8) {
Ok(val) => {
*byte = val as u8;
bytes_read += 1;
}
Err(e) => {
if bytes_read > 0 && e.kind() == io::ErrorKind::UnexpectedEof {
break;
}
return Err(e);
}
}
}
Ok(bytes_read)
}
}
impl<R: Read + Seek> Seek for BitReader<R> {
#[inline]
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
self.bits_left = 0;
self.inner.seek(pos)
}
#[inline]
fn seek_relative(&mut self, offset: i64) -> io::Result<()> {
self.bits_left = 0;
self.inner.seek_relative(offset)
}
}
impl<R: Read> BitRead for BitReader<R> {
#[inline]
fn is_byte_aligned(&self) -> bool {
self.bits_left == 0
}
#[inline]
fn read_bit(&mut self) -> io::Result<bool> {
if self.bits_left == 0 {
self.buf = self.inner.read_u8()?;
self.bits_left = 8;
}
let bit = self.buf.get_bit_msb(8 - u32::from(self.bits_left)); self.bits_left -= 1;
Ok(bit)
}
#[inline]
fn read_bits(&mut self, mut n: u32) -> io::Result<u64> {
if n == 0 {
return Ok(0);
}
assert!(n <= 64);
let mut value: u64 = 0;
let mut bit_count = 0;
while n > 0 {
if self.bits_left == 0 {
match self.inner.read_u8() {
Ok(byte) => {
self.buf = byte;
self.bits_left = 8;
}
Err(e) => {
self.buf = value as u8;
self.bits_left = bit_count as u8;
return Err(e);
}
}
}
let take = n.min(u32::from(self.bits_left));
let shift = u32::from(self.bits_left) - take;
let chunk = u64::from(self.buf) >> shift;
let mask = (1u64 << take) - 1;
value = (value << take) | (chunk & mask);
self.bits_left -= take as u8;
bit_count += take;
n -= take;
}
Ok(value)
}
#[inline]
fn read_signed(&mut self, n: u8) -> io::Result<i64> {
debug_assert!((0..=64).contains(&n));
if n == 0 {
return Ok(0);
}
let raw = self.read_bits(u32::from(n))? as i64;
let sh = 64 - n;
Ok((raw << sh) >> sh)
}
#[inline]
fn take_leftover(&mut self) -> Option<(u8, u8)> {
if self.bits_left != 0 {
Some((
std::mem::take(&mut self.buf).get_bit_range_lsb(0, self.bits_left),
std::mem::take(&mut self.bits_left),
))
} else {
None
}
}
#[inline]
fn discard_leftover(&mut self) {
self.bits_left = 0;
}
}
pub trait BitWrite: Write {
fn is_byte_aligned(&self) -> bool;
fn flush_bits(&mut self) -> io::Result<()>;
fn write_bits(&mut self, value: u64, n: u32) -> io::Result<()>;
#[inline]
fn write_bit(&mut self, bit: bool) -> io::Result<()> {
self.write_bits(u64::from(bit), 1)
}
#[inline]
fn write_signed(&mut self, value: i64, n: u8) -> io::Result<()> {
let mask = if n == 64 { u64::MAX } else { (1u64 << n) - 1 };
self.write_bits(value as u64 & mask, u32::from(n))
}
}
pub struct BitWriter<W: Write> {
inner: W,
buf: u8,
bits_used: u8,
}
impl<W: Write> BitWriter<W> {
pub const fn new(inner: W) -> Self {
Self {
inner,
buf: 0,
bits_used: 0,
}
}
pub fn into_inner(self) -> W {
self.inner
}
}
impl<W: Write> Write for BitWriter<W> {
#[inline]
fn write(&mut self, data: &[u8]) -> io::Result<usize> {
if self.bits_used == 0 {
return self.inner.write(data);
}
for &byte in data {
self.write_bits(u64::from(byte), 8)?;
}
Ok(data.len())
}
#[inline]
fn flush(&mut self) -> io::Result<()> {
self.inner.flush()
}
}
impl<W: Write> BitWrite for BitWriter<W> {
#[inline]
fn is_byte_aligned(&self) -> bool {
self.bits_used == 0
}
#[inline]
fn flush_bits(&mut self) -> io::Result<()> {
if self.bits_used > 0 {
self.inner.write_all(&[self.buf])?;
self.buf = 0;
self.bits_used = 0;
}
Ok(())
}
#[inline]
fn write_bits(&mut self, value: u64, mut n: u32) -> io::Result<()> {
debug_assert!(n <= 64);
while n > 0 {
let space = 8 - u32::from(self.bits_used);
let take = n.min(space);
let shift = n - take;
let chunk = ((value >> shift) as u8) & ((1u8.unbounded_shr(take)).wrapping_sub(1));
self.buf |= chunk << (space - take);
self.bits_used += take as u8;
n -= take;
if self.bits_used == 8 {
self.inner.write_all(&[self.buf])?;
self.buf = 0;
self.bits_used = 0;
}
}
Ok(())
}
}
#[allow(clippy::bool_assert_comparison)]
#[cfg(test)]
mod tests {
use super::*;
use std::io::{Cursor, ErrorKind, Read};
fn create_reader(data: &[u8]) -> BitReader<Cursor<&[u8]>> {
BitReader::new(Cursor::new(data))
}
#[test]
fn read_single_bits() {
let mut reader = create_reader(&[0b1011_0010]);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
}
#[test]
fn read_bits_within_byte() {
let mut reader = create_reader(&[0b1101_0101]);
assert_eq!(reader.read_bits(4).unwrap(), 0b1101);
assert_eq!(reader.read_bits(3).unwrap(), 0b010);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
}
#[test]
fn read_bits_across_byte_boundary() {
let mut reader = create_reader(&[0xF0, 0xAC]);
assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
assert_eq!(reader.read_bits(8).unwrap(), 0b0000_1010);
assert_eq!(reader.read_bits(4).unwrap(), 0b1100);
}
#[test]
fn read_bits_64_bits() {
let data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
let mut reader = create_reader(&data);
assert_eq!(reader.read_bits(64).unwrap(), u64::MAX);
assert_eq!(reader.read_bits(8).unwrap(), 0x00);
}
#[test]
fn discard_leftover() {
let mut reader = create_reader(&[0b1111_1111, 0b0101_0101]);
assert_eq!(reader.read_bits(3).unwrap(), 0b111);
reader.discard_leftover();
assert_eq!(reader.read_bits(8).unwrap(), 0b0101_0101);
}
#[test]
fn read_trait_passthrough() {
let data = [0x01, 0x02, 0x03, 0x04];
let mut reader = create_reader(&data);
let mut buf = [0u8; 3];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 3);
assert_eq!(buf, [0x01, 0x02, 0x03]);
let cursor = reader.into_inner();
assert_eq!(cursor.position(), 3);
}
#[test]
fn read_trait_misaligned() {
let data = [0b1111_0000, 0b1010_1111, 0b0000_1111];
let mut reader = create_reader(&data);
assert_eq!(reader.read_bits(4).unwrap(), 0b1111);
let mut buf = [0u8; 2];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 2);
assert_eq!(buf[0], 0b0000_1010);
assert_eq!(buf[1], 0b1111_0000);
}
#[test]
fn read_bit_eof() {
let mut reader = create_reader(&[0xAA]);
assert_eq!(reader.read_bits(8).unwrap(), 0xAA);
assert_eq!(
reader.read_bit().unwrap_err().kind(),
ErrorKind::UnexpectedEof
);
assert_eq!(reader.take_leftover(), None);
}
#[test]
fn read_trait_partial_eof() {
let data = [0b1111_0000, 0b1010_1111];
let mut reader = create_reader(&data);
reader.read_bits(4).unwrap();
let mut buf = [0u8; 2];
let bytes_read = reader.read(&mut buf).unwrap();
assert_eq!(bytes_read, 1);
assert_eq!(buf[0], 0b0000_1010);
assert_eq!(reader.take_leftover().unwrap(), (0b1111, 4));
}
#[test]
fn seek() {
let mut reader = create_reader(&[0b1011_0010, 0b0101_0101]);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), true);
reader.seek(io::SeekFrom::Start(1)).unwrap();
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
assert_eq!(reader.read_bit().unwrap(), false);
assert_eq!(reader.read_bit().unwrap(), true);
}
}
#[cfg(test)]
mod test_bit_writer {
use super::*;
use crate::bit_io::{BitRead, BitReader};
use std::io::Cursor;
fn make_writer() -> BitWriter<Vec<u8>> {
BitWriter::new(Vec::new())
}
#[test]
fn write_single_bits() {
let mut w = make_writer();
for bit in [true, false, true, true, false, false, true, false] {
w.write_bit(bit).unwrap();
}
assert_eq!(w.into_inner(), [0xB2]);
}
#[test]
fn write_bits_within_byte() {
let mut w = make_writer();
w.write_bits(0b1101, 4).unwrap();
w.write_bits(0b010, 3).unwrap();
w.write_bits(0b1, 1).unwrap();
assert_eq!(w.into_inner(), [0xD5]);
}
#[test]
fn write_bits_across_byte_boundary() {
let mut w = make_writer();
w.write_bits(0b1111, 4).unwrap();
w.write_bits(0b0000_1010, 8).unwrap();
w.write_bits(0b1100, 4).unwrap();
assert_eq!(w.into_inner(), [0xF0, 0xAC]);
}
#[test]
fn write_u64_max() {
let mut w = make_writer();
w.write_bits(u64::MAX, 64).unwrap();
assert_eq!(w.into_inner(), [0xFF; 8]);
}
#[test]
fn flush_bits_zero_pads_remaining_bits() {
let mut w = make_writer();
w.write_bits(0b101, 3).unwrap();
w.flush_bits().unwrap();
assert_eq!(w.into_inner(), [0xA0]);
}
#[test]
fn flush_bits_is_noop_when_aligned() {
let mut w = make_writer();
w.write_bits(0xFF, 8).unwrap();
w.flush_bits().unwrap(); assert_eq!(w.into_inner(), [0xFF]);
}
#[test]
fn write_signed_negative_value() {
let mut w = make_writer();
w.write_signed(-1, 8).unwrap();
assert_eq!(w.into_inner(), [0xFF]);
}
#[test]
fn write_signed_positive_value() {
let mut w = make_writer();
w.write_signed(127, 8).unwrap();
assert_eq!(w.into_inner(), [0x7F]);
}
#[test]
fn write_signed_5bit_negative() {
let mut w = make_writer();
w.write_signed(-1, 5).unwrap();
w.flush_bits().unwrap();
assert_eq!(w.into_inner(), [0xF8]);
}
#[test]
fn round_trip_arbitrary_bit_pattern() {
let mut w = make_writer();
w.write_bits(0b1011, 4).unwrap();
w.write_bits(0b00101, 5).unwrap();
w.write_bits(0b111, 3).unwrap();
w.write_bits(0b0100, 4).unwrap();
let bytes = w.into_inner();
let mut r = BitReader::new(Cursor::new(bytes));
assert_eq!(r.read_bits(4).unwrap(), 0b1011);
assert_eq!(r.read_bits(5).unwrap(), 0b00101);
assert_eq!(r.read_bits(3).unwrap(), 0b111);
assert_eq!(r.read_bits(4).unwrap(), 0b0100);
}
#[test]
fn round_trip_signed_values() {
let values: &[(i64, u8)] = &[
(0, 5),
(1, 5),
(-1, 5),
(15, 5),
(-16, 5),
(127, 8),
(-128, 8),
];
for &(value, bits) in values {
let mut w = make_writer();
w.write_signed(value, bits).unwrap();
w.flush_bits().unwrap();
let bytes = w.into_inner();
let mut r = BitReader::new(Cursor::new(bytes));
let decoded = r.read_signed(bits).unwrap();
assert_eq!(
decoded, value,
"round-trip failed for value={value} bits={bits}"
);
}
}
}