1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 crypto: Option<Box<dyn CryptoStream>>,
10}
11
12impl<'a> BitStreamReader<'a> {
13 pub fn new(buffer: &'a [u8]) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 last_read_byte: None,
20 }
21 }
22
23 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
25 self.crypto = crypto;
26 }
27
28 pub fn byte_pos(&self) -> usize {
30 self.bit_pos / 8
31 }
32
33 fn current_byte(&mut self) -> u8 {
35 if let Some(b) = self.last_read_byte {
36 b
37 } else {
38 let mut b = self.buffer[self.byte_pos()];
39 if let Some(crypto) = self.crypto.as_mut() {
40 b = crypto.decrypt_byte(b);
41 }
42
43 self.last_read_byte = Some(b);
44 b
45 }
46 }
47
48 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
50 self.read_small(1).map(|v| v != 0)
51 }
52
53 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
55 assert!(bits > 0 && bits < 8);
56
57 let mut result: u8 = 0;
58 let mut shift = 0;
59
60 while bits > 0 {
61 if self.byte_pos() >= self.buffer.len() {
62 return Err(DeserializationError::NotEnoughBytes(1));
63 }
64
65 let bit_offset = self.bit_pos % 8;
67
68 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
70
71 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
75 let byte_val = self.current_byte();
76
77 let val = (byte_val & mask) >> bit_offset;
81
82 result |= val << shift;
85
86 bits -= bits_in_current_byte;
88
89 shift += bits_in_current_byte;
91
92 self.bit_pos += bits_in_current_byte as usize;
93
94 if self.bit_pos % 8 == 0 {
96 self.last_read_byte = None;
97 }
98 }
99
100 Ok(result)
101 }
102
103 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
105 self.align_byte();
106
107 if self.byte_pos() >= self.buffer.len() {
108 return Err(DeserializationError::NotEnoughBytes(1));
109 }
110
111 let byte = self.current_byte();
112 self.bit_pos += 8;
113 self.last_read_byte = None;
114
115 Ok(byte)
116 }
117
118 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
120 self.align_byte();
121
122 let start = self.byte_pos();
123 if start + count > self.buffer.len() {
124 return Err(DeserializationError::NotEnoughBytes(
125 start + count - self.buffer.len(),
126 ));
127 }
128
129 self.bit_pos += 8 * count;
130 self.last_read_byte = None;
131
132 let slice = &self.buffer[start..start + count];
133 if let Some(crypto) = self.crypto.as_mut() {
134 Ok(crypto.decrypt_slice(slice))
135 } else {
136 Ok(slice)
137 }
138 }
139
140 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
143 self.align_byte();
144 let mut num: u128 = 0;
145 let mut multiplier: u128 = 1;
146
147 loop {
148 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
150
151 if (byte & 1 << 7) == 0 {
153 break;
154 }
155
156 multiplier *= 128;
157 }
158
159 Ok(num)
160 }
161
162 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
164 &mut self,
165 ) -> Result<T, DeserializationError> {
166 let data = self.read_bytes(S)?;
167 Ok(FixedInt::deserialize(data))
168 }
169
170 pub fn align_byte(&mut self) {
172 let rem = self.bit_pos % 8;
173 if rem != 0 {
174 self.bit_pos += 8 - rem;
175 self.last_read_byte = None;
176 }
177 }
178
179 pub fn bytes_left(&self) -> usize {
181 let left = self.buffer.len() - self.byte_pos();
182 if self.bit_pos % 8 != 0 {
183 left - 1 } else {
185 left
186 }
187 }
188
189 pub fn reset(&mut self) {
191 self.bit_pos = 0;
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use crate::{DeserializationError, bitstream::CryptoStream};
198
199 use super::BitStreamReader;
200
201 struct PlusOneDecrypter {
202 plain: Vec<u8>
203 }
204
205 impl CryptoStream for PlusOneDecrypter {
206 fn decrypt_byte(&mut self, b: u8) -> u8 {
207 self.plain.push(b + 1);
208 *self.plain.last().unwrap()
209 }
210
211 fn decrypt_slice(&mut self, slice: &[u8]) -> &[u8] {
212 let d = slice.iter().map(|s|s + 1);
213 self.plain.extend(d);
214 &self.plain[self.plain.len() - slice.len()..]
215 }
216 }
217
218 #[test]
219 fn test_decrypt_bytes() {
220 let buf = vec![1,2,3,4,5,6,7,8,9,10];
221 let mut reader = BitStreamReader::new(&buf);
222 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
223
224 assert_eq!(reader.read_byte(), Ok(2));
225 assert_eq!(reader.read_byte(), Ok(3));
226 assert_eq!(reader.read_byte(), Ok(4));
227 assert_eq!(reader.read_bit(), Ok(true));
229 assert_eq!(reader.read_bit(), Ok(false));
230 assert_eq!(reader.read_bit(), Ok(true));
231 assert_eq!(reader.read_bytes(5), Ok(&[6,7,8,9,10][..]));
232 assert_eq!(reader.read_byte(), Ok(11));
233 }
234
235 fn make_buffer() -> Vec<u8> {
237 vec![0b10101100, 0b11010010, 0xFF, 0x00]
238 }
239
240 #[test]
241 fn test_read_single_bits() {
242 let buf = make_buffer();
243 let mut reader = BitStreamReader::new(&buf);
244
245 assert_eq!(reader.read_bit(), Ok(false));
247 assert_eq!(reader.read_bit(), Ok(false));
248 assert_eq!(reader.read_bit(), Ok(true));
249 assert_eq!(reader.read_bit(), Ok(true));
250 assert_eq!(reader.read_bit(), Ok(false));
251 assert_eq!(reader.read_bit(), Ok(true));
252 assert_eq!(reader.read_bit(), Ok(false));
253 assert_eq!(reader.read_bit(), Ok(true));
254 }
255
256 #[test]
257 fn test_read_small() {
258 let buf = [0b10101100, 0b11010010];
259 let mut reader = BitStreamReader::new(&buf);
260
261 assert_eq!(reader.read_small(3), Ok(0b100));
262 assert_eq!(reader.read_small(4), Ok(0b0101));
263 assert_eq!(reader.read_small(1), Ok(0b1));
264 assert_eq!(reader.read_small(4), Ok(0b0010));
265 }
266
267 #[test]
268 fn test_read_cross_byte() {
269 let buf = [0b10101100, 0b11010001];
270 let mut reader = BitStreamReader::new(&buf);
271
272 assert_eq!(reader.read_small(7), Ok(0b00101100));
274 assert_eq!(reader.read_small(3), Ok(0b011));
275 }
276
277 #[test]
278 fn test_read_byte() {
279 let buf = [0b10101100, 0b11010010];
280 let mut reader = BitStreamReader::new(&buf);
281
282 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
285
286 #[test]
287 fn test_read_bytes() {
288 let buf = [0x01, 0xAA, 0xBB, 0xCC];
289 let mut reader = BitStreamReader::new(&buf);
290
291 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
293 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
294 }
295
296 #[test]
297 fn test_align_byte() {
298 let buf = [0b10101100, 0b11010010];
299 let mut reader = BitStreamReader::new(&buf);
300
301 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
304 }
305
306 #[test]
307 fn test_eof_behavior() {
308 let buf = [0xFF];
309 let mut reader = BitStreamReader::new(&buf);
310
311 assert_eq!(reader.read_byte(), Ok(0xFF));
312 assert_eq!(
313 reader.read_bit(),
314 Err(DeserializationError::NotEnoughBytes(1))
315 );
316 assert_eq!(
317 reader.read_byte(),
318 Err(DeserializationError::NotEnoughBytes(1))
319 );
320 assert_eq!(
321 reader.read_bytes(2),
322 Err(DeserializationError::NotEnoughBytes(2))
323 );
324 }
325
326 #[test]
327 fn test_multiple_operations() {
328 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
329 let mut reader = BitStreamReader::new(&buf);
330
331 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
335 assert_eq!(
336 reader.read_bit(),
337 Err(DeserializationError::NotEnoughBytes(1))
338 );
339 }
340
341 #[test]
342 fn test_read_dyn_int() {
343 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
344 let mut stream = BitStreamReader::new(&buf);
345
346 assert_eq!(Ok(0), stream.read_byte());
347 assert_eq!(Ok(127), stream.read_dyn_int());
348 assert_eq!(Ok(128), stream.read_dyn_int());
349 assert_eq!(Ok(268435455), stream.read_dyn_int());
350 assert_eq!(
351 Err(DeserializationError::NotEnoughBytes(1)),
352 stream.read_dyn_int()
353 );
354 }
355
356 #[test]
357 fn test_read_fixed_int() {
358 let buf = vec![
359 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
360 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
361 0, 0, 0, 10,
362 ];
363
364 let mut stream = BitStreamReader::new(&buf);
365 let v1: u8 = stream.read_fixed_int().unwrap();
366 let v2: i8 = stream.read_fixed_int().unwrap();
367 let v3: u16 = stream.read_fixed_int().unwrap();
368 let v4: i16 = stream.read_fixed_int().unwrap();
369 let v5: u32 = stream.read_fixed_int().unwrap();
370 let v6: i32 = stream.read_fixed_int().unwrap();
371 let v7: u64 = stream.read_fixed_int().unwrap();
372 let v8: i64 = stream.read_fixed_int().unwrap();
373 let v9: u128 = stream.read_fixed_int().unwrap();
374 let v10: i128 = stream.read_fixed_int().unwrap();
375
376 assert_eq!(v1, 1);
377 assert_eq!(v2, 1);
378 assert_eq!(v3, 2);
379 assert_eq!(v4, 2);
380 assert_eq!(v5, 3);
381 assert_eq!(v6, 3);
382 assert_eq!(v7, 4);
383 assert_eq!(v8, 4);
384 assert_eq!(v9, 5);
385 assert_eq!(v10, 5);
386 }
387
388 #[test]
389 fn test_bytes_left() {
390 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
391 let mut reader = BitStreamReader::new(&buf);
392
393 assert_eq!(reader.bytes_left(), 4);
394 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
403}