Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    /// If crypto is set, it will return the decrypted slice. This is different from other `slice` methods which always returns the raw buffer slice.
49    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50        let start = self.marker.unwrap_or(0);
51        let end = to.unwrap_or(self.byte_pos());
52
53        if let Some(crypto) = self.crypto.as_ref() {
54            return &crypto.get_cached(false)[start..end];
55        }
56
57        &self.buffer[start..end]
58    }
59
60    /// Return slice from offset-end to end of buffer
61    /// If crypto is set, it will apply the keystream to the slice before returning.
62    /// This is different from other `slice` methods which always returns the raw buffer slice.
63    pub fn slice_end(&mut self) -> &[u8] {
64        let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66        if let Some(crypto) = self.crypto.as_mut() {
67            crypto.apply_keystream(slice)
68        } else {
69            slice
70        }
71    }
72
73    /// Return slice from start of buffer to current byte position
74    pub fn slice_start(&self) -> &[u8] {
75        &self.buffer[0..self.byte_pos()]
76    }
77
78    /// Set crypto stream
79    pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
80        if let Some(new) = crypto.as_mut() {
81            if let Some(existing) = self.crypto.as_ref() {
82                new.replace(existing);
83            } else {
84                // Initialize new crypto stream with the data read so far, so that it can cache full plaintext
85                new.set_cached(self.slice_start());
86            }
87        }
88
89        self.crypto = crypto;
90    }
91
92    /// Remove crypto stream
93    pub fn reset_crypto(&mut self) {
94        self.crypto = None;
95    }
96
97    /// Set integrity offset to ignore when reading
98    pub fn set_offset_end(&mut self, len: usize) {
99        self.offset_end = len;
100    }
101
102    /// Get byte position of reader
103    pub fn byte_pos(&self) -> usize {
104        self.bit_pos / 8
105    }
106
107    /// Get current byte, from last_read_byte cache or from buffer
108    fn current_byte(&mut self) -> u8 {
109        if let Some(b) = self.last_read_byte {
110            b
111        } else {
112            let mut b = self.buffer[self.byte_pos()];
113            if let Some(crypto) = self.crypto.as_mut() {
114                b = crypto.apply_keystream_byte(b);
115            }
116
117            self.last_read_byte = Some(b);
118            b
119        }
120    }
121
122    /// Read a single bit
123    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
124        self.read_small(1).map(|v| v != 0)
125    }
126
127    /// Read 1-8 bits as u8 (LSB-first)
128    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
129        assert!(bits > 0 && bits < 8);
130
131        let mut result: u8 = 0;
132        let mut shift = 0;
133
134        while bits > 0 {
135            if self.byte_pos() >= self.buffer.len() - self.offset_end {
136                return Err(DeserializationError::NotEnoughBytes(1));
137            }
138
139            // Find the bit position inside the current byte (0..7)
140            let bit_offset = self.bit_pos % 8;
141
142            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
143            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
144
145            // Create a mask to isolate the bits we want from this byte.
146            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
147            // mask = 00011100 (only bits 2,3,4 are 1)
148            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
149            let byte_val = self.current_byte();
150
151            // Apply the mask to isolate the bits and shift them to LSB
152            // Example: byte_val = 10101100, mask = 00011100
153            // (10101100 & 00011100) >> 2 = 00000101
154            let val = (byte_val & mask) >> bit_offset;
155
156            // Merge the extracted bits into the final result.
157            // Shift them to the correct position based on how many bits we already read.
158            result |= val << shift;
159
160            // Decrease the remaining bits we need to read
161            bits -= bits_in_current_byte;
162
163            // Update the shift for the next batch of bits (if crossing byte boundary)
164            shift += bits_in_current_byte;
165
166            self.bit_pos += bits_in_current_byte as usize;
167
168            // If crossed byte boundary, reset last read byte
169            if self.bit_pos % 8 == 0 {
170                self.last_read_byte = None;
171            }
172        }
173
174        Ok(result)
175    }
176
177    /// Read a full byte, aligning to the next byte boundary
178    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
179        self.align_byte();
180
181        if self.byte_pos() >= self.buffer.len() - self.offset_end {
182            return Err(DeserializationError::NotEnoughBytes(1));
183        }
184
185        let byte = self.current_byte();
186        self.bit_pos += 8;
187        self.last_read_byte = None;
188
189        Ok(byte)
190    }
191
192    /// Read a slice of bytes, aligning first
193    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
194        self.align_byte();
195
196        let start = self.byte_pos();
197        if start + count > self.buffer.len() - self.offset_end {
198            return Err(DeserializationError::NotEnoughBytes(
199                (start + count - self.buffer.len()) as u64,
200            ));
201        }
202
203        self.bit_pos += 8 * count;
204        self.last_read_byte = None;
205
206        let slice = &self.buffer[start..start + count];
207        if let Some(crypto) = self.crypto.as_mut() {
208            Ok(crypto.apply_keystream(slice))
209        } else {
210            Ok(slice)
211        }
212    }
213
214    /// Read a dynamic int, starting at the next byte bounary
215    /// The last bit is used as a continuation flag for the next byte
216    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
217        self.align_byte();
218        let mut num: u128 = 0;
219        let mut multiplier: u128 = 1;
220
221        loop {
222            let byte = self.read_byte()?; // None if EOF
223            num += ((byte & 127) as u128) * multiplier;
224
225            // If no continuation bit, stop
226            if (byte & 1 << 7) == 0 {
227                break;
228            }
229
230            multiplier *= 128;
231        }
232
233        Ok(num)
234    }
235
236    /// Read a integer of fixed size from the buffer
237    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
238        &mut self,
239    ) -> Result<T, DeserializationError> {
240        let data = self.read_bytes(S)?;
241        Ok(FixedInt::deserialize(data))
242    }
243
244    /// Align the reader to the next byte boundary
245    pub fn align_byte(&mut self) {
246        let rem = self.bit_pos % 8;
247        if rem != 0 {
248            self.bit_pos += 8 - rem;
249            self.last_read_byte = None;
250        }
251    }
252
253    /// Get bytes left
254    pub fn bytes_left(&self) -> usize {
255        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
256        if self.bit_pos % 8 != 0 {
257            left - 1 // If not aligned, we can't read the last byte fully
258        } else {
259            left
260        }
261    }
262
263    /// Reset reading position
264    pub fn reset(&mut self) {
265        self.bit_pos = 0;
266    }
267}
268
269#[cfg(test)]
270mod tests {
271    use crate::{DeserializationError, bitstream::CryptoStream};
272
273    use super::BitStreamReader;
274
275    struct PlusOneDecrypter {
276        plain: Vec<u8>,
277    }
278
279    impl CryptoStream for PlusOneDecrypter {
280        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
281            self.plain.push(b + 1);
282            *self.plain.last().unwrap()
283        }
284
285        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
286            let d = slice.iter().map(|s| s + 1);
287            self.plain.extend(d);
288            &self.plain[self.plain.len() - slice.len()..]
289        }
290
291        fn get_cached(&self, original: bool) -> &[u8] {
292            &self.plain
293        }
294
295        fn replace(&mut self, other: &Box<dyn CryptoStream>) {
296            self.plain = other.get_cached(true).to_vec();
297        }
298        
299        fn set_cached(&mut self, data: &[u8]) {
300            self.plain = data.to_vec();
301        }
302    }
303
304    #[test]
305    fn test_decrypt_bytes() {
306        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
307        let mut reader = BitStreamReader::new(&buf);
308        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
309
310        assert_eq!(reader.read_byte(), Ok(2));
311        assert_eq!(reader.read_byte(), Ok(3));
312        assert_eq!(reader.read_byte(), Ok(4));
313        // 4 = 00000100, +1 = 00000101
314        assert_eq!(reader.read_bit(), Ok(true));
315        assert_eq!(reader.read_bit(), Ok(false));
316        assert_eq!(reader.read_bit(), Ok(true));
317        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
318        assert_eq!(reader.read_byte(), Ok(11));
319    }
320
321    /// Helper to build buffers
322    fn make_buffer() -> Vec<u8> {
323        vec![0b10101100, 0b11010010, 0xFF, 0x00]
324    }
325
326    #[test]
327    fn test_read_single_bits() {
328        let buf = make_buffer();
329        let mut reader = BitStreamReader::new(&buf);
330
331        // LSB-first: read bits starting from least significant
332        assert_eq!(reader.read_bit(), Ok(false));
333        assert_eq!(reader.read_bit(), Ok(false));
334        assert_eq!(reader.read_bit(), Ok(true));
335        assert_eq!(reader.read_bit(), Ok(true));
336        assert_eq!(reader.read_bit(), Ok(false));
337        assert_eq!(reader.read_bit(), Ok(true));
338        assert_eq!(reader.read_bit(), Ok(false));
339        assert_eq!(reader.read_bit(), Ok(true));
340    }
341
342    #[test]
343    fn test_read_small() {
344        let buf = [0b10101100, 0b11010010];
345        let mut reader = BitStreamReader::new(&buf);
346
347        assert_eq!(reader.read_small(3), Ok(0b100));
348        assert_eq!(reader.read_small(4), Ok(0b0101));
349        assert_eq!(reader.read_small(1), Ok(0b1));
350        assert_eq!(reader.read_small(4), Ok(0b0010));
351    }
352
353    #[test]
354    fn test_read_cross_byte() {
355        let buf = [0b10101100, 0b11010001];
356        let mut reader = BitStreamReader::new(&buf);
357
358        // Read first 10 bits (crosses into second byte)
359        assert_eq!(reader.read_small(7), Ok(0b00101100));
360        assert_eq!(reader.read_small(3), Ok(0b011));
361    }
362
363    #[test]
364    fn test_read_byte() {
365        let buf = [0b10101100, 0b11010010];
366        let mut reader = BitStreamReader::new(&buf);
367
368        reader.read_small(3).unwrap(); // advance 3 bits
369        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
370    }
371
372    #[test]
373    fn test_read_bytes() {
374        let buf = [0x01, 0xAA, 0xBB, 0xCC];
375        let mut reader = BitStreamReader::new(&buf);
376
377        reader.read_bit().unwrap(); // first bit
378        let slice = reader.read_bytes(3).unwrap();
379        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
380    }
381
382    #[test]
383    fn test_align_byte() {
384        let buf = [0b10101100, 0b11010010];
385        let mut reader = BitStreamReader::new(&buf);
386
387        reader.read_small(3).unwrap(); // 3 bits
388        reader.align_byte(); // move to next byte
389        assert_eq!(reader.read_byte(), Ok(0b11010010));
390    }
391
392    #[test]
393    fn test_eof_behavior() {
394        let buf = [0xFF];
395        let mut reader = BitStreamReader::new(&buf);
396
397        assert_eq!(reader.read_byte(), Ok(0xFF));
398        assert_eq!(
399            reader.read_bit(),
400            Err(DeserializationError::NotEnoughBytes(1))
401        );
402        assert_eq!(
403            reader.read_byte(),
404            Err(DeserializationError::NotEnoughBytes(1))
405        );
406        assert_eq!(
407            reader.read_bytes(2),
408            Err(DeserializationError::NotEnoughBytes(2))
409        );
410    }
411
412    #[test]
413    fn test_multiple_operations() {
414        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
415        let mut reader = BitStreamReader::new(&buf);
416
417        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
418        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
419        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
420        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
421        assert_eq!(
422            reader.read_bit(),
423            Err(DeserializationError::NotEnoughBytes(1))
424        );
425    }
426
427    #[test]
428    fn test_read_dyn_int() {
429        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
430        let mut stream = BitStreamReader::new(&buf);
431
432        assert_eq!(Ok(0), stream.read_byte());
433        assert_eq!(Ok(127), stream.read_dyn_int());
434        assert_eq!(Ok(128), stream.read_dyn_int());
435        assert_eq!(Ok(268435455), stream.read_dyn_int());
436        assert_eq!(
437            Err(DeserializationError::NotEnoughBytes(1)),
438            stream.read_dyn_int()
439        );
440    }
441
442    #[test]
443    fn test_read_fixed_int() {
444        let buf = vec![
445            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
446            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
447            0, 0, 0, 10,
448        ];
449
450        let mut stream = BitStreamReader::new(&buf);
451        let v1: u8 = stream.read_fixed_int().unwrap();
452        let v2: i8 = stream.read_fixed_int().unwrap();
453        let v3: u16 = stream.read_fixed_int().unwrap();
454        let v4: i16 = stream.read_fixed_int().unwrap();
455        let v5: u32 = stream.read_fixed_int().unwrap();
456        let v6: i32 = stream.read_fixed_int().unwrap();
457        let v7: u64 = stream.read_fixed_int().unwrap();
458        let v8: i64 = stream.read_fixed_int().unwrap();
459        let v9: u128 = stream.read_fixed_int().unwrap();
460        let v10: i128 = stream.read_fixed_int().unwrap();
461
462        assert_eq!(v1, 1);
463        assert_eq!(v2, 1);
464        assert_eq!(v3, 2);
465        assert_eq!(v4, 2);
466        assert_eq!(v5, 3);
467        assert_eq!(v6, 3);
468        assert_eq!(v7, 4);
469        assert_eq!(v8, 4);
470        assert_eq!(v9, 5);
471        assert_eq!(v10, 5);
472    }
473
474    #[test]
475    fn test_bytes_left() {
476        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
477        let mut reader = BitStreamReader::new(&buf);
478
479        assert_eq!(reader.bytes_left(), 4);
480        reader.read_small(3).unwrap(); // read 3 bits
481        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
482        reader.read_byte().unwrap(); // read one byte
483        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
484        reader.read_byte().unwrap(); // read another byte
485        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
486        reader.read_bit().unwrap(); // read one bit
487        assert_eq!(reader.bytes_left(), 0); // no full bytes left
488    }
489
490    #[test]
491    fn offset_end_ignores_bytes_and_can_slice() {
492        let buff = [1, 2, 3, 4, 5];
493        let mut reader = BitStreamReader::new(&buff);
494
495        reader.set_offset_end(2);
496        assert_eq!(reader.bytes_left(), 3);
497        assert_eq!(reader.read_byte(), Ok(1));
498
499        assert_eq!(reader.slice(true), &[1, 2, 3]);
500        assert_eq!(reader.slice(false), &[2, 3]);
501        assert_eq!(reader.slice_end(), &[4, 5]);
502
503        assert_eq!(reader.read_byte(), Ok(2));
504        assert_eq!(reader.read_byte(), Ok(3));
505        assert_eq!(
506            reader.read_byte(),
507            Err(DeserializationError::NotEnoughBytes(1))
508        );
509
510        reader.set_offset_end(0);
511        assert_eq!(reader.bytes_left(), 2);
512        assert_eq!(reader.read_byte(), Ok(4));
513        assert_eq!(reader.read_byte(), Ok(5));
514    }
515
516    #[test]
517    fn test_slice_start() {
518        let buff = [10, 20, 30, 40, 50];
519        let mut reader = BitStreamReader::new(&buff);
520
521        assert_eq!(reader.slice_marker(None), &[]);
522
523        reader.read_byte().unwrap(); // Read 10
524        assert_eq!(reader.slice_marker(None), &[10]);
525
526        reader.read_small(4).unwrap(); // Read 4 bits of 20
527        assert_eq!(reader.slice_marker(None), &[10]);
528
529        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
530        assert_eq!(reader.slice_marker(None), &[10, 20]);
531
532        reader.read_bytes(2).unwrap(); // Read 30, 40
533        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
534    }
535
536    #[test]
537    fn test_slice_start_with_marker() {
538        let buff = [10, 20, 30, 40, 50];
539        let mut reader = BitStreamReader::new(&buff);
540
541        reader.read_byte().unwrap(); // Read 10
542        assert_eq!(reader.slice_marker(None), &[10]);
543        reader.set_marker();
544        assert_eq!(reader.slice_marker(None), &[]);
545
546        reader.read_bytes(2).unwrap(); // Read 20, 30
547        assert_eq!(reader.slice_marker(None), &[20, 30]);
548    }
549
550    #[test]
551    fn test_can_read_dynint_0() {
552        let buf = vec![0, 1];
553        let mut stream = BitStreamReader::new(&buf);
554
555        assert_eq!(stream.read_dyn_int(), Ok(0));
556        assert_eq!(stream.read_byte(), Ok(1));
557    }
558}