1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50 let start = self.marker.unwrap_or(0);
51 let end = to.unwrap_or(self.byte_pos());
52
53 if let Some(crypto) = self.crypto.as_ref() {
54 return &crypto.get_cached(false)[start..end];
55 }
56
57 &self.buffer[start..end]
58 }
59
60 pub fn slice_end(&mut self) -> &[u8] {
64 let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66 if let Some(crypto) = self.crypto.as_mut() {
67 crypto.apply_keystream(slice)
68 } else {
69 slice
70 }
71 }
72
73 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
75 self.crypto = crypto;
76 }
77
78 pub fn reset_crypto(&mut self) {
80 self.crypto = None;
81 }
82
83 pub fn set_offset_end(&mut self, len: usize) {
85 self.offset_end = len;
86 }
87
88 pub fn byte_pos(&self) -> usize {
90 self.bit_pos / 8
91 }
92
93 fn current_byte(&mut self) -> u8 {
95 if let Some(b) = self.last_read_byte {
96 b
97 } else {
98 let mut b = self.buffer[self.byte_pos()];
99 if let Some(crypto) = self.crypto.as_mut() {
100 b = crypto.apply_keystream_byte(b);
101 }
102
103 self.last_read_byte = Some(b);
104 b
105 }
106 }
107
108 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
110 self.read_small(1).map(|v| v != 0)
111 }
112
113 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
115 assert!(bits > 0 && bits < 8);
116
117 let mut result: u8 = 0;
118 let mut shift = 0;
119
120 while bits > 0 {
121 if self.byte_pos() >= self.buffer.len() - self.offset_end {
122 return Err(DeserializationError::NotEnoughBytes(1));
123 }
124
125 let bit_offset = self.bit_pos % 8;
127
128 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
130
131 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
135 let byte_val = self.current_byte();
136
137 let val = (byte_val & mask) >> bit_offset;
141
142 result |= val << shift;
145
146 bits -= bits_in_current_byte;
148
149 shift += bits_in_current_byte;
151
152 self.bit_pos += bits_in_current_byte as usize;
153
154 if self.bit_pos % 8 == 0 {
156 self.last_read_byte = None;
157 }
158 }
159
160 Ok(result)
161 }
162
163 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
165 self.align_byte();
166
167 if self.byte_pos() >= self.buffer.len() - self.offset_end {
168 return Err(DeserializationError::NotEnoughBytes(1));
169 }
170
171 let byte = self.current_byte();
172 self.bit_pos += 8;
173 self.last_read_byte = None;
174
175 Ok(byte)
176 }
177
178 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
180 self.align_byte();
181
182 let start = self.byte_pos();
183 if start + count > self.buffer.len() - self.offset_end {
184 return Err(DeserializationError::NotEnoughBytes(
185 start + count - self.buffer.len(),
186 ));
187 }
188
189 self.bit_pos += 8 * count;
190 self.last_read_byte = None;
191
192 let slice = &self.buffer[start..start + count];
193 if let Some(crypto) = self.crypto.as_mut() {
194 Ok(crypto.apply_keystream(slice))
195 } else {
196 Ok(slice)
197 }
198 }
199
200 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
203 self.align_byte();
204 let mut num: u128 = 0;
205 let mut multiplier: u128 = 1;
206
207 loop {
208 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
210
211 if (byte & 1 << 7) == 0 {
213 break;
214 }
215
216 multiplier *= 128;
217 }
218
219 Ok(num)
220 }
221
222 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
224 &mut self,
225 ) -> Result<T, DeserializationError> {
226 let data = self.read_bytes(S)?;
227 Ok(FixedInt::deserialize(data))
228 }
229
230 pub fn align_byte(&mut self) {
232 let rem = self.bit_pos % 8;
233 if rem != 0 {
234 self.bit_pos += 8 - rem;
235 self.last_read_byte = None;
236 }
237 }
238
239 pub fn bytes_left(&self) -> usize {
241 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
242 if self.bit_pos % 8 != 0 {
243 left - 1 } else {
245 left
246 }
247 }
248
249 pub fn reset(&mut self) {
251 self.bit_pos = 0;
252 }
253}
254
255#[cfg(test)]
256mod tests {
257 use crate::{DeserializationError, bitstream::CryptoStream};
258
259 use super::BitStreamReader;
260
261 struct PlusOneDecrypter {
262 plain: Vec<u8>,
263 }
264
265 impl CryptoStream for PlusOneDecrypter {
266 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
267 self.plain.push(b + 1);
268 *self.plain.last().unwrap()
269 }
270
271 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
272 let d = slice.iter().map(|s| s + 1);
273 self.plain.extend(d);
274 &self.plain[self.plain.len() - slice.len()..]
275 }
276
277 fn get_cached(&self, original: bool) -> &[u8] {
278 &self.plain
279 }
280
281 fn replace(&mut self, other: Box<dyn CryptoStream>) {
282 self.plain = other.get_cached(true).to_vec();
283 }
284 }
285
286 #[test]
287 fn test_decrypt_bytes() {
288 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
289 let mut reader = BitStreamReader::new(&buf);
290 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
291
292 assert_eq!(reader.read_byte(), Ok(2));
293 assert_eq!(reader.read_byte(), Ok(3));
294 assert_eq!(reader.read_byte(), Ok(4));
295 assert_eq!(reader.read_bit(), Ok(true));
297 assert_eq!(reader.read_bit(), Ok(false));
298 assert_eq!(reader.read_bit(), Ok(true));
299 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
300 assert_eq!(reader.read_byte(), Ok(11));
301 }
302
303 fn make_buffer() -> Vec<u8> {
305 vec![0b10101100, 0b11010010, 0xFF, 0x00]
306 }
307
308 #[test]
309 fn test_read_single_bits() {
310 let buf = make_buffer();
311 let mut reader = BitStreamReader::new(&buf);
312
313 assert_eq!(reader.read_bit(), Ok(false));
315 assert_eq!(reader.read_bit(), Ok(false));
316 assert_eq!(reader.read_bit(), Ok(true));
317 assert_eq!(reader.read_bit(), Ok(true));
318 assert_eq!(reader.read_bit(), Ok(false));
319 assert_eq!(reader.read_bit(), Ok(true));
320 assert_eq!(reader.read_bit(), Ok(false));
321 assert_eq!(reader.read_bit(), Ok(true));
322 }
323
324 #[test]
325 fn test_read_small() {
326 let buf = [0b10101100, 0b11010010];
327 let mut reader = BitStreamReader::new(&buf);
328
329 assert_eq!(reader.read_small(3), Ok(0b100));
330 assert_eq!(reader.read_small(4), Ok(0b0101));
331 assert_eq!(reader.read_small(1), Ok(0b1));
332 assert_eq!(reader.read_small(4), Ok(0b0010));
333 }
334
335 #[test]
336 fn test_read_cross_byte() {
337 let buf = [0b10101100, 0b11010001];
338 let mut reader = BitStreamReader::new(&buf);
339
340 assert_eq!(reader.read_small(7), Ok(0b00101100));
342 assert_eq!(reader.read_small(3), Ok(0b011));
343 }
344
345 #[test]
346 fn test_read_byte() {
347 let buf = [0b10101100, 0b11010010];
348 let mut reader = BitStreamReader::new(&buf);
349
350 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
353
354 #[test]
355 fn test_read_bytes() {
356 let buf = [0x01, 0xAA, 0xBB, 0xCC];
357 let mut reader = BitStreamReader::new(&buf);
358
359 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
361 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
362 }
363
364 #[test]
365 fn test_align_byte() {
366 let buf = [0b10101100, 0b11010010];
367 let mut reader = BitStreamReader::new(&buf);
368
369 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
372 }
373
374 #[test]
375 fn test_eof_behavior() {
376 let buf = [0xFF];
377 let mut reader = BitStreamReader::new(&buf);
378
379 assert_eq!(reader.read_byte(), Ok(0xFF));
380 assert_eq!(
381 reader.read_bit(),
382 Err(DeserializationError::NotEnoughBytes(1))
383 );
384 assert_eq!(
385 reader.read_byte(),
386 Err(DeserializationError::NotEnoughBytes(1))
387 );
388 assert_eq!(
389 reader.read_bytes(2),
390 Err(DeserializationError::NotEnoughBytes(2))
391 );
392 }
393
394 #[test]
395 fn test_multiple_operations() {
396 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
397 let mut reader = BitStreamReader::new(&buf);
398
399 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
403 assert_eq!(
404 reader.read_bit(),
405 Err(DeserializationError::NotEnoughBytes(1))
406 );
407 }
408
409 #[test]
410 fn test_read_dyn_int() {
411 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
412 let mut stream = BitStreamReader::new(&buf);
413
414 assert_eq!(Ok(0), stream.read_byte());
415 assert_eq!(Ok(127), stream.read_dyn_int());
416 assert_eq!(Ok(128), stream.read_dyn_int());
417 assert_eq!(Ok(268435455), stream.read_dyn_int());
418 assert_eq!(
419 Err(DeserializationError::NotEnoughBytes(1)),
420 stream.read_dyn_int()
421 );
422 }
423
424 #[test]
425 fn test_read_fixed_int() {
426 let buf = vec![
427 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
428 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
429 0, 0, 0, 10,
430 ];
431
432 let mut stream = BitStreamReader::new(&buf);
433 let v1: u8 = stream.read_fixed_int().unwrap();
434 let v2: i8 = stream.read_fixed_int().unwrap();
435 let v3: u16 = stream.read_fixed_int().unwrap();
436 let v4: i16 = stream.read_fixed_int().unwrap();
437 let v5: u32 = stream.read_fixed_int().unwrap();
438 let v6: i32 = stream.read_fixed_int().unwrap();
439 let v7: u64 = stream.read_fixed_int().unwrap();
440 let v8: i64 = stream.read_fixed_int().unwrap();
441 let v9: u128 = stream.read_fixed_int().unwrap();
442 let v10: i128 = stream.read_fixed_int().unwrap();
443
444 assert_eq!(v1, 1);
445 assert_eq!(v2, 1);
446 assert_eq!(v3, 2);
447 assert_eq!(v4, 2);
448 assert_eq!(v5, 3);
449 assert_eq!(v6, 3);
450 assert_eq!(v7, 4);
451 assert_eq!(v8, 4);
452 assert_eq!(v9, 5);
453 assert_eq!(v10, 5);
454 }
455
456 #[test]
457 fn test_bytes_left() {
458 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
459 let mut reader = BitStreamReader::new(&buf);
460
461 assert_eq!(reader.bytes_left(), 4);
462 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
471
472 #[test]
473 fn offset_end_ignores_bytes_and_can_slice() {
474 let buff = [1, 2, 3, 4, 5];
475 let mut reader = BitStreamReader::new(&buff);
476
477 reader.set_offset_end(2);
478 assert_eq!(reader.bytes_left(), 3);
479 assert_eq!(reader.read_byte(), Ok(1));
480
481 assert_eq!(reader.slice(true), &[1, 2, 3]);
482 assert_eq!(reader.slice(false), &[2, 3]);
483 assert_eq!(reader.slice_end(), &[4, 5]);
484
485 assert_eq!(reader.read_byte(), Ok(2));
486 assert_eq!(reader.read_byte(), Ok(3));
487 assert_eq!(
488 reader.read_byte(),
489 Err(DeserializationError::NotEnoughBytes(1))
490 );
491
492 reader.set_offset_end(0);
493 assert_eq!(reader.bytes_left(), 2);
494 assert_eq!(reader.read_byte(), Ok(4));
495 assert_eq!(reader.read_byte(), Ok(5));
496 }
497
498 #[test]
499 fn test_slice_start() {
500 let buff = [10, 20, 30, 40, 50];
501 let mut reader = BitStreamReader::new(&buff);
502
503 assert_eq!(reader.slice_marker(None), &[]);
504
505 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
507
508 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
510
511 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
513
514 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
516 }
517
518 #[test]
519 fn test_slice_start_with_marker() {
520 let buff = [10, 20, 30, 40, 50];
521 let mut reader = BitStreamReader::new(&buff);
522
523 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
525 reader.set_marker();
526 assert_eq!(reader.slice_marker(None), &[]);
527
528 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
530 }
531}