1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50 let start = self.marker.unwrap_or(0);
51 let end = to.unwrap_or(self.byte_pos());
52
53 if let Some(crypto) = self.crypto.as_ref() {
54 return &crypto.get_cached(false)[start..end];
55 }
56
57 &self.buffer[start..end]
58 }
59
60 pub fn slice_end(&mut self) -> &[u8] {
64 let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66 if let Some(crypto) = self.crypto.as_mut() {
67 crypto.apply_keystream(slice)
68 } else {
69 slice
70 }
71 }
72
73 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
75 self.crypto = crypto;
76 }
77
78 pub fn set_offset_end(&mut self, len: usize) {
80 self.offset_end = len;
81 }
82
83 pub fn byte_pos(&self) -> usize {
85 self.bit_pos / 8
86 }
87
88 fn current_byte(&mut self) -> u8 {
90 if let Some(b) = self.last_read_byte {
91 b
92 } else {
93 let mut b = self.buffer[self.byte_pos()];
94 if let Some(crypto) = self.crypto.as_mut() {
95 b = crypto.apply_keystream_byte(b);
96 }
97
98 self.last_read_byte = Some(b);
99 b
100 }
101 }
102
103 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
105 self.read_small(1).map(|v| v != 0)
106 }
107
108 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
110 assert!(bits > 0 && bits < 8);
111
112 let mut result: u8 = 0;
113 let mut shift = 0;
114
115 while bits > 0 {
116 if self.byte_pos() >= self.buffer.len() - self.offset_end {
117 return Err(DeserializationError::NotEnoughBytes(1));
118 }
119
120 let bit_offset = self.bit_pos % 8;
122
123 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
125
126 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
130 let byte_val = self.current_byte();
131
132 let val = (byte_val & mask) >> bit_offset;
136
137 result |= val << shift;
140
141 bits -= bits_in_current_byte;
143
144 shift += bits_in_current_byte;
146
147 self.bit_pos += bits_in_current_byte as usize;
148
149 if self.bit_pos % 8 == 0 {
151 self.last_read_byte = None;
152 }
153 }
154
155 Ok(result)
156 }
157
158 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
160 self.align_byte();
161
162 if self.byte_pos() >= self.buffer.len() - self.offset_end {
163 return Err(DeserializationError::NotEnoughBytes(1));
164 }
165
166 let byte = self.current_byte();
167 self.bit_pos += 8;
168 self.last_read_byte = None;
169
170 Ok(byte)
171 }
172
173 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
175 self.align_byte();
176
177 let start = self.byte_pos();
178 if start + count > self.buffer.len() - self.offset_end {
179 return Err(DeserializationError::NotEnoughBytes(
180 start + count - self.buffer.len(),
181 ));
182 }
183
184 self.bit_pos += 8 * count;
185 self.last_read_byte = None;
186
187 let slice = &self.buffer[start..start + count];
188 if let Some(crypto) = self.crypto.as_mut() {
189 Ok(crypto.apply_keystream(slice))
190 } else {
191 Ok(slice)
192 }
193 }
194
195 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
198 self.align_byte();
199 let mut num: u128 = 0;
200 let mut multiplier: u128 = 1;
201
202 loop {
203 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
205
206 if (byte & 1 << 7) == 0 {
208 break;
209 }
210
211 multiplier *= 128;
212 }
213
214 Ok(num)
215 }
216
217 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
219 &mut self,
220 ) -> Result<T, DeserializationError> {
221 let data = self.read_bytes(S)?;
222 Ok(FixedInt::deserialize(data))
223 }
224
225 pub fn align_byte(&mut self) {
227 let rem = self.bit_pos % 8;
228 if rem != 0 {
229 self.bit_pos += 8 - rem;
230 self.last_read_byte = None;
231 }
232 }
233
234 pub fn bytes_left(&self) -> usize {
236 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
237 if self.bit_pos % 8 != 0 {
238 left - 1 } else {
240 left
241 }
242 }
243
244 pub fn reset(&mut self) {
246 self.bit_pos = 0;
247 }
248}
249
250#[cfg(test)]
251mod tests {
252 use crate::{DeserializationError, bitstream::CryptoStream};
253
254 use super::BitStreamReader;
255
256 struct PlusOneDecrypter {
257 plain: Vec<u8>,
258 }
259
260 impl CryptoStream for PlusOneDecrypter {
261 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
262 self.plain.push(b + 1);
263 *self.plain.last().unwrap()
264 }
265
266 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
267 let d = slice.iter().map(|s| s + 1);
268 self.plain.extend(d);
269 &self.plain[self.plain.len() - slice.len()..]
270 }
271
272 fn get_cached(&self, original: bool) -> &[u8] {
273 &self.plain
274 }
275 }
276
277 #[test]
278 fn test_decrypt_bytes() {
279 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
280 let mut reader = BitStreamReader::new(&buf);
281 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
282
283 assert_eq!(reader.read_byte(), Ok(2));
284 assert_eq!(reader.read_byte(), Ok(3));
285 assert_eq!(reader.read_byte(), Ok(4));
286 assert_eq!(reader.read_bit(), Ok(true));
288 assert_eq!(reader.read_bit(), Ok(false));
289 assert_eq!(reader.read_bit(), Ok(true));
290 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
291 assert_eq!(reader.read_byte(), Ok(11));
292 }
293
294 fn make_buffer() -> Vec<u8> {
296 vec![0b10101100, 0b11010010, 0xFF, 0x00]
297 }
298
299 #[test]
300 fn test_read_single_bits() {
301 let buf = make_buffer();
302 let mut reader = BitStreamReader::new(&buf);
303
304 assert_eq!(reader.read_bit(), Ok(false));
306 assert_eq!(reader.read_bit(), Ok(false));
307 assert_eq!(reader.read_bit(), Ok(true));
308 assert_eq!(reader.read_bit(), Ok(true));
309 assert_eq!(reader.read_bit(), Ok(false));
310 assert_eq!(reader.read_bit(), Ok(true));
311 assert_eq!(reader.read_bit(), Ok(false));
312 assert_eq!(reader.read_bit(), Ok(true));
313 }
314
315 #[test]
316 fn test_read_small() {
317 let buf = [0b10101100, 0b11010010];
318 let mut reader = BitStreamReader::new(&buf);
319
320 assert_eq!(reader.read_small(3), Ok(0b100));
321 assert_eq!(reader.read_small(4), Ok(0b0101));
322 assert_eq!(reader.read_small(1), Ok(0b1));
323 assert_eq!(reader.read_small(4), Ok(0b0010));
324 }
325
326 #[test]
327 fn test_read_cross_byte() {
328 let buf = [0b10101100, 0b11010001];
329 let mut reader = BitStreamReader::new(&buf);
330
331 assert_eq!(reader.read_small(7), Ok(0b00101100));
333 assert_eq!(reader.read_small(3), Ok(0b011));
334 }
335
336 #[test]
337 fn test_read_byte() {
338 let buf = [0b10101100, 0b11010010];
339 let mut reader = BitStreamReader::new(&buf);
340
341 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
344
345 #[test]
346 fn test_read_bytes() {
347 let buf = [0x01, 0xAA, 0xBB, 0xCC];
348 let mut reader = BitStreamReader::new(&buf);
349
350 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
352 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
353 }
354
355 #[test]
356 fn test_align_byte() {
357 let buf = [0b10101100, 0b11010010];
358 let mut reader = BitStreamReader::new(&buf);
359
360 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
363 }
364
365 #[test]
366 fn test_eof_behavior() {
367 let buf = [0xFF];
368 let mut reader = BitStreamReader::new(&buf);
369
370 assert_eq!(reader.read_byte(), Ok(0xFF));
371 assert_eq!(
372 reader.read_bit(),
373 Err(DeserializationError::NotEnoughBytes(1))
374 );
375 assert_eq!(
376 reader.read_byte(),
377 Err(DeserializationError::NotEnoughBytes(1))
378 );
379 assert_eq!(
380 reader.read_bytes(2),
381 Err(DeserializationError::NotEnoughBytes(2))
382 );
383 }
384
385 #[test]
386 fn test_multiple_operations() {
387 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
388 let mut reader = BitStreamReader::new(&buf);
389
390 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
394 assert_eq!(
395 reader.read_bit(),
396 Err(DeserializationError::NotEnoughBytes(1))
397 );
398 }
399
400 #[test]
401 fn test_read_dyn_int() {
402 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
403 let mut stream = BitStreamReader::new(&buf);
404
405 assert_eq!(Ok(0), stream.read_byte());
406 assert_eq!(Ok(127), stream.read_dyn_int());
407 assert_eq!(Ok(128), stream.read_dyn_int());
408 assert_eq!(Ok(268435455), stream.read_dyn_int());
409 assert_eq!(
410 Err(DeserializationError::NotEnoughBytes(1)),
411 stream.read_dyn_int()
412 );
413 }
414
415 #[test]
416 fn test_read_fixed_int() {
417 let buf = vec![
418 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
419 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
420 0, 0, 0, 10,
421 ];
422
423 let mut stream = BitStreamReader::new(&buf);
424 let v1: u8 = stream.read_fixed_int().unwrap();
425 let v2: i8 = stream.read_fixed_int().unwrap();
426 let v3: u16 = stream.read_fixed_int().unwrap();
427 let v4: i16 = stream.read_fixed_int().unwrap();
428 let v5: u32 = stream.read_fixed_int().unwrap();
429 let v6: i32 = stream.read_fixed_int().unwrap();
430 let v7: u64 = stream.read_fixed_int().unwrap();
431 let v8: i64 = stream.read_fixed_int().unwrap();
432 let v9: u128 = stream.read_fixed_int().unwrap();
433 let v10: i128 = stream.read_fixed_int().unwrap();
434
435 assert_eq!(v1, 1);
436 assert_eq!(v2, 1);
437 assert_eq!(v3, 2);
438 assert_eq!(v4, 2);
439 assert_eq!(v5, 3);
440 assert_eq!(v6, 3);
441 assert_eq!(v7, 4);
442 assert_eq!(v8, 4);
443 assert_eq!(v9, 5);
444 assert_eq!(v10, 5);
445 }
446
447 #[test]
448 fn test_bytes_left() {
449 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
450 let mut reader = BitStreamReader::new(&buf);
451
452 assert_eq!(reader.bytes_left(), 4);
453 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
462
463 #[test]
464 fn offset_end_ignores_bytes_and_can_slice() {
465 let buff = [1, 2, 3, 4, 5];
466 let mut reader = BitStreamReader::new(&buff);
467
468 reader.set_offset_end(2);
469 assert_eq!(reader.bytes_left(), 3);
470 assert_eq!(reader.read_byte(), Ok(1));
471
472 assert_eq!(reader.slice(true), &[1, 2, 3]);
473 assert_eq!(reader.slice(false), &[2, 3]);
474 assert_eq!(reader.slice_end(), &[4, 5]);
475
476 assert_eq!(reader.read_byte(), Ok(2));
477 assert_eq!(reader.read_byte(), Ok(3));
478 assert_eq!(
479 reader.read_byte(),
480 Err(DeserializationError::NotEnoughBytes(1))
481 );
482
483 reader.set_offset_end(0);
484 assert_eq!(reader.bytes_left(), 2);
485 assert_eq!(reader.read_byte(), Ok(4));
486 assert_eq!(reader.read_byte(), Ok(5));
487 }
488
489 #[test]
490 fn test_slice_start() {
491 let buff = [10, 20, 30, 40, 50];
492 let mut reader = BitStreamReader::new(&buff);
493
494 assert_eq!(reader.slice_marker(None), &[]);
495
496 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
498
499 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
501
502 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
504
505 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
507 }
508
509 #[test]
510 fn test_slice_start_with_marker() {
511 let buff = [10, 20, 30, 40, 50];
512 let mut reader = BitStreamReader::new(&buff);
513
514 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
516 reader.set_marker();
517 assert_eq!(reader.slice_marker(None), &[]);
518
519 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
521 }
522}