use super::JxsError;
pub struct BitReader<'a> {
data: &'a [u8],
byte_pos: usize,
bit_pos: u8,
}
impl<'a> BitReader<'a> {
pub fn new(data: &'a [u8]) -> Self {
Self {
data,
byte_pos: 0,
bit_pos: 7,
}
}
pub fn read_bit(&mut self) -> Result<u8, JxsError> {
if self.byte_pos >= self.data.len() {
return Err(JxsError::TruncatedStream {
need: self.byte_pos + 1,
have: self.data.len(),
});
}
let bit = (self.data[self.byte_pos] >> self.bit_pos) & 1;
if self.bit_pos == 0 {
self.byte_pos += 1;
self.bit_pos = 7;
} else {
self.bit_pos -= 1;
}
Ok(bit)
}
pub fn read_bits_u32(&mut self, n: u8) -> Result<u32, JxsError> {
let mut val: u32 = 0;
for _ in 0..n {
val = (val << 1) | u32::from(self.read_bit()?);
}
Ok(val)
}
pub fn read_u8(&mut self) -> Result<u8, JxsError> {
self.read_bits_u32(8).map(|v| v as u8)
}
pub fn read_u16_be(&mut self) -> Result<u16, JxsError> {
self.read_bits_u32(16).map(|v| v as u16)
}
pub fn byte_align(&mut self) {
if self.bit_pos != 7 {
self.byte_pos += 1;
self.bit_pos = 7;
}
}
pub fn byte_pos(&self) -> usize {
self.byte_pos
}
pub fn remaining_bytes(&self) -> usize {
self.data.len().saturating_sub(self.byte_pos)
}
pub fn peek_bits_u32(&self, n: u8) -> u32 {
let mut tmp_byte = self.byte_pos;
let mut tmp_bit = self.bit_pos;
let mut val: u32 = 0;
for _ in 0..n {
let bit = if tmp_byte < self.data.len() {
(self.data[tmp_byte] >> tmp_bit) & 1
} else {
0
};
val = (val << 1) | u32::from(bit);
if tmp_bit == 0 {
tmp_byte += 1;
tmp_bit = 7;
} else {
tmp_bit -= 1;
}
}
val
}
pub fn skip_bits(&mut self, n: u8) -> Result<(), JxsError> {
for _ in 0..n {
self.read_bit()?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn read_bit_msb_first() {
let data = &[0b1010_0011u8];
let mut r = BitReader::new(data);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 0);
assert_eq!(r.read_bit().unwrap(), 1);
assert_eq!(r.read_bit().unwrap(), 1);
}
#[test]
fn read_bits_u32_full_byte() {
let data = &[0xABu8];
let mut r = BitReader::new(data);
assert_eq!(r.read_bits_u32(8).unwrap(), 0xAB);
}
#[test]
fn read_bits_u32_crosses_byte_boundary() {
let data = &[0xABu8, 0xCDu8];
let mut r = BitReader::new(data);
assert_eq!(r.read_bits_u32(12).unwrap(), 0xABC);
}
#[test]
fn read_u16_be() {
let data = &[0x12u8, 0x34u8];
let mut r = BitReader::new(data);
assert_eq!(r.read_u16_be().unwrap(), 0x1234);
}
#[test]
fn truncated_stream_error() {
let data = &[0xFFu8];
let mut r = BitReader::new(data);
let _ = r.read_bits_u32(8).unwrap();
assert!(r.read_bit().is_err());
}
#[test]
fn byte_align_advances_past_partial_byte() {
let data = &[0xFFu8, 0xAAu8];
let mut r = BitReader::new(data);
let _ = r.read_bits_u32(3).unwrap();
r.byte_align();
assert_eq!(r.byte_pos(), 1);
assert_eq!(r.read_u8().unwrap(), 0xAA);
}
#[test]
fn byte_align_on_boundary_is_noop() {
let data = &[0xFFu8, 0xAAu8];
let mut r = BitReader::new(data);
let _ = r.read_bits_u32(8).unwrap();
assert_eq!(r.byte_pos(), 1);
r.byte_align();
assert_eq!(r.byte_pos(), 1);
}
#[test]
fn peek_does_not_advance() {
let data = &[0b1100_0000u8];
let mut r = BitReader::new(data);
let peeked = r.peek_bits_u32(4);
assert_eq!(peeked, 0b1100);
assert_eq!(r.byte_pos(), 0);
let read = r.read_bits_u32(4).unwrap();
assert_eq!(read, 0b1100);
}
#[test]
fn skip_bits_advances_position() {
let data = &[0b1111_0000u8];
let mut r = BitReader::new(data);
r.skip_bits(4).unwrap();
assert_eq!(r.read_bits_u32(4).unwrap(), 0b0000);
}
}