1use std::{error, fmt};
2
3#[derive(Debug, Clone)]
4pub enum Error {
5 BufferOverflow,
6 VbrOverflow,
7 Alignment,
8}
9
10impl fmt::Display for Error {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 f.write_str(match self {
13 Self::BufferOverflow => "buffer overflow",
14 Self::VbrOverflow => "vbr overflow",
15 Self::Alignment => "bad alignment",
16 })
17 }
18}
19
20impl error::Error for Error {}
21
22#[derive(Clone)]
23pub struct Cursor<'input> {
24 buffer: &'input [u8],
25 offset: usize,
26}
27
28impl<'input> Cursor<'input> {
29 #[must_use]
30 pub fn new(buffer: &'input [u8]) -> Self {
31 Self { buffer, offset: 0 }
32 }
33
34 #[must_use]
35 pub fn is_at_end(&self) -> bool {
36 self.offset >= (self.buffer.len() << 3)
37 }
38
39 #[inline]
40 pub fn peek(&self, bits: u8) -> Result<u64, Error> {
41 self.read_bits(bits).ok_or(Error::BufferOverflow)
42 }
43
44 #[inline]
45 pub fn read(&mut self, bits: u8) -> Result<u64, Error> {
46 if bits < 1 || bits > 64 {
47 return Err(Error::VbrOverflow);
48 }
49 let res = self.peek(bits)?;
50 self.offset += bits as usize;
51 Ok(res)
52 }
53
54 fn read_bits(&self, count: u8) -> Option<u64> {
55 let upper_bound = self.offset + count as usize;
56 let top_byte_index = upper_bound >> 3;
57 let mut res = 0;
58 if upper_bound & 7 != 0 {
59 let mask = (1u8 << (upper_bound & 7) as u8) - 1;
60 res = u64::from(*self.buffer.get(top_byte_index)? & mask);
61 }
62 for i in ((self.offset >> 3)..(upper_bound >> 3)).rev() {
63 res <<= 8;
64 res |= u64::from(*self.buffer.get(i)?);
65 }
66 if self.offset & 7 != 0 {
67 res >>= self.offset as u64 & 7;
68 }
69 Some(res)
70 }
71
72 pub fn read_bytes(&mut self, length_bytes: usize) -> Result<&'input [u8], Error> {
73 if !self.offset.is_multiple_of(8) {
74 return Err(Error::Alignment);
75 }
76 let byte_start = self.offset >> 3;
77 let byte_end = byte_start + length_bytes;
78 let bytes = self
79 .buffer
80 .get(byte_start..byte_end)
81 .ok_or(Error::BufferOverflow)?;
82 self.offset = byte_end << 3;
83 Ok(bytes)
84 }
85
86 pub fn skip_bytes(&mut self, count: usize) -> Result<(), Error> {
87 if !self.offset.is_multiple_of(8) {
88 return Err(Error::Alignment);
89 }
90 let byte_end = (self.offset >> 3) + count;
91 if byte_end > self.buffer.len() {
92 return Err(Error::BufferOverflow);
93 }
94 self.offset = byte_end << 3;
95 Ok(())
96 }
97
98 pub(crate) fn take_slice(&mut self, length_bytes: usize) -> Result<Self, Error> {
101 if !self.offset.is_multiple_of(32) {
102 return Err(Error::Alignment);
103 }
104 Ok(Cursor {
105 buffer: self.read_bytes(length_bytes)?,
106 offset: 0,
107 })
108 }
109
110 #[inline]
113 pub fn read_vbr(&mut self, width: u8) -> Result<u64, Error> {
114 if width < 1 || width > 32 {
115 return Err(Error::VbrOverflow);
117 }
118 let test_bit = 1u64 << (width - 1);
119 let mask = test_bit - 1;
120 let mut res = 0;
121 let mut offset = 0;
122 loop {
123 let next = self.read(width)?;
124 res |= (next & mask) << offset;
125 offset += width - 1;
126 if offset > 63 + width {
128 return Err(Error::VbrOverflow);
129 }
130 if next & test_bit == 0 {
131 break;
132 }
133 }
134 Ok(res)
135 }
136
137 pub fn align32(&mut self) -> Result<(), Error> {
139 let new_offset = if self.offset.is_multiple_of(32) {
140 self.offset
141 } else {
142 (self.offset + 32) & !(32 - 1)
143 };
144 self.buffer = self
145 .buffer
146 .get((new_offset >> 3)..)
147 .ok_or(Error::BufferOverflow)?;
148 self.offset = 0;
149 Ok(())
150 }
151
152 #[must_use]
154 pub fn unconsumed_bit_len(&self) -> usize {
155 (self.buffer.len() << 3) - self.offset
156 }
157}
158
159struct CursorDebugBytes<'a>(&'a [u8]);
160
161impl fmt::Debug for CursorDebugBytes<'_> {
162 #[cold]
163 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
164 f.write_str("[0x")?;
165 for &b in self.0.iter().take(200) {
166 write!(f, "{b:02x}")?;
167 }
168 if self.0.len() > 200 {
169 f.write_str("...")?;
170 }
171 write!(f, "; {}]", self.0.len())
172 }
173}
174
175impl fmt::Debug for Cursor<'_> {
176 #[cold]
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 let byte_offset = self.offset / 8;
180 let bit_offset = self.offset % 8;
181 let buffer = CursorDebugBytes(self.buffer.get(byte_offset..).unwrap_or_default());
182 f.debug_struct("Cursor")
183 .field("offset", &bit_offset)
184 .field("buffer", &buffer)
185 .field("nextvbr6", &self.peek(6).ok())
186 .finish()
187 }
188}
189
190#[test]
191fn test_cursor_bits() {
192 let mut c = Cursor::new(&[0b1000_0000]);
193 assert_eq!(0, c.peek(1).unwrap());
194 assert!(c.peek(9).is_err());
195 assert_eq!(0, c.peek(2).unwrap());
196 assert_eq!(0, c.peek(3).unwrap());
197 assert_eq!(0, c.peek(4).unwrap());
198 assert_eq!(0, c.peek(5).unwrap());
199 assert_eq!(0, c.peek(6).unwrap());
200 assert_eq!(0, c.peek(7).unwrap());
201 assert_eq!(0b1000_0000, c.peek(8).unwrap());
202 assert_eq!(0, c.read(6).unwrap());
203 assert_eq!(0b10, c.peek(2).unwrap());
204 assert_eq!(0, c.peek(1).unwrap());
205 assert_eq!(0, c.read(1).unwrap());
206 assert_eq!(0b1, c.peek(1).unwrap());
207 assert_eq!(0b1, c.read(1).unwrap());
208
209 let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 0x55, 0x11, 0xff, 1, 127, 0x51]);
210 assert_eq!(0, c.peek(1).unwrap());
211 assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
212 assert_eq!(0, c.peek(2).unwrap());
213 assert_eq!(0, c.peek(3).unwrap());
214 assert_eq!(0, c.peek(4).unwrap());
215 assert_eq!(0, c.peek(5).unwrap());
216 assert_eq!(0, c.peek(6).unwrap());
217 assert_eq!(0, c.peek(7).unwrap());
218 assert_eq!(0, c.peek(8).unwrap());
219 assert_eq!(0b1_0000_0000, c.peek(9).unwrap());
220
221 assert_eq!(0, c.peek(7).unwrap());
222 assert!(c.read(0).is_err());
223 assert_eq!(0, c.read(1).unwrap());
224 assert_eq!(0, c.read(2).unwrap());
225 assert_eq!(0, c.read(3).unwrap());
226 assert_eq!(4, c.read(4).unwrap());
227 assert_eq!(0, c.read(5).unwrap());
228 assert_eq!(4, c.read(6).unwrap());
229 assert_eq!(24, c.read(7).unwrap());
230 assert_eq!(64, c.read(8).unwrap());
231 assert_eq!(80, c.read(9).unwrap());
232 c.align32().unwrap();
233 let mut d = c.take_slice(6).unwrap();
234 assert_eq!(0x51, c.read(8).unwrap());
235 assert!(d.read(0).is_err());
236 assert_eq!(0, d.read(1).unwrap());
237 assert_eq!(0, d.read(2).unwrap());
238 assert_eq!(1, d.read(3).unwrap());
239 assert_eq!(4, d.read(4).unwrap());
240 assert_eq!(21, d.read(5).unwrap());
241 assert_eq!(34, d.read(6).unwrap());
242 assert_eq!(120, d.read(7).unwrap());
243 assert_eq!(31, d.read(8).unwrap());
244 assert!(d.read(63).is_err());
245 assert_eq!(496, d.read(9).unwrap());
246 assert!(d.read(0).is_err());
247 assert_eq!(1, d.read(1).unwrap());
248 assert!(d.align32().is_err());
249 assert_eq!(1, d.read(2).unwrap());
250 assert!(d.align32().is_err());
251 assert!(d.read(1).is_err());
252}
253
254#[test]
255fn test_cursor_bytes() {
256 let mut c = Cursor::new(&[0, 1, 2, 3, 4, 5, 6, 7, 8]);
257 c.align32().unwrap();
258 assert_eq!(0x0100, c.peek(16).unwrap());
259 assert_eq!(0x020100, c.peek(24).unwrap());
260 assert_eq!(0x03020100, c.peek(32).unwrap());
261 assert_eq!(0x0100, c.read(16).unwrap());
262 assert_eq!(0x02, c.read(8).unwrap());
263 assert_eq!([3, 4, 5, 6], c.read_bytes(4).unwrap());
264 c.skip_bytes(1).unwrap();
265 assert!(c.read_bytes(2).is_err());
266 assert_eq!([8], c.read_bytes(1).unwrap());
267}