Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    /// If crypto is set, it will return the decrypted slice. This is different from other `slice` methods which always returns the raw buffer slice.
49    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50        let start = self.marker.unwrap_or(0);
51        let end = to.unwrap_or(self.byte_pos());
52
53        if let Some(crypto) = self.crypto.as_ref() {
54            return &crypto.get_cached(false)[start..end];
55        }
56
57        &self.buffer[start..end]
58    }
59
60    /// Return slice from offset-end to end of buffer
61    /// If crypto is set, it will apply the keystream to the slice before returning.
62    /// This is different from other `slice` methods which always returns the raw buffer slice.
63    pub fn slice_end(&mut self) -> &[u8] {
64        let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66        if let Some(crypto) = self.crypto.as_mut() {
67            crypto.apply_keystream(slice)
68        } else {
69            slice
70        }
71    }
72
73    /// Set crypto stream
74    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
75        self.crypto = crypto;
76    }
77
78    /// Remove crypto stream
79    pub fn reset_crypto(&mut self) {
80        self.crypto = None;
81    }
82
83    /// Set integrity offset to ignore when reading
84    pub fn set_offset_end(&mut self, len: usize) {
85        self.offset_end = len;
86    }
87
88    /// Get byte position of reader
89    pub fn byte_pos(&self) -> usize {
90        self.bit_pos / 8
91    }
92
93    /// Get current byte, from last_read_byte cache or from buffer
94    fn current_byte(&mut self) -> u8 {
95        if let Some(b) = self.last_read_byte {
96            b
97        } else {
98            let mut b = self.buffer[self.byte_pos()];
99            if let Some(crypto) = self.crypto.as_mut() {
100                b = crypto.apply_keystream_byte(b);
101            }
102
103            self.last_read_byte = Some(b);
104            b
105        }
106    }
107
108    /// Read a single bit
109    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
110        self.read_small(1).map(|v| v != 0)
111    }
112
113    /// Read 1-8 bits as u8 (LSB-first)
114    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
115        assert!(bits > 0 && bits < 8);
116
117        let mut result: u8 = 0;
118        let mut shift = 0;
119
120        while bits > 0 {
121            if self.byte_pos() >= self.buffer.len() - self.offset_end {
122                return Err(DeserializationError::NotEnoughBytes(1));
123            }
124
125            // Find the bit position inside the current byte (0..7)
126            let bit_offset = self.bit_pos % 8;
127
128            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
129            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
130
131            // Create a mask to isolate the bits we want from this byte.
132            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
133            // mask = 00011100 (only bits 2,3,4 are 1)
134            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
135            let byte_val = self.current_byte();
136
137            // Apply the mask to isolate the bits and shift them to LSB
138            // Example: byte_val = 10101100, mask = 00011100
139            // (10101100 & 00011100) >> 2 = 00000101
140            let val = (byte_val & mask) >> bit_offset;
141
142            // Merge the extracted bits into the final result.
143            // Shift them to the correct position based on how many bits we already read.
144            result |= val << shift;
145
146            // Decrease the remaining bits we need to read
147            bits -= bits_in_current_byte;
148
149            // Update the shift for the next batch of bits (if crossing byte boundary)
150            shift += bits_in_current_byte;
151
152            self.bit_pos += bits_in_current_byte as usize;
153
154            // If crossed byte boundary, reset last read byte
155            if self.bit_pos % 8 == 0 {
156                self.last_read_byte = None;
157            }
158        }
159
160        Ok(result)
161    }
162
163    /// Read a full byte, aligning to the next byte boundary
164    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
165        self.align_byte();
166
167        if self.byte_pos() >= self.buffer.len() - self.offset_end {
168            return Err(DeserializationError::NotEnoughBytes(1));
169        }
170
171        let byte = self.current_byte();
172        self.bit_pos += 8;
173        self.last_read_byte = None;
174
175        Ok(byte)
176    }
177
178    /// Read a slice of bytes, aligning first
179    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
180        self.align_byte();
181
182        let start = self.byte_pos();
183        if start + count > self.buffer.len() - self.offset_end {
184            return Err(DeserializationError::NotEnoughBytes(
185                start + count - self.buffer.len(),
186            ));
187        }
188
189        self.bit_pos += 8 * count;
190        self.last_read_byte = None;
191
192        let slice = &self.buffer[start..start + count];
193        if let Some(crypto) = self.crypto.as_mut() {
194            Ok(crypto.apply_keystream(slice))
195        } else {
196            Ok(slice)
197        }
198    }
199
200    /// Read a dynamic int, starting at the next byte bounary
201    /// The last bit is used as a continuation flag for the next byte
202    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
203        self.align_byte();
204        let mut num: u128 = 0;
205        let mut multiplier: u128 = 1;
206
207        loop {
208            let byte = self.read_byte()?; // None if EOF
209            num += ((byte & 127) as u128) * multiplier;
210
211            // If no continuation bit, stop
212            if (byte & 1 << 7) == 0 {
213                break;
214            }
215
216            multiplier *= 128;
217        }
218
219        Ok(num)
220    }
221
222    /// Read a integer of fixed size from the buffer
223    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
224        &mut self,
225    ) -> Result<T, DeserializationError> {
226        let data = self.read_bytes(S)?;
227        Ok(FixedInt::deserialize(data))
228    }
229
230    /// Align the reader to the next byte boundary
231    pub fn align_byte(&mut self) {
232        let rem = self.bit_pos % 8;
233        if rem != 0 {
234            self.bit_pos += 8 - rem;
235            self.last_read_byte = None;
236        }
237    }
238
239    /// Get bytes left
240    pub fn bytes_left(&self) -> usize {
241        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
242        if self.bit_pos % 8 != 0 {
243            left - 1 // If not aligned, we can't read the last byte fully
244        } else {
245            left
246        }
247    }
248
249    /// Reset reading position
250    pub fn reset(&mut self) {
251        self.bit_pos = 0;
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use crate::{DeserializationError, bitstream::CryptoStream};
258
259    use super::BitStreamReader;
260
261    struct PlusOneDecrypter {
262        plain: Vec<u8>,
263    }
264
265    impl CryptoStream for PlusOneDecrypter {
266        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
267            self.plain.push(b + 1);
268            *self.plain.last().unwrap()
269        }
270
271        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
272            let d = slice.iter().map(|s| s + 1);
273            self.plain.extend(d);
274            &self.plain[self.plain.len() - slice.len()..]
275        }
276
277        fn get_cached(&self, original: bool) -> &[u8] {
278            &self.plain
279        }
280        
281        fn replace(&mut self, other: Box<dyn CryptoStream>) {
282            self.plain = other.get_cached(true).to_vec();
283        }
284    }
285
286    #[test]
287    fn test_decrypt_bytes() {
288        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
289        let mut reader = BitStreamReader::new(&buf);
290        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
291
292        assert_eq!(reader.read_byte(), Ok(2));
293        assert_eq!(reader.read_byte(), Ok(3));
294        assert_eq!(reader.read_byte(), Ok(4));
295        // 4 = 00000100, +1 = 00000101
296        assert_eq!(reader.read_bit(), Ok(true));
297        assert_eq!(reader.read_bit(), Ok(false));
298        assert_eq!(reader.read_bit(), Ok(true));
299        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
300        assert_eq!(reader.read_byte(), Ok(11));
301    }
302
303    /// Helper to build buffers
304    fn make_buffer() -> Vec<u8> {
305        vec![0b10101100, 0b11010010, 0xFF, 0x00]
306    }
307
308    #[test]
309    fn test_read_single_bits() {
310        let buf = make_buffer();
311        let mut reader = BitStreamReader::new(&buf);
312
313        // LSB-first: read bits starting from least significant
314        assert_eq!(reader.read_bit(), Ok(false));
315        assert_eq!(reader.read_bit(), Ok(false));
316        assert_eq!(reader.read_bit(), Ok(true));
317        assert_eq!(reader.read_bit(), Ok(true));
318        assert_eq!(reader.read_bit(), Ok(false));
319        assert_eq!(reader.read_bit(), Ok(true));
320        assert_eq!(reader.read_bit(), Ok(false));
321        assert_eq!(reader.read_bit(), Ok(true));
322    }
323
324    #[test]
325    fn test_read_small() {
326        let buf = [0b10101100, 0b11010010];
327        let mut reader = BitStreamReader::new(&buf);
328
329        assert_eq!(reader.read_small(3), Ok(0b100));
330        assert_eq!(reader.read_small(4), Ok(0b0101));
331        assert_eq!(reader.read_small(1), Ok(0b1));
332        assert_eq!(reader.read_small(4), Ok(0b0010));
333    }
334
335    #[test]
336    fn test_read_cross_byte() {
337        let buf = [0b10101100, 0b11010001];
338        let mut reader = BitStreamReader::new(&buf);
339
340        // Read first 10 bits (crosses into second byte)
341        assert_eq!(reader.read_small(7), Ok(0b00101100));
342        assert_eq!(reader.read_small(3), Ok(0b011));
343    }
344
345    #[test]
346    fn test_read_byte() {
347        let buf = [0b10101100, 0b11010010];
348        let mut reader = BitStreamReader::new(&buf);
349
350        reader.read_small(3).unwrap(); // advance 3 bits
351        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
352    }
353
354    #[test]
355    fn test_read_bytes() {
356        let buf = [0x01, 0xAA, 0xBB, 0xCC];
357        let mut reader = BitStreamReader::new(&buf);
358
359        reader.read_bit().unwrap(); // first bit
360        let slice = reader.read_bytes(3).unwrap();
361        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
362    }
363
364    #[test]
365    fn test_align_byte() {
366        let buf = [0b10101100, 0b11010010];
367        let mut reader = BitStreamReader::new(&buf);
368
369        reader.read_small(3).unwrap(); // 3 bits
370        reader.align_byte(); // move to next byte
371        assert_eq!(reader.read_byte(), Ok(0b11010010));
372    }
373
374    #[test]
375    fn test_eof_behavior() {
376        let buf = [0xFF];
377        let mut reader = BitStreamReader::new(&buf);
378
379        assert_eq!(reader.read_byte(), Ok(0xFF));
380        assert_eq!(
381            reader.read_bit(),
382            Err(DeserializationError::NotEnoughBytes(1))
383        );
384        assert_eq!(
385            reader.read_byte(),
386            Err(DeserializationError::NotEnoughBytes(1))
387        );
388        assert_eq!(
389            reader.read_bytes(2),
390            Err(DeserializationError::NotEnoughBytes(2))
391        );
392    }
393
394    #[test]
395    fn test_multiple_operations() {
396        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
397        let mut reader = BitStreamReader::new(&buf);
398
399        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
400        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
401        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
402        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
403        assert_eq!(
404            reader.read_bit(),
405            Err(DeserializationError::NotEnoughBytes(1))
406        );
407    }
408
409    #[test]
410    fn test_read_dyn_int() {
411        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
412        let mut stream = BitStreamReader::new(&buf);
413
414        assert_eq!(Ok(0), stream.read_byte());
415        assert_eq!(Ok(127), stream.read_dyn_int());
416        assert_eq!(Ok(128), stream.read_dyn_int());
417        assert_eq!(Ok(268435455), stream.read_dyn_int());
418        assert_eq!(
419            Err(DeserializationError::NotEnoughBytes(1)),
420            stream.read_dyn_int()
421        );
422    }
423
424    #[test]
425    fn test_read_fixed_int() {
426        let buf = vec![
427            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
428            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
429            0, 0, 0, 10,
430        ];
431
432        let mut stream = BitStreamReader::new(&buf);
433        let v1: u8 = stream.read_fixed_int().unwrap();
434        let v2: i8 = stream.read_fixed_int().unwrap();
435        let v3: u16 = stream.read_fixed_int().unwrap();
436        let v4: i16 = stream.read_fixed_int().unwrap();
437        let v5: u32 = stream.read_fixed_int().unwrap();
438        let v6: i32 = stream.read_fixed_int().unwrap();
439        let v7: u64 = stream.read_fixed_int().unwrap();
440        let v8: i64 = stream.read_fixed_int().unwrap();
441        let v9: u128 = stream.read_fixed_int().unwrap();
442        let v10: i128 = stream.read_fixed_int().unwrap();
443
444        assert_eq!(v1, 1);
445        assert_eq!(v2, 1);
446        assert_eq!(v3, 2);
447        assert_eq!(v4, 2);
448        assert_eq!(v5, 3);
449        assert_eq!(v6, 3);
450        assert_eq!(v7, 4);
451        assert_eq!(v8, 4);
452        assert_eq!(v9, 5);
453        assert_eq!(v10, 5);
454    }
455
456    #[test]
457    fn test_bytes_left() {
458        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
459        let mut reader = BitStreamReader::new(&buf);
460
461        assert_eq!(reader.bytes_left(), 4);
462        reader.read_small(3).unwrap(); // read 3 bits
463        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
464        reader.read_byte().unwrap(); // read one byte
465        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
466        reader.read_byte().unwrap(); // read another byte
467        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
468        reader.read_bit().unwrap(); // read one bit
469        assert_eq!(reader.bytes_left(), 0); // no full bytes left
470    }
471
472    #[test]
473    fn offset_end_ignores_bytes_and_can_slice() {
474        let buff = [1, 2, 3, 4, 5];
475        let mut reader = BitStreamReader::new(&buff);
476
477        reader.set_offset_end(2);
478        assert_eq!(reader.bytes_left(), 3);
479        assert_eq!(reader.read_byte(), Ok(1));
480
481        assert_eq!(reader.slice(true), &[1, 2, 3]);
482        assert_eq!(reader.slice(false), &[2, 3]);
483        assert_eq!(reader.slice_end(), &[4, 5]);
484
485        assert_eq!(reader.read_byte(), Ok(2));
486        assert_eq!(reader.read_byte(), Ok(3));
487        assert_eq!(
488            reader.read_byte(),
489            Err(DeserializationError::NotEnoughBytes(1))
490        );
491
492        reader.set_offset_end(0);
493        assert_eq!(reader.bytes_left(), 2);
494        assert_eq!(reader.read_byte(), Ok(4));
495        assert_eq!(reader.read_byte(), Ok(5));
496    }
497
498    #[test]
499    fn test_slice_start() {
500        let buff = [10, 20, 30, 40, 50];
501        let mut reader = BitStreamReader::new(&buff);
502
503        assert_eq!(reader.slice_marker(None), &[]);
504
505        reader.read_byte().unwrap(); // Read 10
506        assert_eq!(reader.slice_marker(None), &[10]);
507
508        reader.read_small(4).unwrap(); // Read 4 bits of 20
509        assert_eq!(reader.slice_marker(None), &[10]);
510
511        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
512        assert_eq!(reader.slice_marker(None), &[10, 20]);
513
514        reader.read_bytes(2).unwrap(); // Read 30, 40
515        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
516    }
517
518    #[test]
519    fn test_slice_start_with_marker() {
520        let buff = [10, 20, 30, 40, 50];
521        let mut reader = BitStreamReader::new(&buff);
522
523        reader.read_byte().unwrap(); // Read 10
524        assert_eq!(reader.slice_marker(None), &[10]);
525        reader.set_marker();
526        assert_eq!(reader.slice_marker(None), &[]);
527
528        reader.read_bytes(2).unwrap(); // Read 20, 30
529        assert_eq!(reader.slice_marker(None), &[20, 30]);
530    }
531}