1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50 let start = self.marker.unwrap_or(0);
51 let end = to.unwrap_or(self.byte_pos());
52
53 if let Some(crypto) = self.crypto.as_ref() {
54 return &crypto.get_cached(false)[start..end];
55 }
56
57 &self.buffer[start..end]
58 }
59
60 pub fn slice_end(&mut self) -> &[u8] {
64 let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66 if let Some(crypto) = self.crypto.as_mut() {
67 crypto.apply_keystream(slice)
68 } else {
69 slice
70 }
71 }
72
73 pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
75 if let Some(existing) = self.crypto.as_ref()
76 && let Some(new_crypto) = crypto.as_mut()
77 {
78 new_crypto.replace(existing);
79 self.crypto = crypto;
80 } else {
81 self.crypto = crypto;
82 }
83 }
84
85 pub fn reset_crypto(&mut self) {
87 self.crypto = None;
88 }
89
90 pub fn set_offset_end(&mut self, len: usize) {
92 self.offset_end = len;
93 }
94
95 pub fn byte_pos(&self) -> usize {
97 self.bit_pos / 8
98 }
99
100 fn current_byte(&mut self) -> u8 {
102 if let Some(b) = self.last_read_byte {
103 b
104 } else {
105 let mut b = self.buffer[self.byte_pos()];
106 if let Some(crypto) = self.crypto.as_mut() {
107 b = crypto.apply_keystream_byte(b);
108 }
109
110 self.last_read_byte = Some(b);
111 b
112 }
113 }
114
115 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
117 self.read_small(1).map(|v| v != 0)
118 }
119
120 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
122 assert!(bits > 0 && bits < 8);
123
124 let mut result: u8 = 0;
125 let mut shift = 0;
126
127 while bits > 0 {
128 if self.byte_pos() >= self.buffer.len() - self.offset_end {
129 return Err(DeserializationError::NotEnoughBytes(1));
130 }
131
132 let bit_offset = self.bit_pos % 8;
134
135 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
137
138 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
142 let byte_val = self.current_byte();
143
144 let val = (byte_val & mask) >> bit_offset;
148
149 result |= val << shift;
152
153 bits -= bits_in_current_byte;
155
156 shift += bits_in_current_byte;
158
159 self.bit_pos += bits_in_current_byte as usize;
160
161 if self.bit_pos % 8 == 0 {
163 self.last_read_byte = None;
164 }
165 }
166
167 Ok(result)
168 }
169
170 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
172 self.align_byte();
173
174 if self.byte_pos() >= self.buffer.len() - self.offset_end {
175 return Err(DeserializationError::NotEnoughBytes(1));
176 }
177
178 let byte = self.current_byte();
179 self.bit_pos += 8;
180 self.last_read_byte = None;
181
182 Ok(byte)
183 }
184
185 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
187 self.align_byte();
188
189 let start = self.byte_pos();
190 if start + count > self.buffer.len() - self.offset_end {
191 return Err(DeserializationError::NotEnoughBytes(
192 start + count - self.buffer.len(),
193 ));
194 }
195
196 self.bit_pos += 8 * count;
197 self.last_read_byte = None;
198
199 let slice = &self.buffer[start..start + count];
200 if let Some(crypto) = self.crypto.as_mut() {
201 Ok(crypto.apply_keystream(slice))
202 } else {
203 Ok(slice)
204 }
205 }
206
207 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
210 self.align_byte();
211 let mut num: u128 = 0;
212 let mut multiplier: u128 = 1;
213
214 loop {
215 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
217
218 if (byte & 1 << 7) == 0 {
220 break;
221 }
222
223 multiplier *= 128;
224 }
225
226 Ok(num)
227 }
228
229 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
231 &mut self,
232 ) -> Result<T, DeserializationError> {
233 let data = self.read_bytes(S)?;
234 Ok(FixedInt::deserialize(data))
235 }
236
237 pub fn align_byte(&mut self) {
239 let rem = self.bit_pos % 8;
240 if rem != 0 {
241 self.bit_pos += 8 - rem;
242 self.last_read_byte = None;
243 }
244 }
245
246 pub fn bytes_left(&self) -> usize {
248 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
249 if self.bit_pos % 8 != 0 {
250 left - 1 } else {
252 left
253 }
254 }
255
256 pub fn reset(&mut self) {
258 self.bit_pos = 0;
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use crate::{DeserializationError, bitstream::CryptoStream};
265
266 use super::BitStreamReader;
267
268 struct PlusOneDecrypter {
269 plain: Vec<u8>,
270 }
271
272 impl CryptoStream for PlusOneDecrypter {
273 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
274 self.plain.push(b + 1);
275 *self.plain.last().unwrap()
276 }
277
278 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
279 let d = slice.iter().map(|s| s + 1);
280 self.plain.extend(d);
281 &self.plain[self.plain.len() - slice.len()..]
282 }
283
284 fn get_cached(&self, original: bool) -> &[u8] {
285 &self.plain
286 }
287
288 fn replace(&mut self, other: &Box<dyn CryptoStream>) {
289 self.plain = other.get_cached(true).to_vec();
290 }
291 }
292
293 #[test]
294 fn test_decrypt_bytes() {
295 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
296 let mut reader = BitStreamReader::new(&buf);
297 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
298
299 assert_eq!(reader.read_byte(), Ok(2));
300 assert_eq!(reader.read_byte(), Ok(3));
301 assert_eq!(reader.read_byte(), Ok(4));
302 assert_eq!(reader.read_bit(), Ok(true));
304 assert_eq!(reader.read_bit(), Ok(false));
305 assert_eq!(reader.read_bit(), Ok(true));
306 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
307 assert_eq!(reader.read_byte(), Ok(11));
308 }
309
310 fn make_buffer() -> Vec<u8> {
312 vec![0b10101100, 0b11010010, 0xFF, 0x00]
313 }
314
315 #[test]
316 fn test_read_single_bits() {
317 let buf = make_buffer();
318 let mut reader = BitStreamReader::new(&buf);
319
320 assert_eq!(reader.read_bit(), Ok(false));
322 assert_eq!(reader.read_bit(), Ok(false));
323 assert_eq!(reader.read_bit(), Ok(true));
324 assert_eq!(reader.read_bit(), Ok(true));
325 assert_eq!(reader.read_bit(), Ok(false));
326 assert_eq!(reader.read_bit(), Ok(true));
327 assert_eq!(reader.read_bit(), Ok(false));
328 assert_eq!(reader.read_bit(), Ok(true));
329 }
330
331 #[test]
332 fn test_read_small() {
333 let buf = [0b10101100, 0b11010010];
334 let mut reader = BitStreamReader::new(&buf);
335
336 assert_eq!(reader.read_small(3), Ok(0b100));
337 assert_eq!(reader.read_small(4), Ok(0b0101));
338 assert_eq!(reader.read_small(1), Ok(0b1));
339 assert_eq!(reader.read_small(4), Ok(0b0010));
340 }
341
342 #[test]
343 fn test_read_cross_byte() {
344 let buf = [0b10101100, 0b11010001];
345 let mut reader = BitStreamReader::new(&buf);
346
347 assert_eq!(reader.read_small(7), Ok(0b00101100));
349 assert_eq!(reader.read_small(3), Ok(0b011));
350 }
351
352 #[test]
353 fn test_read_byte() {
354 let buf = [0b10101100, 0b11010010];
355 let mut reader = BitStreamReader::new(&buf);
356
357 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
360
361 #[test]
362 fn test_read_bytes() {
363 let buf = [0x01, 0xAA, 0xBB, 0xCC];
364 let mut reader = BitStreamReader::new(&buf);
365
366 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
368 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
369 }
370
371 #[test]
372 fn test_align_byte() {
373 let buf = [0b10101100, 0b11010010];
374 let mut reader = BitStreamReader::new(&buf);
375
376 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
379 }
380
381 #[test]
382 fn test_eof_behavior() {
383 let buf = [0xFF];
384 let mut reader = BitStreamReader::new(&buf);
385
386 assert_eq!(reader.read_byte(), Ok(0xFF));
387 assert_eq!(
388 reader.read_bit(),
389 Err(DeserializationError::NotEnoughBytes(1))
390 );
391 assert_eq!(
392 reader.read_byte(),
393 Err(DeserializationError::NotEnoughBytes(1))
394 );
395 assert_eq!(
396 reader.read_bytes(2),
397 Err(DeserializationError::NotEnoughBytes(2))
398 );
399 }
400
401 #[test]
402 fn test_multiple_operations() {
403 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
404 let mut reader = BitStreamReader::new(&buf);
405
406 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
410 assert_eq!(
411 reader.read_bit(),
412 Err(DeserializationError::NotEnoughBytes(1))
413 );
414 }
415
416 #[test]
417 fn test_read_dyn_int() {
418 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
419 let mut stream = BitStreamReader::new(&buf);
420
421 assert_eq!(Ok(0), stream.read_byte());
422 assert_eq!(Ok(127), stream.read_dyn_int());
423 assert_eq!(Ok(128), stream.read_dyn_int());
424 assert_eq!(Ok(268435455), stream.read_dyn_int());
425 assert_eq!(
426 Err(DeserializationError::NotEnoughBytes(1)),
427 stream.read_dyn_int()
428 );
429 }
430
431 #[test]
432 fn test_read_fixed_int() {
433 let buf = vec![
434 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
435 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
436 0, 0, 0, 10,
437 ];
438
439 let mut stream = BitStreamReader::new(&buf);
440 let v1: u8 = stream.read_fixed_int().unwrap();
441 let v2: i8 = stream.read_fixed_int().unwrap();
442 let v3: u16 = stream.read_fixed_int().unwrap();
443 let v4: i16 = stream.read_fixed_int().unwrap();
444 let v5: u32 = stream.read_fixed_int().unwrap();
445 let v6: i32 = stream.read_fixed_int().unwrap();
446 let v7: u64 = stream.read_fixed_int().unwrap();
447 let v8: i64 = stream.read_fixed_int().unwrap();
448 let v9: u128 = stream.read_fixed_int().unwrap();
449 let v10: i128 = stream.read_fixed_int().unwrap();
450
451 assert_eq!(v1, 1);
452 assert_eq!(v2, 1);
453 assert_eq!(v3, 2);
454 assert_eq!(v4, 2);
455 assert_eq!(v5, 3);
456 assert_eq!(v6, 3);
457 assert_eq!(v7, 4);
458 assert_eq!(v8, 4);
459 assert_eq!(v9, 5);
460 assert_eq!(v10, 5);
461 }
462
463 #[test]
464 fn test_bytes_left() {
465 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
466 let mut reader = BitStreamReader::new(&buf);
467
468 assert_eq!(reader.bytes_left(), 4);
469 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
478
479 #[test]
480 fn offset_end_ignores_bytes_and_can_slice() {
481 let buff = [1, 2, 3, 4, 5];
482 let mut reader = BitStreamReader::new(&buff);
483
484 reader.set_offset_end(2);
485 assert_eq!(reader.bytes_left(), 3);
486 assert_eq!(reader.read_byte(), Ok(1));
487
488 assert_eq!(reader.slice(true), &[1, 2, 3]);
489 assert_eq!(reader.slice(false), &[2, 3]);
490 assert_eq!(reader.slice_end(), &[4, 5]);
491
492 assert_eq!(reader.read_byte(), Ok(2));
493 assert_eq!(reader.read_byte(), Ok(3));
494 assert_eq!(
495 reader.read_byte(),
496 Err(DeserializationError::NotEnoughBytes(1))
497 );
498
499 reader.set_offset_end(0);
500 assert_eq!(reader.bytes_left(), 2);
501 assert_eq!(reader.read_byte(), Ok(4));
502 assert_eq!(reader.read_byte(), Ok(5));
503 }
504
505 #[test]
506 fn test_slice_start() {
507 let buff = [10, 20, 30, 40, 50];
508 let mut reader = BitStreamReader::new(&buff);
509
510 assert_eq!(reader.slice_marker(None), &[]);
511
512 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
514
515 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
517
518 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
520
521 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
523 }
524
525 #[test]
526 fn test_slice_start_with_marker() {
527 let buff = [10, 20, 30, 40, 50];
528 let mut reader = BitStreamReader::new(&buff);
529
530 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
532 reader.set_marker();
533 assert_eq!(reader.slice_marker(None), &[]);
534
535 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
537 }
538}