use std::io;
pub struct BitWriter {
buffer: Vec<u8>,
cur: u64,
bits_used: u8,
}
impl BitWriter {
pub fn new() -> Self {
Self { buffer: Vec::new(), cur: 0, bits_used: 0 }
}
pub fn with_capacity(cap: usize) -> Self {
Self { buffer: Vec::with_capacity(cap), cur: 0, bits_used: 0 }
}
#[inline]
pub fn write_bits(&mut self, value: u32, n_bits: u8) {
debug_assert!(n_bits <= 32);
if n_bits == 0 {
return;
}
if self.bits_used + n_bits > 64 {
self.drain_words();
debug_assert!(self.bits_used + n_bits <= 64);
}
let shift = 64 - self.bits_used - n_bits;
let mask = if n_bits == 32 { u32::MAX as u64 } else { (1u64 << n_bits) - 1 };
self.cur |= (value as u64 & mask) << shift;
self.bits_used += n_bits;
if self.bits_used >= 16 {
self.drain_words();
}
}
#[inline]
fn drain_words(&mut self) {
while self.bits_used >= 16 {
let word = (self.cur >> 48) as u16;
let bytes = word.to_le_bytes();
self.buffer.extend_from_slice(&bytes);
self.cur <<= 16;
self.bits_used -= 16;
}
}
#[inline]
pub fn write_bit(&mut self, bit: bool) {
self.write_bits(if bit { 1 } else { 0 }, 1);
}
#[inline]
pub fn write_u24_be(&mut self, value: u32) {
debug_assert!(value < (1 << 24));
self.write_bits((value >> 8) & 0xFFFF, 16);
self.write_bits(value & 0xFF, 8);
}
#[inline]
pub fn write_u32_le(&mut self, value: u32) {
let lo = value & 0xFFFF;
let hi = value >> 16;
self.write_bits(lo, 16);
self.write_bits(hi, 16);
}
pub fn align(&mut self) {
if self.bits_used == 0 {
self.buffer.push(0);
self.buffer.push(0);
} else {
let pad = 16 - (self.bits_used % 16);
self.bits_used += pad;
self.drain_words();
}
}
pub fn write_raw(&mut self, bytes: &[u8]) {
debug_assert_eq!(self.bits_used, 0, "write_raw requires word-aligned stream");
self.buffer.extend_from_slice(bytes);
}
pub fn finish(mut self) -> Vec<u8> {
if self.bits_used != 0 {
let pad = 16 - (self.bits_used % 16);
self.bits_used += pad;
self.drain_words();
}
self.buffer
}
}
impl Default for BitWriter {
fn default() -> Self {
Self::new()
}
}
impl io::Write for BitWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.write_raw(buf);
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn bits_pack_into_expected_byte_count() {
let mut w = BitWriter::new();
w.write_bits(0b1, 1);
w.write_bits(0b10, 2);
w.write_bits(0b111, 3);
w.write_bits(0b1010, 4);
w.write_bits(0b1, 1);
w.write_bits(0xABCD, 16);
w.write_bits(0x1234, 16);
let bytes = w.finish();
assert_eq!(bytes.len(), 6);
}
#[test]
fn u32_le_roundtrips() {
let mut w = BitWriter::new();
w.write_u32_le(0x12345678);
let bytes = w.finish();
assert_eq!(bytes, vec![0x78, 0x56, 0x34, 0x12]);
}
#[test]
fn u24_be_roundtrips() {
let mut w = BitWriter::new();
w.write_u24_be(0xABCDEF);
let bytes = w.finish();
assert_eq!(bytes.len(), 4);
}
}