vexil_runtime/
bit_reader.rs1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4pub struct BitReader<'a> {
14 data: &'a [u8],
15 byte_pos: usize,
16 bit_offset: u8,
17 recursion_depth: u32,
18}
19
20impl<'a> BitReader<'a> {
21 pub fn new(data: &'a [u8]) -> Self {
23 Self {
24 data,
25 byte_pos: 0,
26 bit_offset: 0,
27 recursion_depth: 0,
28 }
29 }
30
31 pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
33 let mut result: u64 = 0;
34 for i in 0..count {
35 if self.byte_pos >= self.data.len() {
36 return Err(DecodeError::UnexpectedEof);
37 }
38 let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
39 result |= u64::from(bit) << i;
40 self.bit_offset += 1;
41 if self.bit_offset == 8 {
42 self.byte_pos += 1;
43 self.bit_offset = 0;
44 }
45 }
46 Ok(result)
47 }
48
49 pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
51 Ok(self.read_bits(1)? != 0)
52 }
53
54 pub fn flush_to_byte_boundary(&mut self) {
57 if self.bit_offset > 0 {
58 self.byte_pos += 1;
59 self.bit_offset = 0;
60 }
61 }
62
63 fn remaining(&self) -> usize {
65 self.data.len().saturating_sub(self.byte_pos)
66 }
67
68 pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
70 self.flush_to_byte_boundary();
71 if self.remaining() < 1 {
72 return Err(DecodeError::UnexpectedEof);
73 }
74 let v = self.data[self.byte_pos];
75 self.byte_pos += 1;
76 Ok(v)
77 }
78
79 pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
81 self.flush_to_byte_boundary();
82 if self.remaining() < 2 {
83 return Err(DecodeError::UnexpectedEof);
84 }
85 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
86 .try_into()
87 .map_err(|_| DecodeError::UnexpectedEof)?;
88 self.byte_pos += 2;
89 Ok(u16::from_le_bytes(bytes))
90 }
91
92 pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
94 self.flush_to_byte_boundary();
95 if self.remaining() < 4 {
96 return Err(DecodeError::UnexpectedEof);
97 }
98 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
99 .try_into()
100 .map_err(|_| DecodeError::UnexpectedEof)?;
101 self.byte_pos += 4;
102 Ok(u32::from_le_bytes(bytes))
103 }
104
105 pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
107 self.flush_to_byte_boundary();
108 if self.remaining() < 8 {
109 return Err(DecodeError::UnexpectedEof);
110 }
111 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
112 .try_into()
113 .map_err(|_| DecodeError::UnexpectedEof)?;
114 self.byte_pos += 8;
115 Ok(u64::from_le_bytes(bytes))
116 }
117
118 pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
120 self.flush_to_byte_boundary();
121 if self.remaining() < 1 {
122 return Err(DecodeError::UnexpectedEof);
123 }
124 let bytes: [u8; 1] = [self.data[self.byte_pos]];
125 self.byte_pos += 1;
126 Ok(i8::from_le_bytes(bytes))
127 }
128
129 pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
131 self.flush_to_byte_boundary();
132 if self.remaining() < 2 {
133 return Err(DecodeError::UnexpectedEof);
134 }
135 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
136 .try_into()
137 .map_err(|_| DecodeError::UnexpectedEof)?;
138 self.byte_pos += 2;
139 Ok(i16::from_le_bytes(bytes))
140 }
141
142 pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
144 self.flush_to_byte_boundary();
145 if self.remaining() < 4 {
146 return Err(DecodeError::UnexpectedEof);
147 }
148 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
149 .try_into()
150 .map_err(|_| DecodeError::UnexpectedEof)?;
151 self.byte_pos += 4;
152 Ok(i32::from_le_bytes(bytes))
153 }
154
155 pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
157 self.flush_to_byte_boundary();
158 if self.remaining() < 8 {
159 return Err(DecodeError::UnexpectedEof);
160 }
161 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
162 .try_into()
163 .map_err(|_| DecodeError::UnexpectedEof)?;
164 self.byte_pos += 8;
165 Ok(i64::from_le_bytes(bytes))
166 }
167
168 pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
170 self.flush_to_byte_boundary();
171 if self.remaining() < 4 {
172 return Err(DecodeError::UnexpectedEof);
173 }
174 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
175 .try_into()
176 .map_err(|_| DecodeError::UnexpectedEof)?;
177 self.byte_pos += 4;
178 Ok(f32::from_le_bytes(bytes))
179 }
180
181 pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
183 self.flush_to_byte_boundary();
184 if self.remaining() < 8 {
185 return Err(DecodeError::UnexpectedEof);
186 }
187 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
188 .try_into()
189 .map_err(|_| DecodeError::UnexpectedEof)?;
190 self.byte_pos += 8;
191 Ok(f64::from_le_bytes(bytes))
192 }
193
194 pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
196 self.flush_to_byte_boundary();
197 let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
198 self.byte_pos += consumed;
199 Ok(value)
200 }
201
202 pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
204 let raw = self.read_leb128(max_bytes)?;
205 Ok(crate::zigzag::zigzag_decode(raw))
206 }
207
208 pub fn read_string(&mut self) -> Result<String, DecodeError> {
210 self.flush_to_byte_boundary();
211 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212 if len > MAX_BYTES_LENGTH {
213 return Err(DecodeError::LimitExceeded {
214 field: "string",
215 limit: MAX_BYTES_LENGTH,
216 actual: len,
217 });
218 }
219 let len = len as usize;
220 if self.remaining() < len {
221 return Err(DecodeError::UnexpectedEof);
222 }
223 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224 self.byte_pos += len;
225 String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
226 }
227
228 pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
230 self.flush_to_byte_boundary();
231 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
232 if len > MAX_BYTES_LENGTH {
233 return Err(DecodeError::LimitExceeded {
234 field: "bytes",
235 limit: MAX_BYTES_LENGTH,
236 actual: len,
237 });
238 }
239 let len = len as usize;
240 if self.remaining() < len {
241 return Err(DecodeError::UnexpectedEof);
242 }
243 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
244 self.byte_pos += len;
245 Ok(bytes)
246 }
247
248 pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
250 self.flush_to_byte_boundary();
251 if self.remaining() < len {
252 return Err(DecodeError::UnexpectedEof);
253 }
254 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
255 self.byte_pos += len;
256 Ok(bytes)
257 }
258
259 pub fn read_remaining(&mut self) -> Vec<u8> {
262 self.flush_to_byte_boundary();
263 let remaining = self.data.len().saturating_sub(self.byte_pos);
264 if remaining == 0 {
265 return Vec::new();
266 }
267 let result = self.data[self.byte_pos..].to_vec();
268 self.byte_pos = self.data.len();
269 result
270 }
271
272 pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
274 self.recursion_depth += 1;
275 if self.recursion_depth > MAX_RECURSION_DEPTH {
276 return Err(DecodeError::RecursionLimitExceeded);
277 }
278 Ok(())
279 }
280
281 pub fn leave_recursive(&mut self) {
283 self.recursion_depth = self.recursion_depth.saturating_sub(1);
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290 use crate::BitWriter;
291
292 #[test]
293 fn read_single_bit() {
294 let mut r = BitReader::new(&[0x01]);
295 assert!(r.read_bool().unwrap());
296 }
297
298 #[test]
299 fn round_trip_sub_byte() {
300 let mut w = BitWriter::new();
301 w.write_bits(5, 3);
302 w.write_bits(19, 5);
303 w.write_bits(42, 6);
304 let buf = w.finish();
305 let mut r = BitReader::new(&buf);
306 assert_eq!(r.read_bits(3).unwrap(), 5);
307 assert_eq!(r.read_bits(5).unwrap(), 19);
308 assert_eq!(r.read_bits(6).unwrap(), 42);
309 }
310
311 #[test]
312 fn round_trip_u16() {
313 let mut w = BitWriter::new();
314 w.write_u16(0x1234);
315 let b = w.finish();
316 assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
317 }
318
319 #[test]
320 fn round_trip_i32_neg() {
321 let mut w = BitWriter::new();
322 w.write_i32(-42);
323 let b = w.finish();
324 assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
325 }
326
327 #[test]
328 fn round_trip_f32() {
329 let mut w = BitWriter::new();
330 w.write_f32(std::f32::consts::PI);
331 let b = w.finish();
332 assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
333 }
334
335 #[test]
336 fn round_trip_f64_nan() {
337 let mut w = BitWriter::new();
338 w.write_f64(f64::NAN);
339 let b = w.finish();
340 let v = BitReader::new(&b).read_f64().unwrap();
341 assert!(v.is_nan());
342 assert_eq!(v.to_bits(), 0x7FF8000000000000);
343 }
344
345 #[test]
346 fn round_trip_string() {
347 let mut w = BitWriter::new();
348 w.write_string("hello");
349 let b = w.finish();
350 assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
351 }
352
353 #[test]
354 fn round_trip_leb128() {
355 let mut w = BitWriter::new();
356 w.write_leb128(300);
357 let b = w.finish();
358 assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
359 }
360
361 #[test]
362 fn round_trip_zigzag() {
363 let mut w = BitWriter::new();
364 w.write_zigzag(-42, 64);
365 let b = w.finish();
366 assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
367 }
368
369 #[test]
370 fn unexpected_eof() {
371 assert_eq!(
372 BitReader::new(&[]).read_u8().unwrap_err(),
373 DecodeError::UnexpectedEof
374 );
375 }
376
377 #[test]
378 fn invalid_utf8() {
379 let mut w = BitWriter::new();
380 w.write_leb128(2);
381 w.write_raw_bytes(&[0xFF, 0xFE]);
382 let b = w.finish();
383 assert_eq!(
384 BitReader::new(&b).read_string().unwrap_err(),
385 DecodeError::InvalidUtf8
386 );
387 }
388
389 #[test]
390 fn recursion_depth_limit() {
391 let mut r = BitReader::new(&[]);
392 for _ in 0..64 {
393 r.enter_recursive().unwrap();
394 }
395 assert_eq!(
396 r.enter_recursive().unwrap_err(),
397 DecodeError::RecursionLimitExceeded
398 );
399 }
400
401 #[test]
402 fn recursion_depth_leave() {
403 let mut r = BitReader::new(&[]);
404 for _ in 0..64 {
405 r.enter_recursive().unwrap();
406 }
407 r.leave_recursive();
408 r.enter_recursive().unwrap();
409 }
410
411 #[test]
412 fn trailing_bytes_not_rejected() {
413 let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
416 let mut r = BitReader::new(&data);
417 let x = r.read_u32().unwrap();
418 assert_eq!(x, 42);
419 r.flush_to_byte_boundary();
420 }
423
424 #[test]
425 fn read_remaining_after_partial_decode() {
426 let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
427 let mut r = BitReader::new(&data);
428 let _x = r.read_u32().unwrap();
429 let remaining = r.read_remaining();
430 assert_eq!(remaining, vec![0x63, 0x00]);
431 }
432
433 #[test]
434 fn read_remaining_when_fully_consumed() {
435 let data = [0x2a, 0x00, 0x00, 0x00];
436 let mut r = BitReader::new(&data);
437 let _x = r.read_u32().unwrap();
438 let remaining = r.read_remaining();
439 assert!(remaining.is_empty());
440 }
441
442 #[test]
443 fn read_remaining_from_start() {
444 let data = [0x01, 0x02, 0x03];
445 let mut r = BitReader::new(&data);
446 let remaining = r.read_remaining();
447 assert_eq!(remaining, vec![0x01, 0x02, 0x03]);
448 }
449
450 #[test]
451 fn flush_reader() {
452 let mut w = BitWriter::new();
453 w.write_bits(0b101, 3);
454 w.flush_to_byte_boundary();
455 w.write_u8(0xAB);
456 let b = w.finish();
457 let mut r = BitReader::new(&b);
458 assert_eq!(r.read_bits(3).unwrap(), 0b101);
459 r.flush_to_byte_boundary();
460 assert_eq!(r.read_u8().unwrap(), 0xAB);
461 }
462}