Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    /// If crypto is set, it will return the decrypted slice. This is different from other `slice` methods which always returns the raw buffer slice.
49    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50        let start = self.marker.unwrap_or(0);
51        let end = to.unwrap_or(self.byte_pos());
52
53        if let Some(crypto) = self.crypto.as_ref() {
54            return &crypto.get_plaintext()[start..end];
55        }
56
57        &self.buffer[start..end]
58    }
59
60    /// Return slice from offset-end to end of buffer
61    pub fn slice_end(&self) -> &[u8] {
62        &self.buffer[self.buffer.len() - self.offset_end..]
63    }
64
65    /// Set crypto stream
66    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
67        self.crypto = crypto;
68    }
69
70    /// Set integrity offset to ignore when reading
71    pub fn set_offset_end(&mut self, len: usize) {
72        self.offset_end = len;
73    }
74
75    /// Get byte position of reader
76    pub fn byte_pos(&self) -> usize {
77        self.bit_pos / 8
78    }
79
80    /// Get current byte, from last_read_byte cache or from buffer
81    fn current_byte(&mut self) -> u8 {
82        if let Some(b) = self.last_read_byte {
83            b
84        } else {
85            let mut b = self.buffer[self.byte_pos()];
86            if let Some(crypto) = self.crypto.as_mut() {
87                b = crypto.apply_keystream_byte(b);
88            }
89
90            self.last_read_byte = Some(b);
91            b
92        }
93    }
94
95    /// Read a single bit
96    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
97        self.read_small(1).map(|v| v != 0)
98    }
99
100    /// Read 1-8 bits as u8 (LSB-first)
101    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
102        assert!(bits > 0 && bits < 8);
103
104        let mut result: u8 = 0;
105        let mut shift = 0;
106
107        while bits > 0 {
108            if self.byte_pos() >= self.buffer.len() - self.offset_end {
109                return Err(DeserializationError::NotEnoughBytes(1));
110            }
111
112            // Find the bit position inside the current byte (0..7)
113            let bit_offset = self.bit_pos % 8;
114
115            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
116            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
117
118            // Create a mask to isolate the bits we want from this byte.
119            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
120            // mask = 00011100 (only bits 2,3,4 are 1)
121            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
122            let byte_val = self.current_byte();
123
124            // Apply the mask to isolate the bits and shift them to LSB
125            // Example: byte_val = 10101100, mask = 00011100
126            // (10101100 & 00011100) >> 2 = 00000101
127            let val = (byte_val & mask) >> bit_offset;
128
129            // Merge the extracted bits into the final result.
130            // Shift them to the correct position based on how many bits we already read.
131            result |= val << shift;
132
133            // Decrease the remaining bits we need to read
134            bits -= bits_in_current_byte;
135
136            // Update the shift for the next batch of bits (if crossing byte boundary)
137            shift += bits_in_current_byte;
138
139            self.bit_pos += bits_in_current_byte as usize;
140
141            // If crossed byte boundary, reset last read byte
142            if self.bit_pos % 8 == 0 {
143                self.last_read_byte = None;
144            }
145        }
146
147        Ok(result)
148    }
149
150    /// Read a full byte, aligning to the next byte boundary
151    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
152        self.align_byte();
153
154        if self.byte_pos() >= self.buffer.len() - self.offset_end {
155            return Err(DeserializationError::NotEnoughBytes(1));
156        }
157
158        let byte = self.current_byte();
159        self.bit_pos += 8;
160        self.last_read_byte = None;
161
162        Ok(byte)
163    }
164
165    /// Read a slice of bytes, aligning first
166    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
167        self.align_byte();
168
169        let start = self.byte_pos();
170        if start + count > self.buffer.len() - self.offset_end {
171            return Err(DeserializationError::NotEnoughBytes(
172                start + count - self.buffer.len(),
173            ));
174        }
175
176        self.bit_pos += 8 * count;
177        self.last_read_byte = None;
178
179        let slice = &self.buffer[start..start + count];
180        if let Some(crypto) = self.crypto.as_mut() {
181            Ok(crypto.apply_keystream(slice))
182        } else {
183            Ok(slice)
184        }
185    }
186
187    /// Read a dynamic int, starting at the next byte bounary
188    /// The last bit is used as a continuation flag for the next byte
189    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
190        self.align_byte();
191        let mut num: u128 = 0;
192        let mut multiplier: u128 = 1;
193
194        loop {
195            let byte = self.read_byte()?; // None if EOF
196            num += ((byte & 127) as u128) * multiplier;
197
198            // If no continuation bit, stop
199            if (byte & 1 << 7) == 0 {
200                break;
201            }
202
203            multiplier *= 128;
204        }
205
206        Ok(num)
207    }
208
209    /// Read a integer of fixed size from the buffer
210    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
211        &mut self,
212    ) -> Result<T, DeserializationError> {
213        let data = self.read_bytes(S)?;
214        Ok(FixedInt::deserialize(data))
215    }
216
217    /// Align the reader to the next byte boundary
218    pub fn align_byte(&mut self) {
219        let rem = self.bit_pos % 8;
220        if rem != 0 {
221            self.bit_pos += 8 - rem;
222            self.last_read_byte = None;
223        }
224    }
225
226    /// Get bytes left
227    pub fn bytes_left(&self) -> usize {
228        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
229        if self.bit_pos % 8 != 0 {
230            left - 1 // If not aligned, we can't read the last byte fully
231        } else {
232            left
233        }
234    }
235
236    /// Reset reading position
237    pub fn reset(&mut self) {
238        self.bit_pos = 0;
239    }
240}
241
242#[cfg(test)]
243mod tests {
244    use crate::{DeserializationError, bitstream::CryptoStream};
245
246    use super::BitStreamReader;
247
248    struct PlusOneDecrypter {
249        plain: Vec<u8>,
250    }
251
252    impl CryptoStream for PlusOneDecrypter {
253        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
254            self.plain.push(b + 1);
255            *self.plain.last().unwrap()
256        }
257
258        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
259            let d = slice.iter().map(|s| s + 1);
260            self.plain.extend(d);
261            &self.plain[self.plain.len() - slice.len()..]
262        }
263
264        fn get_plaintext(&self) -> &[u8] {
265            &self.plain
266        }
267    }
268
269    #[test]
270    fn test_decrypt_bytes() {
271        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
272        let mut reader = BitStreamReader::new(&buf);
273        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
274
275        assert_eq!(reader.read_byte(), Ok(2));
276        assert_eq!(reader.read_byte(), Ok(3));
277        assert_eq!(reader.read_byte(), Ok(4));
278        // 4 = 00000100, +1 = 00000101
279        assert_eq!(reader.read_bit(), Ok(true));
280        assert_eq!(reader.read_bit(), Ok(false));
281        assert_eq!(reader.read_bit(), Ok(true));
282        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
283        assert_eq!(reader.read_byte(), Ok(11));
284    }
285
286    /// Helper to build buffers
287    fn make_buffer() -> Vec<u8> {
288        vec![0b10101100, 0b11010010, 0xFF, 0x00]
289    }
290
291    #[test]
292    fn test_read_single_bits() {
293        let buf = make_buffer();
294        let mut reader = BitStreamReader::new(&buf);
295
296        // LSB-first: read bits starting from least significant
297        assert_eq!(reader.read_bit(), Ok(false));
298        assert_eq!(reader.read_bit(), Ok(false));
299        assert_eq!(reader.read_bit(), Ok(true));
300        assert_eq!(reader.read_bit(), Ok(true));
301        assert_eq!(reader.read_bit(), Ok(false));
302        assert_eq!(reader.read_bit(), Ok(true));
303        assert_eq!(reader.read_bit(), Ok(false));
304        assert_eq!(reader.read_bit(), Ok(true));
305    }
306
307    #[test]
308    fn test_read_small() {
309        let buf = [0b10101100, 0b11010010];
310        let mut reader = BitStreamReader::new(&buf);
311
312        assert_eq!(reader.read_small(3), Ok(0b100));
313        assert_eq!(reader.read_small(4), Ok(0b0101));
314        assert_eq!(reader.read_small(1), Ok(0b1));
315        assert_eq!(reader.read_small(4), Ok(0b0010));
316    }
317
318    #[test]
319    fn test_read_cross_byte() {
320        let buf = [0b10101100, 0b11010001];
321        let mut reader = BitStreamReader::new(&buf);
322
323        // Read first 10 bits (crosses into second byte)
324        assert_eq!(reader.read_small(7), Ok(0b00101100));
325        assert_eq!(reader.read_small(3), Ok(0b011));
326    }
327
328    #[test]
329    fn test_read_byte() {
330        let buf = [0b10101100, 0b11010010];
331        let mut reader = BitStreamReader::new(&buf);
332
333        reader.read_small(3).unwrap(); // advance 3 bits
334        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
335    }
336
337    #[test]
338    fn test_read_bytes() {
339        let buf = [0x01, 0xAA, 0xBB, 0xCC];
340        let mut reader = BitStreamReader::new(&buf);
341
342        reader.read_bit().unwrap(); // first bit
343        let slice = reader.read_bytes(3).unwrap();
344        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
345    }
346
347    #[test]
348    fn test_align_byte() {
349        let buf = [0b10101100, 0b11010010];
350        let mut reader = BitStreamReader::new(&buf);
351
352        reader.read_small(3).unwrap(); // 3 bits
353        reader.align_byte(); // move to next byte
354        assert_eq!(reader.read_byte(), Ok(0b11010010));
355    }
356
357    #[test]
358    fn test_eof_behavior() {
359        let buf = [0xFF];
360        let mut reader = BitStreamReader::new(&buf);
361
362        assert_eq!(reader.read_byte(), Ok(0xFF));
363        assert_eq!(
364            reader.read_bit(),
365            Err(DeserializationError::NotEnoughBytes(1))
366        );
367        assert_eq!(
368            reader.read_byte(),
369            Err(DeserializationError::NotEnoughBytes(1))
370        );
371        assert_eq!(
372            reader.read_bytes(2),
373            Err(DeserializationError::NotEnoughBytes(2))
374        );
375    }
376
377    #[test]
378    fn test_multiple_operations() {
379        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
380        let mut reader = BitStreamReader::new(&buf);
381
382        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
383        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
384        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
385        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
386        assert_eq!(
387            reader.read_bit(),
388            Err(DeserializationError::NotEnoughBytes(1))
389        );
390    }
391
392    #[test]
393    fn test_read_dyn_int() {
394        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
395        let mut stream = BitStreamReader::new(&buf);
396
397        assert_eq!(Ok(0), stream.read_byte());
398        assert_eq!(Ok(127), stream.read_dyn_int());
399        assert_eq!(Ok(128), stream.read_dyn_int());
400        assert_eq!(Ok(268435455), stream.read_dyn_int());
401        assert_eq!(
402            Err(DeserializationError::NotEnoughBytes(1)),
403            stream.read_dyn_int()
404        );
405    }
406
407    #[test]
408    fn test_read_fixed_int() {
409        let buf = vec![
410            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
411            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
412            0, 0, 0, 10,
413        ];
414
415        let mut stream = BitStreamReader::new(&buf);
416        let v1: u8 = stream.read_fixed_int().unwrap();
417        let v2: i8 = stream.read_fixed_int().unwrap();
418        let v3: u16 = stream.read_fixed_int().unwrap();
419        let v4: i16 = stream.read_fixed_int().unwrap();
420        let v5: u32 = stream.read_fixed_int().unwrap();
421        let v6: i32 = stream.read_fixed_int().unwrap();
422        let v7: u64 = stream.read_fixed_int().unwrap();
423        let v8: i64 = stream.read_fixed_int().unwrap();
424        let v9: u128 = stream.read_fixed_int().unwrap();
425        let v10: i128 = stream.read_fixed_int().unwrap();
426
427        assert_eq!(v1, 1);
428        assert_eq!(v2, 1);
429        assert_eq!(v3, 2);
430        assert_eq!(v4, 2);
431        assert_eq!(v5, 3);
432        assert_eq!(v6, 3);
433        assert_eq!(v7, 4);
434        assert_eq!(v8, 4);
435        assert_eq!(v9, 5);
436        assert_eq!(v10, 5);
437    }
438
439    #[test]
440    fn test_bytes_left() {
441        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
442        let mut reader = BitStreamReader::new(&buf);
443
444        assert_eq!(reader.bytes_left(), 4);
445        reader.read_small(3).unwrap(); // read 3 bits
446        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
447        reader.read_byte().unwrap(); // read one byte
448        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
449        reader.read_byte().unwrap(); // read another byte
450        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
451        reader.read_bit().unwrap(); // read one bit
452        assert_eq!(reader.bytes_left(), 0); // no full bytes left
453    }
454
455    #[test]
456    fn offset_end_ignores_bytes_and_can_slice() {
457        let buff = [1, 2, 3, 4, 5];
458        let mut reader = BitStreamReader::new(&buff);
459
460        reader.set_offset_end(2);
461        assert_eq!(reader.bytes_left(), 3);
462        assert_eq!(reader.read_byte(), Ok(1));
463
464        assert_eq!(reader.slice(true), &[1, 2, 3]);
465        assert_eq!(reader.slice(false), &[2, 3]);
466        assert_eq!(reader.slice_end(), &[4, 5]);
467
468        assert_eq!(reader.read_byte(), Ok(2));
469        assert_eq!(reader.read_byte(), Ok(3));
470        assert_eq!(
471            reader.read_byte(),
472            Err(DeserializationError::NotEnoughBytes(1))
473        );
474
475        reader.set_offset_end(0);
476        assert_eq!(reader.bytes_left(), 2);
477        assert_eq!(reader.read_byte(), Ok(4));
478        assert_eq!(reader.read_byte(), Ok(5));
479    }
480
481    #[test]
482    fn test_slice_start() {
483        let buff = [10, 20, 30, 40, 50];
484        let mut reader = BitStreamReader::new(&buff);
485
486        assert_eq!(reader.slice_marker(None), &[]);
487
488        reader.read_byte().unwrap(); // Read 10
489        assert_eq!(reader.slice_marker(None), &[10]);
490
491        reader.read_small(4).unwrap(); // Read 4 bits of 20
492        assert_eq!(reader.slice_marker(None), &[10]);
493
494        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
495        assert_eq!(reader.slice_marker(None), &[10, 20]);
496
497        reader.read_bytes(2).unwrap(); // Read 30, 40
498        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
499    }
500
501    #[test]
502    fn test_slice_start_with_marker() {
503        let buff = [10, 20, 30, 40, 50];
504        let mut reader = BitStreamReader::new(&buff);
505
506        reader.read_byte().unwrap(); // Read 10
507        assert_eq!(reader.slice_marker(None), &[10]);
508        reader.set_marker();
509        assert_eq!(reader.slice_marker(None), &[]);
510
511        reader.read_bytes(2).unwrap(); // Read 20, 30
512        assert_eq!(reader.slice_marker(None), &[20, 30]);
513    }
514}