1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50 let start = self.marker.unwrap_or(0);
51 let end = to.unwrap_or(self.byte_pos());
52
53 if let Some(crypto) = self.crypto.as_ref() {
54 return &crypto.get_cached(false)[start..end];
55 }
56
57 &self.buffer[start..end]
58 }
59
60 pub fn slice_end(&mut self) -> &[u8] {
64 let slice = &self.buffer[self.buffer.len() - self.offset_end..];
65
66 if let Some(crypto) = self.crypto.as_mut() {
67 crypto.apply_keystream(slice)
68 } else {
69 slice
70 }
71 }
72
73 pub fn slice_start(&self) -> &[u8] {
75 &self.buffer[0..self.byte_pos()]
76 }
77
78 pub fn set_crypto(&mut self, mut crypto: Option<Box<dyn CryptoStream>>) {
80 if let Some(new) = crypto.as_mut() {
81 if let Some(existing) = self.crypto.as_ref() {
82 new.replace(existing);
83 } else {
84 new.set_cached(self.slice_start());
86 }
87 }
88
89 self.crypto = crypto;
90 }
91
92 pub fn reset_crypto(&mut self) {
94 self.crypto = None;
95 }
96
97 pub fn set_offset_end(&mut self, len: usize) {
99 self.offset_end = len;
100 }
101
102 pub fn byte_pos(&self) -> usize {
104 self.bit_pos / 8
105 }
106
107 fn current_byte(&mut self) -> u8 {
109 if let Some(b) = self.last_read_byte {
110 b
111 } else {
112 let mut b = self.buffer[self.byte_pos()];
113 if let Some(crypto) = self.crypto.as_mut() {
114 b = crypto.apply_keystream_byte(b);
115 }
116
117 self.last_read_byte = Some(b);
118 b
119 }
120 }
121
122 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
124 self.read_small(1).map(|v| v != 0)
125 }
126
127 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
129 assert!(bits > 0 && bits < 8);
130
131 let mut result: u8 = 0;
132 let mut shift = 0;
133
134 while bits > 0 {
135 if self.byte_pos() >= self.buffer.len() - self.offset_end {
136 return Err(DeserializationError::NotEnoughBytes(1));
137 }
138
139 let bit_offset = self.bit_pos % 8;
141
142 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
144
145 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
149 let byte_val = self.current_byte();
150
151 let val = (byte_val & mask) >> bit_offset;
155
156 result |= val << shift;
159
160 bits -= bits_in_current_byte;
162
163 shift += bits_in_current_byte;
165
166 self.bit_pos += bits_in_current_byte as usize;
167
168 if self.bit_pos % 8 == 0 {
170 self.last_read_byte = None;
171 }
172 }
173
174 Ok(result)
175 }
176
177 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
179 self.align_byte();
180
181 if self.byte_pos() >= self.buffer.len() - self.offset_end {
182 return Err(DeserializationError::NotEnoughBytes(1));
183 }
184
185 let byte = self.current_byte();
186 self.bit_pos += 8;
187 self.last_read_byte = None;
188
189 Ok(byte)
190 }
191
192 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
194 self.align_byte();
195
196 let start = self.byte_pos();
197 if start + count > self.buffer.len() - self.offset_end {
198 return Err(DeserializationError::NotEnoughBytes(
199 (start + count - self.buffer.len()) as u64,
200 ));
201 }
202
203 self.bit_pos += 8 * count;
204 self.last_read_byte = None;
205
206 let slice = &self.buffer[start..start + count];
207 if let Some(crypto) = self.crypto.as_mut() {
208 Ok(crypto.apply_keystream(slice))
209 } else {
210 Ok(slice)
211 }
212 }
213
214 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
217 self.align_byte();
218 let mut num: u128 = 0;
219 let mut multiplier: u128 = 1;
220
221 loop {
222 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
224
225 if (byte & 1 << 7) == 0 {
227 break;
228 }
229
230 multiplier *= 128;
231 }
232
233 Ok(num)
234 }
235
236 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
238 &mut self,
239 ) -> Result<T, DeserializationError> {
240 let data = self.read_bytes(S)?;
241 Ok(FixedInt::deserialize(data))
242 }
243
244 pub fn align_byte(&mut self) {
246 let rem = self.bit_pos % 8;
247 if rem != 0 {
248 self.bit_pos += 8 - rem;
249 self.last_read_byte = None;
250 }
251 }
252
253 pub fn bytes_left(&self) -> usize {
255 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
256 if self.bit_pos % 8 != 0 {
257 left - 1 } else {
259 left
260 }
261 }
262
263 pub fn reset(&mut self) {
265 self.bit_pos = 0;
266 }
267}
268
269#[cfg(test)]
270mod tests {
271 use crate::{DeserializationError, bitstream::CryptoStream};
272
273 use super::BitStreamReader;
274
275 struct PlusOneDecrypter {
276 plain: Vec<u8>,
277 }
278
279 impl CryptoStream for PlusOneDecrypter {
280 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
281 self.plain.push(b + 1);
282 *self.plain.last().unwrap()
283 }
284
285 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
286 let d = slice.iter().map(|s| s + 1);
287 self.plain.extend(d);
288 &self.plain[self.plain.len() - slice.len()..]
289 }
290
291 fn get_cached(&self, original: bool) -> &[u8] {
292 &self.plain
293 }
294
295 fn replace(&mut self, other: &Box<dyn CryptoStream>) {
296 self.plain = other.get_cached(true).to_vec();
297 }
298
299 fn set_cached(&mut self, data: &[u8]) {
300 self.plain = data.to_vec();
301 }
302 }
303
304 #[test]
305 fn test_decrypt_bytes() {
306 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
307 let mut reader = BitStreamReader::new(&buf);
308 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
309
310 assert_eq!(reader.read_byte(), Ok(2));
311 assert_eq!(reader.read_byte(), Ok(3));
312 assert_eq!(reader.read_byte(), Ok(4));
313 assert_eq!(reader.read_bit(), Ok(true));
315 assert_eq!(reader.read_bit(), Ok(false));
316 assert_eq!(reader.read_bit(), Ok(true));
317 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
318 assert_eq!(reader.read_byte(), Ok(11));
319 }
320
321 fn make_buffer() -> Vec<u8> {
323 vec![0b10101100, 0b11010010, 0xFF, 0x00]
324 }
325
326 #[test]
327 fn test_read_single_bits() {
328 let buf = make_buffer();
329 let mut reader = BitStreamReader::new(&buf);
330
331 assert_eq!(reader.read_bit(), Ok(false));
333 assert_eq!(reader.read_bit(), Ok(false));
334 assert_eq!(reader.read_bit(), Ok(true));
335 assert_eq!(reader.read_bit(), Ok(true));
336 assert_eq!(reader.read_bit(), Ok(false));
337 assert_eq!(reader.read_bit(), Ok(true));
338 assert_eq!(reader.read_bit(), Ok(false));
339 assert_eq!(reader.read_bit(), Ok(true));
340 }
341
342 #[test]
343 fn test_read_small() {
344 let buf = [0b10101100, 0b11010010];
345 let mut reader = BitStreamReader::new(&buf);
346
347 assert_eq!(reader.read_small(3), Ok(0b100));
348 assert_eq!(reader.read_small(4), Ok(0b0101));
349 assert_eq!(reader.read_small(1), Ok(0b1));
350 assert_eq!(reader.read_small(4), Ok(0b0010));
351 }
352
353 #[test]
354 fn test_read_cross_byte() {
355 let buf = [0b10101100, 0b11010001];
356 let mut reader = BitStreamReader::new(&buf);
357
358 assert_eq!(reader.read_small(7), Ok(0b00101100));
360 assert_eq!(reader.read_small(3), Ok(0b011));
361 }
362
363 #[test]
364 fn test_read_byte() {
365 let buf = [0b10101100, 0b11010010];
366 let mut reader = BitStreamReader::new(&buf);
367
368 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
371
372 #[test]
373 fn test_read_bytes() {
374 let buf = [0x01, 0xAA, 0xBB, 0xCC];
375 let mut reader = BitStreamReader::new(&buf);
376
377 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
379 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
380 }
381
382 #[test]
383 fn test_align_byte() {
384 let buf = [0b10101100, 0b11010010];
385 let mut reader = BitStreamReader::new(&buf);
386
387 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
390 }
391
392 #[test]
393 fn test_eof_behavior() {
394 let buf = [0xFF];
395 let mut reader = BitStreamReader::new(&buf);
396
397 assert_eq!(reader.read_byte(), Ok(0xFF));
398 assert_eq!(
399 reader.read_bit(),
400 Err(DeserializationError::NotEnoughBytes(1))
401 );
402 assert_eq!(
403 reader.read_byte(),
404 Err(DeserializationError::NotEnoughBytes(1))
405 );
406 assert_eq!(
407 reader.read_bytes(2),
408 Err(DeserializationError::NotEnoughBytes(2))
409 );
410 }
411
412 #[test]
413 fn test_multiple_operations() {
414 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
415 let mut reader = BitStreamReader::new(&buf);
416
417 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
421 assert_eq!(
422 reader.read_bit(),
423 Err(DeserializationError::NotEnoughBytes(1))
424 );
425 }
426
427 #[test]
428 fn test_read_dyn_int() {
429 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
430 let mut stream = BitStreamReader::new(&buf);
431
432 assert_eq!(Ok(0), stream.read_byte());
433 assert_eq!(Ok(127), stream.read_dyn_int());
434 assert_eq!(Ok(128), stream.read_dyn_int());
435 assert_eq!(Ok(268435455), stream.read_dyn_int());
436 assert_eq!(
437 Err(DeserializationError::NotEnoughBytes(1)),
438 stream.read_dyn_int()
439 );
440 }
441
442 #[test]
443 fn test_read_fixed_int() {
444 let buf = vec![
445 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
446 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
447 0, 0, 0, 10,
448 ];
449
450 let mut stream = BitStreamReader::new(&buf);
451 let v1: u8 = stream.read_fixed_int().unwrap();
452 let v2: i8 = stream.read_fixed_int().unwrap();
453 let v3: u16 = stream.read_fixed_int().unwrap();
454 let v4: i16 = stream.read_fixed_int().unwrap();
455 let v5: u32 = stream.read_fixed_int().unwrap();
456 let v6: i32 = stream.read_fixed_int().unwrap();
457 let v7: u64 = stream.read_fixed_int().unwrap();
458 let v8: i64 = stream.read_fixed_int().unwrap();
459 let v9: u128 = stream.read_fixed_int().unwrap();
460 let v10: i128 = stream.read_fixed_int().unwrap();
461
462 assert_eq!(v1, 1);
463 assert_eq!(v2, 1);
464 assert_eq!(v3, 2);
465 assert_eq!(v4, 2);
466 assert_eq!(v5, 3);
467 assert_eq!(v6, 3);
468 assert_eq!(v7, 4);
469 assert_eq!(v8, 4);
470 assert_eq!(v9, 5);
471 assert_eq!(v10, 5);
472 }
473
474 #[test]
475 fn test_bytes_left() {
476 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
477 let mut reader = BitStreamReader::new(&buf);
478
479 assert_eq!(reader.bytes_left(), 4);
480 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
489
490 #[test]
491 fn offset_end_ignores_bytes_and_can_slice() {
492 let buff = [1, 2, 3, 4, 5];
493 let mut reader = BitStreamReader::new(&buff);
494
495 reader.set_offset_end(2);
496 assert_eq!(reader.bytes_left(), 3);
497 assert_eq!(reader.read_byte(), Ok(1));
498
499 assert_eq!(reader.slice(true), &[1, 2, 3]);
500 assert_eq!(reader.slice(false), &[2, 3]);
501 assert_eq!(reader.slice_end(), &[4, 5]);
502
503 assert_eq!(reader.read_byte(), Ok(2));
504 assert_eq!(reader.read_byte(), Ok(3));
505 assert_eq!(
506 reader.read_byte(),
507 Err(DeserializationError::NotEnoughBytes(1))
508 );
509
510 reader.set_offset_end(0);
511 assert_eq!(reader.bytes_left(), 2);
512 assert_eq!(reader.read_byte(), Ok(4));
513 assert_eq!(reader.read_byte(), Ok(5));
514 }
515
516 #[test]
517 fn test_slice_start() {
518 let buff = [10, 20, 30, 40, 50];
519 let mut reader = BitStreamReader::new(&buff);
520
521 assert_eq!(reader.slice_marker(None), &[]);
522
523 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
525
526 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
528
529 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
531
532 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
534 }
535
536 #[test]
537 fn test_slice_start_with_marker() {
538 let buff = [10, 20, 30, 40, 50];
539 let mut reader = BitStreamReader::new(&buff);
540
541 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
543 reader.set_marker();
544 assert_eq!(reader.slice_marker(None), &[]);
545
546 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
548 }
549
550 #[test]
551 fn test_can_read_dynint_0() {
552 let buf = vec![0, 1];
553 let mut stream = BitStreamReader::new(&buf);
554
555 assert_eq!(stream.read_dyn_int(), Ok(0));
556 assert_eq!(stream.read_byte(), Ok(1));
557 }
558}