binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    offset_end: usize,
10    crypto: Option<Box<dyn CryptoStream>>,
11}
12
13impl<'a> BitStreamReader<'a> {
14    /// Create a new LSB-first reader
15    ///
16    /// # Properties
17    /// - `buffer`: The buffer to read from
18    pub fn new(buffer: &'a [u8]) -> Self {
19        Self {
20            buffer,
21            bit_pos: 0,
22            crypto: None,
23            offset_end: 0,
24            last_read_byte: None,
25        }
26    }
27
28    /// Return slice from reading position (or start, if `from_start` is true) to offset-end
29    pub fn slice(&self, from_start: bool) -> &[u8] {
30        let start = if from_start { 0 } else { self.byte_pos() };
31
32        &self.buffer[start..self.buffer.len() - self.offset_end]
33    }
34
35    /// Return slice from offset-end to end of buffer
36    pub fn slice_end(&self) -> &[u8] {
37        &self.buffer[self.buffer.len() - self.offset_end..]
38    }
39
40    /// Set crypto stream
41    pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
42        self.crypto = crypto;
43    }
44
45    /// Set integrity offset to ignore when reading
46    pub fn set_offset_end(&mut self, len: usize) {
47        self.offset_end = len;
48    }
49
50    /// Get byte position of reader
51    pub fn byte_pos(&self) -> usize {
52        self.bit_pos / 8
53    }
54
55    /// Get current byte, from last_read_byte cache or from buffer
56    fn current_byte(&mut self) -> u8 {
57        if let Some(b) = self.last_read_byte {
58            b
59        } else {
60            let mut b = self.buffer[self.byte_pos()];
61            if let Some(crypto) = self.crypto.as_mut() {
62                b = crypto.apply_keystream_byte(b);
63            }
64
65            self.last_read_byte = Some(b);
66            b
67        }
68    }
69
70    /// Read a single bit
71    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
72        self.read_small(1).map(|v| v != 0)
73    }
74
75    /// Read 1-8 bits as u8 (LSB-first)
76    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
77        assert!(bits > 0 && bits < 8);
78
79        let mut result: u8 = 0;
80        let mut shift = 0;
81
82        while bits > 0 {
83            if self.byte_pos() >= self.buffer.len() - self.offset_end {
84                return Err(DeserializationError::NotEnoughBytes(1));
85            }
86
87            // Find the bit position inside the current byte (0..7)
88            let bit_offset = self.bit_pos % 8;
89
90            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
91            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93            // Create a mask to isolate the bits we want from this byte.
94            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
95            // mask = 00011100 (only bits 2,3,4 are 1)
96            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97            let byte_val = self.current_byte();
98
99            // Apply the mask to isolate the bits and shift them to LSB
100            // Example: byte_val = 10101100, mask = 00011100
101            // (10101100 & 00011100) >> 2 = 00000101
102            let val = (byte_val & mask) >> bit_offset;
103
104            // Merge the extracted bits into the final result.
105            // Shift them to the correct position based on how many bits we already read.
106            result |= val << shift;
107
108            // Decrease the remaining bits we need to read
109            bits -= bits_in_current_byte;
110
111            // Update the shift for the next batch of bits (if crossing byte boundary)
112            shift += bits_in_current_byte;
113
114            self.bit_pos += bits_in_current_byte as usize;
115
116            // If crossed byte boundary, reset last read byte
117            if self.bit_pos % 8 == 0 {
118                self.last_read_byte = None;
119            }
120        }
121
122        Ok(result)
123    }
124
125    /// Read a full byte, aligning to the next byte boundary
126    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
127        self.align_byte();
128
129        if self.byte_pos() >= self.buffer.len() - self.offset_end {
130            return Err(DeserializationError::NotEnoughBytes(1));
131        }
132
133        let byte = self.current_byte();
134        self.bit_pos += 8;
135        self.last_read_byte = None;
136
137        Ok(byte)
138    }
139
140    /// Read a slice of bytes, aligning first
141    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
142        self.align_byte();
143
144        let start = self.byte_pos();
145        if start + count > self.buffer.len() - self.offset_end {
146            return Err(DeserializationError::NotEnoughBytes(
147                start + count - self.buffer.len(),
148            ));
149        }
150
151        self.bit_pos += 8 * count;
152        self.last_read_byte = None;
153
154        let slice = &self.buffer[start..start + count];
155        if let Some(crypto) = self.crypto.as_mut() {
156            Ok(crypto.apply_keystream(slice))
157        } else {
158            Ok(slice)
159        }
160    }
161
162    /// Read a dynamic int, starting at the next byte bounary
163    /// The last bit is used as a continuation flag for the next byte
164    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
165        self.align_byte();
166        let mut num: u128 = 0;
167        let mut multiplier: u128 = 1;
168
169        loop {
170            let byte = self.read_byte()?; // None if EOF
171            num += ((byte & 127) as u128) * multiplier;
172
173            // If no continuation bit, stop
174            if (byte & 1 << 7) == 0 {
175                break;
176            }
177
178            multiplier *= 128;
179        }
180
181        Ok(num)
182    }
183
184    /// Read a integer of fixed size from the buffer
185    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
186        &mut self,
187    ) -> Result<T, DeserializationError> {
188        let data = self.read_bytes(S)?;
189        Ok(FixedInt::deserialize(data))
190    }
191
192    /// Align the reader to the next byte boundary
193    pub fn align_byte(&mut self) {
194        let rem = self.bit_pos % 8;
195        if rem != 0 {
196            self.bit_pos += 8 - rem;
197            self.last_read_byte = None;
198        }
199    }
200
201    /// Get bytes left
202    pub fn bytes_left(&self) -> usize {
203        let left = self.buffer.len() - self.byte_pos() - self.offset_end;
204        if self.bit_pos % 8 != 0 {
205            left - 1 // If not aligned, we can't read the last byte fully
206        } else {
207            left
208        }
209    }
210
211    /// Reset reading position
212    pub fn reset(&mut self) {
213        self.bit_pos = 0;
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use crate::{DeserializationError, bitstream::CryptoStream};
220
221    use super::BitStreamReader;
222
223    struct PlusOneDecrypter {
224        plain: Vec<u8>,
225    }
226
227    impl CryptoStream for PlusOneDecrypter {
228        fn apply_keystream_byte(&mut self, b: u8) -> u8 {
229            self.plain.push(b + 1);
230            *self.plain.last().unwrap()
231        }
232
233        fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
234            let d = slice.iter().map(|s| s + 1);
235            self.plain.extend(d);
236            &self.plain[self.plain.len() - slice.len()..]
237        }
238    }
239
240    #[test]
241    fn test_decrypt_bytes() {
242        let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
243        let mut reader = BitStreamReader::new(&buf);
244        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
245
246        assert_eq!(reader.read_byte(), Ok(2));
247        assert_eq!(reader.read_byte(), Ok(3));
248        assert_eq!(reader.read_byte(), Ok(4));
249        // 4 = 00000100, +1 = 00000101
250        assert_eq!(reader.read_bit(), Ok(true));
251        assert_eq!(reader.read_bit(), Ok(false));
252        assert_eq!(reader.read_bit(), Ok(true));
253        assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
254        assert_eq!(reader.read_byte(), Ok(11));
255    }
256
257    /// Helper to build buffers
258    fn make_buffer() -> Vec<u8> {
259        vec![0b10101100, 0b11010010, 0xFF, 0x00]
260    }
261
262    #[test]
263    fn test_read_single_bits() {
264        let buf = make_buffer();
265        let mut reader = BitStreamReader::new(&buf);
266
267        // LSB-first: read bits starting from least significant
268        assert_eq!(reader.read_bit(), Ok(false));
269        assert_eq!(reader.read_bit(), Ok(false));
270        assert_eq!(reader.read_bit(), Ok(true));
271        assert_eq!(reader.read_bit(), Ok(true));
272        assert_eq!(reader.read_bit(), Ok(false));
273        assert_eq!(reader.read_bit(), Ok(true));
274        assert_eq!(reader.read_bit(), Ok(false));
275        assert_eq!(reader.read_bit(), Ok(true));
276    }
277
278    #[test]
279    fn test_read_small() {
280        let buf = [0b10101100, 0b11010010];
281        let mut reader = BitStreamReader::new(&buf);
282
283        assert_eq!(reader.read_small(3), Ok(0b100));
284        assert_eq!(reader.read_small(4), Ok(0b0101));
285        assert_eq!(reader.read_small(1), Ok(0b1));
286        assert_eq!(reader.read_small(4), Ok(0b0010));
287    }
288
289    #[test]
290    fn test_read_cross_byte() {
291        let buf = [0b10101100, 0b11010001];
292        let mut reader = BitStreamReader::new(&buf);
293
294        // Read first 10 bits (crosses into second byte)
295        assert_eq!(reader.read_small(7), Ok(0b00101100));
296        assert_eq!(reader.read_small(3), Ok(0b011));
297    }
298
299    #[test]
300    fn test_read_byte() {
301        let buf = [0b10101100, 0b11010010];
302        let mut reader = BitStreamReader::new(&buf);
303
304        reader.read_small(3).unwrap(); // advance 3 bits
305        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
306    }
307
308    #[test]
309    fn test_read_bytes() {
310        let buf = [0x01, 0xAA, 0xBB, 0xCC];
311        let mut reader = BitStreamReader::new(&buf);
312
313        reader.read_bit().unwrap(); // first bit
314        let slice = reader.read_bytes(3).unwrap();
315        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
316    }
317
318    #[test]
319    fn test_align_byte() {
320        let buf = [0b10101100, 0b11010010];
321        let mut reader = BitStreamReader::new(&buf);
322
323        reader.read_small(3).unwrap(); // 3 bits
324        reader.align_byte(); // move to next byte
325        assert_eq!(reader.read_byte(), Ok(0b11010010));
326    }
327
328    #[test]
329    fn test_eof_behavior() {
330        let buf = [0xFF];
331        let mut reader = BitStreamReader::new(&buf);
332
333        assert_eq!(reader.read_byte(), Ok(0xFF));
334        assert_eq!(
335            reader.read_bit(),
336            Err(DeserializationError::NotEnoughBytes(1))
337        );
338        assert_eq!(
339            reader.read_byte(),
340            Err(DeserializationError::NotEnoughBytes(1))
341        );
342        assert_eq!(
343            reader.read_bytes(2),
344            Err(DeserializationError::NotEnoughBytes(2))
345        );
346    }
347
348    #[test]
349    fn test_multiple_operations() {
350        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
351        let mut reader = BitStreamReader::new(&buf);
352
353        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
354        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
355        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
356        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
357        assert_eq!(
358            reader.read_bit(),
359            Err(DeserializationError::NotEnoughBytes(1))
360        );
361    }
362
363    #[test]
364    fn test_read_dyn_int() {
365        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
366        let mut stream = BitStreamReader::new(&buf);
367
368        assert_eq!(Ok(0), stream.read_byte());
369        assert_eq!(Ok(127), stream.read_dyn_int());
370        assert_eq!(Ok(128), stream.read_dyn_int());
371        assert_eq!(Ok(268435455), stream.read_dyn_int());
372        assert_eq!(
373            Err(DeserializationError::NotEnoughBytes(1)),
374            stream.read_dyn_int()
375        );
376    }
377
378    #[test]
379    fn test_read_fixed_int() {
380        let buf = vec![
381            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
382            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
383            0, 0, 0, 10,
384        ];
385
386        let mut stream = BitStreamReader::new(&buf);
387        let v1: u8 = stream.read_fixed_int().unwrap();
388        let v2: i8 = stream.read_fixed_int().unwrap();
389        let v3: u16 = stream.read_fixed_int().unwrap();
390        let v4: i16 = stream.read_fixed_int().unwrap();
391        let v5: u32 = stream.read_fixed_int().unwrap();
392        let v6: i32 = stream.read_fixed_int().unwrap();
393        let v7: u64 = stream.read_fixed_int().unwrap();
394        let v8: i64 = stream.read_fixed_int().unwrap();
395        let v9: u128 = stream.read_fixed_int().unwrap();
396        let v10: i128 = stream.read_fixed_int().unwrap();
397
398        assert_eq!(v1, 1);
399        assert_eq!(v2, 1);
400        assert_eq!(v3, 2);
401        assert_eq!(v4, 2);
402        assert_eq!(v5, 3);
403        assert_eq!(v6, 3);
404        assert_eq!(v7, 4);
405        assert_eq!(v8, 4);
406        assert_eq!(v9, 5);
407        assert_eq!(v10, 5);
408    }
409
410    #[test]
411    fn test_bytes_left() {
412        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
413        let mut reader = BitStreamReader::new(&buf);
414
415        assert_eq!(reader.bytes_left(), 4);
416        reader.read_small(3).unwrap(); // read 3 bits
417        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
418        reader.read_byte().unwrap(); // read one byte
419        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
420        reader.read_byte().unwrap(); // read another byte
421        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
422        reader.read_bit().unwrap(); // read one bit
423        assert_eq!(reader.bytes_left(), 0); // no full bytes left
424    }
425
426    #[test]
427    fn offset_end_ignores_bytes_and_can_slice() {
428        let buff = [1, 2, 3, 4, 5];
429        let mut reader = BitStreamReader::new(&buff);
430
431        reader.set_offset_end(2);
432        assert_eq!(reader.bytes_left(), 3);
433        assert_eq!(reader.read_byte(), Ok(1));
434
435        assert_eq!(reader.slice(true), &[1, 2, 3]);
436        assert_eq!(reader.slice(false), &[2, 3]);
437        assert_eq!(reader.slice_end(), &[4, 5]);
438
439        assert_eq!(reader.read_byte(), Ok(2));
440        assert_eq!(reader.read_byte(), Ok(3));
441        assert_eq!(
442            reader.read_byte(),
443            Err(DeserializationError::NotEnoughBytes(1))
444        );
445
446        reader.set_offset_end(0);
447        assert_eq!(reader.bytes_left(), 2);
448        assert_eq!(reader.read_byte(), Ok(4));
449        assert_eq!(reader.read_byte(), Ok(5));
450    }
451}