use alloc::vec::Vec;
pub struct BitWriter<'a> {
buf: &'a mut Vec<u8>,
current: u8,
bits: u8,
}
impl<'a> BitWriter<'a> {
pub fn new(buf: &'a mut Vec<u8>) -> Self {
Self {
buf,
current: 0,
bits: 0,
}
}
#[inline]
pub fn write_bit(&mut self, bit: bool) {
self.current = (self.current << 1) | (bit as u8);
self.bits += 1;
if self.bits == 8 {
self.buf.push(self.current);
self.current = 0;
self.bits = 0;
}
}
#[inline]
pub fn write_bits(&mut self, value: u32, count: u8) {
debug_assert!(count <= 32, "count={count} exceeds u32 width");
for i in (0..count).rev() {
self.write_bit((value >> i) & 1 != 0);
}
}
pub fn finish(self) {
if self.bits > 0 {
self.buf.push(self.current << (8 - self.bits));
}
}
#[cfg(test)]
pub(crate) fn bit_count(&self) -> usize {
self.buf.len() * 8 + self.bits as usize
}
}
pub struct BitReader<'a> {
buf: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitReader<'a> {
pub fn new(buf: &'a [u8]) -> Self {
Self {
buf,
byte_pos: 0,
bit_pos: 7,
}
}
#[inline]
pub fn read_bit(&mut self) -> Option<bool> {
if self.byte_pos >= self.buf.len() {
return None;
}
let bit = (self.buf[self.byte_pos] >> self.bit_pos) & 1 != 0;
if self.bit_pos == 0 {
self.byte_pos += 1;
self.bit_pos = 7;
} else {
self.bit_pos -= 1;
}
Some(bit)
}
#[inline]
pub fn read_bits(&mut self, count: u8) -> Option<u32> {
let mut v = 0u32;
for _ in 0..count {
v = (v << 1) | (self.read_bit()? as u32);
}
Some(v)
}
#[inline]
pub fn read_unary(&mut self) -> Option<u32> {
let mut q: u32 = 0;
if self.bit_pos < 7 {
let byte = self.buf[self.byte_pos];
let valid = self.bit_pos + 1;
let live = byte & ((1u8 << valid) - 1);
if live == 0 {
q = valid as u32;
self.byte_pos += 1;
self.bit_pos = 7;
} else {
let terminator_bit = 7u8 - live.leading_zeros() as u8;
q = (self.bit_pos - terminator_bit) as u32;
if terminator_bit == 0 {
self.byte_pos += 1;
self.bit_pos = 7;
} else {
self.bit_pos = terminator_bit - 1;
}
return Some(q);
}
}
while let Some(&byte) = self.buf.get(self.byte_pos) {
if byte == 0 {
q = q.saturating_add(8);
self.byte_pos += 1;
} else {
let lz = byte.leading_zeros();
q = q.saturating_add(lz);
let terminator_bit = 7u8 - lz as u8;
if terminator_bit == 0 {
self.byte_pos += 1;
self.bit_pos = 7;
} else {
self.bit_pos = terminator_bit - 1;
}
return Some(q);
}
}
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use alloc::vec;
#[test]
fn roundtrip_single_bits() {
let bits = [true, false, true, true, false, false, true, false, true];
let mut buf = Vec::new();
{
let mut w = BitWriter::new(&mut buf);
for &b in &bits {
w.write_bit(b);
}
w.finish();
}
let mut r = BitReader::new(&buf);
for &expected in &bits {
assert_eq!(r.read_bit(), Some(expected));
}
}
#[test]
fn roundtrip_multi_bit_values() {
let vals: &[(u32, u8)] = &[(0b10110, 5), (0, 3), (0xAB, 8), (1, 1), (0x1F, 5)];
let mut buf = Vec::new();
{
let mut w = BitWriter::new(&mut buf);
for &(v, n) in vals {
w.write_bits(v, n);
}
w.finish();
}
let mut r = BitReader::new(&buf);
for &(expected, n) in vals {
assert_eq!(r.read_bits(n), Some(expected));
}
}
#[test]
fn byte_padding_is_zero() {
let mut buf = Vec::new();
{
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b101, 3);
w.finish();
}
assert_eq!(buf.len(), 1);
assert_eq!(buf[0], 0b101_00000);
}
#[test]
fn writer_appends_to_existing_bytes() {
let mut buf = vec![0xAA, 0xBB];
{
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b1111_0000, 8);
w.finish();
}
assert_eq!(buf, vec![0xAA, 0xBB, 0xF0]);
}
#[test]
fn read_past_end_returns_none() {
let buf = [0xFFu8];
let mut r = BitReader::new(&buf);
for _ in 0..8 {
assert_eq!(r.read_bit(), Some(true));
}
assert_eq!(r.read_bit(), None);
}
#[test]
fn read_unary_matches_bit_loop() {
for prefix_bits in 0u8..5 {
for q in [0u32, 1, 2, 5, 7, 8, 15, 16, 17, 23] {
let mut buf = Vec::new();
{
let mut w = BitWriter::new(&mut buf);
if prefix_bits > 0 {
w.write_bits(0, prefix_bits);
}
for _ in 0..q {
w.write_bit(false);
}
w.write_bit(true);
w.finish();
}
let mut r = BitReader::new(&buf);
if prefix_bits > 0 {
let _ = r.read_bits(prefix_bits);
}
assert_eq!(r.read_unary(), Some(q), "q={q} prefix_bits={prefix_bits}");
}
}
}
#[test]
fn read_unary_truncated_returns_none() {
let buf = [0u8; 4];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_unary(), None);
}
#[test]
fn read_unary_terminator_at_last_bit_of_byte() {
let buf = [0x01u8, 0x80];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_unary(), Some(7));
assert_eq!(r.read_unary(), Some(0));
}
#[test]
fn bit_count_tracks_partial_byte() {
let mut buf = Vec::new();
let mut w = BitWriter::new(&mut buf);
assert_eq!(w.bit_count(), 0);
w.write_bits(0b1011, 4);
assert_eq!(w.bit_count(), 4);
w.write_bits(0b11, 2);
assert_eq!(w.bit_count(), 6);
w.write_bits(0b10, 2);
assert_eq!(w.bit_count(), 8);
}
}