binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8}
9
10impl<'a> BitStreamReader<'a> {
11    /// Create a new LSB-first reader
12    pub fn new(buffer: &'a [u8]) -> Self {
13        Self { buffer, bit_pos: 0 }
14    }
15
16    /// Get byte position of reader
17    pub fn byte_pos(&self) -> usize {
18        self.bit_pos / 8
19    }
20
21    /// Read a single bit
22    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
23        self.read_small(1).map(|v| v != 0)
24    }
25
26    /// Read 1-8 bits as u8 (LSB-first)
27    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
28        assert!(bits > 0 && bits < 8);
29
30        let mut result: u8 = 0;
31        let mut shift = 0;
32
33        while bits > 0 {
34            if self.byte_pos() >= self.buffer.len() {
35                return Err(DeserializationError::NotEnoughBytes(1));
36            }
37
38            // Find the bit position inside the current byte (0..7)
39            let bit_offset = self.bit_pos % 8;
40
41            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
42            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
43
44            // Create a mask to isolate the bits we want from this byte.
45            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
46            // mask = 00011100 (only bits 2,3,4 are 1)
47            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
48            let byte_val = self.buffer[self.byte_pos()];
49
50            // Apply the mask to isolate the bits and shift them to LSB
51            // Example: byte_val = 10101100, mask = 00011100
52            // (10101100 & 00011100) >> 2 = 00000101
53            let val = (byte_val & mask) >> bit_offset;
54
55            // Merge the extracted bits into the final result.
56            // Shift them to the correct position based on how many bits we already read.
57            result |= val << shift;
58
59            // Decrease the remaining bits we need to read
60            bits -= bits_in_current_byte;
61
62            // Update the shift for the next batch of bits (if crossing byte boundary)
63            shift += bits_in_current_byte;
64
65            self.bit_pos += bits_in_current_byte as usize;
66        }
67
68        Ok(result)
69    }
70
71    /// Read a full byte, aligning to the next byte boundary
72    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
73        self.align_byte();
74
75        if self.byte_pos() >= self.buffer.len() {
76            return Err(DeserializationError::NotEnoughBytes(1));
77        }
78
79        let b = self.buffer[self.byte_pos()];
80        self.bit_pos += 8;
81        Ok(b)
82    }
83
84    /// Read a slice of bytes, aligning first
85    pub fn read_bytes(&mut self, count: usize) -> Result<&'a [u8], DeserializationError> {
86        self.align_byte();
87
88        let start = self.byte_pos();
89        if start + count > self.buffer.len() {
90            return Err(DeserializationError::NotEnoughBytes(
91                start + count - self.buffer.len(),
92            ));
93        }
94
95        self.bit_pos += 8 * count;
96
97        Ok(&self.buffer[start..start + count])
98    }
99
100    /// Read a dynamic int, starting at the next byte bounary
101    /// The last bit is used as a continuation flag for the next byte
102    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
103        self.align_byte();
104        let mut num: u128 = 0;
105        let mut multiplier: u128 = 1;
106
107        loop {
108            let byte = self.read_byte()?; // None if EOF
109            num += ((byte & 127) as u128) * multiplier;
110
111            // If no continuation bit, stop
112            if (byte & 1 << 7) == 0 {
113                break;
114            }
115
116            multiplier *= 128;
117        }
118
119        Ok(num)
120    }
121
122    /// Read a integer of fixed size from the buffer
123    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
124        &mut self,
125    ) -> Result<T, DeserializationError> {
126        let data = self.read_bytes(S)?;
127        Ok(FixedInt::deserialize(data))
128    }
129
130    /// Align the reader to the next byte boundary
131    pub fn align_byte(&mut self) {
132        let rem = self.bit_pos % 8;
133        if rem != 0 {
134            self.bit_pos += 8 - rem;
135        }
136    }
137
138    /// Get bytes left
139    pub fn bytes_left(&self) -> usize {
140        let left = self.buffer.len() - self.byte_pos();
141        if self.bit_pos % 8 != 0 {
142            left - 1 // If not aligned, we can't read the last byte fully
143        } else {
144            left
145        }
146    }
147
148    /// Reset reading position
149    pub fn reset(&mut self) {
150        self.bit_pos = 0;
151    }
152}
153
154#[cfg(test)]
155mod tests {
156    use crate::DeserializationError;
157
158    use super::BitStreamReader;
159
160    /// Helper to build buffers
161    fn make_buffer() -> Vec<u8> {
162        vec![0b10101100, 0b11010010, 0xFF, 0x00]
163    }
164
165    #[test]
166    fn test_read_single_bits() {
167        let buf = make_buffer();
168        let mut reader = BitStreamReader::new(&buf);
169
170        // LSB-first: read bits starting from least significant
171        assert_eq!(reader.read_bit(), Ok(false));
172        assert_eq!(reader.read_bit(), Ok(false));
173        assert_eq!(reader.read_bit(), Ok(true));
174        assert_eq!(reader.read_bit(), Ok(true));
175        assert_eq!(reader.read_bit(), Ok(false));
176        assert_eq!(reader.read_bit(), Ok(true));
177        assert_eq!(reader.read_bit(), Ok(false));
178        assert_eq!(reader.read_bit(), Ok(true));
179    }
180
181    #[test]
182    fn test_read_small() {
183        let buf = [0b10101100, 0b11010010];
184        let mut reader = BitStreamReader::new(&buf);
185
186        assert_eq!(reader.read_small(3), Ok(0b100));
187        assert_eq!(reader.read_small(4), Ok(0b0101));
188        assert_eq!(reader.read_small(1), Ok(0b1));
189        assert_eq!(reader.read_small(4), Ok(0b0010));
190    }
191
192    #[test]
193    fn test_read_cross_byte() {
194        let buf = [0b10101100, 0b11010001];
195        let mut reader = BitStreamReader::new(&buf);
196
197        // Read first 10 bits (crosses into second byte)
198        assert_eq!(reader.read_small(7), Ok(0b00101100));
199        assert_eq!(reader.read_small(3), Ok(0b011));
200    }
201
202    #[test]
203    fn test_read_byte() {
204        let buf = [0b10101100, 0b11010010];
205        let mut reader = BitStreamReader::new(&buf);
206
207        reader.read_small(3).unwrap(); // advance 3 bits
208        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
209    }
210
211    #[test]
212    fn test_read_bytes() {
213        let buf = [0x01, 0xAA, 0xBB, 0xCC];
214        let mut reader = BitStreamReader::new(&buf);
215
216        reader.read_bit().unwrap(); // first bit
217        let slice = reader.read_bytes(3).unwrap();
218        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
219    }
220
221    #[test]
222    fn test_align_byte() {
223        let buf = [0b10101100, 0b11010010];
224        let mut reader = BitStreamReader::new(&buf);
225
226        reader.read_small(3).unwrap(); // 3 bits
227        reader.align_byte(); // move to next byte
228        assert_eq!(reader.read_byte(), Ok(0b11010010));
229    }
230
231    #[test]
232    fn test_eof_behavior() {
233        let buf = [0xFF];
234        let mut reader = BitStreamReader::new(&buf);
235
236        assert_eq!(reader.read_byte(), Ok(0xFF));
237        assert_eq!(
238            reader.read_bit(),
239            Err(DeserializationError::NotEnoughBytes(1))
240        );
241        assert_eq!(
242            reader.read_byte(),
243            Err(DeserializationError::NotEnoughBytes(1))
244        );
245        assert_eq!(
246            reader.read_bytes(2),
247            Err(DeserializationError::NotEnoughBytes(2))
248        );
249    }
250
251    #[test]
252    fn test_multiple_operations() {
253        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
254        let mut reader = BitStreamReader::new(&buf);
255
256        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
257        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
258        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
259        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
260        assert_eq!(
261            reader.read_bit(),
262            Err(DeserializationError::NotEnoughBytes(1))
263        );
264    }
265
266    #[test]
267    fn test_read_dyn_int() {
268        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
269        let mut stream = BitStreamReader::new(&buf);
270
271        assert_eq!(Ok(0), stream.read_byte());
272        assert_eq!(Ok(127), stream.read_dyn_int());
273        assert_eq!(Ok(128), stream.read_dyn_int());
274        assert_eq!(Ok(268435455), stream.read_dyn_int());
275        assert_eq!(
276            Err(DeserializationError::NotEnoughBytes(1)),
277            stream.read_dyn_int()
278        );
279    }
280
281    #[test]
282    fn test_read_fixed_int() {
283        let buf = vec![
284            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
285            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
286            0, 0, 0, 10,
287        ];
288
289        let mut stream = BitStreamReader::new(&buf);
290        let v1: u8 = stream.read_fixed_int().unwrap();
291        let v2: i8 = stream.read_fixed_int().unwrap();
292        let v3: u16 = stream.read_fixed_int().unwrap();
293        let v4: i16 = stream.read_fixed_int().unwrap();
294        let v5: u32 = stream.read_fixed_int().unwrap();
295        let v6: i32 = stream.read_fixed_int().unwrap();
296        let v7: u64 = stream.read_fixed_int().unwrap();
297        let v8: i64 = stream.read_fixed_int().unwrap();
298        let v9: u128 = stream.read_fixed_int().unwrap();
299        let v10: i128 = stream.read_fixed_int().unwrap();
300
301        assert_eq!(v1, 1);
302        assert_eq!(v2, 1);
303        assert_eq!(v3, 2);
304        assert_eq!(v4, 2);
305        assert_eq!(v5, 3);
306        assert_eq!(v6, 3);
307        assert_eq!(v7, 4);
308        assert_eq!(v8, 4);
309        assert_eq!(v9, 5);
310        assert_eq!(v10, 5);
311    }
312
313    #[test]
314    fn test_bytes_left() {
315        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
316        let mut reader = BitStreamReader::new(&buf);
317
318        assert_eq!(reader.bytes_left(), 4);
319        reader.read_small(3).unwrap(); // read 3 bits
320        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
321        reader.read_byte().unwrap(); // read one byte
322        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
323        reader.read_byte().unwrap(); // read another byte
324        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
325        reader.read_bit().unwrap(); // read one bit
326        assert_eq!(reader.bytes_left(), 0); // no full bytes left
327    }
328}