Skip to main content

binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11    marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15    /// Create a new LSB-first reader
16    ///
17    /// # Properties
18    /// - `buffer`: The buffer to read from
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            offset_end: 0,
25            last_read_byte: None,
26            marker: None,
27        }
28    }
29
30    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
31    pub fn slice(&self, from_start: bool) -> &[u8] {
32        let start = if from_start { 0 } else { self.byte_pos() };
33
34        &self.buffer[start..self.buffer.len() - self.offset_end]
35    }
36
37    /// Set marker at current byte
38    pub fn set_marker(&mut self) {
39        self.marker = Some(self.byte_pos());
40    }
41
42    /// Unset marker
43    pub fn reset_marker(&mut self) {
44        self.marker = None;
45    }
46
47    /// Return slice from marker (or start if marker is unset) to specific position or current byte
48    /// If crypto is set, it will return the decrypted slice. This is different from other `slice` methods which always returns the raw buffer slice.
49    pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50        let start = self.marker.unwrap_or(0);
51        let end = to.unwrap_or(self.byte_pos());
52
53        if let Some(crypto) = self.crypto.as_ref() {
54            return &crypto.get_cached(false)[start..end];
55        }
56
57        &self.buffer[start..end]
58    }
59
60    /// Return slice from offset-end to end of buffer
61    /// If crypto is set, it will apply the keystream to the slice before returning.
62    /// This is different from other `slice` methods which always returns the raw buffer slice.
63    pub fn slice_end(&mut self) -> &[u8] {
64        let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66        if let Some(crypto) = self.crypto.as_mut() {
67            crypto.apply_keystream(slice)
68        } else {
69            slice
70        }
71    }
72
73    /// Set crypto stream
74    pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
75        if let Some(existing) = self.crypto.as_ref()
76            && let Some(new_crypto) = crypto.as_mut()
77        {
78            new_crypto.replace(existing);
79            self.crypto = crypto;
80        } else {
81            self.crypto = crypto;
82        }
83    }
84
85    /// Remove crypto stream
86    pub fn reset_crypto(&mut self) {
87        self.crypto = None;
88    }
89
90    /// Set integrity offset to ignore when reading
91    pub fn set_offset_end(&mut self, len: usize) {
92        self.offset_end = len;
93    }
94
95    /// Get byte position of reader
96    pub fn byte_pos(&self) -> usize {
97        self.bit_pos / 8
98    }
99
100    /// Get current byte, from last_read_byte cache or from buffer
101    fn current_byte(&mut self) -> u8 {
102        if let Some(b) = self.last_read_byte {
103            b
104        } else {
105            let mut b = self.buffer[self.byte_pos()];
106            if let Some(crypto) = self.crypto.as_mut() {
107                b = crypto.apply_keystream_byte(b);
108            }
109
110            self.last_read_byte = Some(b);
111            b
112        }
113    }
114
115    /// Read a single bit
116    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
117        self.read_small(1).map(|v| v != 0)
118    }
119
120    /// Read 1-8 bits as u8 (LSB-first)
121    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
122        assert!(bits > 0 && bits < 8);
123
124        let mut result: u8 = 0;
125        let mut shift = 0;
126
127        while bits > 0 {
128            if self.byte_pos() >= self.buffer.len() - self.offset_end {
129                return Err(DeserializationError::NotEnoughBytes(1));
130            }
131
132            // Find the bit position inside the current byte (0..7)
133            let bit_offset = self.bit_pos % 8;
134
135            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
136            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
137
138            // Create a mask to isolate the bits we want from this byte.
139            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
140            // mask = 00011100 (only bits 2,3,4 are 1)
141            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
142            let byte_val = self.current_byte();
143
144            // Apply the mask to isolate the bits and shift them to LSB
145            // Example: byte_val = 10101100, mask = 00011100
146            // (10101100 & 00011100) >> 2 = 00000101
147            let val = (byte_val & mask) >> bit_offset;
148
149            // Merge the extracted bits into the final result.
150            // Shift them to the correct position based on how many bits we already read.
151            result |= val << shift;
152
153            // Decrease the remaining bits we need to read
154            bits -= bits_in_current_byte;
155
156            // Update the shift for the next batch of bits (if crossing byte boundary)
157            shift += bits_in_current_byte;
158
159            self.bit_pos += bits_in_current_byte as usize;
160
161            // If crossed byte boundary, reset last read byte
162            if self.bit_pos % 8 == 0 {
163                self.last_read_byte = None;
164            }
165        }
166
167        Ok(result)
168    }
169
170    /// Read a full byte, aligning to the next byte boundary
171    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
172        self.align_byte();
173
174        if self.byte_pos() >= self.buffer.len() - self.offset_end {
175            return Err(DeserializationError::NotEnoughBytes(1));
176        }
177
178        let byte = self.current_byte();
179        self.bit_pos += 8;
180        self.last_read_byte = None;
181
182        Ok(byte)
183    }
184
185    /// Read a slice of bytes, aligning first
186    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
187        self.align_byte();
188
189        let start = self.byte_pos();
190        if start + count > self.buffer.len() - self.offset_end {
191            return Err(DeserializationError::NotEnoughBytes(
192                start + count - self.buffer.len(),
193            ));
194        }
195
196        self.bit_pos += 8 * count;
197        self.last_read_byte = None;
198
199        let slice = &self.buffer[start..start + count];
200        if let Some(crypto) = self.crypto.as_mut() {
201            Ok(crypto.apply_keystream(slice))
202        } else {
203            Ok(slice)
204        }
205    }
206
207    /// Read a dynamic int, starting at the next byte bounary
208    /// The last bit is used as a continuation flag for the next byte
209    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
210        self.align_byte();
211        let mut num: u128 = 0;
212        let mut multiplier: u128 = 1;
213
214        loop {
215            let byte = self.read_byte()?; // None if EOF
216            num += ((byte & 127) as u128) * multiplier;
217
218            // If no continuation bit, stop
219            if (byte & 1 << 7) == 0 {
220                break;
221            }
222
223            multiplier *= 128;
224        }
225
226        Ok(num)
227    }
228
229    /// Read a integer of fixed size from the buffer
230    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
231        &mut self,
232    ) -> Result<T, DeserializationError> {
233        let data = self.read_bytes(S)?;
234        Ok(FixedInt::deserialize(data))
235    }
236
237    /// Align the reader to the next byte boundary
238    pub fn align_byte(&mut self) {
239        let rem = self.bit_pos % 8;
240        if rem != 0 {
241            self.bit_pos += 8 - rem;
242            self.last_read_byte = None;
243        }
244    }
245
246    /// Get bytes left
247    pub fn bytes_left(&self) -> usize {
248        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
249        if self.bit_pos % 8 != 0 {
250            left - 1 // If not aligned, we can't read the last byte fully
251        } else {
252            left
253        }
254    }
255
256    /// Reset reading position
257    pub fn reset(&mut self) {
258        self.bit_pos = 0;
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use crate::{DeserializationError, bitstream::CryptoStream};
265
266    use super::BitStreamReader;
267
268    struct PlusOneDecrypter {
269        plain: Vec<u8>,
270    }
271
272    impl CryptoStream for PlusOneDecrypter {
273        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
274            self.plain.push(b + 1);
275            *self.plain.last().unwrap()
276        }
277
278        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
279            let d = slice.iter().map(|s| s + 1);
280            self.plain.extend(d);
281            &self.plain[self.plain.len() - slice.len()..]
282        }
283
284        fn get_cached(&self, original: bool) -> &[u8] {
285            &self.plain
286        }
287
288        fn replace(&mut self, other: &Box<dyn CryptoStream>) {
289            self.plain = other.get_cached(true).to_vec();
290        }
291    }
292
293    #[test]
294    fn test_decrypt_bytes() {
295        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
296        let mut reader = BitStreamReader::new(&buf);
297        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
298
299        assert_eq!(reader.read_byte(), Ok(2));
300        assert_eq!(reader.read_byte(), Ok(3));
301        assert_eq!(reader.read_byte(), Ok(4));
302        // 4 = 00000100, +1 = 00000101
303        assert_eq!(reader.read_bit(), Ok(true));
304        assert_eq!(reader.read_bit(), Ok(false));
305        assert_eq!(reader.read_bit(), Ok(true));
306        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
307        assert_eq!(reader.read_byte(), Ok(11));
308    }
309
310    /// Helper to build buffers
311    fn make_buffer() -> Vec<u8> {
312        vec![0b10101100, 0b11010010, 0xFF, 0x00]
313    }
314
315    #[test]
316    fn test_read_single_bits() {
317        let buf = make_buffer();
318        let mut reader = BitStreamReader::new(&buf);
319
320        // LSB-first: read bits starting from least significant
321        assert_eq!(reader.read_bit(), Ok(false));
322        assert_eq!(reader.read_bit(), Ok(false));
323        assert_eq!(reader.read_bit(), Ok(true));
324        assert_eq!(reader.read_bit(), Ok(true));
325        assert_eq!(reader.read_bit(), Ok(false));
326        assert_eq!(reader.read_bit(), Ok(true));
327        assert_eq!(reader.read_bit(), Ok(false));
328        assert_eq!(reader.read_bit(), Ok(true));
329    }
330
331    #[test]
332    fn test_read_small() {
333        let buf = [0b10101100, 0b11010010];
334        let mut reader = BitStreamReader::new(&buf);
335
336        assert_eq!(reader.read_small(3), Ok(0b100));
337        assert_eq!(reader.read_small(4), Ok(0b0101));
338        assert_eq!(reader.read_small(1), Ok(0b1));
339        assert_eq!(reader.read_small(4), Ok(0b0010));
340    }
341
342    #[test]
343    fn test_read_cross_byte() {
344        let buf = [0b10101100, 0b11010001];
345        let mut reader = BitStreamReader::new(&buf);
346
347        // Read first 10 bits (crosses into second byte)
348        assert_eq!(reader.read_small(7), Ok(0b00101100));
349        assert_eq!(reader.read_small(3), Ok(0b011));
350    }
351
352    #[test]
353    fn test_read_byte() {
354        let buf = [0b10101100, 0b11010010];
355        let mut reader = BitStreamReader::new(&buf);
356
357        reader.read_small(3).unwrap(); // advance 3 bits
358        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
359    }
360
361    #[test]
362    fn test_read_bytes() {
363        let buf = [0x01, 0xAA, 0xBB, 0xCC];
364        let mut reader = BitStreamReader::new(&buf);
365
366        reader.read_bit().unwrap(); // first bit
367        let slice = reader.read_bytes(3).unwrap();
368        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
369    }
370
371    #[test]
372    fn test_align_byte() {
373        let buf = [0b10101100, 0b11010010];
374        let mut reader = BitStreamReader::new(&buf);
375
376        reader.read_small(3).unwrap(); // 3 bits
377        reader.align_byte(); // move to next byte
378        assert_eq!(reader.read_byte(), Ok(0b11010010));
379    }
380
381    #[test]
382    fn test_eof_behavior() {
383        let buf = [0xFF];
384        let mut reader = BitStreamReader::new(&buf);
385
386        assert_eq!(reader.read_byte(), Ok(0xFF));
387        assert_eq!(
388            reader.read_bit(),
389            Err(DeserializationError::NotEnoughBytes(1))
390        );
391        assert_eq!(
392            reader.read_byte(),
393            Err(DeserializationError::NotEnoughBytes(1))
394        );
395        assert_eq!(
396            reader.read_bytes(2),
397            Err(DeserializationError::NotEnoughBytes(2))
398        );
399    }
400
401    #[test]
402    fn test_multiple_operations() {
403        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
404        let mut reader = BitStreamReader::new(&buf);
405
406        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
407        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
408        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
409        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
410        assert_eq!(
411            reader.read_bit(),
412            Err(DeserializationError::NotEnoughBytes(1))
413        );
414    }
415
416    #[test]
417    fn test_read_dyn_int() {
418        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
419        let mut stream = BitStreamReader::new(&buf);
420
421        assert_eq!(Ok(0), stream.read_byte());
422        assert_eq!(Ok(127), stream.read_dyn_int());
423        assert_eq!(Ok(128), stream.read_dyn_int());
424        assert_eq!(Ok(268435455), stream.read_dyn_int());
425        assert_eq!(
426            Err(DeserializationError::NotEnoughBytes(1)),
427            stream.read_dyn_int()
428        );
429    }
430
431    #[test]
432    fn test_read_fixed_int() {
433        let buf = vec![
434            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
435            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
436            0, 0, 0, 10,
437        ];
438
439        let mut stream = BitStreamReader::new(&buf);
440        let v1: u8 = stream.read_fixed_int().unwrap();
441        let v2: i8 = stream.read_fixed_int().unwrap();
442        let v3: u16 = stream.read_fixed_int().unwrap();
443        let v4: i16 = stream.read_fixed_int().unwrap();
444        let v5: u32 = stream.read_fixed_int().unwrap();
445        let v6: i32 = stream.read_fixed_int().unwrap();
446        let v7: u64 = stream.read_fixed_int().unwrap();
447        let v8: i64 = stream.read_fixed_int().unwrap();
448        let v9: u128 = stream.read_fixed_int().unwrap();
449        let v10: i128 = stream.read_fixed_int().unwrap();
450
451        assert_eq!(v1, 1);
452        assert_eq!(v2, 1);
453        assert_eq!(v3, 2);
454        assert_eq!(v4, 2);
455        assert_eq!(v5, 3);
456        assert_eq!(v6, 3);
457        assert_eq!(v7, 4);
458        assert_eq!(v8, 4);
459        assert_eq!(v9, 5);
460        assert_eq!(v10, 5);
461    }
462
463    #[test]
464    fn test_bytes_left() {
465        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
466        let mut reader = BitStreamReader::new(&buf);
467
468        assert_eq!(reader.bytes_left(), 4);
469        reader.read_small(3).unwrap(); // read 3 bits
470        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
471        reader.read_byte().unwrap(); // read one byte
472        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
473        reader.read_byte().unwrap(); // read another byte
474        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
475        reader.read_bit().unwrap(); // read one bit
476        assert_eq!(reader.bytes_left(), 0); // no full bytes left
477    }
478
479    #[test]
480    fn offset_end_ignores_bytes_and_can_slice() {
481        let buff = [1, 2, 3, 4, 5];
482        let mut reader = BitStreamReader::new(&buff);
483
484        reader.set_offset_end(2);
485        assert_eq!(reader.bytes_left(), 3);
486        assert_eq!(reader.read_byte(), Ok(1));
487
488        assert_eq!(reader.slice(true), &[1, 2, 3]);
489        assert_eq!(reader.slice(false), &[2, 3]);
490        assert_eq!(reader.slice_end(), &[4, 5]);
491
492        assert_eq!(reader.read_byte(), Ok(2));
493        assert_eq!(reader.read_byte(), Ok(3));
494        assert_eq!(
495            reader.read_byte(),
496            Err(DeserializationError::NotEnoughBytes(1))
497        );
498
499        reader.set_offset_end(0);
500        assert_eq!(reader.bytes_left(), 2);
501        assert_eq!(reader.read_byte(), Ok(4));
502        assert_eq!(reader.read_byte(), Ok(5));
503    }
504
505    #[test]
506    fn test_slice_start() {
507        let buff = [10, 20, 30, 40, 50];
508        let mut reader = BitStreamReader::new(&buff);
509
510        assert_eq!(reader.slice_marker(None), &[]);
511
512        reader.read_byte().unwrap(); // Read 10
513        assert_eq!(reader.slice_marker(None), &[10]);
514
515        reader.read_small(4).unwrap(); // Read 4 bits of 20
516        assert_eq!(reader.slice_marker(None), &[10]);
517
518        reader.read_small(4).unwrap(); // Read remaining 4 bits of 20
519        assert_eq!(reader.slice_marker(None), &[10, 20]);
520
521        reader.read_bytes(2).unwrap(); // Read 30, 40
522        assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
523    }
524
525    #[test]
526    fn test_slice_start_with_marker() {
527        let buff = [10, 20, 30, 40, 50];
528        let mut reader = BitStreamReader::new(&buff);
529
530        reader.read_byte().unwrap(); // Read 10
531        assert_eq!(reader.slice_marker(None), &[10]);
532        reader.set_marker();
533        assert_eq!(reader.slice_marker(None), &[]);
534
535        reader.read_bytes(2).unwrap(); // Read 20, 30
536        assert_eq!(reader.slice_marker(None), &[20, 30]);
537    }
538}