Skip to main content

llvm_bitcode/
bits.rs

1use std::{error, fmt};
2
3#[derive(Debug, Clone)]
4pub enum Error {
5    BufferOverflow,
6    VbrOverflow,
7    Alignment,
8}
9
10impl fmt::Display for Error {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        f.write_str(match self {
13            Self::BufferOverflow => "buffer overflow",
14            Self::VbrOverflow => "vbr overflow",
15            Self::Alignment => "bad alignment",
16        })
17    }
18}
19
20impl error::Error for Error {}
21
22#[derive(Clone)]
23pub struct Cursor<'input> {
24    buffer: &'input [u8],
25    offset: usize,
26}
27
28impl<'input> Cursor<'input> {
29    #[must_use]
30    pub fn new(buffer: &'input [u8]) -> Self {
31        Self { buffer, offset: 0 }
32    }
33
34    #[must_use]
35    pub fn is_at_end(&self) -> bool {
36        self.offset >= (self.buffer.len() << 3)
37    }
38
39    #[inline]
40    pub fn peek(&self, bits: u8) -> Result<u64, Error> {
41        self.read_bits(bits).ok_or(Error::BufferOverflow)
42    }
43
44    #[inline]
45    pub fn read(&mut self, bits: u8) -> Result<u64, Error> {
46        if bits < 1 || bits > 64 {
47            return Err(Error::VbrOverflow);
48        }
49        let res = self.read_bits(bits).ok_or(Error::BufferOverflow)?;
50        self.offset += bits as usize;
51        Ok(res)
52    }
53
54    #[inline]
55    fn read_bits(&self, bits: u8) -> Option<u64> {
56        if bits == 0 || bits > 64 {
57            return None;
58        }
59
60        let byte_start = self.offset >> 3;
61        let shift = (self.offset & 7) as u8;
62
63        let extra_len = shift + (bits & 7);
64        let byte_len = usize::from(extra_len.div_ceil(8) + (bits >> 3));
65
66        let bytes = self.buffer.get(byte_start..byte_start + byte_len)?;
67        let accumulate = byte_len.min(8);
68        let mut res = 0u64;
69        for (i, &b) in bytes.get(..accumulate)?.iter().enumerate().take(accumulate) {
70            res |= (b as u64) << (i << 3);
71        }
72
73        res >>= shift;
74
75        if let Some(&extra_byte) = bytes.get(8) {
76            res |= u64::from(extra_byte) << (64 - shift);
77        }
78
79        if bits < 64 {
80            res &= (1 << bits) - 1;
81        }
82
83        Some(res)
84    }
85
86    pub fn read_bytes(&mut self, length_bytes: usize) -> Result<&'input [u8], Error> {
87        if !self.offset.is_multiple_of(8) {
88            return Err(Error::Alignment);
89        }
90        let byte_start = self.offset >> 3;
91        let byte_end = byte_start + length_bytes;
92        let bytes = self
93            .buffer
94            .get(byte_start..byte_end)
95            .ok_or(Error::BufferOverflow)?;
96        self.offset = byte_end << 3;
97        Ok(bytes)
98    }
99
100    pub fn skip_bytes(&mut self, count: usize) -> Result<(), Error> {
101        if !self.offset.is_multiple_of(8) {
102            return Err(Error::Alignment);
103        }
104        let byte_end = (self.offset >> 3) + count;
105        if byte_end > self.buffer.len() {
106            return Err(Error::BufferOverflow);
107        }
108        self.offset = byte_end << 3;
109        Ok(())
110    }
111
112    /// Create a cursor for `length_bytes`, and skip over `length_bytes`
113    /// Must be aligned to 32 bits.
114    pub(crate) fn take_slice(&mut self, length_bytes: usize) -> Result<Self, Error> {
115        if !self.offset.is_multiple_of(32) {
116            return Err(Error::Alignment);
117        }
118        Ok(Cursor {
119            buffer: self.read_bytes(length_bytes)?,
120            offset: 0,
121        })
122    }
123
124    /// Read a VBR number in `width`-wide encoding.
125    /// The number may be up to 64-bit long regardless of the `width`.
126    #[inline]
127    pub fn read_vbr(&mut self, width: u8) -> Result<u64, Error> {
128        match width {
129            6 => self.read_vbr_fixed::<6>(),
130            8 => self.read_vbr_fixed::<8>(),
131            _ => self.read_vbr_inline(width),
132        }
133    }
134
135    pub(crate) fn read_vbr_fixed<const WIDTH: u8>(&mut self) -> Result<u64, Error> {
136        self.read_vbr_inline(WIDTH)
137    }
138
139    #[inline(always)]
140    pub(crate) fn read_vbr_inline(&mut self, width: u8) -> Result<u64, Error> {
141        if width < 1 || width > 32 {
142            // This is `MaxChunkSize` in LLVM
143            return Err(Error::VbrOverflow);
144        }
145        let test_bit = 1u64 << (width - 1);
146        let mask = test_bit - 1;
147        let mut res = 0;
148        let mut offset = 0;
149        loop {
150            let next = self.read(width)?;
151            res |= (next & mask) << offset;
152            offset += width - 1;
153            // 64 may not be divisible by width
154            if offset > 63 + width {
155                return Err(Error::VbrOverflow);
156            }
157            if next & test_bit == 0 {
158                break;
159            }
160        }
161        Ok(res)
162    }
163
164    /// Skip bytes until a 32-bit boundary (no-op if already aligned)
165    pub fn align32(&mut self) -> Result<(), Error> {
166        let new_offset = if self.offset.is_multiple_of(32) {
167            self.offset
168        } else {
169            (self.offset + 32) & !(32 - 1)
170        };
171        self.buffer = self
172            .buffer
173            .get((new_offset >> 3)..)
174            .ok_or(Error::BufferOverflow)?;
175        self.offset = 0;
176        Ok(())
177    }
178
179    /// Maximum number of bits that can be read
180    #[must_use]
181    pub fn unconsumed_bit_len(&self) -> usize {
182        (self.buffer.len() << 3) - self.offset
183    }
184}
185
186struct CursorDebugBytes<'a>(&'a [u8]);
187
188impl fmt::Debug for CursorDebugBytes<'_> {
189    #[cold]
190    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191        f.write_str("[0x")?;
192        for &b in self.0.iter().take(200) {
193            write!(f, "{b:02x}")?;
194        }
195        if self.0.len() > 200 {
196            f.write_str("...")?;
197        }
198        write!(f, "; {}]", self.0.len())
199    }
200}
201
202impl fmt::Debug for Cursor<'_> {
203    /// Debug-print only the accessible part of the internal buffer
204    #[cold]
205    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206        let byte_offset = self.offset / 8;
207        let bit_offset = self.offset % 8;
208        let buffer = CursorDebugBytes(self.buffer.get(byte_offset..).unwrap_or_default());
209        f.debug_struct("Cursor")
210            .field("offset", &bit_offset)
211            .field("buffer", &buffer)
212            .field("nextvbr6", &self.peek(6).ok())
213            .finish()
214    }
215}
216
217#[test]
218fn test_all_bits() {
219    for i in 1..=64 {
220        let mut c = Cursor::new(&[!0; 17]);
221        let _ = c.read(i).unwrap();
222        assert_eq!(!0, c.read(64).unwrap());
223        assert_eq!(1, c.read(1).unwrap());
224    }
225}
226
227#[test]
228fn test_cursor_bits() {
229    let mut c = Cursor::new(&[0b1000_0000]);
230    assert_eq!(0, c.peek(1).unwrap());
231    assert!(c.peek(9).is_err());
232    assert_eq!(0, c.peek(2).unwrap());
233    assert_eq!(0, c.peek(3).unwrap());
234    assert_eq!(0, c.peek(4).unwrap());
235    assert_eq!(0, c.peek(5).unwrap());
236    assert_eq!(0, c.peek(6).unwrap());
237    assert_eq!(0, c.peek(7).unwrap());
238    assert_eq!(0b1000_0000, c.peek(8).unwrap());
239    assert_eq!(0, c.read(6).unwrap());
240    assert_eq!(0b10, c.peek(2).unwrap());
241    assert_eq!(0, c.peek(1).unwrap());
242    assert_eq!(0, c.read(1).unwrap());
243    assert_eq!(0b1, c.peek(1).unwrap());
244    assert_eq!(0b1, c.read(1).unwrap());
245
246    let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 0x55, 0x11, 0xff, 1, 127, 0x51]);
247    assert_eq!(0, c.peek(1).unwrap());
248    assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
249    assert_eq!(0, c.peek(2).unwrap());
250    assert_eq!(0, c.peek(3).unwrap());
251    assert_eq!(0, c.peek(4).unwrap());
252    assert_eq!(0, c.peek(5).unwrap());
253    assert_eq!(0, c.peek(6).unwrap());
254    assert_eq!(0, c.peek(7).unwrap());
255    assert_eq!(0, c.peek(8).unwrap());
256    assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
257
258    assert_eq!(0, c.peek(7).unwrap());
259    assert!(c.read(0).is_err());
260    assert_eq!(0, c.read(1).unwrap());
261    assert_eq!(0, c.read(2).unwrap());
262    assert_eq!(0, c.read(3).unwrap());
263    assert_eq!(4, c.read(4).unwrap());
264    assert_eq!(0, c.read(5).unwrap());
265    assert_eq!(4, c.read(6).unwrap());
266    assert_eq!(24, c.read(7).unwrap());
267    assert_eq!(64, c.read(8).unwrap());
268    assert_eq!(80, c.read(9).unwrap());
269    c.align32().unwrap();
270    let mut d = c.take_slice(6).unwrap();
271    assert_eq!(0x51, c.read(8).unwrap());
272    assert!(d.read(0).is_err());
273    assert_eq!(0, d.read(1).unwrap());
274    assert_eq!(0, d.read(2).unwrap());
275    assert_eq!(1, d.read(3).unwrap());
276    assert_eq!(4, d.read(4).unwrap());
277    assert_eq!(21, d.read(5).unwrap());
278    assert_eq!(34, d.read(6).unwrap());
279    assert_eq!(120, d.read(7).unwrap());
280    assert_eq!(31, d.read(8).unwrap());
281    assert!(d.read(63).is_err());
282    assert_eq!(496, d.read(9).unwrap());
283    assert!(d.read(0).is_err());
284    assert_eq!(1, d.read(1).unwrap());
285    assert!(d.align32().is_err());
286    assert_eq!(1, d.read(2).unwrap());
287    assert!(d.align32().is_err());
288    assert!(d.read(1).is_err());
289}
290
291#[test]
292fn test_read_bits_edge_cases() {
293    let data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
294    let mut c = Cursor::new(&data);
295    c.read(1).unwrap();
296    c.peek(64).unwrap();
297    let pattern_data = [0xAA; 10];
298    let c = Cursor::new(&pattern_data);
299    for offset in 0..8 {
300        for bits in 1..=64 {
301            let mut c_test = c.clone();
302            if offset > 0 {
303                c_test.read(offset).unwrap();
304            }
305            c_test.peek(bits).unwrap();
306        }
307    }
308
309    let test_data = [0xFF; 10];
310    let mut c = Cursor::new(&test_data);
311    c.read(7).unwrap();
312    let result = c.peek(64).unwrap();
313    assert_eq!(result, 0xFFFFFFFFFFFFFFFF);
314
315    let mut c = Cursor::new(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A]);
316    assert_eq!(c.peek(8).unwrap(), 0x01);
317    c.read(8).unwrap();
318    assert_eq!(c.peek(8).unwrap(), 0x02);
319    c.read(4).unwrap();
320    assert_eq!(c.peek(8).unwrap(), 0x30);
321
322    let data = [0xFF; 10];
323    let c = Cursor::new(&data);
324    let mut c_test = c.clone();
325    c_test.read(7).unwrap();
326    c_test.peek(58).unwrap();
327    let mut c_test2 = c.clone();
328    c_test2.read(1).unwrap();
329    c_test2.peek(64).unwrap();
330    for offset in 0..8 {
331        for bits in 1..=64 {
332            let mut c_aligned = c.clone();
333            if offset > 0 {
334                c_aligned.read(offset).unwrap();
335            }
336            c_aligned.peek(bits).unwrap();
337        }
338    }
339}
340
341#[test]
342fn test_cursor_bytes() {
343    let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8]);
344    c.align32().unwrap();
345    assert_eq!(0x0100, c.peek(16).unwrap());
346    assert_eq!(0x020100, c.peek(24).unwrap());
347    assert_eq!(0x03020100, c.peek(32).unwrap());
348    assert_eq!(0x0100, c.read(16).unwrap());
349    assert_eq!(0x02, c.read(8).unwrap());
350    assert_eq!([3, 4, 5, 6], c.read_bytes(4).unwrap());
351    c.skip_bytes(1).unwrap();
352    assert!(c.read_bytes(2).is_err());
353    assert_eq!([8], c.read_bytes(1).unwrap());
354}