use crate::{Error, Result};
#[derive(Clone, Copy)]
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
acc: u64,
bits_in_acc: u32,
}
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
acc: 0,
bits_in_acc: 0,
}
}
pub fn with_position(data: &'a [u8], byte_pos: usize) -> Self {
let byte_pos = byte_pos.min(data.len());
Self {
data,
byte_pos,
acc: 0,
bits_in_acc: 0,
}
}
pub fn bit_position(&self) -> u64 {
self.byte_pos as u64 * 8 - self.bits_in_acc as u64
}
pub fn byte_position(&self) -> usize {
(self.bit_position() / 8) as usize
}
pub fn bits_remaining(&self) -> u64 {
self.bits_in_acc as u64 + ((self.data.len() - self.byte_pos) as u64) * 8
}
pub fn is_byte_aligned(&self) -> bool {
self.bits_in_acc % 8 == 0
}
pub fn align_to_byte(&mut self) {
let drop = self.bits_in_acc % 8;
self.acc <<= drop;
self.bits_in_acc -= drop;
}
fn refill(&mut self) {
while self.bits_in_acc <= 56 && self.byte_pos < self.data.len() {
self.acc |= (self.data[self.byte_pos] as u64) << (56 - self.bits_in_acc);
self.bits_in_acc += 8;
self.byte_pos += 1;
}
}
pub fn read_u32(&mut self, n: u32) -> Result<u32> {
debug_assert!(n <= 32, "BitReader::read_u32 supports up to 32 bits");
if n == 0 {
return Ok(0);
}
if self.bits_in_acc < n {
self.refill();
if self.bits_in_acc < n {
return Err(Error::invalid("bitreader: out of bits"));
}
}
let v = (self.acc >> (64 - n)) as u32;
self.acc <<= n;
self.bits_in_acc -= n;
Ok(v)
}
pub fn read_u64(&mut self, n: u32) -> Result<u64> {
debug_assert!(n <= 64);
if n <= 32 {
return self.read_u32(n).map(|v| v as u64);
}
let hi = self.read_u32(n - 32)? as u64;
let lo = self.read_u32(32)? as u64;
Ok((hi << 32) | lo)
}
pub fn read_i32(&mut self, n: u32) -> Result<i32> {
if n == 0 {
return Ok(0);
}
let raw = self.read_u32(n)? as i32;
let shift = 32 - n;
Ok((raw << shift) >> shift)
}
pub fn read_bit(&mut self) -> Result<bool> {
Ok(self.read_u32(1)? != 0)
}
pub fn read_u1(&mut self) -> Result<u32> {
self.read_u32(1)
}
pub fn peek_u32(&mut self, n: u32) -> Result<u32> {
debug_assert!(n <= 32);
if n == 0 {
return Ok(0);
}
if self.bits_in_acc < n {
self.refill();
if self.bits_in_acc < n {
return Err(Error::invalid("bitreader: out of bits for peek"));
}
}
Ok((self.acc >> (64 - n)) as u32)
}
pub fn skip(&mut self, n: u32) -> Result<()> {
let mut left = n;
while left > 32 {
self.read_u32(32)?;
left -= 32;
}
self.read_u32(left)?;
Ok(())
}
pub fn consume(&mut self, n: u32) -> Result<()> {
self.skip(n)
}
pub fn read_unary(&mut self) -> Result<u32> {
let mut count = 0u32;
loop {
if self.bits_in_acc == 0 {
self.refill();
if self.bits_in_acc == 0 {
return Err(Error::invalid("bitreader: out of bits in unary code"));
}
}
let lz_total = self.acc.leading_zeros();
let lz_avail = lz_total.min(self.bits_in_acc);
count = count
.checked_add(lz_avail)
.ok_or_else(|| Error::invalid("bitreader: unary count overflow"))?;
if lz_avail >= 64 {
self.acc = 0;
} else {
self.acc <<= lz_avail;
}
self.bits_in_acc -= lz_avail;
if lz_avail < lz_total || self.bits_in_acc == 0 {
continue;
}
self.acc <<= 1;
self.bits_in_acc -= 1;
return Ok(count);
}
}
pub fn read_bytes(&mut self, n: usize) -> Result<Vec<u8>> {
if !self.is_byte_aligned() {
return Err(Error::invalid(
"bitreader: read_bytes requires byte alignment",
));
}
self.align_to_byte();
let start = self.byte_pos - (self.bits_in_acc as usize / 8);
if start + n > self.data.len() {
return Err(Error::invalid("bitreader: read_bytes past end"));
}
let out = self.data[start..start + n].to_vec();
self.acc = 0;
self.bits_in_acc = 0;
self.byte_pos = start + n;
Ok(out)
}
}
pub struct BitWriter {
data: Vec<u8>,
acc: u64,
bits_in_acc: u32,
}
impl BitWriter {
pub fn new() -> Self {
Self {
data: Vec::new(),
acc: 0,
bits_in_acc: 0,
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
data: Vec::with_capacity(cap),
acc: 0,
bits_in_acc: 0,
}
}
pub fn bit_position(&self) -> u64 {
self.data.len() as u64 * 8 + self.bits_in_acc as u64
}
pub fn byte_len(&self) -> usize {
self.data.len()
}
pub fn is_byte_aligned(&self) -> bool {
self.bits_in_acc % 8 == 0
}
pub fn write_u32(&mut self, value: u32, n: u32) {
debug_assert!(n <= 32, "BitWriter::write_u32 supports up to 32 bits");
if n == 0 {
return;
}
let mask: u32 = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
let v = (value & mask) as u64;
let shift = 64 - self.bits_in_acc - n;
self.acc |= v << shift;
self.bits_in_acc += n;
while self.bits_in_acc >= 8 {
let byte = (self.acc >> 56) as u8;
self.data.push(byte);
self.acc <<= 8;
self.bits_in_acc -= 8;
}
}
pub fn write_bits(&mut self, value: u32, n: u32) {
self.write_u32(value, n)
}
pub fn write_u64(&mut self, value: u64, n: u32) {
debug_assert!(n <= 64);
if n <= 32 {
self.write_u32(value as u32, n);
} else {
self.write_u32((value >> 32) as u32, n - 32);
self.write_u32(value as u32, 32);
}
}
pub fn write_i32(&mut self, value: i32, n: u32) {
self.write_u32(value as u32, n);
}
pub fn write_bit(&mut self, bit: bool) {
self.write_u32(bit as u32, 1);
}
pub fn write_unary(&mut self, n: u32) {
let mut remaining = n;
while remaining >= 32 {
self.write_u32(0, 32);
remaining -= 32;
}
if remaining > 0 {
self.write_u32(0, remaining);
}
self.write_bit(true);
}
pub fn write_byte(&mut self, b: u8) {
self.write_u32(b as u32, 8);
}
pub fn write_bytes(&mut self, bytes: &[u8]) {
if self.is_byte_aligned() {
self.data.extend_from_slice(bytes);
} else {
for &b in bytes {
self.write_u32(b as u32, 8);
}
}
}
pub fn align_to_byte(&mut self) {
let pad = (8 - self.bits_in_acc % 8) % 8;
if pad > 0 {
self.write_u32(0, pad);
}
}
pub fn align_to_byte_zero(&mut self) {
self.align_to_byte()
}
pub fn bytes(&self) -> &[u8] {
&self.data
}
pub fn buffer(&self) -> &[u8] {
&self.data
}
pub fn finish(mut self) -> Vec<u8> {
if self.bits_in_acc > 0 {
let byte = (self.acc >> 56) as u8;
self.data.push(byte);
self.acc = 0;
self.bits_in_acc = 0;
}
self.data
}
pub fn into_bytes(self) -> Vec<u8> {
self.finish()
}
}
impl Default for BitWriter {
fn default() -> Self {
Self::new()
}
}
#[derive(Clone, Copy)]
pub struct BitReaderLsb<'a> {
data: &'a [u8],
byte_pos: usize,
acc: u64,
bits_in_acc: u32,
}
impl<'a> BitReaderLsb<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
acc: 0,
bits_in_acc: 0,
}
}
pub fn bit_position(&self) -> u64 {
self.byte_pos as u64 * 8 - self.bits_in_acc as u64
}
pub fn is_byte_aligned(&self) -> bool {
self.bits_in_acc % 8 == 0
}
fn refill(&mut self) {
while self.bits_in_acc <= 56 && self.byte_pos < self.data.len() {
self.acc |= (self.data[self.byte_pos] as u64) << self.bits_in_acc;
self.bits_in_acc += 8;
self.byte_pos += 1;
}
}
pub fn read_u32(&mut self, n: u32) -> Result<u32> {
debug_assert!(n <= 32, "BitReaderLsb::read_u32 supports up to 32 bits");
if n == 0 {
return Ok(0);
}
if self.bits_in_acc < n {
self.refill();
if self.bits_in_acc < n {
return Err(Error::Eof);
}
}
let mask = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
let v = (self.acc as u32) & mask;
self.acc >>= n;
self.bits_in_acc -= n;
Ok(v)
}
pub fn read_u64(&mut self, n: u32) -> Result<u64> {
debug_assert!(n <= 64);
if n == 0 {
return Ok(0);
}
if n <= 32 {
return Ok(self.read_u32(n)? as u64);
}
let lo = self.read_u32(32)? as u64;
let hi = self.read_u32(n - 32)? as u64;
Ok(lo | (hi << 32))
}
pub fn read_i32(&mut self, n: u32) -> Result<i32> {
if n == 0 {
return Ok(0);
}
let raw = self.read_u32(n)? as i32;
let shift = 32 - n;
Ok((raw << shift) >> shift)
}
pub fn read_bit(&mut self) -> Result<bool> {
Ok(self.read_u32(1)? != 0)
}
}
pub struct BitWriterLsb {
data: Vec<u8>,
acc: u64,
bits_in_acc: u32,
}
impl BitWriterLsb {
pub fn new() -> Self {
Self {
data: Vec::new(),
acc: 0,
bits_in_acc: 0,
}
}
pub fn with_capacity(cap: usize) -> Self {
Self {
data: Vec::with_capacity(cap),
acc: 0,
bits_in_acc: 0,
}
}
pub fn bit_position(&self) -> u64 {
self.data.len() as u64 * 8 + self.bits_in_acc as u64
}
pub fn write_u32(&mut self, value: u32, n: u32) {
debug_assert!(n <= 32, "BitWriterLsb::write_u32 supports up to 32 bits");
if n == 0 {
return;
}
let mask: u32 = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
let v = value & mask;
self.acc |= (v as u64) << self.bits_in_acc;
self.bits_in_acc += n;
while self.bits_in_acc >= 8 {
self.data.push((self.acc & 0xFF) as u8);
self.acc >>= 8;
self.bits_in_acc -= 8;
}
}
pub fn write_u64(&mut self, value: u64, n: u32) {
debug_assert!(n <= 64);
if n <= 32 {
self.write_u32(value as u32, n);
} else {
self.write_u32(value as u32, 32);
self.write_u32((value >> 32) as u32, n - 32);
}
}
pub fn write_bit(&mut self, bit: bool) {
self.write_u32(bit as u32, 1);
}
pub fn align_to_byte(&mut self) {
let pad = (8 - self.bits_in_acc % 8) % 8;
self.write_u32(0, pad);
}
pub fn finish(mut self) -> Vec<u8> {
if self.bits_in_acc > 0 {
self.data.push((self.acc & 0xFF) as u8);
self.acc = 0;
self.bits_in_acc = 0;
}
self.data
}
}
impl Default for BitWriterLsb {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn msb_roundtrip_byte() {
let mut w = BitWriter::new();
for &b in &[1u32, 0, 1, 0, 0, 1, 0, 1] {
w.write_u32(b, 1);
}
assert_eq!(w.finish(), vec![0xA5]);
}
#[test]
fn msb_roundtrip_varied_widths() {
let mut bw = BitWriter::new();
let writes: Vec<(u32, u32)> = vec![
(0b1, 1),
(0b10101, 5),
(0b111100001111, 12),
(0xDEADBEEF, 32),
(0b001, 3),
(0xC, 4),
(0xABCD, 16),
(0x12345, 20),
(0, 8),
(0xFFFFFFFF, 32),
];
for &(v, n) in &writes {
bw.write_u32(v, n);
}
let bytes = bw.finish();
let mut br = BitReader::new(&bytes);
for &(v, n) in &writes {
let got = br.read_u32(n).unwrap();
let mask = if n == 32 { u32::MAX } else { (1 << n) - 1 };
assert_eq!(got, v & mask, "mismatch for ({v:#x}, {n})");
}
}
#[test]
fn msb_signed_extension() {
let mut br = BitReader::new(&[0xFF]);
assert_eq!(br.read_i32(4).unwrap(), -1);
assert_eq!(br.read_i32(4).unwrap(), -1);
}
#[test]
fn msb_peek_skip() {
let mut br = BitReader::new(&[0xFF, 0x00]);
assert_eq!(br.peek_u32(12).unwrap(), 0xFF0);
br.skip(4).unwrap();
assert_eq!(br.read_u32(8).unwrap(), 0xF0);
}
#[test]
fn msb_alignment() {
let mut br = BitReader::new(&[0xFF, 0x55]);
br.read_u32(3).unwrap();
assert!(!br.is_byte_aligned());
br.align_to_byte();
assert!(br.is_byte_aligned());
assert_eq!(br.read_u32(8).unwrap(), 0x55);
}
#[test]
fn msb_write_bytes_fast_path() {
let mut w = BitWriter::new();
w.write_bytes(&[0x11, 0x22, 0x33]);
assert_eq!(w.finish(), vec![0x11, 0x22, 0x33]);
}
#[test]
fn msb_write_bytes_unaligned() {
let mut w = BitWriter::new();
w.write_u32(0b101, 3);
w.write_bytes(&[0xFF, 0x00]);
let out = w.finish();
assert_eq!(out.len(), 3);
}
#[test]
fn msb_write_bits_alias() {
let mut w = BitWriter::new();
w.write_bits(0xA, 4);
w.write_u32(0x5, 4);
assert_eq!(w.finish(), vec![0xA5]);
}
#[test]
fn msb_into_bytes_alias() {
let mut w = BitWriter::new();
w.write_u32(0xA5, 8);
assert_eq!(w.into_bytes(), vec![0xA5]);
}
#[test]
fn msb_read_u64_high_bits() {
let mut w = BitWriter::new();
w.write_u64(0x1234567890ABCDEF, 64);
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
assert_eq!(r.read_u64(64).unwrap(), 0x1234567890ABCDEF);
}
#[test]
fn msb_unary_roundtrip() {
let mut w = BitWriter::new();
let counts: Vec<u32> = vec![0, 1, 7, 31, 32, 33, 64, 65, 100];
for &c in &counts {
w.write_unary(c);
}
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
for &c in &counts {
assert_eq!(r.read_unary().unwrap(), c, "unary roundtrip for {c}");
}
}
#[test]
fn msb_read_bytes_aligned() {
let mut br = BitReader::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
let _ = br.read_u32(8).unwrap();
let got = br.read_bytes(2).unwrap();
assert_eq!(got, vec![0xBB, 0xCC]);
assert_eq!(br.read_u32(8).unwrap(), 0xDD);
}
#[test]
fn lsb_roundtrip_byte() {
let mut w = BitWriterLsb::new();
for &b in &[1u32, 0, 1, 0, 0, 1, 0, 1] {
w.write_u32(b, 1);
}
assert_eq!(w.finish(), vec![0xA5]);
}
#[test]
fn lsb_multi_byte() {
let mut w = BitWriterLsb::new();
w.write_u32(0x3412, 16);
let bytes = w.finish();
assert_eq!(bytes, vec![0x12, 0x34]);
let mut r = BitReaderLsb::new(&bytes);
assert_eq!(r.read_u32(16).unwrap(), 0x3412);
}
#[test]
fn lsb_roundtrip_varied_widths() {
let mut bw = BitWriterLsb::new();
let writes: Vec<(u32, u32)> = vec![(5, 3), (0xABCD, 16), (0x1234567, 27), (1, 1)];
for &(v, n) in &writes {
bw.write_u32(v, n);
}
let bytes = bw.finish();
let mut r = BitReaderLsb::new(&bytes);
for &(v, n) in &writes {
assert_eq!(r.read_u32(n).unwrap(), v);
}
}
}