Skip to main content

vexil_runtime/
bit_reader.rs

1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4pub struct BitReader<'a> {
5    data: &'a [u8],
6    byte_pos: usize,
7    bit_offset: u8,
8    recursion_depth: u32,
9}
10
11impl<'a> BitReader<'a> {
12    pub fn new(data: &'a [u8]) -> Self {
13        Self {
14            data,
15            byte_pos: 0,
16            bit_offset: 0,
17            recursion_depth: 0,
18        }
19    }
20
21    /// Read `count` bits LSB-first into a u64.
22    pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
23        let mut result: u64 = 0;
24        for i in 0..count {
25            if self.byte_pos >= self.data.len() {
26                return Err(DecodeError::UnexpectedEof);
27            }
28            let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
29            result |= u64::from(bit) << i;
30            self.bit_offset += 1;
31            if self.bit_offset == 8 {
32                self.byte_pos += 1;
33                self.bit_offset = 0;
34            }
35        }
36        Ok(result)
37    }
38
39    /// Read a single bit as bool.
40    pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
41        Ok(self.read_bits(1)? != 0)
42    }
43
44    /// Advance to the next byte boundary, discarding any remaining bits in the current byte.
45    /// Infallible.
46    pub fn flush_to_byte_boundary(&mut self) {
47        if self.bit_offset > 0 {
48            self.byte_pos += 1;
49            self.bit_offset = 0;
50        }
51    }
52
53    /// Remaining bytes from byte_pos.
54    fn remaining(&self) -> usize {
55        self.data.len().saturating_sub(self.byte_pos)
56    }
57
58    pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
59        self.flush_to_byte_boundary();
60        if self.remaining() < 1 {
61            return Err(DecodeError::UnexpectedEof);
62        }
63        let v = self.data[self.byte_pos];
64        self.byte_pos += 1;
65        Ok(v)
66    }
67
68    pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
69        self.flush_to_byte_boundary();
70        if self.remaining() < 2 {
71            return Err(DecodeError::UnexpectedEof);
72        }
73        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
74            .try_into()
75            .unwrap();
76        self.byte_pos += 2;
77        Ok(u16::from_le_bytes(bytes))
78    }
79
80    pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
81        self.flush_to_byte_boundary();
82        if self.remaining() < 4 {
83            return Err(DecodeError::UnexpectedEof);
84        }
85        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
86            .try_into()
87            .unwrap();
88        self.byte_pos += 4;
89        Ok(u32::from_le_bytes(bytes))
90    }
91
92    pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
93        self.flush_to_byte_boundary();
94        if self.remaining() < 8 {
95            return Err(DecodeError::UnexpectedEof);
96        }
97        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
98            .try_into()
99            .unwrap();
100        self.byte_pos += 8;
101        Ok(u64::from_le_bytes(bytes))
102    }
103
104    pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
105        self.flush_to_byte_boundary();
106        if self.remaining() < 1 {
107            return Err(DecodeError::UnexpectedEof);
108        }
109        let bytes: [u8; 1] = [self.data[self.byte_pos]];
110        self.byte_pos += 1;
111        Ok(i8::from_le_bytes(bytes))
112    }
113
114    pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
115        self.flush_to_byte_boundary();
116        if self.remaining() < 2 {
117            return Err(DecodeError::UnexpectedEof);
118        }
119        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
120            .try_into()
121            .unwrap();
122        self.byte_pos += 2;
123        Ok(i16::from_le_bytes(bytes))
124    }
125
126    pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
127        self.flush_to_byte_boundary();
128        if self.remaining() < 4 {
129            return Err(DecodeError::UnexpectedEof);
130        }
131        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
132            .try_into()
133            .unwrap();
134        self.byte_pos += 4;
135        Ok(i32::from_le_bytes(bytes))
136    }
137
138    pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
139        self.flush_to_byte_boundary();
140        if self.remaining() < 8 {
141            return Err(DecodeError::UnexpectedEof);
142        }
143        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
144            .try_into()
145            .unwrap();
146        self.byte_pos += 8;
147        Ok(i64::from_le_bytes(bytes))
148    }
149
150    pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
151        self.flush_to_byte_boundary();
152        if self.remaining() < 4 {
153            return Err(DecodeError::UnexpectedEof);
154        }
155        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
156            .try_into()
157            .unwrap();
158        self.byte_pos += 4;
159        Ok(f32::from_le_bytes(bytes))
160    }
161
162    pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
163        self.flush_to_byte_boundary();
164        if self.remaining() < 8 {
165            return Err(DecodeError::UnexpectedEof);
166        }
167        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
168            .try_into()
169            .unwrap();
170        self.byte_pos += 8;
171        Ok(f64::from_le_bytes(bytes))
172    }
173
174    /// Read a LEB128-encoded u64, consuming at most `max_bytes` bytes.
175    pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
176        self.flush_to_byte_boundary();
177        let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
178        self.byte_pos += consumed;
179        Ok(value)
180    }
181
182    /// Read a ZigZag + LEB128 encoded signed integer.
183    pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
184        let raw = self.read_leb128(max_bytes)?;
185        Ok(crate::zigzag::zigzag_decode(raw))
186    }
187
188    /// Read a length-prefixed UTF-8 string.
189    pub fn read_string(&mut self) -> Result<String, DecodeError> {
190        self.flush_to_byte_boundary();
191        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
192        if len > MAX_BYTES_LENGTH {
193            return Err(DecodeError::LimitExceeded {
194                field: "string",
195                limit: MAX_BYTES_LENGTH,
196                actual: len,
197            });
198        }
199        let len = len as usize;
200        if self.remaining() < len {
201            return Err(DecodeError::UnexpectedEof);
202        }
203        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
204        self.byte_pos += len;
205        String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
206    }
207
208    /// Read a length-prefixed byte vector.
209    pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
210        self.flush_to_byte_boundary();
211        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212        if len > MAX_BYTES_LENGTH {
213            return Err(DecodeError::LimitExceeded {
214                field: "bytes",
215                limit: MAX_BYTES_LENGTH,
216                actual: len,
217            });
218        }
219        let len = len as usize;
220        if self.remaining() < len {
221            return Err(DecodeError::UnexpectedEof);
222        }
223        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224        self.byte_pos += len;
225        Ok(bytes)
226    }
227
228    /// Read exactly `len` raw bytes with no length prefix.
229    pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
230        self.flush_to_byte_boundary();
231        if self.remaining() < len {
232            return Err(DecodeError::UnexpectedEof);
233        }
234        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
235        self.byte_pos += len;
236        Ok(bytes)
237    }
238
239    /// Increment recursion depth; return error if limit exceeded.
240    pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
241        self.recursion_depth += 1;
242        if self.recursion_depth > MAX_RECURSION_DEPTH {
243            return Err(DecodeError::RecursionLimitExceeded);
244        }
245        Ok(())
246    }
247
248    /// Decrement recursion depth.
249    pub fn leave_recursive(&mut self) {
250        self.recursion_depth = self.recursion_depth.saturating_sub(1);
251    }
252}
253
254#[cfg(test)]
255mod tests {
256    use super::*;
257    use crate::BitWriter;
258
259    #[test]
260    fn read_single_bit() {
261        let mut r = BitReader::new(&[0x01]);
262        assert!(r.read_bool().unwrap());
263    }
264
265    #[test]
266    fn round_trip_sub_byte() {
267        let mut w = BitWriter::new();
268        w.write_bits(5, 3);
269        w.write_bits(19, 5);
270        w.write_bits(42, 6);
271        let buf = w.finish();
272        let mut r = BitReader::new(&buf);
273        assert_eq!(r.read_bits(3).unwrap(), 5);
274        assert_eq!(r.read_bits(5).unwrap(), 19);
275        assert_eq!(r.read_bits(6).unwrap(), 42);
276    }
277
278    #[test]
279    fn round_trip_u16() {
280        let mut w = BitWriter::new();
281        w.write_u16(0x1234);
282        let b = w.finish();
283        assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
284    }
285
286    #[test]
287    fn round_trip_i32_neg() {
288        let mut w = BitWriter::new();
289        w.write_i32(-42);
290        let b = w.finish();
291        assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
292    }
293
294    #[test]
295    fn round_trip_f32() {
296        let mut w = BitWriter::new();
297        w.write_f32(std::f32::consts::PI);
298        let b = w.finish();
299        assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
300    }
301
302    #[test]
303    fn round_trip_f64_nan() {
304        let mut w = BitWriter::new();
305        w.write_f64(f64::NAN);
306        let b = w.finish();
307        let v = BitReader::new(&b).read_f64().unwrap();
308        assert!(v.is_nan());
309        assert_eq!(v.to_bits(), 0x7FF8000000000000);
310    }
311
312    #[test]
313    fn round_trip_string() {
314        let mut w = BitWriter::new();
315        w.write_string("hello");
316        let b = w.finish();
317        assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
318    }
319
320    #[test]
321    fn round_trip_leb128() {
322        let mut w = BitWriter::new();
323        w.write_leb128(300);
324        let b = w.finish();
325        assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
326    }
327
328    #[test]
329    fn round_trip_zigzag() {
330        let mut w = BitWriter::new();
331        w.write_zigzag(-42, 64);
332        let b = w.finish();
333        assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
334    }
335
336    #[test]
337    fn unexpected_eof() {
338        assert_eq!(
339            BitReader::new(&[]).read_u8().unwrap_err(),
340            DecodeError::UnexpectedEof
341        );
342    }
343
344    #[test]
345    fn invalid_utf8() {
346        let mut w = BitWriter::new();
347        w.write_leb128(2);
348        w.write_raw_bytes(&[0xFF, 0xFE]);
349        let b = w.finish();
350        assert_eq!(
351            BitReader::new(&b).read_string().unwrap_err(),
352            DecodeError::InvalidUtf8
353        );
354    }
355
356    #[test]
357    fn recursion_depth_limit() {
358        let mut r = BitReader::new(&[]);
359        for _ in 0..64 {
360            r.enter_recursive().unwrap();
361        }
362        assert_eq!(
363            r.enter_recursive().unwrap_err(),
364            DecodeError::RecursionLimitExceeded
365        );
366    }
367
368    #[test]
369    fn recursion_depth_leave() {
370        let mut r = BitReader::new(&[]);
371        for _ in 0..64 {
372            r.enter_recursive().unwrap();
373        }
374        r.leave_recursive();
375        r.enter_recursive().unwrap();
376    }
377
378    #[test]
379    fn flush_reader() {
380        let mut w = BitWriter::new();
381        w.write_bits(0b101, 3);
382        w.flush_to_byte_boundary();
383        w.write_u8(0xAB);
384        let b = w.finish();
385        let mut r = BitReader::new(&b);
386        assert_eq!(r.read_bits(3).unwrap(), 0b101);
387        r.flush_to_byte_boundary();
388        assert_eq!(r.read_u8().unwrap(), 0xAB);
389    }
390}