Skip to main content

oximedia_codec/jpegxl/
bitreader.rs

1//! Bit-level reader and writer for JPEG-XL bitstream processing.
2//!
3//! JPEG-XL uses variable-length bit fields extensively. This module provides
4//! efficient bit-granularity access to byte buffers in both reading and writing
5//! directions.
6
7use crate::error::{CodecError, CodecResult};
8
9/// Bit-level reader for parsing JPEG-XL bitstreams.
10///
11/// Reads bits from a byte buffer in LSB-first order, which is the convention
12/// used by JPEG-XL codestreams (and ANS entropy coding).
13pub struct BitReader<'a> {
14    data: &'a [u8],
15    /// Current byte position in the data buffer.
16    byte_pos: usize,
17    /// Current bit position within the current byte (0-7, 0 = LSB).
18    bit_pos: u8,
19}
20
21impl<'a> BitReader<'a> {
22    /// Create a new bit reader over a byte slice.
23    pub fn new(data: &'a [u8]) -> Self {
24        Self {
25            data,
26            byte_pos: 0,
27            bit_pos: 0,
28        }
29    }
30
31    /// Read up to 32 bits from the stream (LSB first).
32    ///
33    /// # Errors
34    ///
35    /// Returns `CodecError::InvalidBitstream` if not enough bits remain.
36    pub fn read_bits(&mut self, n: u8) -> CodecResult<u32> {
37        if n == 0 {
38            return Ok(0);
39        }
40        if n > 32 {
41            return Err(CodecError::InvalidBitstream(
42                "Cannot read more than 32 bits at once".into(),
43            ));
44        }
45        if self.remaining_bits() < n as usize {
46            return Err(CodecError::InvalidBitstream(
47                "Not enough bits remaining in stream".into(),
48            ));
49        }
50
51        let mut result: u32 = 0;
52        let mut bits_read: u8 = 0;
53
54        while bits_read < n {
55            let bits_available_in_byte = 8 - self.bit_pos;
56            let bits_needed = n - bits_read;
57            let bits_to_read = bits_available_in_byte.min(bits_needed);
58
59            let byte_val = self.data[self.byte_pos] as u32;
60            let mask = (1u32 << bits_to_read) - 1;
61            let extracted = (byte_val >> self.bit_pos) & mask;
62
63            result |= extracted << bits_read;
64            bits_read += bits_to_read;
65
66            self.bit_pos += bits_to_read;
67            if self.bit_pos >= 8 {
68                self.bit_pos = 0;
69                self.byte_pos += 1;
70            }
71        }
72
73        Ok(result)
74    }
75
76    /// Read a single boolean bit.
77    pub fn read_bool(&mut self) -> CodecResult<bool> {
78        Ok(self.read_bits(1)? != 0)
79    }
80
81    /// Read up to 8 bits as a u8.
82    pub fn read_u8(&mut self, n: u8) -> CodecResult<u8> {
83        if n > 8 {
84            return Err(CodecError::InvalidBitstream(
85                "Cannot read more than 8 bits into u8".into(),
86            ));
87        }
88        self.read_bits(n).map(|v| v as u8)
89    }
90
91    /// Read up to 16 bits as a u16.
92    pub fn read_u16(&mut self, n: u8) -> CodecResult<u16> {
93        if n > 16 {
94            return Err(CodecError::InvalidBitstream(
95                "Cannot read more than 16 bits into u16".into(),
96            ));
97        }
98        self.read_bits(n).map(|v| v as u16)
99    }
100
101    /// Read up to 32 bits as a u32.
102    pub fn read_u32(&mut self, n: u8) -> CodecResult<u32> {
103        self.read_bits(n)
104    }
105
106    /// Read a u64 value using JPEG-XL's U64 encoding.
107    ///
108    /// The JXL U64 encoding uses a selector to determine how many bits follow:
109    /// - 0: value = 0
110    /// - 1: value = 1 + read(4)
111    /// - 2: value = 17 + read(8)
112    /// - 3: value = read(12) + (read variable chunks until done)
113    pub fn read_u64(&mut self) -> CodecResult<u64> {
114        let selector = self.read_bits(2)?;
115        match selector {
116            0 => Ok(0),
117            1 => {
118                let extra = self.read_bits(4)? as u64;
119                Ok(1 + extra)
120            }
121            2 => {
122                let extra = self.read_bits(8)? as u64;
123                Ok(17 + extra)
124            }
125            3 => {
126                // Variable-length: read 12 bits, then optionally more in 8-bit chunks
127                let mut value = self.read_bits(12)? as u64;
128                let mut shift = 12u32;
129                while shift < 60 {
130                    let more = self.read_bool()?;
131                    if more {
132                        let chunk = self.read_bits(8)? as u64;
133                        value |= chunk << shift;
134                        shift += 8;
135                    } else {
136                        break;
137                    }
138                }
139                // Final chunk if we reached 60 bits
140                if shift >= 60 {
141                    let chunk = self.read_bits(4)? as u64;
142                    value |= chunk << shift;
143                }
144                Ok(273 + value)
145            }
146            _ => Err(CodecError::InvalidBitstream("Invalid U64 selector".into())),
147        }
148    }
149
150    /// Number of bits remaining in the stream.
151    pub fn remaining_bits(&self) -> usize {
152        if self.byte_pos >= self.data.len() {
153            return 0;
154        }
155        (self.data.len() - self.byte_pos) * 8 - self.bit_pos as usize
156    }
157
158    /// Align the reader to the next byte boundary.
159    ///
160    /// If already at a byte boundary, this is a no-op.
161    pub fn align_to_byte(&mut self) {
162        if self.bit_pos != 0 {
163            self.bit_pos = 0;
164            self.byte_pos += 1;
165        }
166    }
167
168    /// Get current byte position.
169    pub fn byte_position(&self) -> usize {
170        self.byte_pos
171    }
172
173    /// Check if the reader has been exhausted.
174    pub fn is_empty(&self) -> bool {
175        self.remaining_bits() == 0
176    }
177}
178
179/// Bit-level writer for constructing JPEG-XL bitstreams.
180///
181/// Writes bits in LSB-first order, accumulating a byte buffer.
182pub struct BitWriter {
183    data: Vec<u8>,
184    current_byte: u8,
185    bit_pos: u8,
186}
187
188impl BitWriter {
189    /// Create a new empty bit writer.
190    pub fn new() -> Self {
191        Self {
192            data: Vec::new(),
193            current_byte: 0,
194            bit_pos: 0,
195        }
196    }
197
198    /// Create a new bit writer with pre-allocated capacity (in bytes).
199    pub fn with_capacity(bytes: usize) -> Self {
200        Self {
201            data: Vec::with_capacity(bytes),
202            current_byte: 0,
203            bit_pos: 0,
204        }
205    }
206
207    /// Write up to 32 bits (LSB first).
208    pub fn write_bits(&mut self, value: u32, n: u8) {
209        if n == 0 {
210            return;
211        }
212        let mut remaining = n;
213        let mut val = value;
214        let mut written: u8 = 0;
215
216        while written < n {
217            let space_in_byte = 8 - self.bit_pos;
218            let bits_to_write = space_in_byte.min(remaining);
219            let mask = (1u32 << bits_to_write) - 1;
220            let bits = (val & mask) as u8;
221
222            self.current_byte |= bits << self.bit_pos;
223            self.bit_pos += bits_to_write;
224            val >>= bits_to_write;
225            written += bits_to_write;
226            remaining -= bits_to_write;
227
228            if self.bit_pos >= 8 {
229                self.data.push(self.current_byte);
230                self.current_byte = 0;
231                self.bit_pos = 0;
232            }
233        }
234    }
235
236    /// Write a single boolean bit.
237    pub fn write_bool(&mut self, v: bool) {
238        self.write_bits(u32::from(v), 1);
239    }
240
241    /// Write a u64 value using JPEG-XL's U64 encoding.
242    pub fn write_u64(&mut self, value: u64) {
243        if value == 0 {
244            self.write_bits(0, 2); // selector 0
245        } else if value <= 16 {
246            self.write_bits(1, 2); // selector 1
247            self.write_bits((value - 1) as u32, 4);
248        } else if value <= 272 {
249            self.write_bits(2, 2); // selector 2
250            self.write_bits((value - 17) as u32, 8);
251        } else {
252            self.write_bits(3, 2); // selector 3
253            let mut remaining = value - 273;
254            // Write first 12 bits
255            self.write_bits((remaining & 0xFFF) as u32, 12);
256            remaining >>= 12;
257            let mut shift = 12u32;
258            while shift < 60 && remaining > 0 {
259                self.write_bool(true); // more chunks
260                self.write_bits((remaining & 0xFF) as u32, 8);
261                remaining >>= 8;
262                shift += 8;
263            }
264            if shift < 60 {
265                self.write_bool(false); // no more chunks
266            } else if shift >= 60 {
267                // Final 4 bits
268                self.write_bits((remaining & 0xF) as u32, 4);
269            }
270        }
271    }
272
273    /// Align to the next byte boundary by writing zero bits.
274    pub fn align_to_byte(&mut self) {
275        if self.bit_pos != 0 {
276            self.data.push(self.current_byte);
277            self.current_byte = 0;
278            self.bit_pos = 0;
279        }
280    }
281
282    /// Consume the writer and return the accumulated byte buffer.
283    ///
284    /// Any partial byte is flushed with zero-padding.
285    pub fn finish(mut self) -> Vec<u8> {
286        self.align_to_byte();
287        self.data
288    }
289
290    /// Current number of bytes written (excluding partial byte).
291    pub fn bytes_written(&self) -> usize {
292        self.data.len()
293    }
294
295    /// Current total bits written.
296    pub fn bits_written(&self) -> usize {
297        self.data.len() * 8 + self.bit_pos as usize
298    }
299}
300
301impl Default for BitWriter {
302    fn default() -> Self {
303        Self::new()
304    }
305}
306
307#[cfg(test)]
308mod tests {
309    use super::*;
310
311    #[test]
312    #[ignore]
313    fn test_bitreader_basic() {
314        let data = [0b1010_0110u8, 0b1100_0011];
315        let mut reader = BitReader::new(&data);
316
317        // Read 4 bits from first byte (LSB first): 0110 -> 6
318        assert_eq!(reader.read_bits(4).expect("ok"), 0b0110);
319        // Read 4 more bits: 1010 -> 10
320        assert_eq!(reader.read_bits(4).expect("ok"), 0b1010);
321        // Read 8 bits from second byte: 0b1100_0011 -> 195
322        assert_eq!(reader.read_bits(8).expect("ok"), 0b1100_0011);
323    }
324
325    #[test]
326    #[ignore]
327    fn test_bitreader_cross_byte() {
328        let data = [0xFF, 0x00];
329        let mut reader = BitReader::new(&data);
330
331        // Read 4 bits: 1111 = 15
332        assert_eq!(reader.read_bits(4).expect("ok"), 0xF);
333        // Read 8 bits crossing byte boundary: 0000_1111 in LSB order
334        assert_eq!(reader.read_bits(8).expect("ok"), 0x0F);
335    }
336
337    #[test]
338    #[ignore]
339    fn test_bitreader_bool() {
340        let data = [0b0000_0101];
341        let mut reader = BitReader::new(&data);
342
343        assert!(reader.read_bool().expect("ok")); // bit 0 = 1
344        assert!(!reader.read_bool().expect("ok")); // bit 1 = 0
345        assert!(reader.read_bool().expect("ok")); // bit 2 = 1
346    }
347
348    #[test]
349    #[ignore]
350    fn test_bitreader_eof() {
351        let data = [0xFF];
352        let mut reader = BitReader::new(&data);
353        let _ = reader.read_bits(8).expect("ok");
354        assert!(reader.read_bits(1).is_err());
355    }
356
357    #[test]
358    #[ignore]
359    fn test_bitreader_remaining() {
360        let data = [0xFF, 0xFF];
361        let mut reader = BitReader::new(&data);
362        assert_eq!(reader.remaining_bits(), 16);
363        let _ = reader.read_bits(3).expect("ok");
364        assert_eq!(reader.remaining_bits(), 13);
365    }
366
367    #[test]
368    #[ignore]
369    fn test_bitreader_align() {
370        let data = [0xFF, 0xAA];
371        let mut reader = BitReader::new(&data);
372        let _ = reader.read_bits(3).expect("ok");
373        reader.align_to_byte();
374        // Should now be at byte 1
375        assert_eq!(reader.read_bits(8).expect("ok"), 0xAA);
376    }
377
378    #[test]
379    #[ignore]
380    fn test_bitwriter_basic() {
381        let mut writer = BitWriter::new();
382        writer.write_bits(0b0110, 4);
383        writer.write_bits(0b1010, 4);
384        let data = writer.finish();
385        assert_eq!(data, vec![0b1010_0110]);
386    }
387
388    #[test]
389    #[ignore]
390    fn test_bitwriter_cross_byte() {
391        let mut writer = BitWriter::new();
392        writer.write_bits(0xF, 4);
393        writer.write_bits(0x0F, 8);
394        let data = writer.finish();
395        assert_eq!(data, vec![0xFF, 0x00]);
396    }
397
398    #[test]
399    #[ignore]
400    fn test_bitwriter_bool() {
401        let mut writer = BitWriter::new();
402        writer.write_bool(true);
403        writer.write_bool(false);
404        writer.write_bool(true);
405        writer.write_bool(false);
406        writer.write_bool(false);
407        writer.write_bool(false);
408        writer.write_bool(false);
409        writer.write_bool(false);
410        let data = writer.finish();
411        assert_eq!(data, vec![0b0000_0101]);
412    }
413
414    #[test]
415    #[ignore]
416    fn test_roundtrip_bits() {
417        let mut writer = BitWriter::new();
418        writer.write_bits(42, 7);
419        writer.write_bits(1023, 10);
420        writer.write_bits(0, 3);
421        writer.write_bits(255, 8);
422        let data = writer.finish();
423
424        let mut reader = BitReader::new(&data);
425        assert_eq!(reader.read_bits(7).expect("ok"), 42);
426        assert_eq!(reader.read_bits(10).expect("ok"), 1023);
427        assert_eq!(reader.read_bits(3).expect("ok"), 0);
428        assert_eq!(reader.read_bits(8).expect("ok"), 255);
429    }
430
431    #[test]
432    #[ignore]
433    fn test_roundtrip_u64() {
434        for value in [0u64, 1, 5, 16, 17, 100, 272, 273, 1000, 65535, 1_000_000] {
435            let mut writer = BitWriter::new();
436            writer.write_u64(value);
437            let data = writer.finish();
438
439            let mut reader = BitReader::new(&data);
440            let decoded = reader.read_u64().expect("ok");
441            assert_eq!(decoded, value, "U64 roundtrip failed for {value}");
442        }
443    }
444
445    #[test]
446    #[ignore]
447    fn test_bitwriter_align() {
448        let mut writer = BitWriter::new();
449        writer.write_bits(0b101, 3);
450        writer.align_to_byte();
451        writer.write_bits(0xAA, 8);
452        let data = writer.finish();
453        assert_eq!(data.len(), 2);
454        assert_eq!(data[0], 0b0000_0101);
455        assert_eq!(data[1], 0xAA);
456    }
457}