Skip to main content

oxiarc_zstd/
bitwriter.rs

1//! Bitstream writers for Zstandard encoding.
2//!
3//! Zstandard uses two bitstream directions:
4//! - **Forward bitstream** (LSB first): used for FSE table descriptions, literals headers,
5//!   and other metadata fields.
6//! - **Backward bitstream**: used for FSE sequence encoding, where the last symbol written
7//!   is the first one read during decoding. The output bytes are stored in reverse order
8//!   with a sentinel bit marking the start of data.
9
10/// Forward bitstream writer (LSB first).
11///
12/// Used for writing FSE table descriptions, various headers, and other
13/// forward-direction bitstream data in Zstandard frames.
14///
15/// Bits are packed into bytes starting from the least significant bit.
16/// When a byte is full, it is flushed to the output buffer and the
17/// accumulator resets.
18pub struct ForwardBitWriter {
19    /// Accumulated output bytes.
20    output: Vec<u8>,
21    /// Current byte being assembled (bits accumulated so far).
22    current_byte: u8,
23    /// Number of valid bits in `current_byte` (0..8).
24    bits_in_current: u8,
25}
26
27impl ForwardBitWriter {
28    /// Create a new forward bitstream writer.
29    pub fn new() -> Self {
30        Self {
31            output: Vec::new(),
32            current_byte: 0,
33            bits_in_current: 0,
34        }
35    }
36
37    /// Create a new forward bitstream writer with a capacity hint.
38    pub fn with_capacity(byte_capacity: usize) -> Self {
39        Self {
40            output: Vec::with_capacity(byte_capacity),
41            current_byte: 0,
42            bits_in_current: 0,
43        }
44    }
45
46    /// Write `num_bits` bits from `value` (LSB first, up to 25 bits).
47    ///
48    /// The lowest `num_bits` bits of `value` are written to the stream.
49    /// Bits are packed into bytes starting from the least significant bit.
50    ///
51    /// # Panics
52    ///
53    /// Panics if `num_bits` exceeds 25.
54    pub fn write_bits(&mut self, value: u32, num_bits: u8) {
55        debug_assert!(
56            num_bits <= 25,
57            "ForwardBitWriter supports up to 25 bits per call"
58        );
59
60        if num_bits == 0 {
61            return;
62        }
63
64        // Mask off any extraneous high bits from value.
65        let mask = if num_bits >= 32 {
66            u32::MAX
67        } else {
68            (1u32 << num_bits) - 1
69        };
70        let masked_value = value & mask;
71
72        // Pack bits into current_byte, flushing full bytes as we go.
73        let mut remaining_bits = num_bits;
74        let mut bits_to_write = masked_value;
75
76        while remaining_bits > 0 {
77            let space_in_current = 8 - self.bits_in_current;
78            let take = remaining_bits.min(space_in_current);
79
80            // Extract the lowest `take` bits from bits_to_write.
81            let take_mask = if take >= 32 {
82                u32::MAX
83            } else {
84                (1u32 << take) - 1
85            };
86            let chunk = (bits_to_write & take_mask) as u8;
87
88            // Place them at the correct position in current_byte.
89            self.current_byte |= chunk << self.bits_in_current;
90            self.bits_in_current += take;
91
92            // Advance past the bits we consumed.
93            bits_to_write >>= take;
94            remaining_bits -= take;
95
96            // If the byte is full, flush it.
97            if self.bits_in_current == 8 {
98                self.output.push(self.current_byte);
99                self.current_byte = 0;
100                self.bits_in_current = 0;
101            }
102        }
103    }
104
105    /// Write a single bit (0 or 1).
106    pub fn write_bit(&mut self, bit: bool) {
107        self.write_bits(if bit { 1 } else { 0 }, 1);
108    }
109
110    /// Flush remaining bits, padding with zeros to the next byte boundary.
111    ///
112    /// Consumes the writer and returns the accumulated output bytes.
113    /// If there are any pending bits that do not fill a complete byte,
114    /// they are padded with zeros in the high bits.
115    pub fn finish(mut self) -> Vec<u8> {
116        if self.bits_in_current > 0 {
117            // Pad remaining bits with zeros (already zero from initialization).
118            self.output.push(self.current_byte);
119        }
120        self.output
121    }
122
123    /// Current bit position (total number of bits written so far).
124    pub fn bit_position(&self) -> usize {
125        self.output.len() * 8 + self.bits_in_current as usize
126    }
127
128    /// Current byte length of the output (not counting partial byte).
129    pub fn byte_len(&self) -> usize {
130        self.output.len()
131    }
132
133    /// Whether no bits have been written yet.
134    pub fn is_empty(&self) -> bool {
135        self.output.is_empty() && self.bits_in_current == 0
136    }
137
138    /// Get a reference to the bytes written so far (not including partial byte).
139    pub fn as_bytes(&self) -> &[u8] {
140        &self.output
141    }
142}
143
144impl Default for ForwardBitWriter {
145    fn default() -> Self {
146        Self::new()
147    }
148}
149
150/// Backward bitstream writer for FSE sequence encoding.
151///
152/// Produces a byte array compatible with the `FseBitReader`:
153/// - The last byte contains a sentinel (highest set bit) and the first data bits.
154/// - Preceding bytes contain subsequent data bits, with byte at index N-2
155///   being read after the sentinel byte, N-3 after that, etc.
156///
157/// The encoder writes bits in the same order the decoder reads them
158/// (first written = first decoded).
159///
160/// Internally, bits are accumulated into a `Vec<u8>` from MSB of the highest
161/// byte down to LSB of byte 0. At `finish()`, a sentinel is added and the
162/// output is ready for the decoder.
163pub struct BackwardBitWriter {
164    /// All data bits collected in a flat bit vector. We track them from the
165    /// "first written" end so that we can serialize in the order the reader
166    /// expects.
167    data_bits: Vec<u8>,
168    /// Total number of data bits written.
169    total_bits: usize,
170}
171
172impl BackwardBitWriter {
173    /// Create a new backward bitstream writer.
174    pub fn new() -> Self {
175        Self {
176            data_bits: Vec::new(),
177            total_bits: 0,
178        }
179    }
180
181    /// Create a new backward bitstream writer with a capacity hint.
182    pub fn with_capacity(byte_capacity: usize) -> Self {
183        Self {
184            data_bits: Vec::with_capacity(byte_capacity * 8),
185            total_bits: 0,
186        }
187    }
188
189    /// Write `num_bits` bits from `value` into the backward stream.
190    ///
191    /// The lowest `num_bits` bits of `value` are appended. The first call's
192    /// bits will be the first bits the decoder reads.
193    pub fn write_bits(&mut self, value: u64, num_bits: u8) {
194        if num_bits == 0 {
195            return;
196        }
197
198        // Store individual bits (LSB of value first).
199        for i in 0..num_bits {
200            let bit = ((value >> i) & 1) as u8;
201            self.data_bits.push(bit);
202        }
203        self.total_bits += num_bits as usize;
204    }
205
206    /// Write a single bit (0 or 1).
207    pub fn write_bit(&mut self, bit: bool) {
208        self.data_bits.push(if bit { 1 } else { 0 });
209        self.total_bits += 1;
210    }
211
212    /// Finalize the backward bitstream.
213    ///
214    /// Produces a byte array where:
215    /// - The last byte contains the sentinel and the first data bits.
216    /// - Preceding bytes (read from index N-2 down to 0) contain later data bits.
217    ///
218    /// The `FseBitReader` loads the sentinel byte's data bits first (into the
219    /// accumulator's LSB), then loads byte N-2, N-3, ..., 0 into successively
220    /// higher accumulator positions.
221    ///
222    /// Returns the finalized byte vector. If no bits were written, returns `[0x01]`.
223    pub fn finish(self) -> Vec<u8> {
224        if self.data_bits.is_empty() {
225            return vec![0x01];
226        }
227
228        // The FseBitReader reads:
229        //   1. Sentinel byte (last byte): data bits below sentinel loaded first (LSB of accumulator)
230        //   2. Byte at index N-2: loaded into bits above sentinel data
231        //   3. Byte at index N-3: loaded above that
232        //   ...
233        //   N. Byte at index 0: loaded into highest positions
234        //
235        // So the first data bits go into the sentinel byte, next 8 bits into byte N-2,
236        // next 8 bits into byte N-3, etc.
237        //
238        // Build the output in reverse: start with byte 0, then byte 1, ..., then sentinel.
239
240        let n = self.data_bits.len();
241
242        // Figure out how many bits go into the sentinel byte.
243        // The sentinel byte can hold up to 7 data bits (bits 0-6, sentinel at bit 7 max).
244        // If total bits mod 8 == 0, sentinel gets 0 data bits (sentinel-only byte).
245        // If total bits mod 8 == k (1..7), sentinel gets k data bits.
246        // Actually, we need the total bits to decompose into: sentinel_bits + full_bytes * 8.
247        // sentinel_bits can be 0..7. If 0, we need an extra sentinel-only byte.
248
249        // Pack the data bits into bytes. The reader reads:
250        //   sentinel_data (first S data bits, S=0..7), then
251        //   byte N-2 (next 8 bits), byte N-3 (next 8), ..., byte 0 (last 8 bits).
252        //
253        // So byte 0 has the LAST 8 data bits, byte 1 has the second-to-last 8, etc.
254
255        let sentinel_data_bits = n % 8;
256        let full_bytes = n / 8;
257
258        // Build from byte 0 (which has the last 8 data bits) to the sentinel byte.
259        let mut output = Vec::with_capacity(full_bytes + 1);
260
261        // Byte 0 has data bits at indices [n - 8, n - 1] (the last 8 data bits).
262        // Byte 1 has data bits at indices [n - 16, n - 9].
263        // ...
264        // Byte k has data bits at indices [n - 8*(k+1), n - 8*k - 1].
265        //
266        // If sentinel_data_bits > 0, the sentinel covers indices [0, sentinel_data_bits-1].
267        // The remaining full_bytes cover indices [sentinel_data_bits, n-1].
268
269        // Build full bytes: byte 0 = last 8, byte 1 = second-to-last 8, etc.
270        for byte_idx in 0..full_bytes {
271            // This byte covers data_bits starting at offset:
272            // sentinel_data_bits + (full_bytes - 1 - byte_idx) * 8
273            let start = sentinel_data_bits + (full_bytes - 1 - byte_idx) * 8;
274            let mut byte_val = 0u8;
275            for bit in 0..8 {
276                if self.data_bits[start + bit] != 0 {
277                    byte_val |= 1 << bit;
278                }
279            }
280            output.push(byte_val);
281        }
282
283        // Build sentinel byte: first sentinel_data_bits of data_bits.
284        let mut sentinel_byte = 0u8;
285        for bit in 0..sentinel_data_bits {
286            if self.data_bits[bit] != 0 {
287                sentinel_byte |= 1 << bit;
288            }
289        }
290        sentinel_byte |= 1 << sentinel_data_bits; // Sentinel bit
291        output.push(sentinel_byte);
292
293        output
294    }
295
296    /// Number of data bits written so far (excludes sentinel).
297    pub fn len(&self) -> usize {
298        self.total_bits
299    }
300
301    /// Whether no bits have been written yet.
302    pub fn is_empty(&self) -> bool {
303        self.total_bits == 0
304    }
305}
306
307impl Default for BackwardBitWriter {
308    fn default() -> Self {
309        Self::new()
310    }
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_forward_empty() {
319        let writer = ForwardBitWriter::new();
320        assert!(writer.is_empty());
321        assert_eq!(writer.bit_position(), 0);
322        let output = writer.finish();
323        assert!(output.is_empty());
324    }
325
326    #[test]
327    fn test_forward_single_byte() {
328        let mut writer = ForwardBitWriter::new();
329        writer.write_bits(0xAB, 8);
330        assert_eq!(writer.bit_position(), 8);
331        let output = writer.finish();
332        assert_eq!(output, vec![0xAB]);
333    }
334
335    #[test]
336    fn test_forward_partial_byte() {
337        let mut writer = ForwardBitWriter::new();
338        // Write 3 bits: binary 101 = 5
339        writer.write_bits(5, 3);
340        assert_eq!(writer.bit_position(), 3);
341        let output = writer.finish();
342        // Should be padded: 0b00000_101 = 0x05
343        assert_eq!(output, vec![0x05]);
344    }
345
346    #[test]
347    fn test_forward_multi_byte() {
348        let mut writer = ForwardBitWriter::new();
349        // Write 12 bits: 0xABC & 0xFFF = 0xABC
350        // LSB first: low 8 bits = 0xBC, then high 4 bits = 0x0A
351        writer.write_bits(0xABC, 12);
352        let output = writer.finish();
353        assert_eq!(output, vec![0xBC, 0x0A]);
354    }
355
356    #[test]
357    fn test_forward_cross_byte_boundary() {
358        let mut writer = ForwardBitWriter::new();
359        writer.write_bits(0x07, 3); // bits: 111
360        writer.write_bits(0x1F, 5); // bits: 11111
361        // Combined: 11111_111 = 0xFF
362        let output = writer.finish();
363        assert_eq!(output, vec![0xFF]);
364    }
365
366    #[test]
367    fn test_forward_multiple_writes() {
368        let mut writer = ForwardBitWriter::new();
369        writer.write_bits(1, 1); // bit 0: 1
370        writer.write_bits(0, 1); // bit 1: 0
371        writer.write_bits(1, 1); // bit 2: 1
372        writer.write_bits(0, 1); // bit 3: 0
373        writer.write_bits(1, 1); // bit 4: 1
374        writer.write_bits(0, 1); // bit 5: 0
375        writer.write_bits(1, 1); // bit 6: 1
376        writer.write_bits(0, 1); // bit 7: 0
377        // Binary: 01010101 = 0x55
378        let output = writer.finish();
379        assert_eq!(output, vec![0x55]);
380    }
381
382    #[test]
383    fn test_forward_write_bit() {
384        let mut writer = ForwardBitWriter::new();
385        for _ in 0..8 {
386            writer.write_bit(true);
387        }
388        let output = writer.finish();
389        assert_eq!(output, vec![0xFF]);
390    }
391
392    #[test]
393    fn test_forward_zero_bits() {
394        let mut writer = ForwardBitWriter::new();
395        writer.write_bits(0xFF, 0); // Should write nothing
396        assert!(writer.is_empty());
397        let output = writer.finish();
398        assert!(output.is_empty());
399    }
400
401    #[test]
402    fn test_forward_25_bits() {
403        let mut writer = ForwardBitWriter::new();
404        let val = (1u32 << 25) - 1; // 25 bits all ones
405        writer.write_bits(val, 25);
406        assert_eq!(writer.bit_position(), 25);
407        let output = writer.finish();
408        // 25 bits = 3 full bytes (24 bits) + 1 partial byte (1 bit)
409        assert_eq!(output.len(), 4);
410        assert_eq!(output[0], 0xFF);
411        assert_eq!(output[1], 0xFF);
412        assert_eq!(output[2], 0xFF);
413        assert_eq!(output[3], 0x01); // 1 bit set, padded
414    }
415
416    #[test]
417    fn test_backward_empty() {
418        let writer = BackwardBitWriter::new();
419        assert!(writer.is_empty());
420        assert_eq!(writer.len(), 0);
421        let output = writer.finish();
422        // Sentinel-only byte.
423        assert_eq!(output, vec![0x01]);
424    }
425
426    #[test]
427    fn test_backward_single_bit() {
428        let mut writer = BackwardBitWriter::new();
429        writer.write_bit(true);
430        let output = writer.finish();
431        // 1 data bit = 1, sentinel_data_bits = 1 mod 8 = 1.
432        // sentinel = 1 | (1 << 1) = 0x03
433        assert_eq!(output, vec![0x03]);
434    }
435
436    #[test]
437    fn test_backward_single_byte_data() {
438        let mut writer = BackwardBitWriter::new();
439        // Write 8 bits of data: 0xAB
440        writer.write_bits(0xAB, 8);
441        let output = writer.finish();
442        // 8 data bits: sentinel gets 0 data bits (8 mod 8 = 0).
443        // 1 full byte = 0xAB at index 0, sentinel 0x01 at index 1.
444        assert_eq!(output, vec![0xAB, 0x01]);
445    }
446
447    #[test]
448    fn test_backward_partial_bits() {
449        let mut writer = BackwardBitWriter::new();
450        // Write 5 bits: 0b10110 = 22
451        writer.write_bits(22, 5);
452        let output = writer.finish();
453        // 5 data bits, 0 full bytes, sentinel gets all 5 bits.
454        // sentinel = 22 | (1 << 5) = 0x36
455        assert_eq!(output, vec![0x36]);
456    }
457
458    #[test]
459    fn test_backward_multi_byte() {
460        let mut writer = BackwardBitWriter::new();
461        writer.write_bits(0xFF, 8);
462        writer.write_bits(0xAA, 8);
463        let output = writer.finish();
464        // 16 data bits, 2 full bytes, sentinel gets 0 data bits.
465        // Byte 0 = last 8 data bits (0xAA), byte 1 = first 8 data bits (0xFF),
466        // sentinel = 0x01.
467        // Wait: data_bits in write order = [0xFF bits, 0xAA bits].
468        // sentinel_data_bits = 16 % 8 = 0, full_bytes = 2.
469        // byte_idx=0: start = 0 + (2-1-0)*8 = 8, data_bits[8..15] = 0xAA bits
470        // byte_idx=1: start = 0 + (2-1-1)*8 = 0, data_bits[0..7] = 0xFF bits
471        assert_eq!(output, vec![0xAA, 0xFF, 0x01]);
472    }
473
474    #[test]
475    fn test_backward_len() {
476        let mut writer = BackwardBitWriter::new();
477        writer.write_bits(0, 3);
478        assert_eq!(writer.len(), 3);
479        writer.write_bits(0, 10);
480        assert_eq!(writer.len(), 13);
481    }
482
483    #[test]
484    fn test_backward_zero_bits() {
485        let mut writer = BackwardBitWriter::new();
486        writer.write_bits(0xFF, 0); // Should write nothing
487        assert!(writer.is_empty());
488    }
489
490    #[test]
491    fn test_forward_with_capacity() {
492        let writer = ForwardBitWriter::with_capacity(128);
493        assert!(writer.is_empty());
494        let output = writer.finish();
495        assert!(output.is_empty());
496    }
497
498    #[test]
499    fn test_backward_with_capacity() {
500        let writer = BackwardBitWriter::with_capacity(128);
501        assert!(writer.is_empty());
502    }
503}