pub(crate) struct BitWriter {
out: Vec<u8>,
acc: u8,
bits: u32,
}
impl BitWriter {
pub fn new() -> Self {
Self {
out: Vec::new(),
acc: 0,
bits: 8,
}
}
pub fn emit(&mut self, data: u32, n: u32) {
debug_assert!(n > 0 && n <= 32);
let mut remaining = n;
let mut val = data;
if n < 32 {
val &= (1u32 << n) - 1;
}
if remaining <= self.bits {
self.bits -= remaining;
self.acc |= (val as u8) << self.bits;
if self.bits == 0 {
self.out.push(self.acc);
self.acc = 0;
self.bits = 8;
}
} else {
remaining -= self.bits;
self.acc |= (val >> remaining) as u8;
self.out.push(self.acc);
while remaining > 8 {
remaining -= 8;
self.out.push((val >> remaining) as u8);
}
self.bits = 8 - remaining;
self.acc = (val as u8) << self.bits;
}
}
pub fn emit_fs(&mut self, fs: u32) {
let mut zeros = fs;
loop {
if zeros < self.bits {
self.bits -= zeros + 1;
self.acc |= 1u8 << self.bits;
if self.bits == 0 {
self.out.push(self.acc);
self.acc = 0;
self.bits = 8;
}
break;
} else {
zeros -= self.bits;
self.out.push(self.acc);
self.acc = 0;
self.bits = 8;
}
}
}
pub fn bit_position(&self) -> u64 {
self.out.len() as u64 * 8 + (8 - self.bits) as u64
}
pub fn finish(mut self) -> Vec<u8> {
if self.bits < 8 {
self.out.push(self.acc);
}
self.out
}
pub fn pad_to_byte(&mut self) {
if self.bits < 8 {
self.out.push(self.acc);
self.acc = 0;
self.bits = 8;
}
}
}
pub(crate) struct BitReader<'a> {
data: &'a [u8],
pos: usize,
acc: u64,
bitp: u32,
}
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
pos: 0,
acc: 0,
bitp: 0,
}
}
pub fn from_bit_offset(data: &'a [u8], bit_offset: u64) -> Self {
let byte_offset = (bit_offset / 8) as usize;
let bit_remainder = (bit_offset % 8) as u32;
let mut reader = Self {
data,
pos: byte_offset,
acc: 0,
bitp: 0,
};
if bit_remainder > 0 && byte_offset < data.len() {
reader.acc = (data[byte_offset] as u64) & ((1u64 << (8 - bit_remainder)) - 1);
reader.bitp = 8 - bit_remainder;
reader.pos = byte_offset + 1;
}
reader
}
fn fill(&mut self, n: u32) -> bool {
while self.bitp < n {
if self.pos >= self.data.len() {
return false;
}
self.acc = (self.acc << 8) | (self.data[self.pos] as u64);
self.pos += 1;
self.bitp += 8;
}
true
}
pub fn read(&mut self, n: u32) -> Option<u32> {
debug_assert!(n > 0 && n <= 32);
if !self.fill(n) {
return None;
}
self.bitp -= n;
let val = (self.acc >> self.bitp) & ((1u64 << n) - 1);
Some(val as u32)
}
pub fn read_fs(&mut self) -> Option<u32> {
let mut fs: u32 = 0;
if self.bitp > 0 {
self.acc &= (1u64 << self.bitp) - 1;
} else {
self.acc = 0;
}
while self.acc == 0 {
if self.pos >= self.data.len() {
return None;
}
fs += self.bitp;
let load = (self.data.len() - self.pos).min(7);
self.acc = 0;
for i in 0..load {
self.acc = (self.acc << 8) | (self.data[self.pos + i] as u64);
}
self.pos += load;
self.bitp = (load as u32) * 8;
if self.bitp > 0 {
self.acc &= (1u64 << self.bitp) - 1;
}
}
let leading_zeros = self.acc.leading_zeros() - (64 - self.bitp);
fs += leading_zeros;
self.bitp -= leading_zeros + 1; Some(fs)
}
#[allow(dead_code)] pub fn bits_consumed(&self) -> u64 {
self.pos as u64 * 8 - self.bitp as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn write_then_read_bits() {
let mut w = BitWriter::new();
w.emit(0b101, 3);
w.emit(0b1111_0000, 8);
w.emit(0b1, 1);
let buf = w.finish();
let mut r = BitReader::new(&buf);
assert_eq!(r.read(3), Some(0b101));
assert_eq!(r.read(8), Some(0b1111_0000));
assert_eq!(r.read(1), Some(1));
}
#[test]
fn write_then_read_fs() {
let mut w = BitWriter::new();
w.emit_fs(0); w.emit_fs(5); w.emit_fs(0); w.emit_fs(12); let buf = w.finish();
let mut r = BitReader::new(&buf);
assert_eq!(r.read_fs(), Some(0));
assert_eq!(r.read_fs(), Some(5));
assert_eq!(r.read_fs(), Some(0));
assert_eq!(r.read_fs(), Some(12));
}
#[test]
fn emit_32_bits() {
let mut w = BitWriter::new();
w.emit(0xDEADBEEF, 32);
let buf = w.finish();
assert_eq!(buf, vec![0xDE, 0xAD, 0xBE, 0xEF]);
}
#[test]
fn bit_position_tracking() {
let mut w = BitWriter::new();
assert_eq!(w.bit_position(), 0);
w.emit(0, 3);
assert_eq!(w.bit_position(), 3);
w.emit(0, 5);
assert_eq!(w.bit_position(), 8);
w.emit(0, 1);
assert_eq!(w.bit_position(), 9);
}
#[test]
fn reader_from_bit_offset() {
let data = [0b1010_1100, 0b0011_1111];
let mut r = BitReader::from_bit_offset(&data, 4);
assert_eq!(r.read(4), Some(0b1100)); assert_eq!(r.read(4), Some(0b0011)); }
#[test]
fn round_trip_various_widths() {
for n in 1..=32 {
let val = if n == 32 {
0xFFFF_FFFF
} else {
(1u32 << n) - 1
};
let mut w = BitWriter::new();
w.emit(val, n);
let buf = w.finish();
let mut r = BitReader::new(&buf);
assert_eq!(r.read(n), Some(val), "failed for n={n}");
}
}
#[test]
fn fs_large_value() {
let mut w = BitWriter::new();
w.emit_fs(100);
let buf = w.finish();
let mut r = BitReader::new(&buf);
assert_eq!(r.read_fs(), Some(100));
}
#[test]
fn interleaved_fs_and_bits() {
let mut w = BitWriter::new();
w.emit(0b110, 3);
w.emit_fs(3);
w.emit(0xFF, 8);
w.emit_fs(0);
let buf = w.finish();
let mut r = BitReader::new(&buf);
assert_eq!(r.read(3), Some(0b110));
assert_eq!(r.read_fs(), Some(3));
assert_eq!(r.read(8), Some(0xFF));
assert_eq!(r.read_fs(), Some(0));
}
}