Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
49        &self.buffer[self.marker.unwrap_or(0)..to.unwrap_or(self.byte_pos())]
50    }
51
52    /// Return slice from offset-end to end of buffer
53    pub fn slice_end(&self) -> &[u8] {
54        &self.buffer[self.buffer.len() - self.offset_end..]
55    }
56
57    /// Set crypto stream
58    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
59        self.crypto = crypto;
60    }
61
62    /// Set integrity offset to ignore when reading
63    pub fn set_offset_end(&mut self, len: usize) {
64        self.offset_end = len;
65    }
66
67    /// Get byte position of reader
68    pub fn byte_pos(&self) -> usize {
69        self.bit_pos / 8
70    }
71
72    /// Get current byte, from last_read_byte cache or from buffer
73    fn current_byte(&mut self) -> u8 {
74        if let Some(b) = self.last_read_byte {
75            b
76        } else {
77            let mut b = self.buffer[self.byte_pos()];
78            if let Some(crypto) = self.crypto.as_mut() {
79                b = crypto.apply_keystream_byte(b);
80            }
81
82            self.last_read_byte = Some(b);
83            b
84        }
85    }
86
87    /// Read a single bit
88    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
89        self.read_small(1).map(|v| v != 0)
90    }
91
92    /// Read 1-8 bits as u8 (LSB-first)
93    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
94        assert!(bits > 0 && bits < 8);
95
96        let mut result: u8 = 0;
97        let mut shift = 0;
98
99        while bits > 0 {
100            if self.byte_pos() >= self.buffer.len() - self.offset_end {
101                return Err(DeserializationError::NotEnoughBytes(1));
102            }
103
104            // Find the bit position inside the current byte (0..7)
105            let bit_offset = self.bit_pos % 8;
106
107            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
108            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
109
110            // Create a mask to isolate the bits we want from this byte.
111            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
112            // mask = 00011100 (only bits 2,3,4 are 1)
113            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
114            let byte_val = self.current_byte();
115
116            // Apply the mask to isolate the bits and shift them to LSB
117            // Example: byte_val = 10101100, mask = 00011100
118            // (10101100 & 00011100) >> 2 = 00000101
119            let val = (byte_val & mask) >> bit_offset;
120
121            // Merge the extracted bits into the final result.
122            // Shift them to the correct position based on how many bits we already read.
123            result |= val << shift;
124
125            // Decrease the remaining bits we need to read
126            bits -= bits_in_current_byte;
127
128            // Update the shift for the next batch of bits (if crossing byte boundary)
129            shift += bits_in_current_byte;
130
131            self.bit_pos += bits_in_current_byte as usize;
132
133            // If crossed byte boundary, reset last read byte
134            if self.bit_pos % 8 == 0 {
135                self.last_read_byte = None;
136            }
137        }
138
139        Ok(result)
140    }
141
142    /// Read a full byte, aligning to the next byte boundary
143    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
144        self.align_byte();
145
146        if self.byte_pos() >= self.buffer.len() - self.offset_end {
147            return Err(DeserializationError::NotEnoughBytes(1));
148        }
149
150        let byte = self.current_byte();
151        self.bit_pos += 8;
152        self.last_read_byte = None;
153
154        Ok(byte)
155    }
156
157    /// Read a slice of bytes, aligning first
158    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
159        self.align_byte();
160
161        let start = self.byte_pos();
162        if start + count > self.buffer.len() - self.offset_end {
163            return Err(DeserializationError::NotEnoughBytes(
164                start + count - self.buffer.len(),
165            ));
166        }
167
168        self.bit_pos += 8 * count;
169        self.last_read_byte = None;
170
171        let slice = &self.buffer[start..start + count];
172        if let Some(crypto) = self.crypto.as_mut() {
173            Ok(crypto.apply_keystream(slice))
174        } else {
175            Ok(slice)
176        }
177    }
178
179    /// Read a dynamic int, starting at the next byte bounary
180    /// The last bit is used as a continuation flag for the next byte
181    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
182        self.align_byte();
183        let mut num: u128 = 0;
184        let mut multiplier: u128 = 1;
185
186        loop {
187            let byte = self.read_byte()?; // None if EOF
188            num += ((byte & 127) as u128) * multiplier;
189
190            // If no continuation bit, stop
191            if (byte & 1 << 7) == 0 {
192                break;
193            }
194
195            multiplier *= 128;
196        }
197
198        Ok(num)
199    }
200
201    /// Read a integer of fixed size from the buffer
202    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
203        &mut self,
204    ) -> Result<T, DeserializationError> {
205        let data = self.read_bytes(S)?;
206        Ok(FixedInt::deserialize(data))
207    }
208
209    /// Align the reader to the next byte boundary
210    pub fn align_byte(&mut self) {
211        let rem = self.bit_pos % 8;
212        if rem != 0 {
213            self.bit_pos += 8 - rem;
214            self.last_read_byte = None;
215        }
216    }
217
218    /// Get bytes left
219    pub fn bytes_left(&self) -> usize {
220        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
221        if self.bit_pos % 8 != 0 {
222            left - 1 // If not aligned, we can't read the last byte fully
223        } else {
224            left
225        }
226    }
227
228    /// Reset reading position
229    pub fn reset(&mut self) {
230        self.bit_pos = 0;
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use crate::{DeserializationError, bitstream::CryptoStream};
237
238    use super::BitStreamReader;
239
240    struct PlusOneDecrypter {
241        plain: Vec<u8>,
242    }
243
244    impl CryptoStream for PlusOneDecrypter {
245        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
246            self.plain.push(b + 1);
247            *self.plain.last().unwrap()
248        }
249
250        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
251            let d = slice.iter().map(|s| s + 1);
252            self.plain.extend(d);
253            &self.plain[self.plain.len() - slice.len()..]
254        }
255    }
256
257    #[test]
258    fn test_decrypt_bytes() {
259        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
260        let mut reader = BitStreamReader::new(&buf);
261        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
262
263        assert_eq!(reader.read_byte(), Ok(2));
264        assert_eq!(reader.read_byte(), Ok(3));
265        assert_eq!(reader.read_byte(), Ok(4));
266        // 4 = 00000100, +1 = 00000101
267        assert_eq!(reader.read_bit(), Ok(true));
268        assert_eq!(reader.read_bit(), Ok(false));
269        assert_eq!(reader.read_bit(), Ok(true));
270        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
271        assert_eq!(reader.read_byte(), Ok(11));
272    }
273
274    /// Helper to build buffers
275    fn make_buffer() -> Vec<u8> {
276        vec![0b10101100, 0b11010010, 0xFF, 0x00]
277    }
278
279    #[test]
280    fn test_read_single_bits() {
281        let buf = make_buffer();
282        let mut reader = BitStreamReader::new(&buf);
283
284        // LSB-first: read bits starting from least significant
285        assert_eq!(reader.read_bit(), Ok(false));
286        assert_eq!(reader.read_bit(), Ok(false));
287        assert_eq!(reader.read_bit(), Ok(true));
288        assert_eq!(reader.read_bit(), Ok(true));
289        assert_eq!(reader.read_bit(), Ok(false));
290        assert_eq!(reader.read_bit(), Ok(true));
291        assert_eq!(reader.read_bit(), Ok(false));
292        assert_eq!(reader.read_bit(), Ok(true));
293    }
294
295    #[test]
296    fn test_read_small() {
297        let buf = [0b10101100, 0b11010010];
298        let mut reader = BitStreamReader::new(&buf);
299
300        assert_eq!(reader.read_small(3), Ok(0b100));
301        assert_eq!(reader.read_small(4), Ok(0b0101));
302        assert_eq!(reader.read_small(1), Ok(0b1));
303        assert_eq!(reader.read_small(4), Ok(0b0010));
304    }
305
306    #[test]
307    fn test_read_cross_byte() {
308        let buf = [0b10101100, 0b11010001];
309        let mut reader = BitStreamReader::new(&buf);
310
311        // Read first 10 bits (crosses into second byte)
312        assert_eq!(reader.read_small(7), Ok(0b00101100));
313        assert_eq!(reader.read_small(3), Ok(0b011));
314    }
315
316    #[test]
317    fn test_read_byte() {
318        let buf = [0b10101100, 0b11010010];
319        let mut reader = BitStreamReader::new(&buf);
320
321        reader.read_small(3).unwrap(); // advance 3 bits
322        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
323    }
324
325    #[test]
326    fn test_read_bytes() {
327        let buf = [0x01, 0xAA, 0xBB, 0xCC];
328        let mut reader = BitStreamReader::new(&buf);
329
330        reader.read_bit().unwrap(); // first bit
331        let slice = reader.read_bytes(3).unwrap();
332        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
333    }
334
335    #[test]
336    fn test_align_byte() {
337        let buf = [0b10101100, 0b11010010];
338        let mut reader = BitStreamReader::new(&buf);
339
340        reader.read_small(3).unwrap(); // 3 bits
341        reader.align_byte(); // move to next byte
342        assert_eq!(reader.read_byte(), Ok(0b11010010));
343    }
344
345    #[test]
346    fn test_eof_behavior() {
347        let buf = [0xFF];
348        let mut reader = BitStreamReader::new(&buf);
349
350        assert_eq!(reader.read_byte(), Ok(0xFF));
351        assert_eq!(
352            reader.read_bit(),
353            Err(DeserializationError::NotEnoughBytes(1))
354        );
355        assert_eq!(
356            reader.read_byte(),
357            Err(DeserializationError::NotEnoughBytes(1))
358        );
359        assert_eq!(
360            reader.read_bytes(2),
361            Err(DeserializationError::NotEnoughBytes(2))
362        );
363    }
364
365    #[test]
366    fn test_multiple_operations() {
367        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
368        let mut reader = BitStreamReader::new(&buf);
369
370        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
371        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
372        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
373        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
374        assert_eq!(
375            reader.read_bit(),
376            Err(DeserializationError::NotEnoughBytes(1))
377        );
378    }
379
380    #[test]
381    fn test_read_dyn_int() {
382        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
383        let mut stream = BitStreamReader::new(&buf);
384
385        assert_eq!(Ok(0), stream.read_byte());
386        assert_eq!(Ok(127), stream.read_dyn_int());
387        assert_eq!(Ok(128), stream.read_dyn_int());
388        assert_eq!(Ok(268435455), stream.read_dyn_int());
389        assert_eq!(
390            Err(DeserializationError::NotEnoughBytes(1)),
391            stream.read_dyn_int()
392        );
393    }
394
395    #[test]
396    fn test_read_fixed_int() {
397        let buf = vec![
398            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
399            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
400            0, 0, 0, 10,
401        ];
402
403        let mut stream = BitStreamReader::new(&buf);
404        let v1: u8 = stream.read_fixed_int().unwrap();
405        let v2: i8 = stream.read_fixed_int().unwrap();
406        let v3: u16 = stream.read_fixed_int().unwrap();
407        let v4: i16 = stream.read_fixed_int().unwrap();
408        let v5: u32 = stream.read_fixed_int().unwrap();
409        let v6: i32 = stream.read_fixed_int().unwrap();
410        let v7: u64 = stream.read_fixed_int().unwrap();
411        let v8: i64 = stream.read_fixed_int().unwrap();
412        let v9: u128 = stream.read_fixed_int().unwrap();
413        let v10: i128 = stream.read_fixed_int().unwrap();
414
415        assert_eq!(v1, 1);
416        assert_eq!(v2, 1);
417        assert_eq!(v3, 2);
418        assert_eq!(v4, 2);
419        assert_eq!(v5, 3);
420        assert_eq!(v6, 3);
421        assert_eq!(v7, 4);
422        assert_eq!(v8, 4);
423        assert_eq!(v9, 5);
424        assert_eq!(v10, 5);
425    }
426
427    #[test]
428    fn test_bytes_left() {
429        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
430        let mut reader = BitStreamReader::new(&buf);
431
432        assert_eq!(reader.bytes_left(), 4);
433        reader.read_small(3).unwrap(); // read 3 bits
434        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
435        reader.read_byte().unwrap(); // read one byte
436        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
437        reader.read_byte().unwrap(); // read another byte
438        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
439        reader.read_bit().unwrap(); // read one bit
440        assert_eq!(reader.bytes_left(), 0); // no full bytes left
441    }
442
443    #[test]
444    fn offset_end_ignores_bytes_and_can_slice() {
445        let buff = [1, 2, 3, 4, 5];
446        let mut reader = BitStreamReader::new(&buff);
447
448        reader.set_offset_end(2);
449        assert_eq!(reader.bytes_left(), 3);
450        assert_eq!(reader.read_byte(), Ok(1));
451
452        assert_eq!(reader.slice(true), &[1, 2, 3]);
453        assert_eq!(reader.slice(false), &[2, 3]);
454        assert_eq!(reader.slice_end(), &[4, 5]);
455
456        assert_eq!(reader.read_byte(), Ok(2));
457        assert_eq!(reader.read_byte(), Ok(3));
458        assert_eq!(
459            reader.read_byte(),
460            Err(DeserializationError::NotEnoughBytes(1))
461        );
462
463        reader.set_offset_end(0);
464        assert_eq!(reader.bytes_left(), 2);
465        assert_eq!(reader.read_byte(), Ok(4));
466        assert_eq!(reader.read_byte(), Ok(5));
467    }
468
469    #[test]
470    fn test_slice_start() {
471        let buff = [10, 20, 30, 40, 50];
472        let mut reader = BitStreamReader::new(&buff);
473
474        assert_eq!(reader.slice_marker(None), &[]);
475
476        reader.read_byte().unwrap(); // Read 10
477        assert_eq!(reader.slice_marker(None), &[10]);
478
479        reader.read_small(4).unwrap(); // Read 4 bits of 20
480        assert_eq!(reader.slice_marker(None), &[10]);
481
482        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
483        assert_eq!(reader.slice_marker(None), &[10, 20]);
484
485        reader.read_bytes(2).unwrap(); // Read 30, 40
486        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
487    }
488
489    #[test]
490    fn test_slice_start_with_marker() {
491        let buff = [10, 20, 30, 40, 50];
492        let mut reader = BitStreamReader::new(&buff);
493
494        reader.read_byte().unwrap(); // Read 10
495        assert_eq!(reader.slice_marker(None), &[10]);
496        reader.set_marker();
497        assert_eq!(reader.slice_marker(None), &[]);
498
499        reader.read_bytes(2).unwrap(); // Read 20, 30
500        assert_eq!(reader.slice_marker(None), &[20, 30]);
501    }
502}