1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11 marker: Option<usize>,
12}
13
14impl<'a> BitStreamReader<'a> {
15 pub fn new(buffer: &'a [u8]) -> Self {
20 Self {
21 buffer,
22 bit_pos: 0,
23 crypto: None,
24 offset_end: 0,
25 last_read_byte: None,
26 marker: None,
27 }
28 }
29
30 pub fn slice(&self, from_start: bool) -> &[u8] {
32 let start = if from_start { 0 } else { self.byte_pos() };
33
34 &self.buffer[start..self.buffer.len() - self.offset_end]
35 }
36
37 pub fn set_marker(&mut self) {
39 self.marker = Some(self.byte_pos());
40 }
41
42 pub fn reset_marker(&mut self) {
44 self.marker = None;
45 }
46
47 pub fn slice_marker(&self, to: Option<usize>) -> &[u8] {
50 let start = self.marker.unwrap_or(0);
51 let end = to.unwrap_or(self.byte_pos());
52
53 if let Some(crypto) = self.crypto.as_ref() {
54 return &crypto.get_plaintext()[start..end];
55 }
56
57 &self.buffer[start..end]
58 }
59
60 pub fn slice_end(&self) -> &[u8] {
62 &self.buffer[self.buffer.len() - self.offset_end..]
63 }
64
65 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
67 self.crypto = crypto;
68 }
69
70 pub fn set_offset_end(&mut self, len: usize) {
72 self.offset_end = len;
73 }
74
75 pub fn byte_pos(&self) -> usize {
77 self.bit_pos / 8
78 }
79
80 fn current_byte(&mut self) -> u8 {
82 if let Some(b) = self.last_read_byte {
83 b
84 } else {
85 let mut b = self.buffer[self.byte_pos()];
86 if let Some(crypto) = self.crypto.as_mut() {
87 b = crypto.apply_keystream_byte(b);
88 }
89
90 self.last_read_byte = Some(b);
91 b
92 }
93 }
94
95 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
97 self.read_small(1).map(|v| v != 0)
98 }
99
100 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
102 assert!(bits > 0 && bits < 8);
103
104 let mut result: u8 = 0;
105 let mut shift = 0;
106
107 while bits > 0 {
108 if self.byte_pos() >= self.buffer.len() - self.offset_end {
109 return Err(DeserializationError::NotEnoughBytes(1));
110 }
111
112 let bit_offset = self.bit_pos % 8;
114
115 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
117
118 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
122 let byte_val = self.current_byte();
123
124 let val = (byte_val & mask) >> bit_offset;
128
129 result |= val << shift;
132
133 bits -= bits_in_current_byte;
135
136 shift += bits_in_current_byte;
138
139 self.bit_pos += bits_in_current_byte as usize;
140
141 if self.bit_pos % 8 == 0 {
143 self.last_read_byte = None;
144 }
145 }
146
147 Ok(result)
148 }
149
150 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
152 self.align_byte();
153
154 if self.byte_pos() >= self.buffer.len() - self.offset_end {
155 return Err(DeserializationError::NotEnoughBytes(1));
156 }
157
158 let byte = self.current_byte();
159 self.bit_pos += 8;
160 self.last_read_byte = None;
161
162 Ok(byte)
163 }
164
165 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
167 self.align_byte();
168
169 let start = self.byte_pos();
170 if start + count > self.buffer.len() - self.offset_end {
171 return Err(DeserializationError::NotEnoughBytes(
172 start + count - self.buffer.len(),
173 ));
174 }
175
176 self.bit_pos += 8 * count;
177 self.last_read_byte = None;
178
179 let slice = &self.buffer[start..start + count];
180 if let Some(crypto) = self.crypto.as_mut() {
181 Ok(crypto.apply_keystream(slice))
182 } else {
183 Ok(slice)
184 }
185 }
186
187 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
190 self.align_byte();
191 let mut num: u128 = 0;
192 let mut multiplier: u128 = 1;
193
194 loop {
195 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
197
198 if (byte & 1 << 7) == 0 {
200 break;
201 }
202
203 multiplier *= 128;
204 }
205
206 Ok(num)
207 }
208
209 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
211 &mut self,
212 ) -> Result<T, DeserializationError> {
213 let data = self.read_bytes(S)?;
214 Ok(FixedInt::deserialize(data))
215 }
216
217 pub fn align_byte(&mut self) {
219 let rem = self.bit_pos % 8;
220 if rem != 0 {
221 self.bit_pos += 8 - rem;
222 self.last_read_byte = None;
223 }
224 }
225
226 pub fn bytes_left(&self) -> usize {
228 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
229 if self.bit_pos % 8 != 0 {
230 left - 1 } else {
232 left
233 }
234 }
235
236 pub fn reset(&mut self) {
238 self.bit_pos = 0;
239 }
240}
241
242#[cfg(test)]
243mod tests {
244 use crate::{DeserializationError, bitstream::CryptoStream};
245
246 use super::BitStreamReader;
247
248 struct PlusOneDecrypter {
249 plain: Vec<u8>,
250 }
251
252 impl CryptoStream for PlusOneDecrypter {
253 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
254 self.plain.push(b + 1);
255 *self.plain.last().unwrap()
256 }
257
258 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
259 let d = slice.iter().map(|s| s + 1);
260 self.plain.extend(d);
261 &self.plain[self.plain.len() - slice.len()..]
262 }
263
264 fn get_plaintext(&self) -> &[u8] {
265 &self.plain
266 }
267 }
268
269 #[test]
270 fn test_decrypt_bytes() {
271 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
272 let mut reader = BitStreamReader::new(&buf);
273 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
274
275 assert_eq!(reader.read_byte(), Ok(2));
276 assert_eq!(reader.read_byte(), Ok(3));
277 assert_eq!(reader.read_byte(), Ok(4));
278 assert_eq!(reader.read_bit(), Ok(true));
280 assert_eq!(reader.read_bit(), Ok(false));
281 assert_eq!(reader.read_bit(), Ok(true));
282 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
283 assert_eq!(reader.read_byte(), Ok(11));
284 }
285
286 fn make_buffer() -> Vec<u8> {
288 vec![0b10101100, 0b11010010, 0xFF, 0x00]
289 }
290
291 #[test]
292 fn test_read_single_bits() {
293 let buf = make_buffer();
294 let mut reader = BitStreamReader::new(&buf);
295
296 assert_eq!(reader.read_bit(), Ok(false));
298 assert_eq!(reader.read_bit(), Ok(false));
299 assert_eq!(reader.read_bit(), Ok(true));
300 assert_eq!(reader.read_bit(), Ok(true));
301 assert_eq!(reader.read_bit(), Ok(false));
302 assert_eq!(reader.read_bit(), Ok(true));
303 assert_eq!(reader.read_bit(), Ok(false));
304 assert_eq!(reader.read_bit(), Ok(true));
305 }
306
307 #[test]
308 fn test_read_small() {
309 let buf = [0b10101100, 0b11010010];
310 let mut reader = BitStreamReader::new(&buf);
311
312 assert_eq!(reader.read_small(3), Ok(0b100));
313 assert_eq!(reader.read_small(4), Ok(0b0101));
314 assert_eq!(reader.read_small(1), Ok(0b1));
315 assert_eq!(reader.read_small(4), Ok(0b0010));
316 }
317
318 #[test]
319 fn test_read_cross_byte() {
320 let buf = [0b10101100, 0b11010001];
321 let mut reader = BitStreamReader::new(&buf);
322
323 assert_eq!(reader.read_small(7), Ok(0b00101100));
325 assert_eq!(reader.read_small(3), Ok(0b011));
326 }
327
328 #[test]
329 fn test_read_byte() {
330 let buf = [0b10101100, 0b11010010];
331 let mut reader = BitStreamReader::new(&buf);
332
333 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
336
337 #[test]
338 fn test_read_bytes() {
339 let buf = [0x01, 0xAA, 0xBB, 0xCC];
340 let mut reader = BitStreamReader::new(&buf);
341
342 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
344 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
345 }
346
347 #[test]
348 fn test_align_byte() {
349 let buf = [0b10101100, 0b11010010];
350 let mut reader = BitStreamReader::new(&buf);
351
352 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
355 }
356
357 #[test]
358 fn test_eof_behavior() {
359 let buf = [0xFF];
360 let mut reader = BitStreamReader::new(&buf);
361
362 assert_eq!(reader.read_byte(), Ok(0xFF));
363 assert_eq!(
364 reader.read_bit(),
365 Err(DeserializationError::NotEnoughBytes(1))
366 );
367 assert_eq!(
368 reader.read_byte(),
369 Err(DeserializationError::NotEnoughBytes(1))
370 );
371 assert_eq!(
372 reader.read_bytes(2),
373 Err(DeserializationError::NotEnoughBytes(2))
374 );
375 }
376
377 #[test]
378 fn test_multiple_operations() {
379 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
380 let mut reader = BitStreamReader::new(&buf);
381
382 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
386 assert_eq!(
387 reader.read_bit(),
388 Err(DeserializationError::NotEnoughBytes(1))
389 );
390 }
391
392 #[test]
393 fn test_read_dyn_int() {
394 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
395 let mut stream = BitStreamReader::new(&buf);
396
397 assert_eq!(Ok(0), stream.read_byte());
398 assert_eq!(Ok(127), stream.read_dyn_int());
399 assert_eq!(Ok(128), stream.read_dyn_int());
400 assert_eq!(Ok(268435455), stream.read_dyn_int());
401 assert_eq!(
402 Err(DeserializationError::NotEnoughBytes(1)),
403 stream.read_dyn_int()
404 );
405 }
406
407 #[test]
408 fn test_read_fixed_int() {
409 let buf = vec![
410 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
411 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
412 0, 0, 0, 10,
413 ];
414
415 let mut stream = BitStreamReader::new(&buf);
416 let v1: u8 = stream.read_fixed_int().unwrap();
417 let v2: i8 = stream.read_fixed_int().unwrap();
418 let v3: u16 = stream.read_fixed_int().unwrap();
419 let v4: i16 = stream.read_fixed_int().unwrap();
420 let v5: u32 = stream.read_fixed_int().unwrap();
421 let v6: i32 = stream.read_fixed_int().unwrap();
422 let v7: u64 = stream.read_fixed_int().unwrap();
423 let v8: i64 = stream.read_fixed_int().unwrap();
424 let v9: u128 = stream.read_fixed_int().unwrap();
425 let v10: i128 = stream.read_fixed_int().unwrap();
426
427 assert_eq!(v1, 1);
428 assert_eq!(v2, 1);
429 assert_eq!(v3, 2);
430 assert_eq!(v4, 2);
431 assert_eq!(v5, 3);
432 assert_eq!(v6, 3);
433 assert_eq!(v7, 4);
434 assert_eq!(v8, 4);
435 assert_eq!(v9, 5);
436 assert_eq!(v10, 5);
437 }
438
439 #[test]
440 fn test_bytes_left() {
441 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
442 let mut reader = BitStreamReader::new(&buf);
443
444 assert_eq!(reader.bytes_left(), 4);
445 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
454
455 #[test]
456 fn offset_end_ignores_bytes_and_can_slice() {
457 let buff = [1, 2, 3, 4, 5];
458 let mut reader = BitStreamReader::new(&buff);
459
460 reader.set_offset_end(2);
461 assert_eq!(reader.bytes_left(), 3);
462 assert_eq!(reader.read_byte(), Ok(1));
463
464 assert_eq!(reader.slice(true), &[1, 2, 3]);
465 assert_eq!(reader.slice(false), &[2, 3]);
466 assert_eq!(reader.slice_end(), &[4, 5]);
467
468 assert_eq!(reader.read_byte(), Ok(2));
469 assert_eq!(reader.read_byte(), Ok(3));
470 assert_eq!(
471 reader.read_byte(),
472 Err(DeserializationError::NotEnoughBytes(1))
473 );
474
475 reader.set_offset_end(0);
476 assert_eq!(reader.bytes_left(), 2);
477 assert_eq!(reader.read_byte(), Ok(4));
478 assert_eq!(reader.read_byte(), Ok(5));
479 }
480
481 #[test]
482 fn test_slice_start() {
483 let buff = [10, 20, 30, 40, 50];
484 let mut reader = BitStreamReader::new(&buff);
485
486 assert_eq!(reader.slice_marker(None), &[]);
487
488 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
490
491 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
493
494 reader.read_small(4).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20]);
496
497 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[10, 20, 30, 40]);
499 }
500
501 #[test]
502 fn test_slice_start_with_marker() {
503 let buff = [10, 20, 30, 40, 50];
504 let mut reader = BitStreamReader::new(&buff);
505
506 reader.read_byte().unwrap(); assert_eq!(reader.slice_marker(None), &[10]);
508 reader.set_marker();
509 assert_eq!(reader.slice_marker(None), &[]);
510
511 reader.read_bytes(2).unwrap(); assert_eq!(reader.slice_marker(None), &[20, 30]);
513 }
514}