jxl_bitstream/
bitstream.rs

1use crate::{BitstreamResult, Error};
2
3/// Bitstream reader with borrowed in-memory buffer.
4///
5/// Implementation is mostly from [jxl-rs].
6///
7/// [jxl-rs]: https://github.com/libjxl/jxl-rs
8#[derive(Clone)]
9pub struct Bitstream<'buf> {
10    bytes: &'buf [u8],
11    buf: u64,
12    num_read_bits: usize,
13    remaining_buf_bits: usize,
14}
15
16impl std::fmt::Debug for Bitstream<'_> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("Bitstream")
19            .field(
20                "bytes",
21                &format_args!(
22                    "({} byte{} left)",
23                    self.bytes.len(),
24                    if self.bytes.len() == 1 { "" } else { "s" },
25                ),
26            )
27            .field("buf", &format_args!("0x{:016x}", self.buf))
28            .field("num_read_bits", &self.num_read_bits)
29            .field("remaining_buf_bits", &self.remaining_buf_bits)
30            .finish()
31    }
32}
33
34impl<'buf> Bitstream<'buf> {
35    /// Create a new bitstream reader.
36    #[inline]
37    pub fn new(bytes: &'buf [u8]) -> Self {
38        Self {
39            bytes,
40            buf: 0,
41            num_read_bits: 0,
42            remaining_buf_bits: 0,
43        }
44    }
45
46    /// Returns the number of bits that are read or skipped.
47    #[inline]
48    pub fn num_read_bits(&self) -> usize {
49        self.num_read_bits
50    }
51}
52
53impl Bitstream<'_> {
54    /// Fills bit buffer from byte buffer.
55    #[inline]
56    fn refill(&mut self) {
57        if let &[b0, b1, b2, b3, b4, b5, b6, b7, ..] = self.bytes {
58            let bits = u64::from_le_bytes([b0, b1, b2, b3, b4, b5, b6, b7]);
59            self.buf |= bits << self.remaining_buf_bits;
60            let read_bytes = (63 - self.remaining_buf_bits) >> 3;
61            self.remaining_buf_bits |= 56;
62            // SAFETY: read_bytes < 8, self.bytes.len() >= 8 (from the pattern).
63            self.bytes = unsafe {
64                std::slice::from_raw_parts(
65                    self.bytes.as_ptr().add(read_bytes),
66                    self.bytes.len() - read_bytes,
67                )
68            };
69        } else {
70            self.refill_slow()
71        }
72    }
73
74    #[inline(never)]
75    fn refill_slow(&mut self) {
76        while self.remaining_buf_bits < 56 {
77            let Some((&b, next)) = self.bytes.split_first() else {
78                return;
79            };
80
81            self.buf |= (b as u64) << self.remaining_buf_bits;
82            self.remaining_buf_bits += 8;
83            self.bytes = next;
84        }
85    }
86}
87
88impl Bitstream<'_> {
89    /// Peeks bits from bitstream, without consuming them.
90    ///
91    /// This method refills the bit buffer.
92    #[inline]
93    pub fn peek_bits(&mut self, n: usize) -> u32 {
94        debug_assert!(n <= 32);
95        self.refill();
96        (self.buf & ((1u64 << n) - 1)) as u32
97    }
98
99    /// Peeks bits from bitstream, without consuming them.
100    ///
101    /// This method refills the bit buffer.
102    #[inline]
103    pub fn peek_bits_const<const N: usize>(&mut self) -> u32 {
104        debug_assert!(N <= 32);
105        self.refill();
106        (self.buf & ((1u64 << N) - 1)) as u32
107    }
108
109    /// Peeks bits from already filled bitstream, without consuming them.
110    ///
111    /// This method *does not* refill the bit buffer.
112    #[inline]
113    pub fn peek_bits_prefilled(&mut self, n: usize) -> u32 {
114        debug_assert!(n <= 32);
115        (self.buf & ((1u64 << n) - 1)) as u32
116    }
117
118    /// Peeks bits from already filled bitstream, without consuming them.
119    ///
120    /// This method *does not* refill the bit buffer.
121    #[inline]
122    pub fn peek_bits_prefilled_const<const N: usize>(&mut self) -> u32 {
123        debug_assert!(N <= 32);
124        (self.buf & ((1u64 << N) - 1)) as u32
125    }
126
127    /// Consumes bits in bit buffer.
128    ///
129    /// # Errors
130    /// This method returns `Err(Io(std::io::ErrorKind::UnexpectedEof))` when there are not enough
131    /// bits in the bit buffer.
132    #[inline]
133    pub fn consume_bits(&mut self, n: usize) -> BitstreamResult<()> {
134        self.remaining_buf_bits = self
135            .remaining_buf_bits
136            .checked_sub(n)
137            .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
138        self.num_read_bits += n;
139        self.buf >>= n;
140        Ok(())
141    }
142
143    /// Consumes bits in bit buffer.
144    ///
145    /// # Errors
146    /// This method returns `Err(Io(std::io::ErrorKind::UnexpectedEof))` when there are not enough
147    /// bits in the bit buffer.
148    #[inline]
149    pub fn consume_bits_const<const N: usize>(&mut self) -> BitstreamResult<()> {
150        self.remaining_buf_bits = self
151            .remaining_buf_bits
152            .checked_sub(N)
153            .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
154        self.num_read_bits += N;
155        self.buf >>= N;
156        Ok(())
157    }
158
159    /// Read and consume bits from bitstream.
160    #[inline]
161    pub fn read_bits(&mut self, n: usize) -> BitstreamResult<u32> {
162        let ret = self.peek_bits(n);
163        self.consume_bits(n)?;
164        Ok(ret)
165    }
166
167    #[inline(never)]
168    pub fn skip_bits(&mut self, mut n: usize) -> BitstreamResult<()> {
169        if let Some(next_remaining_bits) = self.remaining_buf_bits.checked_sub(n) {
170            self.num_read_bits += n;
171            self.remaining_buf_bits = next_remaining_bits;
172            self.buf >>= n;
173            return Ok(());
174        }
175
176        n -= self.remaining_buf_bits;
177        self.num_read_bits += self.remaining_buf_bits;
178        self.buf = 0;
179        self.remaining_buf_bits = 0;
180        if n > self.bytes.len() * 8 {
181            self.num_read_bits += self.bytes.len() * 8;
182            return Err(Error::Io(std::io::ErrorKind::UnexpectedEof.into()));
183        }
184
185        self.num_read_bits += n;
186        self.bytes = &self.bytes[n / 8..];
187        n %= 8;
188        self.refill();
189        self.remaining_buf_bits = self
190            .remaining_buf_bits
191            .checked_sub(n)
192            .ok_or(Error::Io(std::io::ErrorKind::UnexpectedEof.into()))?;
193        self.buf >>= n;
194        Ok(())
195    }
196
197    /// Performs `ZeroPadToByte` as defined in the JPEG XL specification.
198    pub fn zero_pad_to_byte(&mut self) -> BitstreamResult<()> {
199        let byte_boundary = self.num_read_bits.div_ceil(8) * 8;
200        let n = byte_boundary - self.num_read_bits;
201        if self.read_bits(n)? != 0 {
202            Err(Error::NonZeroPadding)
203        } else {
204            Ok(())
205        }
206    }
207}
208
209impl Bitstream<'_> {
210    /// Reads an `U32` as defined in the JPEG XL specification.
211    ///
212    /// # Example
213    ///
214    /// ```
215    /// use jxl_bitstream::{Bitstream, U};
216    ///
217    /// let buf = [0b110010];
218    /// let mut bitstream = Bitstream::new(&buf);
219    /// let val = bitstream.read_u32(1, U(2), 3 + U(4), 19 + U(8)).expect("failed to read data");
220    /// assert_eq!(val, 15);
221    /// ```
222    #[inline]
223    pub fn read_u32(
224        &mut self,
225        d0: impl Into<U32Specifier>,
226        d1: impl Into<U32Specifier>,
227        d2: impl Into<U32Specifier>,
228        d3: impl Into<U32Specifier>,
229    ) -> BitstreamResult<u32> {
230        let d = match self.read_bits(2)? {
231            0 => d0.into(),
232            1 => d1.into(),
233            2 => d2.into(),
234            3 => d3.into(),
235            _ => unreachable!(),
236        };
237        match d {
238            U32Specifier::Constant(x) => Ok(x),
239            U32Specifier::BitsOffset(offset, n) => {
240                self.read_bits(n).map(|x| x.wrapping_add(offset))
241            }
242        }
243    }
244
245    /// Reads an `U64` as defined in the JPEG XL specification.
246    pub fn read_u64(&mut self) -> BitstreamResult<u64> {
247        let selector = self.read_bits(2)?;
248        Ok(match selector {
249            0 => 0u64,
250            1 => self.read_bits(4)? as u64 + 1,
251            2 => self.read_bits(8)? as u64 + 17,
252            3 => {
253                let mut value = self.read_bits(12)? as u64;
254                let mut shift = 12u32;
255                while self.read_bits(1)? == 1 {
256                    if shift == 60 {
257                        value |= (self.read_bits(4)? as u64) << shift;
258                        break;
259                    }
260                    value |= (self.read_bits(8)? as u64) << shift;
261                    shift += 8;
262                }
263                value
264            }
265            _ => unreachable!(),
266        })
267    }
268
269    /// Reads a `Bool` as defined in the JPEG XL specification.
270    #[inline]
271    pub fn read_bool(&mut self) -> BitstreamResult<bool> {
272        self.read_bits(1).map(|x| x != 0)
273    }
274
275    /// Reads an `F16` as defined in the JPEG XL specification, and convert it to `f32`.
276    ///
277    /// # Errors
278    /// Returns `Error::InvalidFloat` if the value is `NaN` or `Infinity`.
279    pub fn read_f16_as_f32(&mut self) -> BitstreamResult<f32> {
280        let v = self.read_bits(16)?;
281        let neg_bit = (v & 0x8000) << 16;
282
283        if v & 0x7fff == 0 {
284            // Zero
285            return Ok(f32::from_bits(neg_bit));
286        }
287        let mantissa = v & 0x3ff; // 10 bits
288        let exponent = (v >> 10) & 0x1f; // 5 bits
289        if exponent == 0x1f {
290            // NaN, Infinity
291            Err(Error::InvalidFloat)
292        } else if exponent == 0 {
293            // Subnormal
294            let val = (1.0 / 16384.0) * (mantissa as f32 / 1024.0);
295            Ok(if neg_bit != 0 { -val } else { val })
296        } else {
297            // Normal
298            let mantissa = mantissa << 13; // 23 bits
299            let exponent = exponent + 112;
300            let bitpattern = mantissa | (exponent << 23) | neg_bit;
301            Ok(f32::from_bits(bitpattern))
302        }
303    }
304
305    /// Reads an enum as defined in the JPEG XL specification.
306    pub fn read_enum<E: TryFrom<u32>>(&mut self) -> BitstreamResult<E> {
307        let v = self.read_u32(0, 1, 2 + U(4), 18 + U(6))?;
308        E::try_from(v).map_err(|_| Error::InvalidEnum {
309            name: std::any::type_name::<E>(),
310            value: v,
311        })
312    }
313}
314
315/// Bit specifier for [`Bitstream::read_u32`].
316pub enum U32Specifier {
317    Constant(u32),
318    BitsOffset(u32, usize),
319}
320
321/// Bit count for use in [`Bitstream::read_u32`].
322pub struct U(pub usize);
323
324impl From<u32> for U32Specifier {
325    fn from(value: u32) -> Self {
326        Self::Constant(value)
327    }
328}
329
330impl From<U> for U32Specifier {
331    fn from(value: U) -> Self {
332        Self::BitsOffset(0, value.0)
333    }
334}
335
336impl std::ops::Add<U> for u32 {
337    type Output = U32Specifier;
338
339    fn add(self, rhs: U) -> Self::Output {
340        U32Specifier::BitsOffset(self, rhs.0)
341    }
342}