vexil_runtime/
bit_reader.rs1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4pub struct BitReader<'a> {
5 data: &'a [u8],
6 byte_pos: usize,
7 bit_offset: u8,
8 recursion_depth: u32,
9}
10
11impl<'a> BitReader<'a> {
12 pub fn new(data: &'a [u8]) -> Self {
13 Self {
14 data,
15 byte_pos: 0,
16 bit_offset: 0,
17 recursion_depth: 0,
18 }
19 }
20
21 pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
23 let mut result: u64 = 0;
24 for i in 0..count {
25 if self.byte_pos >= self.data.len() {
26 return Err(DecodeError::UnexpectedEof);
27 }
28 let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
29 result |= u64::from(bit) << i;
30 self.bit_offset += 1;
31 if self.bit_offset == 8 {
32 self.byte_pos += 1;
33 self.bit_offset = 0;
34 }
35 }
36 Ok(result)
37 }
38
39 pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
41 Ok(self.read_bits(1)? != 0)
42 }
43
44 pub fn flush_to_byte_boundary(&mut self) {
47 if self.bit_offset > 0 {
48 self.byte_pos += 1;
49 self.bit_offset = 0;
50 }
51 }
52
53 fn remaining(&self) -> usize {
55 self.data.len().saturating_sub(self.byte_pos)
56 }
57
58 pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
59 self.flush_to_byte_boundary();
60 if self.remaining() < 1 {
61 return Err(DecodeError::UnexpectedEof);
62 }
63 let v = self.data[self.byte_pos];
64 self.byte_pos += 1;
65 Ok(v)
66 }
67
68 pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
69 self.flush_to_byte_boundary();
70 if self.remaining() < 2 {
71 return Err(DecodeError::UnexpectedEof);
72 }
73 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
74 .try_into()
75 .unwrap();
76 self.byte_pos += 2;
77 Ok(u16::from_le_bytes(bytes))
78 }
79
80 pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
81 self.flush_to_byte_boundary();
82 if self.remaining() < 4 {
83 return Err(DecodeError::UnexpectedEof);
84 }
85 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
86 .try_into()
87 .unwrap();
88 self.byte_pos += 4;
89 Ok(u32::from_le_bytes(bytes))
90 }
91
92 pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
93 self.flush_to_byte_boundary();
94 if self.remaining() < 8 {
95 return Err(DecodeError::UnexpectedEof);
96 }
97 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
98 .try_into()
99 .unwrap();
100 self.byte_pos += 8;
101 Ok(u64::from_le_bytes(bytes))
102 }
103
104 pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
105 self.flush_to_byte_boundary();
106 if self.remaining() < 1 {
107 return Err(DecodeError::UnexpectedEof);
108 }
109 let bytes: [u8; 1] = [self.data[self.byte_pos]];
110 self.byte_pos += 1;
111 Ok(i8::from_le_bytes(bytes))
112 }
113
114 pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
115 self.flush_to_byte_boundary();
116 if self.remaining() < 2 {
117 return Err(DecodeError::UnexpectedEof);
118 }
119 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
120 .try_into()
121 .unwrap();
122 self.byte_pos += 2;
123 Ok(i16::from_le_bytes(bytes))
124 }
125
126 pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
127 self.flush_to_byte_boundary();
128 if self.remaining() < 4 {
129 return Err(DecodeError::UnexpectedEof);
130 }
131 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
132 .try_into()
133 .unwrap();
134 self.byte_pos += 4;
135 Ok(i32::from_le_bytes(bytes))
136 }
137
138 pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
139 self.flush_to_byte_boundary();
140 if self.remaining() < 8 {
141 return Err(DecodeError::UnexpectedEof);
142 }
143 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
144 .try_into()
145 .unwrap();
146 self.byte_pos += 8;
147 Ok(i64::from_le_bytes(bytes))
148 }
149
150 pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
151 self.flush_to_byte_boundary();
152 if self.remaining() < 4 {
153 return Err(DecodeError::UnexpectedEof);
154 }
155 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
156 .try_into()
157 .unwrap();
158 self.byte_pos += 4;
159 Ok(f32::from_le_bytes(bytes))
160 }
161
162 pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
163 self.flush_to_byte_boundary();
164 if self.remaining() < 8 {
165 return Err(DecodeError::UnexpectedEof);
166 }
167 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
168 .try_into()
169 .unwrap();
170 self.byte_pos += 8;
171 Ok(f64::from_le_bytes(bytes))
172 }
173
174 pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
176 self.flush_to_byte_boundary();
177 let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
178 self.byte_pos += consumed;
179 Ok(value)
180 }
181
182 pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
184 let raw = self.read_leb128(max_bytes)?;
185 Ok(crate::zigzag::zigzag_decode(raw))
186 }
187
188 pub fn read_string(&mut self) -> Result<String, DecodeError> {
190 self.flush_to_byte_boundary();
191 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
192 if len > MAX_BYTES_LENGTH {
193 return Err(DecodeError::LimitExceeded {
194 field: "string",
195 limit: MAX_BYTES_LENGTH,
196 actual: len,
197 });
198 }
199 let len = len as usize;
200 if self.remaining() < len {
201 return Err(DecodeError::UnexpectedEof);
202 }
203 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
204 self.byte_pos += len;
205 String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
206 }
207
208 pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
210 self.flush_to_byte_boundary();
211 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212 if len > MAX_BYTES_LENGTH {
213 return Err(DecodeError::LimitExceeded {
214 field: "bytes",
215 limit: MAX_BYTES_LENGTH,
216 actual: len,
217 });
218 }
219 let len = len as usize;
220 if self.remaining() < len {
221 return Err(DecodeError::UnexpectedEof);
222 }
223 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224 self.byte_pos += len;
225 Ok(bytes)
226 }
227
228 pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
230 self.flush_to_byte_boundary();
231 if self.remaining() < len {
232 return Err(DecodeError::UnexpectedEof);
233 }
234 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
235 self.byte_pos += len;
236 Ok(bytes)
237 }
238
239 pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
241 self.recursion_depth += 1;
242 if self.recursion_depth > MAX_RECURSION_DEPTH {
243 return Err(DecodeError::RecursionLimitExceeded);
244 }
245 Ok(())
246 }
247
248 pub fn leave_recursive(&mut self) {
250 self.recursion_depth = self.recursion_depth.saturating_sub(1);
251 }
252}
253
254#[cfg(test)]
255mod tests {
256 use super::*;
257 use crate::BitWriter;
258
259 #[test]
260 fn read_single_bit() {
261 let mut r = BitReader::new(&[0x01]);
262 assert!(r.read_bool().unwrap());
263 }
264
265 #[test]
266 fn round_trip_sub_byte() {
267 let mut w = BitWriter::new();
268 w.write_bits(5, 3);
269 w.write_bits(19, 5);
270 w.write_bits(42, 6);
271 let buf = w.finish();
272 let mut r = BitReader::new(&buf);
273 assert_eq!(r.read_bits(3).unwrap(), 5);
274 assert_eq!(r.read_bits(5).unwrap(), 19);
275 assert_eq!(r.read_bits(6).unwrap(), 42);
276 }
277
278 #[test]
279 fn round_trip_u16() {
280 let mut w = BitWriter::new();
281 w.write_u16(0x1234);
282 let b = w.finish();
283 assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
284 }
285
286 #[test]
287 fn round_trip_i32_neg() {
288 let mut w = BitWriter::new();
289 w.write_i32(-42);
290 let b = w.finish();
291 assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
292 }
293
294 #[test]
295 fn round_trip_f32() {
296 let mut w = BitWriter::new();
297 w.write_f32(std::f32::consts::PI);
298 let b = w.finish();
299 assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
300 }
301
302 #[test]
303 fn round_trip_f64_nan() {
304 let mut w = BitWriter::new();
305 w.write_f64(f64::NAN);
306 let b = w.finish();
307 let v = BitReader::new(&b).read_f64().unwrap();
308 assert!(v.is_nan());
309 assert_eq!(v.to_bits(), 0x7FF8000000000000);
310 }
311
312 #[test]
313 fn round_trip_string() {
314 let mut w = BitWriter::new();
315 w.write_string("hello");
316 let b = w.finish();
317 assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
318 }
319
320 #[test]
321 fn round_trip_leb128() {
322 let mut w = BitWriter::new();
323 w.write_leb128(300);
324 let b = w.finish();
325 assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
326 }
327
328 #[test]
329 fn round_trip_zigzag() {
330 let mut w = BitWriter::new();
331 w.write_zigzag(-42, 64);
332 let b = w.finish();
333 assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
334 }
335
336 #[test]
337 fn unexpected_eof() {
338 assert_eq!(
339 BitReader::new(&[]).read_u8().unwrap_err(),
340 DecodeError::UnexpectedEof
341 );
342 }
343
344 #[test]
345 fn invalid_utf8() {
346 let mut w = BitWriter::new();
347 w.write_leb128(2);
348 w.write_raw_bytes(&[0xFF, 0xFE]);
349 let b = w.finish();
350 assert_eq!(
351 BitReader::new(&b).read_string().unwrap_err(),
352 DecodeError::InvalidUtf8
353 );
354 }
355
356 #[test]
357 fn recursion_depth_limit() {
358 let mut r = BitReader::new(&[]);
359 for _ in 0..64 {
360 r.enter_recursive().unwrap();
361 }
362 assert_eq!(
363 r.enter_recursive().unwrap_err(),
364 DecodeError::RecursionLimitExceeded
365 );
366 }
367
368 #[test]
369 fn recursion_depth_leave() {
370 let mut r = BitReader::new(&[]);
371 for _ in 0..64 {
372 r.enter_recursive().unwrap();
373 }
374 r.leave_recursive();
375 r.enter_recursive().unwrap();
376 }
377
378 #[test]
379 fn flush_reader() {
380 let mut w = BitWriter::new();
381 w.write_bits(0b101, 3);
382 w.flush_to_byte_boundary();
383 w.write_u8(0xAB);
384 let b = w.finish();
385 let mut r = BitReader::new(&b);
386 assert_eq!(r.read_bits(3).unwrap(), 0b101);
387 r.flush_to_byte_boundary();
388 assert_eq!(r.read_u8().unwrap(), 0xAB);
389 }
390}