binary_codec/bitstream/
reader.rs

1use std::cmp::min;
2
3use crate::{DeserializationError, encoding::fixed_int::FixedInt};
4
5pub trait StreamDecrypter {
6    fn decrypt_byte(&mut self, b: u8) -> u8;
7    fn decrypt_slice(&mut self, slice: &[u8]) -> &[u8];
8}
9
10pub struct BitStreamReader<'a> {
11    buffer: &'a [u8],
12    bit_pos: usize,
13    last_read_byte: Option<u8>,
14    crypto: Option<Box<dyn StreamDecrypter>>,
15}
16
17impl<'a> BitStreamReader<'a> {
18    /// Create a new LSB-first reader
19    pub fn new(buffer: &'a [u8]) -> Self {
20        Self {
21            buffer,
22            bit_pos: 0,
23            crypto: None,
24            last_read_byte: None,
25        }
26    }
27
28    /// Get byte position of reader
29    pub fn byte_pos(&self) -> usize {
30        self.bit_pos / 8
31    }
32
33    /// Get current byte, from last_read_byte cache or from buffer
34    fn current_byte(&mut self) -> u8 {
35        if let Some(b) = self.last_read_byte {
36            b
37        } else {
38            let mut b = self.buffer[self.byte_pos()];
39            if let Some(crypto) = self.crypto.as_mut() {
40                b = crypto.decrypt_byte(b);
41            }
42
43            self.last_read_byte = Some(b);
44            b
45        }
46    }
47
48    /// Read a single bit
49    pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
50        self.read_small(1).map(|v| v != 0)
51    }
52
53    /// Read 1-8 bits as u8 (LSB-first)
54    pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
55        assert!(bits > 0 && bits < 8);
56
57        let mut result: u8 = 0;
58        let mut shift = 0;
59
60        while bits > 0 {
61            if self.byte_pos() >= self.buffer.len() {
62                return Err(DeserializationError::NotEnoughBytes(1));
63            }
64
65            // Find the bit position inside the current byte (0..7)
66            let bit_offset = self.bit_pos % 8;
67
68            // Determine how many bytes left to read. Min(bits left in this byte, bits to read)
69            let bits_in_current_byte = min(8 - bit_offset as u8, bits);
70
71            // Create a mask to isolate the bits we want from this byte.
72            // Example: if bit_offset = 2 and bits_in_current_byte = 3,
73            // mask = 00011100 (only bits 2,3,4 are 1)
74            let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
75            let byte_val = self.current_byte();
76
77            // Apply the mask to isolate the bits and shift them to LSB
78            // Example: byte_val = 10101100, mask = 00011100
79            // (10101100 & 00011100) >> 2 = 00000101
80            let val = (byte_val & mask) >> bit_offset;
81
82            // Merge the extracted bits into the final result.
83            // Shift them to the correct position based on how many bits we already read.
84            result |= val << shift;
85
86            // Decrease the remaining bits we need to read
87            bits -= bits_in_current_byte;
88
89            // Update the shift for the next batch of bits (if crossing byte boundary)
90            shift += bits_in_current_byte;
91
92            self.bit_pos += bits_in_current_byte as usize;
93
94            // If crossed byte boundary, reset last read byte
95            if self.bit_pos % 8 == 0 {
96                self.last_read_byte = None;
97            }
98        }
99
100        Ok(result)
101    }
102
103    /// Read a full byte, aligning to the next byte boundary
104    pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
105        self.align_byte();
106
107        if self.byte_pos() >= self.buffer.len() {
108            return Err(DeserializationError::NotEnoughBytes(1));
109        }
110
111        let byte = self.current_byte();
112        self.bit_pos += 8;
113        self.last_read_byte = None;
114        
115        Ok(byte)
116    }
117
118    /// Read a slice of bytes, aligning first
119    pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
120        self.align_byte();
121
122        let start = self.byte_pos();
123        if start + count > self.buffer.len() {
124            return Err(DeserializationError::NotEnoughBytes(
125                start + count - self.buffer.len(),
126            ));
127        }
128
129        self.bit_pos += 8 * count;
130        self.last_read_byte = None;
131
132        let slice = &self.buffer[start..start + count];
133        if let Some(crypto) = self.crypto.as_mut() {
134            Ok(crypto.decrypt_slice(slice))
135        } else {
136            Ok(slice)
137        }
138    }
139
140    /// Read a dynamic int, starting at the next byte bounary
141    /// The last bit is used as a continuation flag for the next byte
142    pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
143        self.align_byte();
144        let mut num: u128 = 0;
145        let mut multiplier: u128 = 1;
146
147        loop {
148            let byte = self.read_byte()?; // None if EOF
149            num += ((byte & 127) as u128) * multiplier;
150
151            // If no continuation bit, stop
152            if (byte & 1 << 7) == 0 {
153                break;
154            }
155
156            multiplier *= 128;
157        }
158
159        Ok(num)
160    }
161
162    /// Read a integer of fixed size from the buffer
163    pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
164        &mut self,
165    ) -> Result<T, DeserializationError> {
166        let data = self.read_bytes(S)?;
167        Ok(FixedInt::deserialize(data))
168    }
169
170    /// Align the reader to the next byte boundary
171    pub fn align_byte(&mut self) {
172        let rem = self.bit_pos % 8;
173        if rem != 0 {
174            self.bit_pos += 8 - rem;
175            self.last_read_byte = None;
176        }
177    }
178
179    /// Get bytes left
180    pub fn bytes_left(&self) -> usize {
181        let left = self.buffer.len() - self.byte_pos();
182        if self.bit_pos % 8 != 0 {
183            left - 1 // If not aligned, we can't read the last byte fully
184        } else {
185            left
186        }
187    }
188
189    /// Reset reading position
190    pub fn reset(&mut self) {
191        self.bit_pos = 0;
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use crate::{DeserializationError, bitstream::reader::StreamDecrypter};
198
199    use super::BitStreamReader;
200
201    struct PlusOneDecrypter {
202        plain: Vec<u8>
203    }
204
205    impl StreamDecrypter for PlusOneDecrypter {
206        fn decrypt_byte(&mut self, b: u8) -> u8 {
207            self.plain.push(b + 1);
208            *self.plain.last().unwrap()
209        }
210    
211        fn decrypt_slice(&mut self, slice: &[u8]) -> &[u8] {
212            let d = slice.iter().map(|s|s + 1);
213            self.plain.extend(d);
214            &self.plain[self.plain.len() - slice.len()..]
215        }
216    }
217
218    #[test]
219    fn test_decrypt_bytes() {
220        let buf = vec![1,2,3,4,5,6,7,8,9,10];
221        let mut reader = BitStreamReader::new(&buf);
222        reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
223        
224        assert_eq!(reader.read_byte(), Ok(2));
225        assert_eq!(reader.read_byte(), Ok(3));
226        assert_eq!(reader.read_byte(), Ok(4));
227        // 4 = 00000100, +1 = 00000101
228        assert_eq!(reader.read_bit(), Ok(true));
229        assert_eq!(reader.read_bit(), Ok(false));
230        assert_eq!(reader.read_bit(), Ok(true));
231        assert_eq!(reader.read_bytes(5), Ok(&[6,7,8,9,10][..]));
232        assert_eq!(reader.read_byte(), Ok(11));
233    }
234
235    /// Helper to build buffers
236    fn make_buffer() -> Vec<u8> {
237        vec![0b10101100, 0b11010010, 0xFF, 0x00]
238    }
239
240    #[test]
241    fn test_read_single_bits() {
242        let buf = make_buffer();
243        let mut reader = BitStreamReader::new(&buf);
244
245        // LSB-first: read bits starting from least significant
246        assert_eq!(reader.read_bit(), Ok(false));
247        assert_eq!(reader.read_bit(), Ok(false));
248        assert_eq!(reader.read_bit(), Ok(true));
249        assert_eq!(reader.read_bit(), Ok(true));
250        assert_eq!(reader.read_bit(), Ok(false));
251        assert_eq!(reader.read_bit(), Ok(true));
252        assert_eq!(reader.read_bit(), Ok(false));
253        assert_eq!(reader.read_bit(), Ok(true));
254    }
255
256    #[test]
257    fn test_read_small() {
258        let buf = [0b10101100, 0b11010010];
259        let mut reader = BitStreamReader::new(&buf);
260
261        assert_eq!(reader.read_small(3), Ok(0b100));
262        assert_eq!(reader.read_small(4), Ok(0b0101));
263        assert_eq!(reader.read_small(1), Ok(0b1));
264        assert_eq!(reader.read_small(4), Ok(0b0010));
265    }
266
267    #[test]
268    fn test_read_cross_byte() {
269        let buf = [0b10101100, 0b11010001];
270        let mut reader = BitStreamReader::new(&buf);
271
272        // Read first 10 bits (crosses into second byte)
273        assert_eq!(reader.read_small(7), Ok(0b00101100));
274        assert_eq!(reader.read_small(3), Ok(0b011));
275    }
276
277    #[test]
278    fn test_read_byte() {
279        let buf = [0b10101100, 0b11010010];
280        let mut reader = BitStreamReader::new(&buf);
281
282        reader.read_small(3).unwrap(); // advance 3 bits
283        assert_eq!(reader.read_byte(), Ok(0b11010010)); // full second byte
284    }
285
286    #[test]
287    fn test_read_bytes() {
288        let buf = [0x01, 0xAA, 0xBB, 0xCC];
289        let mut reader = BitStreamReader::new(&buf);
290
291        reader.read_bit().unwrap(); // first bit
292        let slice = reader.read_bytes(3).unwrap();
293        assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
294    }
295
296    #[test]
297    fn test_align_byte() {
298        let buf = [0b10101100, 0b11010010];
299        let mut reader = BitStreamReader::new(&buf);
300
301        reader.read_small(3).unwrap(); // 3 bits
302        reader.align_byte(); // move to next byte
303        assert_eq!(reader.read_byte(), Ok(0b11010010));
304    }
305
306    #[test]
307    fn test_eof_behavior() {
308        let buf = [0xFF];
309        let mut reader = BitStreamReader::new(&buf);
310
311        assert_eq!(reader.read_byte(), Ok(0xFF));
312        assert_eq!(
313            reader.read_bit(),
314            Err(DeserializationError::NotEnoughBytes(1))
315        );
316        assert_eq!(
317            reader.read_byte(),
318            Err(DeserializationError::NotEnoughBytes(1))
319        );
320        assert_eq!(
321            reader.read_bytes(2),
322            Err(DeserializationError::NotEnoughBytes(2))
323        );
324    }
325
326    #[test]
327    fn test_multiple_operations() {
328        let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
329        let mut reader = BitStreamReader::new(&buf);
330
331        assert_eq!(reader.read_bit(), Ok(false)); // bit 0
332        assert_eq!(reader.read_small(3), Ok(0b101)); // bits 1-3
333        assert_eq!(reader.read_byte(), Ok(0b11001100)); // aligned full byte
334        assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
335        assert_eq!(
336            reader.read_bit(),
337            Err(DeserializationError::NotEnoughBytes(1))
338        );
339    }
340
341    #[test]
342    fn test_read_dyn_int() {
343        let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
344        let mut stream = BitStreamReader::new(&buf);
345
346        assert_eq!(Ok(0), stream.read_byte());
347        assert_eq!(Ok(127), stream.read_dyn_int());
348        assert_eq!(Ok(128), stream.read_dyn_int());
349        assert_eq!(Ok(268435455), stream.read_dyn_int());
350        assert_eq!(
351            Err(DeserializationError::NotEnoughBytes(1)),
352            stream.read_dyn_int()
353        );
354    }
355
356    #[test]
357    fn test_read_fixed_int() {
358        let buf = vec![
359            1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
360            8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
361            0, 0, 0, 10,
362        ];
363
364        let mut stream = BitStreamReader::new(&buf);
365        let v1: u8 = stream.read_fixed_int().unwrap();
366        let v2: i8 = stream.read_fixed_int().unwrap();
367        let v3: u16 = stream.read_fixed_int().unwrap();
368        let v4: i16 = stream.read_fixed_int().unwrap();
369        let v5: u32 = stream.read_fixed_int().unwrap();
370        let v6: i32 = stream.read_fixed_int().unwrap();
371        let v7: u64 = stream.read_fixed_int().unwrap();
372        let v8: i64 = stream.read_fixed_int().unwrap();
373        let v9: u128 = stream.read_fixed_int().unwrap();
374        let v10: i128 = stream.read_fixed_int().unwrap();
375
376        assert_eq!(v1, 1);
377        assert_eq!(v2, 1);
378        assert_eq!(v3, 2);
379        assert_eq!(v4, 2);
380        assert_eq!(v5, 3);
381        assert_eq!(v6, 3);
382        assert_eq!(v7, 4);
383        assert_eq!(v8, 4);
384        assert_eq!(v9, 5);
385        assert_eq!(v10, 5);
386    }
387
388    #[test]
389    fn test_bytes_left() {
390        let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
391        let mut reader = BitStreamReader::new(&buf);
392
393        assert_eq!(reader.bytes_left(), 4);
394        reader.read_small(3).unwrap(); // read 3 bits
395        assert_eq!(reader.bytes_left(), 3); // 3 full bytes left
396        reader.read_byte().unwrap(); // read one byte
397        assert_eq!(reader.bytes_left(), 2); // now 2 bytes left
398        reader.read_byte().unwrap(); // read another byte
399        assert_eq!(reader.bytes_left(), 1); // now 1 bytes left
400        reader.read_bit().unwrap(); // read one bit
401        assert_eq!(reader.bytes_left(), 0); // no full bytes left
402    }
403}