1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
49 &self.buffer[self.marker.unwrap_or(0)..to.unwrap_or(self.byte_pos())]
50 }
51
52 pub fn slice_end(&self) -> &[u8] {
54 &self.buffer[self.buffer.len() - self.offset_end..]
55 }
56
57 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
59 self.crypto = crypto;
60 }
61
62 pub fn set_offset_end(&mut self, len: usize) {
64 self.offset_end = len;
65 }
66
67 pub fn byte_pos(&self) -> usize {
69 self.bit_pos / 8
70 }
71
72 fn current_byte(&mut self) -> u8 {
74 if let Some(b) = self.last_read_byte {
75 b
76 } else {
77 let mut b = self.buffer[self.byte_pos()];
78 if let Some(crypto) = self.crypto.as_mut() {
79 b = crypto.apply_keystream_byte(b);
80 }
81
82 self.last_read_byte = Some(b);
83 b
84 }
85 }
86
87 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
89 self.read_small(1).map(|v| v != 0)
90 }
91
92 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
94 assert!(bits > 0 && bits < 8);
95
96 let mut result: u8 = 0;
97 let mut shift = 0;
98
99 while bits > 0 {
100 if self.byte_pos() >= self.buffer.len() - self.offset_end {
101 return Err(DeserializationError::NotEnoughBytes(1));
102 }
103
104 let bit_offset = self.bit_pos % 8;
106
107 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
109
110 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
114 let byte_val = self.current_byte();
115
116 let val = (byte_val & mask) >> bit_offset;
120
121 result |= val << shift;
124
125 bits -= bits_in_current_byte;
127
128 shift += bits_in_current_byte;
130
131 self.bit_pos += bits_in_current_byte as usize;
132
133 if self.bit_pos % 8 == 0 {
135 self.last_read_byte = None;
136 }
137 }
138
139 Ok(result)
140 }
141
142 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
144 self.align_byte();
145
146 if self.byte_pos() >= self.buffer.len() - self.offset_end {
147 return Err(DeserializationError::NotEnoughBytes(1));
148 }
149
150 let byte = self.current_byte();
151 self.bit_pos += 8;
152 self.last_read_byte = None;
153
154 Ok(byte)
155 }
156
157 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
159 self.align_byte();
160
161 let start = self.byte_pos();
162 if start + count > self.buffer.len() - self.offset_end {
163 return Err(DeserializationError::NotEnoughBytes(
164 start + count - self.buffer.len(),
165 ));
166 }
167
168 self.bit_pos += 8 * count;
169 self.last_read_byte = None;
170
171 let slice = &self.buffer[start..start + count];
172 if let Some(crypto) = self.crypto.as_mut() {
173 Ok(crypto.apply_keystream(slice))
174 } else {
175 Ok(slice)
176 }
177 }
178
179 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
182 self.align_byte();
183 let mut num: u128 = 0;
184 let mut multiplier: u128 = 1;
185
186 loop {
187 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
189
190 if (byte & 1 << 7) == 0 {
192 break;
193 }
194
195 multiplier *= 128;
196 }
197
198 Ok(num)
199 }
200
201 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
203 &mut self,
204 ) -> Result<T, DeserializationError> {
205 let data = self.read_bytes(S)?;
206 Ok(FixedInt::deserialize(data))
207 }
208
209 pub fn align_byte(&mut self) {
211 let rem = self.bit_pos % 8;
212 if rem != 0 {
213 self.bit_pos += 8 - rem;
214 self.last_read_byte = None;
215 }
216 }
217
218 pub fn bytes_left(&self) -> usize {
220 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
221 if self.bit_pos % 8 != 0 {
222 left - 1 } else {
224 left
225 }
226 }
227
228 pub fn reset(&mut self) {
230 self.bit_pos = 0;
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use crate::{DeserializationError, bitstream::CryptoStream};
237
238 use super::BitStreamReader;
239
240 struct PlusOneDecrypter {
241 plain: Vec<u8>,
242 }
243
244 impl CryptoStream for PlusOneDecrypter {
245 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
246 self.plain.push(b + 1);
247 *self.plain.last().unwrap()
248 }
249
250 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
251 let d = slice.iter().map(|s| s + 1);
252 self.plain.extend(d);
253 &self.plain[self.plain.len() - slice.len()..]
254 }
255 }
256
257 #[test]
258 fn test_decrypt_bytes() {
259 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
260 let mut reader = BitStreamReader::new(&buf);
261 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
262
263 assert_eq!(reader.read_byte(), Ok(2));
264 assert_eq!(reader.read_byte(), Ok(3));
265 assert_eq!(reader.read_byte(), Ok(4));
266 assert_eq!(reader.read_bit(), Ok(true));
268 assert_eq!(reader.read_bit(), Ok(false));
269 assert_eq!(reader.read_bit(), Ok(true));
270 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
271 assert_eq!(reader.read_byte(), Ok(11));
272 }
273
274 fn make_buffer() -> Vec<u8> {
276 vec![0b10101100, 0b11010010, 0xFF, 0x00]
277 }
278
279 #[test]
280 fn test_read_single_bits() {
281 let buf = make_buffer();
282 let mut reader = BitStreamReader::new(&buf);
283
284 assert_eq!(reader.read_bit(), Ok(false));
286 assert_eq!(reader.read_bit(), Ok(false));
287 assert_eq!(reader.read_bit(), Ok(true));
288 assert_eq!(reader.read_bit(), Ok(true));
289 assert_eq!(reader.read_bit(), Ok(false));
290 assert_eq!(reader.read_bit(), Ok(true));
291 assert_eq!(reader.read_bit(), Ok(false));
292 assert_eq!(reader.read_bit(), Ok(true));
293 }
294
295 #[test]
296 fn test_read_small() {
297 let buf = [0b10101100, 0b11010010];
298 let mut reader = BitStreamReader::new(&buf);
299
300 assert_eq!(reader.read_small(3), Ok(0b100));
301 assert_eq!(reader.read_small(4), Ok(0b0101));
302 assert_eq!(reader.read_small(1), Ok(0b1));
303 assert_eq!(reader.read_small(4), Ok(0b0010));
304 }
305
306 #[test]
307 fn test_read_cross_byte() {
308 let buf = [0b10101100, 0b11010001];
309 let mut reader = BitStreamReader::new(&buf);
310
311 assert_eq!(reader.read_small(7), Ok(0b00101100));
313 assert_eq!(reader.read_small(3), Ok(0b011));
314 }
315
316 #[test]
317 fn test_read_byte() {
318 let buf = [0b10101100, 0b11010010];
319 let mut reader = BitStreamReader::new(&buf);
320
321 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
324
325 #[test]
326 fn test_read_bytes() {
327 let buf = [0x01, 0xAA, 0xBB, 0xCC];
328 let mut reader = BitStreamReader::new(&buf);
329
330 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
332 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
333 }
334
335 #[test]
336 fn test_align_byte() {
337 let buf = [0b10101100, 0b11010010];
338 let mut reader = BitStreamReader::new(&buf);
339
340 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
343 }
344
345 #[test]
346 fn test_eof_behavior() {
347 let buf = [0xFF];
348 let mut reader = BitStreamReader::new(&buf);
349
350 assert_eq!(reader.read_byte(), Ok(0xFF));
351 assert_eq!(
352 reader.read_bit(),
353 Err(DeserializationError::NotEnoughBytes(1))
354 );
355 assert_eq!(
356 reader.read_byte(),
357 Err(DeserializationError::NotEnoughBytes(1))
358 );
359 assert_eq!(
360 reader.read_bytes(2),
361 Err(DeserializationError::NotEnoughBytes(2))
362 );
363 }
364
365 #[test]
366 fn test_multiple_operations() {
367 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
368 let mut reader = BitStreamReader::new(&buf);
369
370 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
374 assert_eq!(
375 reader.read_bit(),
376 Err(DeserializationError::NotEnoughBytes(1))
377 );
378 }
379
380 #[test]
381 fn test_read_dyn_int() {
382 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
383 let mut stream = BitStreamReader::new(&buf);
384
385 assert_eq!(Ok(0), stream.read_byte());
386 assert_eq!(Ok(127), stream.read_dyn_int());
387 assert_eq!(Ok(128), stream.read_dyn_int());
388 assert_eq!(Ok(268435455), stream.read_dyn_int());
389 assert_eq!(
390 Err(DeserializationError::NotEnoughBytes(1)),
391 stream.read_dyn_int()
392 );
393 }
394
395 #[test]
396 fn test_read_fixed_int() {
397 let buf = vec![
398 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
399 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
400 0, 0, 0, 10,
401 ];
402
403 let mut stream = BitStreamReader::new(&buf);
404 let v1: u8 = stream.read_fixed_int().unwrap();
405 let v2: i8 = stream.read_fixed_int().unwrap();
406 let v3: u16 = stream.read_fixed_int().unwrap();
407 let v4: i16 = stream.read_fixed_int().unwrap();
408 let v5: u32 = stream.read_fixed_int().unwrap();
409 let v6: i32 = stream.read_fixed_int().unwrap();
410 let v7: u64 = stream.read_fixed_int().unwrap();
411 let v8: i64 = stream.read_fixed_int().unwrap();
412 let v9: u128 = stream.read_fixed_int().unwrap();
413 let v10: i128 = stream.read_fixed_int().unwrap();
414
415 assert_eq!(v1, 1);
416 assert_eq!(v2, 1);
417 assert_eq!(v3, 2);
418 assert_eq!(v4, 2);
419 assert_eq!(v5, 3);
420 assert_eq!(v6, 3);
421 assert_eq!(v7, 4);
422 assert_eq!(v8, 4);
423 assert_eq!(v9, 5);
424 assert_eq!(v10, 5);
425 }
426
427 #[test]
428 fn test_bytes_left() {
429 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
430 let mut reader = BitStreamReader::new(&buf);
431
432 assert_eq!(reader.bytes_left(), 4);
433 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
442
443 #[test]
444 fn offset_end_ignores_bytes_and_can_slice() {
445 let buff = [1, 2, 3, 4, 5];
446 let mut reader = BitStreamReader::new(&buff);
447
448 reader.set_offset_end(2);
449 assert_eq!(reader.bytes_left(), 3);
450 assert_eq!(reader.read_byte(), Ok(1));
451
452 assert_eq!(reader.slice(true), &[1, 2, 3]);
453 assert_eq!(reader.slice(false), &[2, 3]);
454 assert_eq!(reader.slice_end(), &[4, 5]);
455
456 assert_eq!(reader.read_byte(), Ok(2));
457 assert_eq!(reader.read_byte(), Ok(3));
458 assert_eq!(
459 reader.read_byte(),
460 Err(DeserializationError::NotEnoughBytes(1))
461 );
462
463 reader.set_offset_end(0);
464 assert_eq!(reader.bytes_left(), 2);
465 assert_eq!(reader.read_byte(), Ok(4));
466 assert_eq!(reader.read_byte(), Ok(5));
467 }
468
469 #[test]
470 fn test_slice_start() {
471 let buff = [10, 20, 30, 40, 50];
472 let mut reader = BitStreamReader::new(&buff);
473
474 assert_eq!(reader.slice_marker(None), &[]);
475
476 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
478
479 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
481
482 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
484
485 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
487 }
488
489 #[test]
490 fn test_slice_start_with_marker() {
491 let buff = [10, 20, 30, 40, 50];
492 let mut reader = BitStreamReader::new(&buff);
493
494 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
496 reader.set_marker();
497 assert_eq!(reader.slice_marker(None), &[]);
498
499 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
501 }
502}