Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    /// If crypto is set, it will return the decrypted slice. This is different from other `slice` methods which always returns the raw buffer slice.
49    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50        let start = self.marker.unwrap_or(0);
51        let end = to.unwrap_or(self.byte_pos());
52
53        if let Some(crypto) = self.crypto.as_ref() {
54            return &crypto.get_plaintext()[start..end];
55        }
56
57        &self.buffer[start..end]
58    }
59
60    /// Return slice from offset-end to end of buffer
61    /// If crypto is set, it will apply the keystream to the slice before returning.
62    /// This is different from other `slice` methods which always returns the raw buffer slice.
63    pub fn slice_end(&mut self) -> &[u8] {
64        let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66        if let Some(crypto) = self.crypto.as_mut() {
67            crypto.apply_keystream(slice)
68        } else {
69            slice
70        }
71    }
72
73    /// Set crypto stream
74    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
75        self.crypto = crypto;
76    }
77
78    /// Set integrity offset to ignore when reading
79    pub fn set_offset_end(&mut self, len: usize) {
80        self.offset_end = len;
81    }
82
83    /// Get byte position of reader
84    pub fn byte_pos(&self) -> usize {
85        self.bit_pos / 8
86    }
87
88    /// Get current byte, from last_read_byte cache or from buffer
89    fn current_byte(&mut self) -> u8 {
90        if let Some(b) = self.last_read_byte {
91            b
92        } else {
93            let mut b = self.buffer[self.byte_pos()];
94            if let Some(crypto) = self.crypto.as_mut() {
95                b = crypto.apply_keystream_byte(b);
96            }
97
98            self.last_read_byte = Some(b);
99            b
100        }
101    }
102
103    /// Read a single bit
104    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
105        self.read_small(1).map(|v| v != 0)
106    }
107
108    /// Read 1-8 bits as u8 (LSB-first)
109    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
110        assert!(bits > 0 && bits < 8);
111
112        let mut result: u8 = 0;
113        let mut shift = 0;
114
115        while bits > 0 {
116            if self.byte_pos() >= self.buffer.len() - self.offset_end {
117                return Err(DeserializationError::NotEnoughBytes(1));
118            }
119
120            // Find the bit position inside the current byte (0..7)
121            let bit_offset = self.bit_pos % 8;
122
123            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
124            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
125
126            // Create a mask to isolate the bits we want from this byte.
127            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
128            // mask = 00011100 (only bits 2,3,4 are 1)
129            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
130            let byte_val = self.current_byte();
131
132            // Apply the mask to isolate the bits and shift them to LSB
133            // Example: byte_val = 10101100, mask = 00011100
134            // (10101100 & 00011100) >> 2 = 00000101
135            let val = (byte_val & mask) >> bit_offset;
136
137            // Merge the extracted bits into the final result.
138            // Shift them to the correct position based on how many bits we already read.
139            result |= val << shift;
140
141            // Decrease the remaining bits we need to read
142            bits -= bits_in_current_byte;
143
144            // Update the shift for the next batch of bits (if crossing byte boundary)
145            shift += bits_in_current_byte;
146
147            self.bit_pos += bits_in_current_byte as usize;
148
149            // If crossed byte boundary, reset last read byte
150            if self.bit_pos % 8 == 0 {
151                self.last_read_byte = None;
152            }
153        }
154
155        Ok(result)
156    }
157
158    /// Read a full byte, aligning to the next byte boundary
159    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
160        self.align_byte();
161
162        if self.byte_pos() >= self.buffer.len() - self.offset_end {
163            return Err(DeserializationError::NotEnoughBytes(1));
164        }
165
166        let byte = self.current_byte();
167        self.bit_pos += 8;
168        self.last_read_byte = None;
169
170        Ok(byte)
171    }
172
173    /// Read a slice of bytes, aligning first
174    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
175        self.align_byte();
176
177        let start = self.byte_pos();
178        if start + count > self.buffer.len() - self.offset_end {
179            return Err(DeserializationError::NotEnoughBytes(
180                start + count - self.buffer.len(),
181            ));
182        }
183
184        self.bit_pos += 8 * count;
185        self.last_read_byte = None;
186
187        let slice = &self.buffer[start..start + count];
188        if let Some(crypto) = self.crypto.as_mut() {
189            Ok(crypto.apply_keystream(slice))
190        } else {
191            Ok(slice)
192        }
193    }
194
195    /// Read a dynamic int, starting at the next byte bounary
196    /// The last bit is used as a continuation flag for the next byte
197    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
198        self.align_byte();
199        let mut num: u128 = 0;
200        let mut multiplier: u128 = 1;
201
202        loop {
203            let byte = self.read_byte()?; // None if EOF
204            num += ((byte & 127) as u128) * multiplier;
205
206            // If no continuation bit, stop
207            if (byte & 1 << 7) == 0 {
208                break;
209            }
210
211            multiplier *= 128;
212        }
213
214        Ok(num)
215    }
216
217    /// Read a integer of fixed size from the buffer
218    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
219        &mut self,
220    ) -> Result<T, DeserializationError> {
221        let data = self.read_bytes(S)?;
222        Ok(FixedInt::deserialize(data))
223    }
224
225    /// Align the reader to the next byte boundary
226    pub fn align_byte(&mut self) {
227        let rem = self.bit_pos % 8;
228        if rem != 0 {
229            self.bit_pos += 8 - rem;
230            self.last_read_byte = None;
231        }
232    }
233
234    /// Get bytes left
235    pub fn bytes_left(&self) -> usize {
236        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
237        if self.bit_pos % 8 != 0 {
238            left - 1 // If not aligned, we can't read the last byte fully
239        } else {
240            left
241        }
242    }
243
244    /// Reset reading position
245    pub fn reset(&mut self) {
246        self.bit_pos = 0;
247    }
248}
249
250#[cfg(test)]
251mod tests {
252    use crate::{DeserializationError, bitstream::CryptoStream};
253
254    use super::BitStreamReader;
255
256    struct PlusOneDecrypter {
257        plain: Vec<u8>,
258    }
259
260    impl CryptoStream for PlusOneDecrypter {
261        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
262            self.plain.push(b + 1);
263            *self.plain.last().unwrap()
264        }
265
266        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
267            let d = slice.iter().map(|s| s + 1);
268            self.plain.extend(d);
269            &self.plain[self.plain.len() - slice.len()..]
270        }
271
272        fn get_plaintext(&self) -> &[u8] {
273            &self.plain
274        }
275    }
276
277    #[test]
278    fn test_decrypt_bytes() {
279        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
280        let mut reader = BitStreamReader::new(&buf);
281        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
282
283        assert_eq!(reader.read_byte(), Ok(2));
284        assert_eq!(reader.read_byte(), Ok(3));
285        assert_eq!(reader.read_byte(), Ok(4));
286        // 4 = 00000100, +1 = 00000101
287        assert_eq!(reader.read_bit(), Ok(true));
288        assert_eq!(reader.read_bit(), Ok(false));
289        assert_eq!(reader.read_bit(), Ok(true));
290        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
291        assert_eq!(reader.read_byte(), Ok(11));
292    }
293
294    /// Helper to build buffers
295    fn make_buffer() -> Vec<u8> {
296        vec![0b10101100, 0b11010010, 0xFF, 0x00]
297    }
298
299    #[test]
300    fn test_read_single_bits() {
301        let buf = make_buffer();
302        let mut reader = BitStreamReader::new(&buf);
303
304        // LSB-first: read bits starting from least significant
305        assert_eq!(reader.read_bit(), Ok(false));
306        assert_eq!(reader.read_bit(), Ok(false));
307        assert_eq!(reader.read_bit(), Ok(true));
308        assert_eq!(reader.read_bit(), Ok(true));
309        assert_eq!(reader.read_bit(), Ok(false));
310        assert_eq!(reader.read_bit(), Ok(true));
311        assert_eq!(reader.read_bit(), Ok(false));
312        assert_eq!(reader.read_bit(), Ok(true));
313    }
314
315    #[test]
316    fn test_read_small() {
317        let buf = [0b10101100, 0b11010010];
318        let mut reader = BitStreamReader::new(&buf);
319
320        assert_eq!(reader.read_small(3), Ok(0b100));
321        assert_eq!(reader.read_small(4), Ok(0b0101));
322        assert_eq!(reader.read_small(1), Ok(0b1));
323        assert_eq!(reader.read_small(4), Ok(0b0010));
324    }
325
326    #[test]
327    fn test_read_cross_byte() {
328        let buf = [0b10101100, 0b11010001];
329        let mut reader = BitStreamReader::new(&buf);
330
331        // Read first 10 bits (crosses into second byte)
332        assert_eq!(reader.read_small(7), Ok(0b00101100));
333        assert_eq!(reader.read_small(3), Ok(0b011));
334    }
335
336    #[test]
337    fn test_read_byte() {
338        let buf = [0b10101100, 0b11010010];
339        let mut reader = BitStreamReader::new(&buf);
340
341        reader.read_small(3).unwrap(); // advance 3 bits
342        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
343    }
344
345    #[test]
346    fn test_read_bytes() {
347        let buf = [0x01, 0xAA, 0xBB, 0xCC];
348        let mut reader = BitStreamReader::new(&buf);
349
350        reader.read_bit().unwrap(); // first bit
351        let slice = reader.read_bytes(3).unwrap();
352        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
353    }
354
355    #[test]
356    fn test_align_byte() {
357        let buf = [0b10101100, 0b11010010];
358        let mut reader = BitStreamReader::new(&buf);
359
360        reader.read_small(3).unwrap(); // 3 bits
361        reader.align_byte(); // move to next byte
362        assert_eq!(reader.read_byte(), Ok(0b11010010));
363    }
364
365    #[test]
366    fn test_eof_behavior() {
367        let buf = [0xFF];
368        let mut reader = BitStreamReader::new(&buf);
369
370        assert_eq!(reader.read_byte(), Ok(0xFF));
371        assert_eq!(
372            reader.read_bit(),
373            Err(DeserializationError::NotEnoughBytes(1))
374        );
375        assert_eq!(
376            reader.read_byte(),
377            Err(DeserializationError::NotEnoughBytes(1))
378        );
379        assert_eq!(
380            reader.read_bytes(2),
381            Err(DeserializationError::NotEnoughBytes(2))
382        );
383    }
384
385    #[test]
386    fn test_multiple_operations() {
387        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
388        let mut reader = BitStreamReader::new(&buf);
389
390        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
391        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
392        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
393        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
394        assert_eq!(
395            reader.read_bit(),
396            Err(DeserializationError::NotEnoughBytes(1))
397        );
398    }
399
400    #[test]
401    fn test_read_dyn_int() {
402        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
403        let mut stream = BitStreamReader::new(&buf);
404
405        assert_eq!(Ok(0), stream.read_byte());
406        assert_eq!(Ok(127), stream.read_dyn_int());
407        assert_eq!(Ok(128), stream.read_dyn_int());
408        assert_eq!(Ok(268435455), stream.read_dyn_int());
409        assert_eq!(
410            Err(DeserializationError::NotEnoughBytes(1)),
411            stream.read_dyn_int()
412        );
413    }
414
415    #[test]
416    fn test_read_fixed_int() {
417        let buf = vec![
418            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
419            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
420            0, 0, 0, 10,
421        ];
422
423        let mut stream = BitStreamReader::new(&buf);
424        let v1: u8 = stream.read_fixed_int().unwrap();
425        let v2: i8 = stream.read_fixed_int().unwrap();
426        let v3: u16 = stream.read_fixed_int().unwrap();
427        let v4: i16 = stream.read_fixed_int().unwrap();
428        let v5: u32 = stream.read_fixed_int().unwrap();
429        let v6: i32 = stream.read_fixed_int().unwrap();
430        let v7: u64 = stream.read_fixed_int().unwrap();
431        let v8: i64 = stream.read_fixed_int().unwrap();
432        let v9: u128 = stream.read_fixed_int().unwrap();
433        let v10: i128 = stream.read_fixed_int().unwrap();
434
435        assert_eq!(v1, 1);
436        assert_eq!(v2, 1);
437        assert_eq!(v3, 2);
438        assert_eq!(v4, 2);
439        assert_eq!(v5, 3);
440        assert_eq!(v6, 3);
441        assert_eq!(v7, 4);
442        assert_eq!(v8, 4);
443        assert_eq!(v9, 5);
444        assert_eq!(v10, 5);
445    }
446
447    #[test]
448    fn test_bytes_left() {
449        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
450        let mut reader = BitStreamReader::new(&buf);
451
452        assert_eq!(reader.bytes_left(), 4);
453        reader.read_small(3).unwrap(); // read 3 bits
454        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
455        reader.read_byte().unwrap(); // read one byte
456        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
457        reader.read_byte().unwrap(); // read another byte
458        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
459        reader.read_bit().unwrap(); // read one bit
460        assert_eq!(reader.bytes_left(), 0); // no full bytes left
461    }
462
463    #[test]
464    fn offset_end_ignores_bytes_and_can_slice() {
465        let buff = [1, 2, 3, 4, 5];
466        let mut reader = BitStreamReader::new(&buff);
467
468        reader.set_offset_end(2);
469        assert_eq!(reader.bytes_left(), 3);
470        assert_eq!(reader.read_byte(), Ok(1));
471
472        assert_eq!(reader.slice(true), &[1, 2, 3]);
473        assert_eq!(reader.slice(false), &[2, 3]);
474        assert_eq!(reader.slice_end(), &[4, 5]);
475
476        assert_eq!(reader.read_byte(), Ok(2));
477        assert_eq!(reader.read_byte(), Ok(3));
478        assert_eq!(
479            reader.read_byte(),
480            Err(DeserializationError::NotEnoughBytes(1))
481        );
482
483        reader.set_offset_end(0);
484        assert_eq!(reader.bytes_left(), 2);
485        assert_eq!(reader.read_byte(), Ok(4));
486        assert_eq!(reader.read_byte(), Ok(5));
487    }
488
489    #[test]
490    fn test_slice_start() {
491        let buff = [10, 20, 30, 40, 50];
492        let mut reader = BitStreamReader::new(&buff);
493
494        assert_eq!(reader.slice_marker(None), &[]);
495
496        reader.read_byte().unwrap(); // Read 10
497        assert_eq!(reader.slice_marker(None), &[10]);
498
499        reader.read_small(4).unwrap(); // Read 4 bits of 20
500        assert_eq!(reader.slice_marker(None), &[10]);
501
502        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
503        assert_eq!(reader.slice_marker(None), &[10, 20]);
504
505        reader.read_bytes(2).unwrap(); // Read 30, 40
506        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
507    }
508
509    #[test]
510    fn test_slice_start_with_marker() {
511        let buff = [10, 20, 30, 40, 50];
512        let mut reader = BitStreamReader::new(&buff);
513
514        reader.read_byte().unwrap(); // Read 10
515        assert_eq!(reader.slice_marker(None), &[10]);
516        reader.set_marker();
517        assert_eq!(reader.slice_marker(None), &[]);
518
519        reader.read_bytes(2).unwrap(); // Read 20, 30
520        assert_eq!(reader.slice_marker(None), &[20, 30]);
521    }
522}