use super::{AlacError, AlacResult};
pub struct BitReader<'a> {
data: &'a [u8],
bit_pos: usize,
}
impl<'a> BitReader<'a> {
#[must_use]
pub fn new(data: &'a [u8]) -> Self {
Self { data, bit_pos: 0 }
}
#[must_use]
pub fn total_bits(&self) -> usize {
self.data.len() * 8
}
#[must_use]
pub fn position(&self) -> usize {
self.bit_pos
}
#[must_use]
pub fn bits_left(&self) -> usize {
self.total_bits().saturating_sub(self.bit_pos)
}
pub fn read_bit(&mut self) -> AlacResult<bool> {
if self.bit_pos >= self.total_bits() {
return Err(AlacError::Truncated("read_bit past end of frame".into()));
}
let byte = self.data[self.bit_pos >> 3];
let shift = 7 - (self.bit_pos & 7);
self.bit_pos += 1;
Ok((byte >> shift) & 1 != 0)
}
pub fn read_bits(&mut self, count: u32) -> AlacResult<u32> {
debug_assert!(count <= 32);
if count == 0 {
return Ok(0);
}
if self.bit_pos + count as usize > self.total_bits() {
return Err(AlacError::Truncated(format!(
"read_bits({count}) past end of frame"
)));
}
let mut value: u32 = 0;
let mut remaining = count;
while remaining > 0 {
let byte_index = self.bit_pos >> 3;
let bit_in_byte = self.bit_pos & 7;
let avail = 8 - bit_in_byte;
let take = remaining.min(avail as u32);
let byte = u32::from(self.data[byte_index]);
let shift = avail as u32 - take;
let mask = if take == 32 {
u32::MAX
} else {
(1u32 << take) - 1
};
let chunk = (byte >> shift) & mask;
value = (value << take) | chunk;
self.bit_pos += take as usize;
remaining -= take;
}
Ok(value)
}
pub fn read_signed(&mut self, count: u32) -> AlacResult<i32> {
let raw = self.read_bits(count)?;
Ok(sign_extend(raw, count))
}
pub fn read_unary(&mut self) -> AlacResult<u32> {
let mut count = 0u32;
loop {
let bit = self.read_bit()?;
if !bit {
break;
}
count += 1;
if count > (1u32 << 20) {
return Err(AlacError::InvalidBitstream(
"unary run length exceeds sanity bound".into(),
));
}
}
Ok(count)
}
pub fn peek_24(&self) -> u32 {
let mut value = 0u32;
for i in 0..24 {
let pos = self.bit_pos + i;
let bit = if pos < self.total_bits() {
let byte = self.data[pos >> 3];
let shift = 7 - (pos & 7);
u32::from((byte >> shift) & 1)
} else {
0
};
value = (value << 1) | bit;
}
value
}
pub fn skip(&mut self, count: usize) -> AlacResult<()> {
if self.bit_pos + count > self.total_bits() {
return Err(AlacError::Truncated("skip past end of frame".into()));
}
self.bit_pos += count;
Ok(())
}
}
#[derive(Default)]
pub struct BitWriter {
bytes: Vec<u8>,
cur: u8,
fill: u8,
}
impl BitWriter {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn bit_len(&self) -> usize {
self.bytes.len() * 8 + self.fill as usize
}
pub fn write_bit(&mut self, bit: bool) {
self.cur = (self.cur << 1) | u8::from(bit);
self.fill += 1;
if self.fill == 8 {
self.bytes.push(self.cur);
self.cur = 0;
self.fill = 0;
}
}
pub fn write_bits(&mut self, value: u32, count: u32) {
debug_assert!(count <= 32);
let mut remaining = count;
while remaining > 0 {
let bit_index = remaining - 1;
let bit = (value >> bit_index) & 1 != 0;
self.write_bit(bit);
remaining -= 1;
}
}
pub fn write_signed(&mut self, value: i32, count: u32) {
let mask = if count >= 32 {
u32::MAX
} else {
(1u32 << count) - 1
};
self.write_bits((value as u32) & mask, count);
}
pub fn write_unary(&mut self, n: u32) {
for _ in 0..n {
self.write_bit(true);
}
self.write_bit(false);
}
#[must_use]
pub fn finish(mut self) -> Vec<u8> {
if self.fill > 0 {
let padded = self.cur << (8 - self.fill);
self.bytes.push(padded);
self.cur = 0;
self.fill = 0;
}
self.bytes
}
#[must_use]
pub fn finish_with_len(self) -> (Vec<u8>, usize) {
let bit_len = self.bit_len();
(self.finish(), bit_len)
}
pub fn append_bits(&mut self, bytes: &[u8], bit_len: usize) {
for i in 0..bit_len {
let byte = bytes[i >> 3];
let bit = (byte >> (7 - (i & 7))) & 1 != 0;
self.write_bit(bit);
}
}
}
#[inline]
#[must_use]
pub fn sign_extend(value: u32, bits: u32) -> i32 {
if bits == 0 || bits >= 32 {
return value as i32;
}
let shift = 32 - bits;
((value << shift) as i32) >> shift
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_write_read_bits_roundtrip() {
let mut w = BitWriter::new();
w.write_bits(0b101, 3);
w.write_bits(0xABCD, 16);
w.write_bits(0, 5);
w.write_bits(0x7FFF_FFFF, 31);
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
assert_eq!(r.read_bits(3).unwrap(), 0b101);
assert_eq!(r.read_bits(16).unwrap(), 0xABCD);
assert_eq!(r.read_bits(5).unwrap(), 0);
assert_eq!(r.read_bits(31).unwrap(), 0x7FFF_FFFF);
}
#[test]
fn test_signed_roundtrip() {
let mut w = BitWriter::new();
for &v in &[-1i32, -100, 0, 1, 100, -2048, 2047] {
w.write_signed(v, 12);
}
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
for &v in &[-1i32, -100, 0, 1, 100, -2048, 2047] {
assert_eq!(r.read_signed(12).unwrap(), v);
}
}
#[test]
fn test_unary_roundtrip() {
let mut w = BitWriter::new();
for n in [0u32, 1, 5, 17, 33] {
w.write_unary(n);
}
let bytes = w.finish();
let mut r = BitReader::new(&bytes);
for n in [0u32, 1, 5, 17, 33] {
assert_eq!(r.read_unary().unwrap(), n);
}
}
#[test]
fn test_read_past_end_errors() {
let bytes = [0xFFu8];
let mut r = BitReader::new(&bytes);
assert!(r.read_bits(8).is_ok());
assert!(r.read_bit().is_err());
assert!(r.read_bits(1).is_err());
}
#[test]
fn test_sign_extend() {
assert_eq!(sign_extend(0b1111, 4), -1);
assert_eq!(sign_extend(0b0111, 4), 7);
assert_eq!(sign_extend(0b1000, 4), -8);
assert_eq!(sign_extend(0, 16), 0);
}
#[test]
fn test_peek_24() {
let bytes = [0xAB, 0xCD, 0xEF];
let r = BitReader::new(&bytes);
assert_eq!(r.peek_24(), 0xABCDEF);
}
}