jxl_bitstream/
bitstream.rs1use crate::{BitstreamResult, Error};
2
3#[derive(Clone)]
9pub struct Bitstream<'buf> {
10 bytes: &'buf [u8],
11 buf: u64,
12 num_read_bits: usize,
13 remaining_buf_bits: usize,
14}
15
16impl std::fmt::Debug for Bitstream<'_> {
17 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18 f.debug_struct("Bitstream")
19 .field(
20 "bytes",
21 &format_args!(
22 "({} byte{} left)",
23 self.bytes.len(),
24 if self.bytes.len() == 1 { "" } else { "s" },
25 ),
26 )
27 .field("buf", &format_args!("0x{:016x}", self.buf))
28 .field("num_read_bits", &self.num_read_bits)
29 .field("remaining_buf_bits", &self.remaining_buf_bits)
30 .finish()
31 }
32}
33
34impl<'buf> Bitstream<'buf> {
35 #[inline]
37 pub fn new(bytes: &'buf [u8]) -> Self {
38 Self {
39 bytes,
40 buf: 0,
41 num_read_bits: 0,
42 remaining_buf_bits: 0,
43 }
44 }
45
46 #[inline]
48 pub fn num_read_bits(&self) -> usize {
49 self.num_read_bits
50 }
51}
52
53impl Bitstream<'_> {
54 #[inline]
56 fn refill(&mut self) {
57 if let &[b0, b1, b2, b3, b4, b5, b6, b7, ..] = self.bytes {
58 let bits = u64::from_le_bytes([b0, b1, b2, b3, b4, b5, b6, b7]);
59 self.buf |= bits << self.remaining_buf_bits;
60 let read_bytes = (63 - self.remaining_buf_bits) >> 3;
61 self.remaining_buf_bits |= 56;
62 self.bytes = unsafe {
64 std::slice::from_raw_parts(
65 self.bytes.as_ptr().add(read_bytes),
66 self.bytes.len() - read_bytes,
67 )
68 };
69 } else {
70 self.refill_slow()
71 }
72 }
73
74 #[inline(never)]
75 fn refill_slow(&mut self) {
76 while self.remaining_buf_bits < 56 {
77 let Some((&b, next)) = self.bytes.split_first() else {
78 return;
79 };
80
81 self.buf |= (b as u64) << self.remaining_buf_bits;
82 self.remaining_buf_bits += 8;
83 self.bytes = next;
84 }
85 }
86}
87
88impl Bitstream<'_> {
89 #[inline]
93 pub fn peek_bits(&mut self, n: usize) -> u32 {
94 debug_assert!(n <= 32);
95 self.refill();
96 (self.buf & ((1u64 << n) - 1)) as u32
97 }
98
99 #[inline]
103 pub fn peek_bits_const<const N: usize>(&mut self) -> u32 {
104 debug_assert!(N <= 32);
105 self.refill();
106 (self.buf & ((1u64 << N) - 1)) as u32
107 }
108
109 #[inline]
113 pub fn peek_bits_prefilled(&mut self, n: usize) -> u32 {
114 debug_assert!(n <= 32);
115 (self.buf & ((1u64 << n) - 1)) as u32
116 }
117
118 #[inline]
122 pub fn peek_bits_prefilled_const<const N: usize>(&mut self) -> u32 {
123 debug_assert!(N <= 32);
124 (self.buf & ((1u64 << N) - 1)) as u32
125 }
126
127 #[inline]
133 pub fn consume_bits(&mut self, n: usize) -> BitstreamResult<()> {
134 self.remaining_buf_bits = self
135 .remaining_buf_bits
136 .checked_sub(n)
137 .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
138 self.num_read_bits += n;
139 self.buf >>= n;
140 Ok(())
141 }
142
143 #[inline]
149 pub fn consume_bits_const<const N: usize>(&mut self) -> BitstreamResult<()> {
150 self.remaining_buf_bits = self
151 .remaining_buf_bits
152 .checked_sub(N)
153 .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
154 self.num_read_bits += N;
155 self.buf >>= N;
156 Ok(())
157 }
158
159 #[inline]
161 pub fn read_bits(&mut self, n: usize) -> BitstreamResult<u32> {
162 let ret = self.peek_bits(n);
163 self.consume_bits(n)?;
164 Ok(ret)
165 }
166
167 #[inline(never)]
168 pub fn skip_bits(&mut self, mut n: usize) -> BitstreamResult<()> {
169 if let Some(next_remaining_bits) = self.remaining_buf_bits.checked_sub(n) {
170 self.num_read_bits += n;
171 self.remaining_buf_bits = next_remaining_bits;
172 self.buf >>= n;
173 return Ok(());
174 }
175
176 n -= self.remaining_buf_bits;
177 self.num_read_bits += self.remaining_buf_bits;
178 self.buf = 0;
179 self.remaining_buf_bits = 0;
180 if n > self.bytes.len() * 8 {
181 self.num_read_bits += self.bytes.len() * 8;
182 return Err(Error::Io(std::io::ErrorKind::UnexpectedEof.into()));
183 }
184
185 self.num_read_bits += n;
186 self.bytes = &self.bytes[n / 8..];
187 n %= 8;
188 self.refill();
189 self.remaining_buf_bits = self
190 .remaining_buf_bits
191 .checked_sub(n)
192 .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
193 self.buf >>= n;
194 Ok(())
195 }
196
197 pub fn zero_pad_to_byte(&mut self) -> BitstreamResult<()> {
199 let byte_boundary = self.num_read_bits.div_ceil(8) * 8;
200 let n = byte_boundary - self.num_read_bits;
201 if self.read_bits(n)? != 0 {
202 Err(Error::NonZeroPadding)
203 } else {
204 Ok(())
205 }
206 }
207}
208
209impl Bitstream<'_> {
210 #[inline]
223 pub fn read_u32(
224 &mut self,
225 d0: impl Into<U32Specifier>,
226 d1: impl Into<U32Specifier>,
227 d2: impl Into<U32Specifier>,
228 d3: impl Into<U32Specifier>,
229 ) -> BitstreamResult<u32> {
230 let d = match self.read_bits(2)? {
231 0 => d0.into(),
232 1 => d1.into(),
233 2 => d2.into(),
234 3 => d3.into(),
235 _ => unreachable!(),
236 };
237 match d {
238 U32Specifier::Constant(x) => Ok(x),
239 U32Specifier::BitsOffset(offset, n) => {
240 self.read_bits(n).map(|x| x.wrapping_add(offset))
241 }
242 }
243 }
244
245 pub fn read_u64(&mut self) -> BitstreamResult<u64> {
247 let selector = self.read_bits(2)?;
248 Ok(match selector {
249 0 => 0u64,
250 1 => self.read_bits(4)? as u64 + 1,
251 2 => self.read_bits(8)? as u64 + 17,
252 3 => {
253 let mut value = self.read_bits(12)? as u64;
254 let mut shift = 12u32;
255 while self.read_bits(1)? == 1 {
256 if shift == 60 {
257 value |= (self.read_bits(4)? as u64) << shift;
258 break;
259 }
260 value |= (self.read_bits(8)? as u64) << shift;
261 shift += 8;
262 }
263 value
264 }
265 _ => unreachable!(),
266 })
267 }
268
269 #[inline]
271 pub fn read_bool(&mut self) -> BitstreamResult<bool> {
272 self.read_bits(1).map(|x| x != 0)
273 }
274
275 pub fn read_f16_as_f32(&mut self) -> BitstreamResult<f32> {
280 let v = self.read_bits(16)?;
281 let neg_bit = (v & 0x8000) << 16;
282
283 if v & 0x7fff == 0 {
284 return Ok(f32::from_bits(neg_bit));
286 }
287 let mantissa = v & 0x3ff; let exponent = (v >> 10) & 0x1f; if exponent == 0x1f {
290 Err(Error::InvalidFloat)
292 } else if exponent == 0 {
293 let val = (1.0 / 16384.0) * (mantissa as f32 / 1024.0);
295 Ok(if neg_bit != 0 { -val } else { val })
296 } else {
297 let mantissa = mantissa << 13; let exponent = exponent + 112;
300 let bitpattern = mantissa | (exponent << 23) | neg_bit;
301 Ok(f32::from_bits(bitpattern))
302 }
303 }
304
305 pub fn read_enum<E: TryFrom<u32>>(&mut self) -> BitstreamResult<E> {
307 let v = self.read_u32(0, 1, 2 + U(4), 18 + U(6))?;
308 E::try_from(v).map_err(|_| Error::InvalidEnum {
309 name: std::any::type_name::<E>(),
310 value: v,
311 })
312 }
313}
314
315pub enum U32Specifier {
317 Constant(u32),
318 BitsOffset(u32, usize),
319}
320
321pub struct U(pub usize);
323
324impl From<u32> for U32Specifier {
325 fn from(value: u32) -> Self {
326 Self::Constant(value)
327 }
328}
329
330impl From<U> for U32Specifier {
331 fn from(value: U) -> Self {
332 Self::BitsOffset(0, value.0)
333 }
334}
335
336impl std::ops::Add<U> for u32 {
337 type Output = U32Specifier;
338
339 fn add(self, rhs: U) -> Self::Output {
340 U32Specifier::BitsOffset(self, rhs.0)
341 }
342}