1use crate::error::AecError;
2
3#[derive(Debug, Clone)]
5pub struct BitReader<'a> {
6 data: &'a [u8],
7 bit_pos: usize,
8}
9
10impl<'a> BitReader<'a> {
11 pub fn new(data: &'a [u8]) -> Self {
12 Self { data, bit_pos: 0 }
13 }
14
15 pub fn bits_read(&self) -> usize {
16 self.bit_pos
17 }
18
19 pub fn align_to_byte(&mut self) {
20 let rem = self.bit_pos % 8;
21 if rem != 0 {
22 self.bit_pos += 8 - rem;
23 }
24 }
25
26 pub fn read_bit(&mut self) -> Result<bool, AecError> {
27 Ok(self.read_bits_u32(1)? != 0)
28 }
29
30 pub fn read_bits_u32(&mut self, nbits: usize) -> Result<u32, AecError> {
31 if nbits == 0 {
32 return Ok(0);
33 }
34 if nbits > 32 {
35 return Err(AecError::InvalidInput("read_bits_u32 supports up to 32 bits"));
36 }
37
38 let mut out: u32 = 0;
39 for _ in 0..nbits {
40 let byte_idx = self.bit_pos / 8;
41 let bit_in_byte = self.bit_pos % 8;
42 let byte = *self
43 .data
44 .get(byte_idx)
45 .ok_or(AecError::UnexpectedEof { bit_pos: self.bit_pos })?;
46 let bit = (byte >> (7 - bit_in_byte)) & 1;
47 out = (out << 1) | (bit as u32);
48 self.bit_pos += 1;
49 }
50 Ok(out)
51 }
52}
53
54#[derive(Debug, Clone)]
59pub struct BitReaderLsb<'a> {
60 data: &'a [u8],
61 bit_pos: usize,
62}
63
64impl<'a> BitReaderLsb<'a> {
65 pub fn new(data: &'a [u8]) -> Self {
66 Self { data, bit_pos: 0 }
67 }
68
69 pub fn bits_read(&self) -> usize {
70 self.bit_pos
71 }
72
73 pub fn align_to_byte(&mut self) {
74 let rem = self.bit_pos % 8;
75 if rem != 0 {
76 self.bit_pos += 8 - rem;
77 }
78 }
79
80 pub fn read_bit(&mut self) -> Result<bool, AecError> {
81 Ok(self.read_bits_u32(1)? != 0)
82 }
83
84 pub fn read_bits_u32(&mut self, nbits: usize) -> Result<u32, AecError> {
85 if nbits == 0 {
86 return Ok(0);
87 }
88 if nbits > 32 {
89 return Err(AecError::InvalidInput("read_bits_u32 supports up to 32 bits"));
90 }
91
92 let mut out: u32 = 0;
93 for _ in 0..nbits {
94 let byte_idx = self.bit_pos / 8;
95 let bit_in_byte = self.bit_pos % 8;
96 let byte = *self
97 .data
98 .get(byte_idx)
99 .ok_or(AecError::UnexpectedEof { bit_pos: self.bit_pos })?;
100 let bit = (byte >> bit_in_byte) & 1;
101 out = (out << 1) | (bit as u32);
102 self.bit_pos += 1;
103 }
104 Ok(out)
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use super::*;
111
112 #[test]
113 fn read_bits_across_bytes() -> anyhow::Result<()> {
114 let data = [0b1010_1100u8, 0b0101_0001u8];
115 let mut r = BitReader::new(&data);
116
117 assert_eq!(r.read_bits_u32(4)?, 0b1010);
118 assert_eq!(r.read_bits_u32(4)?, 0b1100);
119 assert_eq!(r.read_bits_u32(3)?, 0b010);
120 assert_eq!(r.read_bits_u32(5)?, 0b10001);
121
122 Ok(())
123 }
124
125 #[test]
126 fn align_to_byte() -> anyhow::Result<()> {
127 let data = [0xffu8, 0x12u8];
128 let mut r = BitReader::new(&data);
129 assert_eq!(r.read_bits_u32(1)?, 1);
130 r.align_to_byte();
131 assert_eq!(r.read_bits_u32(8)?, 0x12);
132 Ok(())
133 }
134}