1use std::cmp::min;
2
3use crate::{DeserializationError, bitstream::CryptoStream, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8 last_read_byte: Option<u8>,
9 crypto: Option<Box<dyn CryptoStream>>,
10}
11
12impl<'a> BitStreamReader<'a> {
13 pub fn new(buffer: &'a [u8]) -> Self {
15 Self {
16 buffer,
17 bit_pos: 0,
18 crypto: None,
19 last_read_byte: None,
20 }
21 }
22
23 pub fn byte_pos(&self) -> usize {
25 self.bit_pos / 8
26 }
27
28 fn current_byte(&mut self) -> u8 {
30 if let Some(b) = self.last_read_byte {
31 b
32 } else {
33 let mut b = self.buffer[self.byte_pos()];
34 if let Some(crypto) = self.crypto.as_mut() {
35 b = crypto.decrypt_byte(b);
36 }
37
38 self.last_read_byte = Some(b);
39 b
40 }
41 }
42
43 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
45 self.read_small(1).map(|v| v != 0)
46 }
47
48 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
50 assert!(bits > 0 && bits < 8);
51
52 let mut result: u8 = 0;
53 let mut shift = 0;
54
55 while bits > 0 {
56 if self.byte_pos() >= self.buffer.len() {
57 return Err(DeserializationError::NotEnoughBytes(1));
58 }
59
60 let bit_offset = self.bit_pos % 8;
62
63 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
65
66 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
70 let byte_val = self.current_byte();
71
72 let val = (byte_val & mask) >> bit_offset;
76
77 result |= val << shift;
80
81 bits -= bits_in_current_byte;
83
84 shift += bits_in_current_byte;
86
87 self.bit_pos += bits_in_current_byte as usize;
88
89 if self.bit_pos % 8 == 0 {
91 self.last_read_byte = None;
92 }
93 }
94
95 Ok(result)
96 }
97
98 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
100 self.align_byte();
101
102 if self.byte_pos() >= self.buffer.len() {
103 return Err(DeserializationError::NotEnoughBytes(1));
104 }
105
106 let byte = self.current_byte();
107 self.bit_pos += 8;
108 self.last_read_byte = None;
109
110 Ok(byte)
111 }
112
113 pub fn read_bytes(&mut self, count: usize) -> Result<&[u8], DeserializationError> {
115 self.align_byte();
116
117 let start = self.byte_pos();
118 if start + count > self.buffer.len() {
119 return Err(DeserializationError::NotEnoughBytes(
120 start + count - self.buffer.len(),
121 ));
122 }
123
124 self.bit_pos += 8 * count;
125 self.last_read_byte = None;
126
127 let slice = &self.buffer[start..start + count];
128 if let Some(crypto) = self.crypto.as_mut() {
129 Ok(crypto.decrypt_slice(slice))
130 } else {
131 Ok(slice)
132 }
133 }
134
135 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
138 self.align_byte();
139 let mut num: u128 = 0;
140 let mut multiplier: u128 = 1;
141
142 loop {
143 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
145
146 if (byte & 1 << 7) == 0 {
148 break;
149 }
150
151 multiplier *= 128;
152 }
153
154 Ok(num)
155 }
156
157 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
159 &mut self,
160 ) -> Result<T, DeserializationError> {
161 let data = self.read_bytes(S)?;
162 Ok(FixedInt::deserialize(data))
163 }
164
165 pub fn align_byte(&mut self) {
167 let rem = self.bit_pos % 8;
168 if rem != 0 {
169 self.bit_pos += 8 - rem;
170 self.last_read_byte = None;
171 }
172 }
173
174 pub fn bytes_left(&self) -> usize {
176 let left = self.buffer.len() - self.byte_pos();
177 if self.bit_pos % 8 != 0 {
178 left - 1 } else {
180 left
181 }
182 }
183
184 pub fn reset(&mut self) {
186 self.bit_pos = 0;
187 }
188}
189
190#[cfg(test)]
191mod tests {
192 use crate::{DeserializationError, bitstream::CryptoStream};
193
194 use super::BitStreamReader;
195
196 struct PlusOneDecrypter {
197 plain: Vec<u8>
198 }
199
200 impl CryptoStream for PlusOneDecrypter {
201 fn decrypt_byte(&mut self, b: u8) -> u8 {
202 self.plain.push(b + 1);
203 *self.plain.last().unwrap()
204 }
205
206 fn decrypt_slice(&mut self, slice: &[u8]) -> &[u8] {
207 let d = slice.iter().map(|s|s + 1);
208 self.plain.extend(d);
209 &self.plain[self.plain.len() - slice.len()..]
210 }
211 }
212
213 #[test]
214 fn test_decrypt_bytes() {
215 let buf = vec![1,2,3,4,5,6,7,8,9,10];
216 let mut reader = BitStreamReader::new(&buf);
217 reader.crypto = Some(Box::new(PlusOneDecrypter { plain: Vec::new() }));
218
219 assert_eq!(reader.read_byte(), Ok(2));
220 assert_eq!(reader.read_byte(), Ok(3));
221 assert_eq!(reader.read_byte(), Ok(4));
222 assert_eq!(reader.read_bit(), Ok(true));
224 assert_eq!(reader.read_bit(), Ok(false));
225 assert_eq!(reader.read_bit(), Ok(true));
226 assert_eq!(reader.read_bytes(5), Ok(&[6,7,8,9,10][..]));
227 assert_eq!(reader.read_byte(), Ok(11));
228 }
229
230 fn make_buffer() -> Vec<u8> {
232 vec![0b10101100, 0b11010010, 0xFF, 0x00]
233 }
234
235 #[test]
236 fn test_read_single_bits() {
237 let buf = make_buffer();
238 let mut reader = BitStreamReader::new(&buf);
239
240 assert_eq!(reader.read_bit(), Ok(false));
242 assert_eq!(reader.read_bit(), Ok(false));
243 assert_eq!(reader.read_bit(), Ok(true));
244 assert_eq!(reader.read_bit(), Ok(true));
245 assert_eq!(reader.read_bit(), Ok(false));
246 assert_eq!(reader.read_bit(), Ok(true));
247 assert_eq!(reader.read_bit(), Ok(false));
248 assert_eq!(reader.read_bit(), Ok(true));
249 }
250
251 #[test]
252 fn test_read_small() {
253 let buf = [0b10101100, 0b11010010];
254 let mut reader = BitStreamReader::new(&buf);
255
256 assert_eq!(reader.read_small(3), Ok(0b100));
257 assert_eq!(reader.read_small(4), Ok(0b0101));
258 assert_eq!(reader.read_small(1), Ok(0b1));
259 assert_eq!(reader.read_small(4), Ok(0b0010));
260 }
261
262 #[test]
263 fn test_read_cross_byte() {
264 let buf = [0b10101100, 0b11010001];
265 let mut reader = BitStreamReader::new(&buf);
266
267 assert_eq!(reader.read_small(7), Ok(0b00101100));
269 assert_eq!(reader.read_small(3), Ok(0b011));
270 }
271
272 #[test]
273 fn test_read_byte() {
274 let buf = [0b10101100, 0b11010010];
275 let mut reader = BitStreamReader::new(&buf);
276
277 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
280
281 #[test]
282 fn test_read_bytes() {
283 let buf = [0x01, 0xAA, 0xBB, 0xCC];
284 let mut reader = BitStreamReader::new(&buf);
285
286 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
288 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
289 }
290
291 #[test]
292 fn test_align_byte() {
293 let buf = [0b10101100, 0b11010010];
294 let mut reader = BitStreamReader::new(&buf);
295
296 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
299 }
300
301 #[test]
302 fn test_eof_behavior() {
303 let buf = [0xFF];
304 let mut reader = BitStreamReader::new(&buf);
305
306 assert_eq!(reader.read_byte(), Ok(0xFF));
307 assert_eq!(
308 reader.read_bit(),
309 Err(DeserializationError::NotEnoughBytes(1))
310 );
311 assert_eq!(
312 reader.read_byte(),
313 Err(DeserializationError::NotEnoughBytes(1))
314 );
315 assert_eq!(
316 reader.read_bytes(2),
317 Err(DeserializationError::NotEnoughBytes(2))
318 );
319 }
320
321 #[test]
322 fn test_multiple_operations() {
323 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
324 let mut reader = BitStreamReader::new(&buf);
325
326 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
330 assert_eq!(
331 reader.read_bit(),
332 Err(DeserializationError::NotEnoughBytes(1))
333 );
334 }
335
336 #[test]
337 fn test_read_dyn_int() {
338 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
339 let mut stream = BitStreamReader::new(&buf);
340
341 assert_eq!(Ok(0), stream.read_byte());
342 assert_eq!(Ok(127), stream.read_dyn_int());
343 assert_eq!(Ok(128), stream.read_dyn_int());
344 assert_eq!(Ok(268435455), stream.read_dyn_int());
345 assert_eq!(
346 Err(DeserializationError::NotEnoughBytes(1)),
347 stream.read_dyn_int()
348 );
349 }
350
351 #[test]
352 fn test_read_fixed_int() {
353 let buf = vec![
354 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
355 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
356 0, 0, 0, 10,
357 ];
358
359 let mut stream = BitStreamReader::new(&buf);
360 let v1: u8 = stream.read_fixed_int().unwrap();
361 let v2: i8 = stream.read_fixed_int().unwrap();
362 let v3: u16 = stream.read_fixed_int().unwrap();
363 let v4: i16 = stream.read_fixed_int().unwrap();
364 let v5: u32 = stream.read_fixed_int().unwrap();
365 let v6: i32 = stream.read_fixed_int().unwrap();
366 let v7: u64 = stream.read_fixed_int().unwrap();
367 let v8: i64 = stream.read_fixed_int().unwrap();
368 let v9: u128 = stream.read_fixed_int().unwrap();
369 let v10: i128 = stream.read_fixed_int().unwrap();
370
371 assert_eq!(v1, 1);
372 assert_eq!(v2, 1);
373 assert_eq!(v3, 2);
374 assert_eq!(v4, 2);
375 assert_eq!(v5, 3);
376 assert_eq!(v6, 3);
377 assert_eq!(v7, 4);
378 assert_eq!(v8, 4);
379 assert_eq!(v9, 5);
380 assert_eq!(v10, 5);
381 }
382
383 #[test]
384 fn test_bytes_left() {
385 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
386 let mut reader = BitStreamReader::new(&buf);
387
388 assert_eq!(reader.bytes_left(), 4);
389 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
398}