use core::fmt;
#[derive(Debug, Clone, PartialEq, Eq)]
#[non_exhaustive]
pub enum BitError {
OutOfBounds {
needed_bits: usize,
remaining_bits: usize,
},
TooManyBits {
requested: u32,
},
ValueTooWide {
value: u64,
bits: u32,
},
}
impl fmt::Display for BitError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
BitError::OutOfBounds {
needed_bits,
remaining_bits,
} => write!(
f,
"bit buffer out of bounds: need {needed_bits} bit(s), {remaining_bits} remaining"
),
BitError::TooManyBits { requested } => {
write!(f, "requested {requested} bits exceeds the 64-bit carrier")
}
BitError::ValueTooWide { value, bits } => {
write!(f, "value {value:#x} does not fit in {bits} bit(s)")
}
}
}
}
impl core::error::Error for BitError {}
#[derive(Debug, Clone)]
pub struct BitReader<'a> {
data: &'a [u8],
bit_pos: usize,
}
impl<'a> BitReader<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self { data, bit_pos: 0 }
}
#[must_use]
pub fn total_bits(&self) -> usize {
self.data.len() * 8
}
#[must_use]
pub fn bits_read(&self) -> usize {
self.bit_pos
}
#[must_use]
pub fn bits_remaining(&self) -> usize {
self.total_bits() - self.bit_pos
}
#[must_use]
pub fn is_byte_aligned(&self) -> bool {
self.bit_pos % 8 == 0
}
pub fn read_bits(&mut self, n: u32) -> Result<u64, BitError> {
if n > 64 {
return Err(BitError::TooManyBits { requested: n });
}
if n == 0 {
return Ok(0);
}
let need = n as usize;
let remaining = self.bits_remaining();
if need > remaining {
return Err(BitError::OutOfBounds {
needed_bits: need,
remaining_bits: remaining,
});
}
let mut value: u64 = 0;
for _ in 0..n {
let byte = self.data[self.bit_pos / 8];
let bit_index = 7 - (self.bit_pos % 8); let bit = (byte >> bit_index) & 1;
value = (value << 1) | u64::from(bit);
self.bit_pos += 1;
}
Ok(value)
}
pub fn read_bool(&mut self) -> Result<bool, BitError> {
Ok(self.read_bits(1)? != 0)
}
pub fn skip_bits(&mut self, n: usize) -> Result<(), BitError> {
let remaining = self.bits_remaining();
if n > remaining {
return Err(BitError::OutOfBounds {
needed_bits: n,
remaining_bits: remaining,
});
}
self.bit_pos += n;
Ok(())
}
pub fn align_to_byte(&mut self) {
let rem = self.bit_pos % 8;
if rem != 0 {
self.bit_pos += 8 - rem;
}
}
}
#[derive(Debug)]
pub struct BitWriter<'a> {
data: &'a mut [u8],
bit_pos: usize,
}
impl<'a> BitWriter<'a> {
#[must_use]
pub fn new(data: &'a mut [u8]) -> Self {
Self { data, bit_pos: 0 }
}
#[must_use]
pub fn capacity_bits(&self) -> usize {
self.data.len() * 8
}
#[must_use]
pub fn bits_written(&self) -> usize {
self.bit_pos
}
#[must_use]
pub fn is_byte_aligned(&self) -> bool {
self.bit_pos % 8 == 0
}
pub fn write_bits(&mut self, value: u64, n: u32) -> Result<(), BitError> {
if n > 64 {
return Err(BitError::TooManyBits { requested: n });
}
if n == 0 {
return Ok(());
}
if n < 64 && value >= (1u64 << n) {
return Err(BitError::ValueTooWide { value, bits: n });
}
let need = n as usize;
let remaining = self.capacity_bits() - self.bit_pos;
if need > remaining {
return Err(BitError::OutOfBounds {
needed_bits: need,
remaining_bits: remaining,
});
}
for i in (0..n).rev() {
let bit = ((value >> i) & 1) as u8;
let byte_idx = self.bit_pos / 8;
let bit_index = 7 - (self.bit_pos % 8);
if bit == 1 {
self.data[byte_idx] |= 1 << bit_index;
} else {
self.data[byte_idx] &= !(1u8 << bit_index);
}
self.bit_pos += 1;
}
Ok(())
}
pub fn write_bool(&mut self, value: bool) -> Result<(), BitError> {
self.write_bits(u64::from(value), 1)
}
pub fn align_to_byte(&mut self) -> Result<(), BitError> {
let rem = self.bit_pos % 8;
if rem != 0 {
self.write_bits(0, (8 - rem) as u32)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
#![allow(clippy::unusual_byte_groupings)]
use super::*;
#[test]
fn single_byte_fields_round_trip() {
let mut buf = [0u8; 1];
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b1, 1).unwrap();
w.write_bits(0b01, 2).unwrap();
w.write_bits(0b10101, 5).unwrap();
assert_eq!(w.bits_written(), 8);
assert_eq!(buf[0], 0b1_01_10101);
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(1).unwrap(), 0b1);
assert_eq!(r.read_bits(2).unwrap(), 0b01);
assert_eq!(r.read_bits(5).unwrap(), 0b10101);
assert_eq!(r.bits_remaining(), 0);
}
#[test]
fn field_crossing_byte_boundary_round_trips() {
let mut buf = [0u8; 4];
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b101, 3).unwrap();
w.write_bits(0b10_1010_1010_1010_1011, 18).unwrap();
let val18 = 0b10_1010_1010_1010_1011u64;
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(3).unwrap(), 0b101);
assert_eq!(r.read_bits(18).unwrap(), val18);
}
#[test]
fn read_zero_bits_is_noop() {
let buf = [0xFFu8];
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(0).unwrap(), 0);
assert_eq!(r.bits_read(), 0);
}
#[test]
fn full_64_bit_field() {
let mut buf = [0u8; 8];
let value = 0xDEAD_BEEF_CAFE_F00Du64;
let mut w = BitWriter::new(&mut buf);
w.write_bits(value, 64).unwrap();
let mut r = BitReader::new(&buf);
assert_eq!(r.read_bits(64).unwrap(), value);
}
#[test]
fn read_past_end_errs() {
let buf = [0xFFu8]; let mut r = BitReader::new(&buf);
r.read_bits(7).unwrap();
let err = r.read_bits(2).unwrap_err();
assert_eq!(
err,
BitError::OutOfBounds {
needed_bits: 2,
remaining_bits: 1,
}
);
}
#[test]
fn read_too_many_bits_errs() {
let buf = [0u8; 16];
let mut r = BitReader::new(&buf);
assert_eq!(
r.read_bits(65).unwrap_err(),
BitError::TooManyBits { requested: 65 }
);
}
#[test]
fn write_value_too_wide_errs() {
let mut buf = [0u8; 4];
let mut w = BitWriter::new(&mut buf);
assert_eq!(
w.write_bits(0b100, 2).unwrap_err(),
BitError::ValueTooWide {
value: 0b100,
bits: 2
}
);
}
#[test]
fn write_past_end_errs() {
let mut buf = [0u8; 1];
let mut w = BitWriter::new(&mut buf);
w.write_bits(0, 7).unwrap();
assert_eq!(
w.write_bits(0b11, 2).unwrap_err(),
BitError::OutOfBounds {
needed_bits: 2,
remaining_bits: 1,
}
);
}
#[test]
fn writer_does_not_require_zeroed_buffer() {
let mut buf = [0xFFu8; 1];
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b0000_0000, 8).unwrap();
assert_eq!(buf[0], 0x00);
}
#[test]
fn bool_round_trips() {
let mut buf = [0u8; 1];
let mut w = BitWriter::new(&mut buf);
w.write_bool(true).unwrap();
w.write_bool(false).unwrap();
w.write_bool(true).unwrap();
let mut r = BitReader::new(&buf);
assert!(r.read_bool().unwrap());
assert!(!r.read_bool().unwrap());
assert!(r.read_bool().unwrap());
}
#[test]
fn skip_and_align() {
let buf = [0b1010_1100u8, 0b1111_0000];
let mut r = BitReader::new(&buf);
r.read_bits(2).unwrap(); r.skip_bits(3).unwrap(); assert!(!r.is_byte_aligned());
r.align_to_byte(); assert!(r.is_byte_aligned());
assert_eq!(r.read_bits(4).unwrap(), 0b1111);
}
#[test]
fn writer_align_pads_with_zero() {
let mut buf = [0xFFu8; 1];
let mut w = BitWriter::new(&mut buf);
w.write_bits(0b101, 3).unwrap();
w.align_to_byte().unwrap();
assert_eq!(w.bits_written(), 8);
assert_eq!(buf[0], 0b1010_0000); }
#[test]
fn exhaustive_small_width_round_trip() {
for bits in 1u32..=16 {
let max = if bits == 64 {
u64::MAX
} else {
(1u64 << bits) - 1
};
for value in [0u64, 1, max, max / 2] {
let mut buf = [0u8; 8];
let mut w = BitWriter::new(&mut buf);
w.write_bits(value, bits).unwrap();
let mut r = BitReader::new(&buf);
assert_eq!(
r.read_bits(bits).unwrap(),
value,
"round-trip failed: value={value:#x} bits={bits}"
);
}
}
}
}