1use std::{error, fmt};
2
3#[derive(Debug, Clone)]
4pub enum Error {
5 BufferOverflow,
6 VbrOverflow,
7 Alignment,
8}
9
10impl fmt::Display for Error {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 f.write_str(match self {
13 Self::BufferOverflow => "buffer overflow",
14 Self::VbrOverflow => "vbr overflow",
15 Self::Alignment => "bad alignment",
16 })
17 }
18}
19
20impl error::Error for Error {}
21
22#[derive(Clone)]
23pub struct Cursor<'input> {
24 buffer: &'input [u8],
25 offset: usize,
26}
27
28impl<'input> Cursor<'input> {
29 #[must_use]
30 pub fn new(buffer: &'input [u8]) -> Self {
31 Self { buffer, offset: 0 }
32 }
33
34 #[must_use]
35 pub fn is_at_end(&self) -> bool {
36 self.offset >= (self.buffer.len() << 3)
37 }
38
39 #[inline]
40 pub fn peek(&self, bits: u8) -> Result<u64, Error> {
41 self.read_bits(bits).ok_or(Error::BufferOverflow)
42 }
43
44 #[inline]
45 pub fn read(&mut self, bits: u8) -> Result<u64, Error> {
46 if bits < 1 || bits > 64 {
47 return Err(Error::VbrOverflow);
48 }
49 let res = self.read_bits(bits).ok_or(Error::BufferOverflow)?;
50 self.offset += bits as usize;
51 Ok(res)
52 }
53
54 #[inline]
55 fn read_bits(&self, bits: u8) -> Option<u64> {
56 if bits == 0 || bits > 64 {
57 return None;
58 }
59
60 let byte_start = self.offset >> 3;
61 let shift = (self.offset & 7) as u8;
62
63 let extra_len = shift + (bits & 7);
64 let byte_len = usize::from(extra_len.div_ceil(8) + (bits >> 3));
65
66 let bytes = self.buffer.get(byte_start..byte_start + byte_len)?;
67 let accumulate = byte_len.min(8);
68 let mut res = 0u64;
69 for (i, &b) in bytes.get(..accumulate)?.iter().enumerate().take(accumulate) {
70 res |= (b as u64) << (i << 3);
71 }
72
73 res >>= shift;
74
75 if let Some(&extra_byte) = bytes.get(8) {
76 res |= u64::from(extra_byte) << (64 - shift);
77 }
78
79 if bits < 64 {
80 res &= (1 << bits) - 1;
81 }
82
83 Some(res)
84 }
85
86 pub fn read_bytes(&mut self, length_bytes: usize) -> Result<&'input [u8], Error> {
87 if !self.offset.is_multiple_of(8) {
88 return Err(Error::Alignment);
89 }
90 let byte_start = self.offset >> 3;
91 let byte_end = byte_start + length_bytes;
92 let bytes = self
93 .buffer
94 .get(byte_start..byte_end)
95 .ok_or(Error::BufferOverflow)?;
96 self.offset = byte_end << 3;
97 Ok(bytes)
98 }
99
100 pub fn skip_bytes(&mut self, count: usize) -> Result<(), Error> {
101 if !self.offset.is_multiple_of(8) {
102 return Err(Error::Alignment);
103 }
104 let byte_end = (self.offset >> 3) + count;
105 if byte_end > self.buffer.len() {
106 return Err(Error::BufferOverflow);
107 }
108 self.offset = byte_end << 3;
109 Ok(())
110 }
111
112 pub(crate) fn take_slice(&mut self, length_bytes: usize) -> Result<Self, Error> {
115 if !self.offset.is_multiple_of(32) {
116 return Err(Error::Alignment);
117 }
118 Ok(Cursor {
119 buffer: self.read_bytes(length_bytes)?,
120 offset: 0,
121 })
122 }
123
124 #[inline]
127 pub fn read_vbr(&mut self, width: u8) -> Result<u64, Error> {
128 match width {
129 6 => self.read_vbr_fixed::<6>(),
130 8 => self.read_vbr_fixed::<8>(),
131 _ => self.read_vbr_inline(width),
132 }
133 }
134
135 pub(crate) fn read_vbr_fixed<const WIDTH: u8>(&mut self) -> Result<u64, Error> {
136 self.read_vbr_inline(WIDTH)
137 }
138
139 #[inline(always)]
140 pub(crate) fn read_vbr_inline(&mut self, width: u8) -> Result<u64, Error> {
141 if width < 1 || width > 32 {
142 return Err(Error::VbrOverflow);
144 }
145 let test_bit = 1u64 << (width - 1);
146 let mask = test_bit - 1;
147 let mut res = 0;
148 let mut offset = 0;
149 loop {
150 let next = self.read(width)?;
151 res |= (next & mask) << offset;
152 offset += width - 1;
153 if offset > 63 + width {
155 return Err(Error::VbrOverflow);
156 }
157 if next & test_bit == 0 {
158 break;
159 }
160 }
161 Ok(res)
162 }
163
164 pub fn align32(&mut self) -> Result<(), Error> {
166 let new_offset = if self.offset.is_multiple_of(32) {
167 self.offset
168 } else {
169 (self.offset + 32) & !(32 - 1)
170 };
171 self.buffer = self
172 .buffer
173 .get((new_offset >> 3)..)
174 .ok_or(Error::BufferOverflow)?;
175 self.offset = 0;
176 Ok(())
177 }
178
179 #[must_use]
181 pub fn unconsumed_bit_len(&self) -> usize {
182 (self.buffer.len() << 3) - self.offset
183 }
184}
185
186struct CursorDebugBytes<'a>(&'a [u8]);
187
188impl fmt::Debug for CursorDebugBytes<'_> {
189 #[cold]
190 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
191 f.write_str("[0x")?;
192 for &b in self.0.iter().take(200) {
193 write!(f, "{b:02x}")?;
194 }
195 if self.0.len() > 200 {
196 f.write_str("...")?;
197 }
198 write!(f, "; {}]", self.0.len())
199 }
200}
201
202impl fmt::Debug for Cursor<'_> {
203 #[cold]
205 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
206 let byte_offset = self.offset / 8;
207 let bit_offset = self.offset % 8;
208 let buffer = CursorDebugBytes(self.buffer.get(byte_offset..).unwrap_or_default());
209 f.debug_struct("Cursor")
210 .field("offset", &bit_offset)
211 .field("buffer", &buffer)
212 .field("nextvbr6", &self.peek(6).ok())
213 .finish()
214 }
215}
216
217#[test]
218fn test_all_bits() {
219 for i in 1..=64 {
220 let mut c = Cursor::new(&[!0; 17]);
221 let _ = c.read(i).unwrap();
222 assert_eq!(!0, c.read(64).unwrap());
223 assert_eq!(1, c.read(1).unwrap());
224 }
225}
226
227#[test]
228fn test_cursor_bits() {
229 let mut c = Cursor::new(&[0b1000_0000]);
230 assert_eq!(0, c.peek(1).unwrap());
231 assert!(c.peek(9).is_err());
232 assert_eq!(0, c.peek(2).unwrap());
233 assert_eq!(0, c.peek(3).unwrap());
234 assert_eq!(0, c.peek(4).unwrap());
235 assert_eq!(0, c.peek(5).unwrap());
236 assert_eq!(0, c.peek(6).unwrap());
237 assert_eq!(0, c.peek(7).unwrap());
238 assert_eq!(0b1000_0000, c.peek(8).unwrap());
239 assert_eq!(0, c.read(6).unwrap());
240 assert_eq!(0b10, c.peek(2).unwrap());
241 assert_eq!(0, c.peek(1).unwrap());
242 assert_eq!(0, c.read(1).unwrap());
243 assert_eq!(0b1, c.peek(1).unwrap());
244 assert_eq!(0b1, c.read(1).unwrap());
245
246 let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 0x55, 0x11, 0xff, 1, 127, 0x51]);
247 assert_eq!(0, c.peek(1).unwrap());
248 assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
249 assert_eq!(0, c.peek(2).unwrap());
250 assert_eq!(0, c.peek(3).unwrap());
251 assert_eq!(0, c.peek(4).unwrap());
252 assert_eq!(0, c.peek(5).unwrap());
253 assert_eq!(0, c.peek(6).unwrap());
254 assert_eq!(0, c.peek(7).unwrap());
255 assert_eq!(0, c.peek(8).unwrap());
256 assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
257
258 assert_eq!(0, c.peek(7).unwrap());
259 assert!(c.read(0).is_err());
260 assert_eq!(0, c.read(1).unwrap());
261 assert_eq!(0, c.read(2).unwrap());
262 assert_eq!(0, c.read(3).unwrap());
263 assert_eq!(4, c.read(4).unwrap());
264 assert_eq!(0, c.read(5).unwrap());
265 assert_eq!(4, c.read(6).unwrap());
266 assert_eq!(24, c.read(7).unwrap());
267 assert_eq!(64, c.read(8).unwrap());
268 assert_eq!(80, c.read(9).unwrap());
269 c.align32().unwrap();
270 let mut d = c.take_slice(6).unwrap();
271 assert_eq!(0x51, c.read(8).unwrap());
272 assert!(d.read(0).is_err());
273 assert_eq!(0, d.read(1).unwrap());
274 assert_eq!(0, d.read(2).unwrap());
275 assert_eq!(1, d.read(3).unwrap());
276 assert_eq!(4, d.read(4).unwrap());
277 assert_eq!(21, d.read(5).unwrap());
278 assert_eq!(34, d.read(6).unwrap());
279 assert_eq!(120, d.read(7).unwrap());
280 assert_eq!(31, d.read(8).unwrap());
281 assert!(d.read(63).is_err());
282 assert_eq!(496, d.read(9).unwrap());
283 assert!(d.read(0).is_err());
284 assert_eq!(1, d.read(1).unwrap());
285 assert!(d.align32().is_err());
286 assert_eq!(1, d.read(2).unwrap());
287 assert!(d.align32().is_err());
288 assert!(d.read(1).is_err());
289}
290
291#[test]
292fn test_read_bits_edge_cases() {
293 let data = [0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x00];
294 let mut c = Cursor::new(&data);
295 c.read(1).unwrap();
296 c.peek(64).unwrap();
297 let pattern_data = [0xAA; 10];
298 let c = Cursor::new(&pattern_data);
299 for offset in 0..8 {
300 for bits in 1..=64 {
301 let mut c_test = c.clone();
302 if offset > 0 {
303 c_test.read(offset).unwrap();
304 }
305 c_test.peek(bits).unwrap();
306 }
307 }
308
309 let test_data = [0xFF; 10];
310 let mut c = Cursor::new(&test_data);
311 c.read(7).unwrap();
312 let result = c.peek(64).unwrap();
313 assert_eq!(result, 0xFFFFFFFFFFFFFFFF);
314
315 let mut c = Cursor::new(&[0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A]);
316 assert_eq!(c.peek(8).unwrap(), 0x01);
317 c.read(8).unwrap();
318 assert_eq!(c.peek(8).unwrap(), 0x02);
319 c.read(4).unwrap();
320 assert_eq!(c.peek(8).unwrap(), 0x30);
321
322 let data = [0xFF; 10];
323 let c = Cursor::new(&data);
324 let mut c_test = c.clone();
325 c_test.read(7).unwrap();
326 c_test.peek(58).unwrap();
327 let mut c_test2 = c.clone();
328 c_test2.read(1).unwrap();
329 c_test2.peek(64).unwrap();
330 for offset in 0..8 {
331 for bits in 1..=64 {
332 let mut c_aligned = c.clone();
333 if offset > 0 {
334 c_aligned.read(offset).unwrap();
335 }
336 c_aligned.peek(bits).unwrap();
337 }
338 }
339}
340
341#[test]
342fn test_cursor_bytes() {
343 let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8]);
344 c.align32().unwrap();
345 assert_eq!(0x0100, c.peek(16).unwrap());
346 assert_eq!(0x020100, c.peek(24).unwrap());
347 assert_eq!(0x03020100, c.peek(32).unwrap());
348 assert_eq!(0x0100, c.read(16).unwrap());
349 assert_eq!(0x02, c.read(8).unwrap());
350 assert_eq!([3, 4, 5, 6], c.read_bytes(4).unwrap());
351 c.skip_bytes(1).unwrap();
352 assert!(c.read_bytes(2).is_err());
353 assert_eq!([8], c.read_bytes(1).unwrap());
354}