binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6    buffer: &'a [u8],
7    bit_pos: usize,
8    last_read_byte: Option<u8>,
9    crypto: Option<Box<dyn CryptoStream>>,
10}
11
12impl<'a> BitStreamReader<'a> {
13    /// Create a new LSB-first reader
14    pub fn new(buffer: &'a [u8]) -> Self {
15        Self {
16            buffer,
17            bit_pos: 0,
18            crypto: None,
19            last_read_byte: None,
20        }
21    }
22
23    /// Get byte position of reader
24    pub fn byte_pos(&self) -> usize {
25        self.bit_pos / 8
26    }
27
28    /// Get current byte, from last_read_byte cache or from buffer
29    fn current_byte(&mut self) -> u8 {
30        if let Some(b) = self.last_read_byte {
31            b
32        } else {
33            let mut b = self.buffer[self.byte_pos()];
34            if let Some(crypto) = self.crypto.as_mut() {
35                b = crypto.decrypt_byte(b);
36            }
37
38            self.last_read_byte = Some(b);
39            b
40        }
41    }
42
43    /// Read a single bit
44    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
45        self.read_small(1).map(|v| v != 0)
46    }
47
48    /// Read 1-8 bits as u8 (LSB-first)
49    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
50        assert!(bits > 0 && bits < 8);
51
52        let mut result: u8 = 0;
53        let mut shift = 0;
54
55        while bits > 0 {
56            if self.byte_pos() >= self.buffer.len() {
57                return Err(DeserializationError::NotEnoughBytes(1));
58            }
59
60            // Find the bit position inside the current byte (0..7)
61            let bit_offset = self.bit_pos % 8;
62
63            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
64            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
65
66            // Create a mask to isolate the bits we want from this byte.
67            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
68            // mask = 00011100 (only bits 2,3,4 are 1)
69            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
70            let byte_val = self.current_byte();
71
72            // Apply the mask to isolate the bits and shift them to LSB
73            // Example: byte_val = 10101100, mask = 00011100
74            // (10101100 & 00011100) >> 2 = 00000101
75            let val = (byte_val & mask) >> bit_offset;
76
77            // Merge the extracted bits into the final result.
78            // Shift them to the correct position based on how many bits we already read.
79            result |= val << shift;
80
81            // Decrease the remaining bits we need to read
82            bits -= bits_in_current_byte;
83
84            // Update the shift for the next batch of bits (if crossing byte boundary)
85            shift += bits_in_current_byte;
86
87            self.bit_pos += bits_in_current_byte as usize;
88
89            // If crossed byte boundary, reset last read byte
90            if self.bit_pos % 8 == 0 {
91                self.last_read_byte = None;
92            }
93        }
94
95        Ok(result)
96    }
97
98    /// Read a full byte, aligning to the next byte boundary
99    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
100        self.align_byte();
101
102        if self.byte_pos() >= self.buffer.len() {
103            return Err(DeserializationError::NotEnoughBytes(1));
104        }
105
106        let byte = self.current_byte();
107        self.bit_pos += 8;
108        self.last_read_byte = None;
109        
110        Ok(byte)
111    }
112
113    /// Read a slice of bytes, aligning first
114    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
115        self.align_byte();
116
117        let start = self.byte_pos();
118        if start + count > self.buffer.len() {
119            return Err(DeserializationError::NotEnoughBytes(
120                start + count - self.buffer.len(),
121            ));
122        }
123
124        self.bit_pos += 8 * count;
125        self.last_read_byte = None;
126
127        let slice = &self.buffer[start..start + count];
128        if let Some(crypto) = self.crypto.as_mut() {
129            Ok(crypto.decrypt_slice(slice))
130        } else {
131            Ok(slice)
132        }
133    }
134
135    /// Read a dynamic int, starting at the next byte bounary
136    /// The last bit is used as a continuation flag for the next byte
137    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
138        self.align_byte();
139        let mut num: u128 = 0;
140        let mut multiplier: u128 = 1;
141
142        loop {
143            let byte = self.read_byte()?; // None if EOF
144            num += ((byte & 127) as u128) * multiplier;
145
146            // If no continuation bit, stop
147            if (byte & 1 << 7) == 0 {
148                break;
149            }
150
151            multiplier *= 128;
152        }
153
154        Ok(num)
155    }
156
157    /// Read a integer of fixed size from the buffer
158    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
159        &mut self,
160    ) -> Result<T, DeserializationError> {
161        let data = self.read_bytes(S)?;
162        Ok(FixedInt::deserialize(data))
163    }
164
165    /// Align the reader to the next byte boundary
166    pub fn align_byte(&mut self) {
167        let rem = self.bit_pos % 8;
168        if rem != 0 {
169            self.bit_pos += 8 - rem;
170            self.last_read_byte = None;
171        }
172    }
173
174    /// Get bytes left
175    pub fn bytes_left(&self) -> usize {
176        let left = self.buffer.len() - self.byte_pos();
177        if self.bit_pos % 8 != 0 {
178            left - 1 // If not aligned, we can't read the last byte fully
179        } else {
180            left
181        }
182    }
183
184    /// Reset reading position
185    pub fn reset(&mut self) {
186        self.bit_pos = 0;
187    }
188}
189
190#[cfg(test)]
191mod tests {
192    use crate::{DeserializationError, bitstream::CryptoStream};
193
194    use super::BitStreamReader;
195
196    struct PlusOneDecrypter {
197        plain: Vec<u8>
198    }
199
200    impl CryptoStream for PlusOneDecrypter {
201        fn decrypt_byte(&mut self, b: u8) -> u8 {
202            self.plain.push(b + 1);
203            *self.plain.last().unwrap()
204        }
205    
206        fn decrypt_slice(&mut self, slice: &[u8]) -> &[u8] {
207            let d = slice.iter().map(|s|s + 1);
208            self.plain.extend(d);
209            &self.plain[self.plain.len() - slice.len()..]
210        }
211    }
212
213    #[test]
214    fn test_decrypt_bytes() {
215        let buf = vec![1,2,3,4,5,6,7,8,9,10];
216        let mut reader = BitStreamReader::new(&buf);
217        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
218        
219        assert_eq!(reader.read_byte(), Ok(2));
220        assert_eq!(reader.read_byte(), Ok(3));
221        assert_eq!(reader.read_byte(), Ok(4));
222        // 4 = 00000100, +1 = 00000101
223        assert_eq!(reader.read_bit(), Ok(true));
224        assert_eq!(reader.read_bit(), Ok(false));
225        assert_eq!(reader.read_bit(), Ok(true));
226        assert_eq!(reader.read_bytes(5), Ok(&[6,7,8,9,10][..]));
227        assert_eq!(reader.read_byte(), Ok(11));
228    }
229
230    /// Helper to build buffers
231    fn make_buffer() -> Vec<u8> {
232        vec![0b10101100, 0b11010010, 0xFF, 0x00]
233    }
234
235    #[test]
236    fn test_read_single_bits() {
237        let buf = make_buffer();
238        let mut reader = BitStreamReader::new(&buf);
239
240        // LSB-first: read bits starting from least significant
241        assert_eq!(reader.read_bit(), Ok(false));
242        assert_eq!(reader.read_bit(), Ok(false));
243        assert_eq!(reader.read_bit(), Ok(true));
244        assert_eq!(reader.read_bit(), Ok(true));
245        assert_eq!(reader.read_bit(), Ok(false));
246        assert_eq!(reader.read_bit(), Ok(true));
247        assert_eq!(reader.read_bit(), Ok(false));
248        assert_eq!(reader.read_bit(), Ok(true));
249    }
250
251    #[test]
252    fn test_read_small() {
253        let buf = [0b10101100, 0b11010010];
254        let mut reader = BitStreamReader::new(&buf);
255
256        assert_eq!(reader.read_small(3), Ok(0b100));
257        assert_eq!(reader.read_small(4), Ok(0b0101));
258        assert_eq!(reader.read_small(1), Ok(0b1));
259        assert_eq!(reader.read_small(4), Ok(0b0010));
260    }
261
262    #[test]
263    fn test_read_cross_byte() {
264        let buf = [0b10101100, 0b11010001];
265        let mut reader = BitStreamReader::new(&buf);
266
267        // Read first 10 bits (crosses into second byte)
268        assert_eq!(reader.read_small(7), Ok(0b00101100));
269        assert_eq!(reader.read_small(3), Ok(0b011));
270    }
271
272    #[test]
273    fn test_read_byte() {
274        let buf = [0b10101100, 0b11010010];
275        let mut reader = BitStreamReader::new(&buf);
276
277        reader.read_small(3).unwrap(); // advance 3 bits
278        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
279    }
280
281    #[test]
282    fn test_read_bytes() {
283        let buf = [0x01, 0xAA, 0xBB, 0xCC];
284        let mut reader = BitStreamReader::new(&buf);
285
286        reader.read_bit().unwrap(); // first bit
287        let slice = reader.read_bytes(3).unwrap();
288        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
289    }
290
291    #[test]
292    fn test_align_byte() {
293        let buf = [0b10101100, 0b11010010];
294        let mut reader = BitStreamReader::new(&buf);
295
296        reader.read_small(3).unwrap(); // 3 bits
297        reader.align_byte(); // move to next byte
298        assert_eq!(reader.read_byte(), Ok(0b11010010));
299    }
300
301    #[test]
302    fn test_eof_behavior() {
303        let buf = [0xFF];
304        let mut reader = BitStreamReader::new(&buf);
305
306        assert_eq!(reader.read_byte(), Ok(0xFF));
307        assert_eq!(
308            reader.read_bit(),
309            Err(DeserializationError::NotEnoughBytes(1))
310        );
311        assert_eq!(
312            reader.read_byte(),
313            Err(DeserializationError::NotEnoughBytes(1))
314        );
315        assert_eq!(
316            reader.read_bytes(2),
317            Err(DeserializationError::NotEnoughBytes(2))
318        );
319    }
320
321    #[test]
322    fn test_multiple_operations() {
323        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
324        let mut reader = BitStreamReader::new(&buf);
325
326        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
327        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
328        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
329        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
330        assert_eq!(
331            reader.read_bit(),
332            Err(DeserializationError::NotEnoughBytes(1))
333        );
334    }
335
336    #[test]
337    fn test_read_dyn_int() {
338        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
339        let mut stream = BitStreamReader::new(&buf);
340
341        assert_eq!(Ok(0), stream.read_byte());
342        assert_eq!(Ok(127), stream.read_dyn_int());
343        assert_eq!(Ok(128), stream.read_dyn_int());
344        assert_eq!(Ok(268435455), stream.read_dyn_int());
345        assert_eq!(
346            Err(DeserializationError::NotEnoughBytes(1)),
347            stream.read_dyn_int()
348        );
349    }
350
351    #[test]
352    fn test_read_fixed_int() {
353        let buf = vec![
354            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
355            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356            0, 0, 0, 10,
357        ];
358
359        let mut stream = BitStreamReader::new(&buf);
360        let v1: u8 = stream.read_fixed_int().unwrap();
361        let v2: i8 = stream.read_fixed_int().unwrap();
362        let v3: u16 = stream.read_fixed_int().unwrap();
363        let v4: i16 = stream.read_fixed_int().unwrap();
364        let v5: u32 = stream.read_fixed_int().unwrap();
365        let v6: i32 = stream.read_fixed_int().unwrap();
366        let v7: u64 = stream.read_fixed_int().unwrap();
367        let v8: i64 = stream.read_fixed_int().unwrap();
368        let v9: u128 = stream.read_fixed_int().unwrap();
369        let v10: i128 = stream.read_fixed_int().unwrap();
370
371        assert_eq!(v1, 1);
372        assert_eq!(v2, 1);
373        assert_eq!(v3, 2);
374        assert_eq!(v4, 2);
375        assert_eq!(v5, 3);
376        assert_eq!(v6, 3);
377        assert_eq!(v7, 4);
378        assert_eq!(v8, 4);
379        assert_eq!(v9, 5);
380        assert_eq!(v10, 5);
381    }
382
383    #[test]
384    fn test_bytes_left() {
385        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
386        let mut reader = BitStreamReader::new(&buf);
387
388        assert_eq!(reader.bytes_left(), 4);
389        reader.read_small(3).unwrap(); // read 3 bits
390        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
391        reader.read_byte().unwrap(); // read one byte
392        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
393        reader.read_byte().unwrap(); // read another byte
394        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
395        reader.read_bit().unwrap(); // read one bit
396        assert_eq!(reader.bytes_left(), 0); // no full bytes left
397    }
398}