1use std::cmp::min;
2
3use crate::{DeserializationError, encoding::fixed_int::FixedInt};
4
5pub struct BitStreamReader<'a> {
6 buffer: &'a [u8],
7 bit_pos: usize,
8}
9
10impl<'a> BitStreamReader<'a> {
11 pub fn new(buffer: &'a [u8]) -> Self {
13 Self { buffer, bit_pos: 0 }
14 }
15
16 pub fn byte_pos(&self) -> usize {
18 self.bit_pos / 8
19 }
20
21 pub fn read_bit(&mut self) -> Result<bool, DeserializationError> {
23 self.read_small(1).map(|v| v != 0)
24 }
25
26 pub fn read_small(&mut self, mut bits: u8) -> Result<u8, DeserializationError> {
28 assert!(bits > 0 && bits < 8);
29
30 let mut result: u8 = 0;
31 let mut shift = 0;
32
33 while bits > 0 {
34 if self.byte_pos() >= self.buffer.len() {
35 return Err(DeserializationError::NotEnoughBytes(1));
36 }
37
38 let bit_offset = self.bit_pos % 8;
40
41 let bits_in_current_byte = min(8 - bit_offset as u8, bits);
43
44 let mask = ((1 << bits_in_current_byte) - 1) << bit_offset;
48 let byte_val = self.buffer[self.byte_pos()];
49
50 let val = (byte_val & mask) >> bit_offset;
54
55 result |= val << shift;
58
59 bits -= bits_in_current_byte;
61
62 shift += bits_in_current_byte;
64
65 self.bit_pos += bits_in_current_byte as usize;
66 }
67
68 Ok(result)
69 }
70
71 pub fn read_byte(&mut self) -> Result<u8, DeserializationError> {
73 self.align_byte();
74
75 if self.byte_pos() >= self.buffer.len() {
76 return Err(DeserializationError::NotEnoughBytes(1));
77 }
78
79 let b = self.buffer[self.byte_pos()];
80 self.bit_pos += 8;
81 Ok(b)
82 }
83
84 pub fn read_bytes(&mut self, count: usize) -> Result<&'a [u8], DeserializationError> {
86 self.align_byte();
87
88 let start = self.byte_pos();
89 if start + count > self.buffer.len() {
90 return Err(DeserializationError::NotEnoughBytes(
91 start + count - self.buffer.len(),
92 ));
93 }
94
95 self.bit_pos += 8 * count;
96
97 Ok(&self.buffer[start..start + count])
98 }
99
100 pub fn read_dyn_int(&mut self) -> Result<u128, DeserializationError> {
103 self.align_byte();
104 let mut num: u128 = 0;
105 let mut multiplier: u128 = 1;
106
107 loop {
108 let byte = self.read_byte()?; num += ((byte & 127) as u128) * multiplier;
110
111 if (byte & 1 << 7) == 0 {
113 break;
114 }
115
116 multiplier *= 128;
117 }
118
119 Ok(num)
120 }
121
122 pub fn read_fixed_int<const S: usize, T: FixedInt<S>>(
124 &mut self,
125 ) -> Result<T, DeserializationError> {
126 let data = self.read_bytes(S)?;
127 Ok(FixedInt::deserialize(data))
128 }
129
130 pub fn align_byte(&mut self) {
132 let rem = self.bit_pos % 8;
133 if rem != 0 {
134 self.bit_pos += 8 - rem;
135 }
136 }
137
138 pub fn bytes_left(&self) -> usize {
140 let left = self.buffer.len() - self.byte_pos();
141 if self.bit_pos % 8 != 0 {
142 left - 1 } else {
144 left
145 }
146 }
147
148 pub fn reset(&mut self) {
150 self.bit_pos = 0;
151 }
152}
153
154#[cfg(test)]
155mod tests {
156 use crate::DeserializationError;
157
158 use super::BitStreamReader;
159
160 fn make_buffer() -> Vec<u8> {
162 vec![0b10101100, 0b11010010, 0xFF, 0x00]
163 }
164
165 #[test]
166 fn test_read_single_bits() {
167 let buf = make_buffer();
168 let mut reader = BitStreamReader::new(&buf);
169
170 assert_eq!(reader.read_bit(), Ok(false));
172 assert_eq!(reader.read_bit(), Ok(false));
173 assert_eq!(reader.read_bit(), Ok(true));
174 assert_eq!(reader.read_bit(), Ok(true));
175 assert_eq!(reader.read_bit(), Ok(false));
176 assert_eq!(reader.read_bit(), Ok(true));
177 assert_eq!(reader.read_bit(), Ok(false));
178 assert_eq!(reader.read_bit(), Ok(true));
179 }
180
181 #[test]
182 fn test_read_small() {
183 let buf = [0b10101100, 0b11010010];
184 let mut reader = BitStreamReader::new(&buf);
185
186 assert_eq!(reader.read_small(3), Ok(0b100));
187 assert_eq!(reader.read_small(4), Ok(0b0101));
188 assert_eq!(reader.read_small(1), Ok(0b1));
189 assert_eq!(reader.read_small(4), Ok(0b0010));
190 }
191
192 #[test]
193 fn test_read_cross_byte() {
194 let buf = [0b10101100, 0b11010001];
195 let mut reader = BitStreamReader::new(&buf);
196
197 assert_eq!(reader.read_small(7), Ok(0b00101100));
199 assert_eq!(reader.read_small(3), Ok(0b011));
200 }
201
202 #[test]
203 fn test_read_byte() {
204 let buf = [0b10101100, 0b11010010];
205 let mut reader = BitStreamReader::new(&buf);
206
207 reader.read_small(3).unwrap(); assert_eq!(reader.read_byte(), Ok(0b11010010)); }
210
211 #[test]
212 fn test_read_bytes() {
213 let buf = [0x01, 0xAA, 0xBB, 0xCC];
214 let mut reader = BitStreamReader::new(&buf);
215
216 reader.read_bit().unwrap(); let slice = reader.read_bytes(3).unwrap();
218 assert_eq!(slice, &[0xAA, 0xBB, 0xCC]);
219 }
220
221 #[test]
222 fn test_align_byte() {
223 let buf = [0b10101100, 0b11010010];
224 let mut reader = BitStreamReader::new(&buf);
225
226 reader.read_small(3).unwrap(); reader.align_byte(); assert_eq!(reader.read_byte(), Ok(0b11010010));
229 }
230
231 #[test]
232 fn test_eof_behavior() {
233 let buf = [0xFF];
234 let mut reader = BitStreamReader::new(&buf);
235
236 assert_eq!(reader.read_byte(), Ok(0xFF));
237 assert_eq!(
238 reader.read_bit(),
239 Err(DeserializationError::NotEnoughBytes(1))
240 );
241 assert_eq!(
242 reader.read_byte(),
243 Err(DeserializationError::NotEnoughBytes(1))
244 );
245 assert_eq!(
246 reader.read_bytes(2),
247 Err(DeserializationError::NotEnoughBytes(2))
248 );
249 }
250
251 #[test]
252 fn test_multiple_operations() {
253 let buf = [0b10101010, 0b11001100, 0xFF, 0x00];
254 let mut reader = BitStreamReader::new(&buf);
255
256 assert_eq!(reader.read_bit(), Ok(false)); assert_eq!(reader.read_small(3), Ok(0b101)); assert_eq!(reader.read_byte(), Ok(0b11001100)); assert_eq!(reader.read_bytes(2), Ok(&[0xFF, 0x00][..]));
260 assert_eq!(
261 reader.read_bit(),
262 Err(DeserializationError::NotEnoughBytes(1))
263 );
264 }
265
266 #[test]
267 fn test_read_dyn_int() {
268 let buf = vec![0, 127, 128, 1, 255, 255, 255, 127];
269 let mut stream = BitStreamReader::new(&buf);
270
271 assert_eq!(Ok(0), stream.read_byte());
272 assert_eq!(Ok(127), stream.read_dyn_int());
273 assert_eq!(Ok(128), stream.read_dyn_int());
274 assert_eq!(Ok(268435455), stream.read_dyn_int());
275 assert_eq!(
276 Err(DeserializationError::NotEnoughBytes(1)),
277 stream.read_dyn_int()
278 );
279 }
280
281 #[test]
282 fn test_read_fixed_int() {
283 let buf = vec![
284 1, 2, 0, 2, 0, 4, 0, 0, 0, 3, 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 4, 0, 0, 0, 0, 0, 0, 0,
285 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
286 0, 0, 0, 10,
287 ];
288
289 let mut stream = BitStreamReader::new(&buf);
290 let v1: u8 = stream.read_fixed_int().unwrap();
291 let v2: i8 = stream.read_fixed_int().unwrap();
292 let v3: u16 = stream.read_fixed_int().unwrap();
293 let v4: i16 = stream.read_fixed_int().unwrap();
294 let v5: u32 = stream.read_fixed_int().unwrap();
295 let v6: i32 = stream.read_fixed_int().unwrap();
296 let v7: u64 = stream.read_fixed_int().unwrap();
297 let v8: i64 = stream.read_fixed_int().unwrap();
298 let v9: u128 = stream.read_fixed_int().unwrap();
299 let v10: i128 = stream.read_fixed_int().unwrap();
300
301 assert_eq!(v1, 1);
302 assert_eq!(v2, 1);
303 assert_eq!(v3, 2);
304 assert_eq!(v4, 2);
305 assert_eq!(v5, 3);
306 assert_eq!(v6, 3);
307 assert_eq!(v7, 4);
308 assert_eq!(v8, 4);
309 assert_eq!(v9, 5);
310 assert_eq!(v10, 5);
311 }
312
313 #[test]
314 fn test_bytes_left() {
315 let buf = [0b10101100, 0b11010010, 0xFF, 0x00];
316 let mut reader = BitStreamReader::new(&buf);
317
318 assert_eq!(reader.bytes_left(), 4);
319 reader.read_small(3).unwrap(); assert_eq!(reader.bytes_left(), 3); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 2); reader.read_byte().unwrap(); assert_eq!(reader.bytes_left(), 1); reader.read_bit().unwrap(); assert_eq!(reader.bytes_left(), 0); }
328}