Skip to main content

jxl_encoder/
bit_writer.rs

1// Copyright (c) Imazen LLC and the JPEG XL Project Authors.
2// Algorithms and constants derived from libjxl (BSD-3-Clause).
3// Licensed under AGPL-3.0-or-later. Commercial licenses at https://www.imazen.io/pricing
4
5//! BitWriter for encoding JPEG XL bitstreams.
6//!
7//! Writes bits in little-endian order, least-significant-bit first within
8//! each byte. This is the inverse of the BitReader used in decoding.
9
10use crate::error::{Error, Result};
11
12/// Maximum bits that can be written in a single call.
13/// Matches the decoder's MAX_BITS_PER_CALL for symmetry.
14pub const MAX_BITS_PER_CALL: usize = 56;
15
16/// Writes bits into a growable byte buffer.
17///
18/// Bits are written in little-endian order, with the least significant bit
19/// of each value written first. This matches the JXL bitstream format.
20///
21/// # Example
22///
23/// ```
24/// use jxl_encoder::bit_writer::BitWriter;
25///
26/// let mut writer = BitWriter::new();
27/// writer.write(8, 0x12).unwrap();
28/// writer.write(4, 0x3).unwrap();
29/// writer.write(4, 0x4).unwrap();
30/// writer.zero_pad_to_byte();
31///
32/// let bytes = writer.finish();
33/// assert_eq!(bytes, vec![0x12, 0x43]);
34/// ```
35#[derive(Debug, Clone)]
36pub struct BitWriter {
37    /// Output buffer containing written bytes.
38    storage: Vec<u8>,
39    /// Total number of bits written.
40    bits_written: usize,
41}
42
43impl Default for BitWriter {
44    fn default() -> Self {
45        Self::new()
46    }
47}
48
49impl BitWriter {
50    /// Creates a new empty BitWriter.
51    pub fn new() -> Self {
52        Self {
53            storage: Vec::new(),
54            bits_written: 0,
55        }
56    }
57
58    /// Creates a new BitWriter with pre-allocated capacity.
59    ///
60    /// # Arguments
61    ///
62    /// * `capacity_bytes` - Initial capacity in bytes.
63    pub fn with_capacity(capacity_bytes: usize) -> Self {
64        Self {
65            storage: Vec::with_capacity(capacity_bytes),
66            bits_written: 0,
67        }
68    }
69
70    /// Returns the total number of bits written.
71    #[inline]
72    pub fn bits_written(&self) -> usize {
73        self.bits_written
74    }
75
76    /// Returns the number of bytes written (rounded up).
77    #[inline]
78    pub fn bytes_written(&self) -> usize {
79        self.bits_written.div_ceil(8)
80    }
81
82    /// Returns true if the writer is aligned to a byte boundary.
83    #[inline]
84    pub fn is_byte_aligned(&self) -> bool {
85        self.bits_written.is_multiple_of(8)
86    }
87
88    /// Returns the number of bits needed to reach the next byte boundary.
89    #[inline]
90    pub fn bits_to_byte_boundary(&self) -> usize {
91        if self.bits_written.is_multiple_of(8) {
92            0
93        } else {
94            8 - (self.bits_written % 8)
95        }
96    }
97
98    /// Ensures the storage has capacity for at least `additional_bits` more bits.
99    fn ensure_capacity(&mut self, additional_bits: usize) -> Result<()> {
100        let total_bits = self.bits_written + additional_bits;
101        let required_bytes = total_bits.div_ceil(8) + 8; // Extra 8 bytes for unaligned writes
102
103        if self.storage.len() < required_bytes {
104            self.storage
105                .try_reserve(required_bytes - self.storage.len())?;
106            self.storage.resize(required_bytes, 0);
107        }
108        Ok(())
109    }
110
111    /// Writes up to 56 bits to the buffer.
112    ///
113    /// Bits are written in little-endian order, least-significant-bit first.
114    /// The value must fit in `n_bits` bits.
115    ///
116    /// # Arguments
117    ///
118    /// * `n_bits` - Number of bits to write (0-56).
119    /// * `bits` - The value to write. Only the lower `n_bits` are used.
120    ///
121    /// # Errors
122    ///
123    /// Returns an error if `n_bits > 56` or if allocation fails.
124    #[inline]
125    pub fn write(&mut self, n_bits: usize, bits: u64) -> Result<()> {
126        if n_bits > MAX_BITS_PER_CALL {
127            return Err(Error::TooManyBitsPerCall(n_bits));
128        }
129
130        if n_bits == 0 {
131            return Ok(());
132        }
133
134        debug_assert!(
135            bits >> n_bits == 0 || n_bits == 64,
136            "bits {bits:#x} has more than {n_bits} bits"
137        );
138
139        self.ensure_capacity(n_bits)?;
140
141        let byte_offset = self.bits_written / 8;
142        let bits_in_first_byte = self.bits_written % 8;
143
144        // Shift the bits to align with the current position
145        let shifted_bits = bits << bits_in_first_byte;
146
147        // Read current value, OR in new bits, write back
148        // This handles the case where we're continuing a partial byte
149        let p = &mut self.storage[byte_offset..];
150
151        // Use little-endian 64-bit write for efficiency
152        // This may write more bytes than strictly necessary, but the extra
153        // bytes are already zero-initialized
154        let mut current = u64::from_le_bytes(p[..8].try_into().unwrap());
155        current |= shifted_bits;
156        p[..8].copy_from_slice(&current.to_le_bytes());
157
158        self.bits_written += n_bits;
159        Ok(())
160    }
161
162    /// Writes zeros to pad to the next byte boundary.
163    ///
164    /// If already byte-aligned, this is a no-op.
165    pub fn zero_pad_to_byte(&mut self) {
166        let remainder = self.bits_to_byte_boundary();
167        if remainder > 0 {
168            // We know this won't fail since remainder <= 7
169            let _ = self.write(remainder, 0);
170        }
171        debug_assert!(self.is_byte_aligned());
172    }
173
174    /// Appends byte-aligned data from a slice.
175    ///
176    /// The writer must be byte-aligned before calling this method.
177    ///
178    /// # Errors
179    ///
180    /// Returns an error if the writer is not byte-aligned.
181    pub fn append_bytes(&mut self, data: &[u8]) -> Result<()> {
182        if !self.is_byte_aligned() {
183            return Err(Error::NotByteAligned(self.bits_written));
184        }
185
186        if data.is_empty() {
187            return Ok(());
188        }
189
190        let byte_offset = self.bits_written / 8;
191        let new_len = byte_offset + data.len() + 8; // Extra padding for future writes
192
193        if self.storage.len() < new_len {
194            self.storage.try_reserve(new_len - self.storage.len())?;
195            self.storage.resize(new_len, 0);
196        }
197
198        self.storage[byte_offset..byte_offset + data.len()].copy_from_slice(data);
199        self.bits_written += data.len() * 8;
200
201        // Ensure trailing zero for next write
202        if byte_offset + data.len() < self.storage.len() {
203            self.storage[byte_offset + data.len()] = 0;
204        }
205
206        Ok(())
207    }
208
209    /// Appends another BitWriter's contents.
210    ///
211    /// Both writers must be byte-aligned.
212    ///
213    /// # Errors
214    ///
215    /// Returns an error if either writer is not byte-aligned.
216    pub fn append_byte_aligned(&mut self, other: &BitWriter) -> Result<()> {
217        if !self.is_byte_aligned() {
218            return Err(Error::NotByteAligned(self.bits_written));
219        }
220        if !other.is_byte_aligned() {
221            return Err(Error::NotByteAligned(other.bits_written));
222        }
223
224        let other_bytes = other.bytes_written();
225        self.append_bytes(&other.storage[..other_bytes])
226    }
227
228    /// Appends another BitWriter's contents, allowing unaligned data.
229    ///
230    /// This is slower than `append_byte_aligned` but works with any alignment.
231    pub fn append_unaligned(&mut self, other: &BitWriter) -> Result<()> {
232        let full_bytes = other.bits_written / 8;
233        let remaining_bits = other.bits_written % 8;
234
235        for &byte in &other.storage[..full_bytes] {
236            self.write(8, byte as u64)?;
237        }
238
239        if remaining_bits > 0 {
240            let mask = (1u64 << remaining_bits) - 1;
241            let last_bits = other.storage[full_bytes] as u64 & mask;
242            self.write(remaining_bits, last_bits)?;
243        }
244
245        Ok(())
246    }
247
248    /// Returns a view of the written bytes.
249    ///
250    /// The writer must be byte-aligned.
251    ///
252    /// # Panics
253    ///
254    /// Panics if the writer is not byte-aligned.
255    pub fn as_bytes(&self) -> &[u8] {
256        assert!(
257            self.is_byte_aligned(),
258            "BitWriter must be byte-aligned to get bytes"
259        );
260        &self.storage[..self.bytes_written()]
261    }
262
263    /// Returns a view of the internal storage for debugging.
264    ///
265    /// Unlike `as_bytes()`, this does not require byte alignment and returns
266    /// the raw storage including any partial bytes. Useful for debugging
267    /// bit-level encoding issues.
268    pub fn peek_bytes(&self) -> &[u8] {
269        let bytes = self.bits_written.div_ceil(8);
270        &self.storage[..bytes.min(self.storage.len())]
271    }
272
273    /// Consumes the writer and returns the written bytes.
274    ///
275    /// The writer must be byte-aligned.
276    ///
277    /// # Panics
278    ///
279    /// Panics if the writer is not byte-aligned.
280    pub fn finish(mut self) -> Vec<u8> {
281        assert!(
282            self.is_byte_aligned(),
283            "BitWriter must be byte-aligned to finish"
284        );
285        self.storage.truncate(self.bytes_written());
286        self.storage
287    }
288
289    /// Consumes the writer and returns the written bytes, padding if necessary.
290    ///
291    /// Unlike `finish`, this will zero-pad to byte alignment if needed.
292    pub fn finish_with_padding(mut self) -> Vec<u8> {
293        self.zero_pad_to_byte();
294        self.storage.truncate(self.bytes_written());
295        self.storage
296    }
297}
298
299// Convenience write methods for common types
300impl BitWriter {
301    /// Writes a single bit (0 or 1).
302    #[inline]
303    pub fn write_bit(&mut self, bit: bool) -> Result<()> {
304        self.write(1, bit as u64)
305    }
306
307    /// Writes an 8-bit unsigned integer.
308    #[inline]
309    pub fn write_u8(&mut self, value: u8) -> Result<()> {
310        self.write(8, value as u64)
311    }
312
313    /// Writes a 16-bit unsigned integer in little-endian order.
314    #[inline]
315    pub fn write_u16(&mut self, value: u16) -> Result<()> {
316        self.write(16, value as u64)
317    }
318
319    /// Writes a 32-bit unsigned integer in little-endian order.
320    #[inline]
321    pub fn write_u32(&mut self, value: u32) -> Result<()> {
322        self.write(32, value as u64)
323    }
324
325    /// Writes a U32 value using the JXL variable-length encoding.
326    ///
327    /// The encoding is selector-based:
328    /// - 0: value is `d0`
329    /// - 1: value is `d1`
330    /// - 2: value is `d2`
331    /// - 3: `u_bits` bits follow, value is `d3 + read_bits`
332    ///
333    /// # Arguments
334    ///
335    /// * `value` - The value to encode.
336    /// * `d0`, `d1`, `d2`, `d3` - Direct values for selectors 0-2 and offset for selector 3.
337    /// * `u_bits` - Number of bits for the variable portion (selector 3).
338    pub fn write_u32_coder(
339        &mut self,
340        value: u32,
341        d0: u32,
342        d1: u32,
343        d2: u32,
344        d3: u32,
345        u_bits: usize,
346    ) -> Result<()> {
347        if value == d0 {
348            self.write(2, 0)?;
349        } else if value == d1 {
350            self.write(2, 1)?;
351        } else if value == d2 {
352            self.write(2, 2)?;
353        } else {
354            debug_assert!(value >= d3, "value {value} < d3 {d3}");
355            debug_assert!(
356                (value - d3) < (1 << u_bits),
357                "value {value} - d3 {d3} doesn't fit in {u_bits} bits"
358            );
359            self.write(2, 3)?;
360            self.write(u_bits, (value - d3) as u64)?;
361        }
362        Ok(())
363    }
364
365    /// Writes an enum value using the jxl-rs default u2S encoding.
366    /// This uses u2S(0, 1, Bits(4)+2, Bits(6)+18):
367    /// - selector 0 → value 0
368    /// - selector 1 → value 1
369    /// - selector 2 → 2 + Bits(4) = values 2-17
370    /// - selector 3 → 18 + Bits(6) = values 18-81
371    pub fn write_enum_default(&mut self, value: u32) -> Result<()> {
372        if value == 0 {
373            self.write(2, 0)?;
374        } else if value == 1 {
375            self.write(2, 1)?;
376        } else if value < 18 {
377            self.write(2, 2)?;
378            self.write(4, (value - 2) as u64)?;
379        } else {
380            debug_assert!(
381                value < 82,
382                "value {value} too large for default enum encoding"
383            );
384            self.write(2, 3)?;
385            self.write(6, (value - 18) as u64)?;
386        }
387        Ok(())
388    }
389
390    /// Writes a U64 value using the JXL variable-length encoding.
391    ///
392    /// The encoding depends on the selector:
393    /// - 0: value is 0
394    /// - 1: 4 bits follow, value is 1 + read_bits (1-16)
395    /// - 2: 8 bits follow, value is 17 + read_bits (17-272)
396    /// - 3: 12 bits follow, then:
397    ///   - If high bit is 0: value is 273 + low 12 bits (273-4368)
398    ///   - If high bit is 1: 32 more bits follow, value is combined
399    pub fn write_u64_coder(&mut self, value: u64) -> Result<()> {
400        if value == 0 {
401            self.write(2, 0)?;
402        } else if value <= 16 {
403            self.write(2, 1)?;
404            self.write(4, value - 1)?;
405        } else if value <= 272 {
406            self.write(2, 2)?;
407            self.write(8, value - 17)?;
408        } else if value <= 4368 {
409            self.write(2, 3)?;
410            self.write(12, value - 273)?;
411        } else {
412            // Large value: selector 3 with shift indicator
413            self.write(2, 3)?;
414            let low = (value - 273) & 0xFFF;
415            let high = (value - 273) >> 12;
416            self.write(12, low | 0x1000)?; // Set high bit to indicate more data
417            self.write(32, high)?;
418        }
419        Ok(())
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426
427    #[test]
428    fn test_write_simple() {
429        let mut writer = BitWriter::new();
430        writer.write(8, 0x12).unwrap();
431        writer.write(8, 0x34).unwrap();
432
433        let bytes = writer.finish();
434        assert_eq!(bytes, vec![0x12, 0x34]);
435    }
436
437    #[test]
438    fn test_write_partial_bytes() {
439        let mut writer = BitWriter::new();
440        writer.write(4, 0x2).unwrap(); // Lower nibble
441        writer.write(4, 0x1).unwrap(); // Upper nibble
442        // Result: 0x12 (little-endian, LSB first)
443
444        let bytes = writer.finish();
445        assert_eq!(bytes, vec![0x12]);
446    }
447
448    #[test]
449    fn test_write_across_bytes() {
450        let mut writer = BitWriter::new();
451        writer.write(4, 0x2).unwrap();
452        writer.write(8, 0x34).unwrap();
453        writer.write(4, 0x1).unwrap();
454
455        let bytes = writer.finish();
456        // Bits: 0010 | 0011_0100 | 0001
457        // Byte 0: 0010 + lower 4 of 0x34 (0100) = 0100_0010 = 0x42
458        // Byte 1: upper 4 of 0x34 (0011) + 0001 = 0001_0011 = 0x13
459        assert_eq!(bytes, vec![0x42, 0x13]);
460    }
461
462    #[test]
463    fn test_zero_pad() {
464        let mut writer = BitWriter::new();
465        writer.write(5, 0x15).unwrap();
466        assert!(!writer.is_byte_aligned());
467        assert_eq!(writer.bits_to_byte_boundary(), 3);
468
469        writer.zero_pad_to_byte();
470        assert!(writer.is_byte_aligned());
471
472        let bytes = writer.finish();
473        assert_eq!(bytes, vec![0x15]); // Lower 5 bits of 0x15, padded with zeros
474    }
475
476    #[test]
477    fn test_append_bytes() {
478        let mut writer = BitWriter::new();
479        writer.write(8, 0x12).unwrap();
480        writer.append_bytes(&[0x34, 0x56]).unwrap();
481
482        let bytes = writer.finish();
483        assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
484    }
485
486    #[test]
487    fn test_append_bytes_unaligned_fails() {
488        let mut writer = BitWriter::new();
489        writer.write(4, 0x2).unwrap();
490
491        let result = writer.append_bytes(&[0x34]);
492        assert!(result.is_err());
493    }
494
495    #[test]
496    fn test_write_too_many_bits() {
497        let mut writer = BitWriter::new();
498        let result = writer.write(57, 0);
499        assert!(matches!(result, Err(Error::TooManyBitsPerCall(57))));
500    }
501
502    #[test]
503    fn test_bits_written() {
504        let mut writer = BitWriter::new();
505        assert_eq!(writer.bits_written(), 0);
506
507        writer.write(5, 0).unwrap();
508        assert_eq!(writer.bits_written(), 5);
509
510        writer.write(11, 0).unwrap();
511        assert_eq!(writer.bits_written(), 16);
512    }
513
514    #[test]
515    fn test_append_byte_aligned() {
516        let mut writer1 = BitWriter::new();
517        writer1.write(8, 0x12).unwrap();
518
519        let mut writer2 = BitWriter::new();
520        writer2.write(16, 0x5634).unwrap();
521
522        writer1.append_byte_aligned(&writer2).unwrap();
523
524        let bytes = writer1.finish();
525        assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
526    }
527
528    #[test]
529    fn test_append_unaligned() {
530        let mut writer1 = BitWriter::new();
531        writer1.write(4, 0x2).unwrap();
532
533        let mut writer2 = BitWriter::new();
534        writer2.write(8, 0x34).unwrap();
535
536        writer1.append_unaligned(&writer2).unwrap();
537        writer1.zero_pad_to_byte();
538
539        let bytes = writer1.finish();
540        // 4 bits: 0010
541        // 8 bits: 0011_0100
542        // Result: 0010 + 0100 = 0100_0010 = 0x42, then 0011 padded = 0x03
543        assert_eq!(bytes, vec![0x42, 0x03]);
544    }
545
546    #[test]
547    fn test_finish_with_padding() {
548        let mut writer = BitWriter::new();
549        writer.write(5, 0x15).unwrap();
550
551        let bytes = writer.finish_with_padding();
552        assert_eq!(bytes, vec![0x15]);
553    }
554
555    #[test]
556    fn test_u32_coder() {
557        // Test direct values
558        let mut writer = BitWriter::new();
559        writer.write_u32_coder(0, 0, 1, 2, 3, 8).unwrap();
560        writer.zero_pad_to_byte();
561        assert_eq!(writer.as_bytes(), &[0b00]); // selector 0
562
563        let mut writer = BitWriter::new();
564        writer.write_u32_coder(1, 0, 1, 2, 3, 8).unwrap();
565        writer.zero_pad_to_byte();
566        assert_eq!(writer.as_bytes(), &[0b01]); // selector 1
567
568        let mut writer = BitWriter::new();
569        writer.write_u32_coder(2, 0, 1, 2, 3, 8).unwrap();
570        writer.zero_pad_to_byte();
571        assert_eq!(writer.as_bytes(), &[0b10]); // selector 2
572
573        // Test variable encoding
574        let mut writer = BitWriter::new();
575        writer.write_u32_coder(10, 0, 1, 2, 3, 8).unwrap(); // 10 - 3 = 7
576        writer.zero_pad_to_byte();
577        // selector 3 (0b11) + 7 (0b0000_0111) in 8 bits
578        // LSB first: bits are 11, then 11100000 (7 in 8 bits, LSB first)
579        // Byte 0: 11 + 000001 (6 bits of 7) = 00011111 = 0x1F
580        // Byte 1: remaining 00 = 0x00
581        // Actually: selector 11, then value 7 = 00000111
582        // Combined: 11 + 00000111 = 0b0000011111 -> bytes [0x1F, 0x00]
583        assert_eq!(writer.as_bytes(), &[0x1F, 0x00]);
584    }
585}