1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 offset_end: usize,
10 crypto: Option<Box<dyn CryptoStream>>,
11}
12
13impl<'a> BitStreamReader<'a> {
14 pub fn new(buffer: &'a [u8]) -> Self {
19 Self {
20 buffer,
21 bit_pos: 0,
22 crypto: None,
23 offset_end: 0,
24 last_read_byte: None,
25 }
26 }
27
28 pub fn slice(&self, from_start: bool) -> &[u8] {
30 let start = if from_start { 0 } else { self.byte_pos() };
31
32 &self.buffer[start..self.buffer.len() - self.offset_end]
33 }
34
35 pub fn slice_end(&self) -> &[u8] {
37 &self.buffer[self.buffer.len() - self.offset_end..]
38 }
39
40 pub fn set_crypto(&mut self, crypto: Option<Box<dyn CryptoStream>>) {
42 self.crypto = crypto;
43 }
44
45 pub fn set_offset_end(&mut self, len: usize) {
47 self.offset_end = len;
48 }
49
50 pub fn byte_pos(&self) -> usize {
52 self.bit_pos / 8
53 }
54
55 fn current_byte(&mut self) -> u8 {
57 if let Some(b) = self.last_read_byte {
58 b
59 } else {
60 let mut b = self.buffer[self.byte_pos()];
61 if let Some(crypto) = self.crypto.as_mut() {
62 b = crypto.apply_keystream_byte(b);
63 }
64
65 self.last_read_byte = Some(b);
66 b
67 }
68 }
69
70 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
72 self.read_small(1).map(|v| v != 0)
73 }
74
75 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
77 assert!(bits > 0 && bits < 8);
78
79 let mut result: u8 = 0;
80 let mut shift = 0;
81
82 while bits > 0 {
83 if self.byte_pos() >= self.buffer.len() - self.offset_end {
84 return Err(DeserializationError::NotEnoughBytes(1));
85 }
86
87 let bit_offset = self.bit_pos % 8;
89
90 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
92
93 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
97 let byte_val = self.current_byte();
98
99 let val = (byte_val & mask) >> bit_offset;
103
104 result |= val << shift;
107
108 bits -= bits_in_current_byte;
110
111 shift += bits_in_current_byte;
113
114 self.bit_pos += bits_in_current_byte as usize;
115
116 if self.bit_pos % 8 == 0 {
118 self.last_read_byte = None;
119 }
120 }
121
122 Ok(result)
123 }
124
125 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
127 self.align_byte();
128
129 if self.byte_pos() >= self.buffer.len() - self.offset_end {
130 return Err(DeserializationError::NotEnoughBytes(1));
131 }
132
133 let byte = self.current_byte();
134 self.bit_pos += 8;
135 self.last_read_byte = None;
136
137 Ok(byte)
138 }
139
140 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
142 self.align_byte();
143
144 let start = self.byte_pos();
145 if start + count > self.buffer.len() - self.offset_end {
146 return Err(DeserializationError::NotEnoughBytes(
147 start + count - self.buffer.len(),
148 ));
149 }
150
151 self.bit_pos += 8 * count;
152 self.last_read_byte = None;
153
154 let slice = &self.buffer[start..start + count];
155 if let Some(crypto) = self.crypto.as_mut() {
156 Ok(crypto.apply_keystream(slice))
157 } else {
158 Ok(slice)
159 }
160 }
161
162 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
165 self.align_byte();
166 let mut num: u128 = 0;
167 let mut multiplier: u128 = 1;
168
169 loop {
170 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
172
173 if (byte & 1 << 7) == 0 {
175 break;
176 }
177
178 multiplier *= 128;
179 }
180
181 Ok(num)
182 }
183
184 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
186 &mut self,
187 ) -> Result<T, DeserializationError> {
188 let data = self.read_bytes(S)?;
189 Ok(FixedInt::deserialize(data))
190 }
191
192 pub fn align_byte(&mut self) {
194 let rem = self.bit_pos % 8;
195 if rem != 0 {
196 self.bit_pos += 8 - rem;
197 self.last_read_byte = None;
198 }
199 }
200
201 pub fn bytes_left(&self) -> usize {
203 let left = self.buffer.len() - self.byte_pos() - self.offset_end;
204 if self.bit_pos % 8 != 0 {
205 left - 1 } else {
207 left
208 }
209 }
210
211 pub fn reset(&mut self) {
213 self.bit_pos = 0;
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use crate::{DeserializationError, bitstream::CryptoStream};
220
221 use super::BitStreamReader;
222
223 struct PlusOneDecrypter {
224 plain: Vec<u8>,
225 }
226
227 impl CryptoStream for PlusOneDecrypter {
228 fn apply_keystream_byte(&mut self, b: u8) -> u8 {
229 self.plain.push(b + 1);
230 *self.plain.last().unwrap()
231 }
232
233 fn apply_keystream(&mut self, slice: &[u8]) -> &[u8] {
234 let d = slice.iter().map(|s| s + 1);
235 self.plain.extend(d);
236 &self.plain[self.plain.len() - slice.len()..]
237 }
238 }
239
240 #[test]
241 fn test_decrypt_bytes() {
242 let buf = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
243 let mut reader = BitStreamReader::new(&buf);
244 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
245
246 assert_eq!(reader.read_byte(), Ok(2));
247 assert_eq!(reader.read_byte(), Ok(3));
248 assert_eq!(reader.read_byte(), Ok(4));
249 assert_eq!(reader.read_bit(), Ok(true));
251 assert_eq!(reader.read_bit(), Ok(false));
252 assert_eq!(reader.read_bit(), Ok(true));
253 assert_eq!(reader.read_bytes(5), Ok(&[6, 7, 8, 9, 10][..]));
254 assert_eq!(reader.read_byte(), Ok(11));
255 }
256
257 fn make_buffer() -> Vec<u8> {
259 vec![0b10101100, 0b11010010, 0xFF, 0x00]
260 }
261
262 #[test]
263 fn test_read_single_bits() {
264 let buf = make_buffer();
265 let mut reader = BitStreamReader::new(&buf);
266
267 assert_eq!(reader.read_bit(), Ok(false));
269 assert_eq!(reader.read_bit(), Ok(false));
270 assert_eq!(reader.read_bit(), Ok(true));
271 assert_eq!(reader.read_bit(), Ok(true));
272 assert_eq!(reader.read_bit(), Ok(false));
273 assert_eq!(reader.read_bit(), Ok(true));
274 assert_eq!(reader.read_bit(), Ok(false));
275 assert_eq!(reader.read_bit(), Ok(true));
276 }
277
278 #[test]
279 fn test_read_small() {
280 let buf = [0b10101100, 0b11010010];
281 let mut reader = BitStreamReader::new(&buf);
282
283 assert_eq!(reader.read_small(3), Ok(0b100));
284 assert_eq!(reader.read_small(4), Ok(0b0101));
285 assert_eq!(reader.read_small(1), Ok(0b1));
286 assert_eq!(reader.read_small(4), Ok(0b0010));
287 }
288
289 #[test]
290 fn test_read_cross_byte() {
291 let buf = [0b10101100, 0b11010001];
292 let mut reader = BitStreamReader::new(&buf);
293
294 assert_eq!(reader.read_small(7), Ok(0b00101100));
296 assert_eq!(reader.read_small(3), Ok(0b011));
297 }
298
299 #[test]
300 fn test_read_byte() {
301 let buf = [0b10101100, 0b11010010];
302 let mut reader = BitStreamReader::new(&buf);
303
304 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
307
308 #[test]
309 fn test_read_bytes() {
310 let buf = [0x01, 0xAA, 0xBB, 0xCC];
311 let mut reader = BitStreamReader::new(&buf);
312
313 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
315 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
316 }
317
318 #[test]
319 fn test_align_byte() {
320 let buf = [0b10101100, 0b11010010];
321 let mut reader = BitStreamReader::new(&buf);
322
323 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
326 }
327
328 #[test]
329 fn test_eof_behavior() {
330 let buf = [0xFF];
331 let mut reader = BitStreamReader::new(&buf);
332
333 assert_eq!(reader.read_byte(), Ok(0xFF));
334 assert_eq!(
335 reader.read_bit(),
336 Err(DeserializationError::NotEnoughBytes(1))
337 );
338 assert_eq!(
339 reader.read_byte(),
340 Err(DeserializationError::NotEnoughBytes(1))
341 );
342 assert_eq!(
343 reader.read_bytes(2),
344 Err(DeserializationError::NotEnoughBytes(2))
345 );
346 }
347
348 #[test]
349 fn test_multiple_operations() {
350 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
351 let mut reader = BitStreamReader::new(&buf);
352
353 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
357 assert_eq!(
358 reader.read_bit(),
359 Err(DeserializationError::NotEnoughBytes(1))
360 );
361 }
362
363 #[test]
364 fn test_read_dyn_int() {
365 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
366 let mut stream = BitStreamReader::new(&buf);
367
368 assert_eq!(Ok(0), stream.read_byte());
369 assert_eq!(Ok(127), stream.read_dyn_int());
370 assert_eq!(Ok(128), stream.read_dyn_int());
371 assert_eq!(Ok(268435455), stream.read_dyn_int());
372 assert_eq!(
373 Err(DeserializationError::NotEnoughBytes(1)),
374 stream.read_dyn_int()
375 );
376 }
377
378 #[test]
379 fn test_read_fixed_int() {
380 let buf = vec![
381 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
382 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
383 0, 0, 0, 10,
384 ];
385
386 let mut stream = BitStreamReader::new(&buf);
387 let v1: u8 = stream.read_fixed_int().unwrap();
388 let v2: i8 = stream.read_fixed_int().unwrap();
389 let v3: u16 = stream.read_fixed_int().unwrap();
390 let v4: i16 = stream.read_fixed_int().unwrap();
391 let v5: u32 = stream.read_fixed_int().unwrap();
392 let v6: i32 = stream.read_fixed_int().unwrap();
393 let v7: u64 = stream.read_fixed_int().unwrap();
394 let v8: i64 = stream.read_fixed_int().unwrap();
395 let v9: u128 = stream.read_fixed_int().unwrap();
396 let v10: i128 = stream.read_fixed_int().unwrap();
397
398 assert_eq!(v1, 1);
399 assert_eq!(v2, 1);
400 assert_eq!(v3, 2);
401 assert_eq!(v4, 2);
402 assert_eq!(v5, 3);
403 assert_eq!(v6, 3);
404 assert_eq!(v7, 4);
405 assert_eq!(v8, 4);
406 assert_eq!(v9, 5);
407 assert_eq!(v10, 5);
408 }
409
410 #[test]
411 fn test_bytes_left() {
412 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
413 let mut reader = BitStreamReader::new(&buf);
414
415 assert_eq!(reader.bytes_left(), 4);
416 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
425
426 #[test]
427 fn offset_end_ignores_bytes_and_can_slice() {
428 let buff = [1, 2, 3, 4, 5];
429 let mut reader = BitStreamReader::new(&buff);
430
431 reader.set_offset_end(2);
432 assert_eq!(reader.bytes_left(), 3);
433 assert_eq!(reader.read_byte(), Ok(1));
434
435 assert_eq!(reader.slice(true), &[1, 2, 3]);
436 assert_eq!(reader.slice(false), &[2, 3]);
437 assert_eq!(reader.slice_end(), &[4, 5]);
438
439 assert_eq!(reader.read_byte(), Ok(2));
440 assert_eq!(reader.read_byte(), Ok(3));
441 assert_eq!(
442 reader.read_byte(),
443 Err(DeserializationError::NotEnoughBytes(1))
444 );
445
446 reader.set_offset_end(0);
447 assert_eq!(reader.bytes_left(), 2);
448 assert_eq!(reader.read_byte(), Ok(4));
449 assert_eq!(reader.read_byte(), Ok(5));
450 }
451}