oxiarc_zstd/bitwriter.rs
1//! Bitstream writers for Zstandard encoding.
2//!
3//! Zstandard uses two bitstream directions:
4//! - **Forward bitstream** (LSB first): used for FSE table descriptions, literals headers,
5//! and other metadata fields.
6//! - **Backward bitstream**: used for FSE sequence encoding, where the last symbol written
7//! is the first one read during decoding. The output bytes are stored in reverse order
8//! with a sentinel bit marking the start of data.
9
10/// Forward bitstream writer (LSB first).
11///
12/// Used for writing FSE table descriptions, various headers, and other
13/// forward-direction bitstream data in Zstandard frames.
14///
15/// Bits are packed into bytes starting from the least significant bit.
16/// When a byte is full, it is flushed to the output buffer and the
17/// accumulator resets.
18pub struct ForwardBitWriter {
19 /// Accumulated output bytes.
20 output: Vec<u8>,
21 /// Current byte being assembled (bits accumulated so far).
22 current_byte: u8,
23 /// Number of valid bits in `current_byte` (0..8).
24 bits_in_current: u8,
25}
26
27impl ForwardBitWriter {
28 /// Create a new forward bitstream writer.
29 pub fn new() -> Self {
30 Self {
31 output: Vec::new(),
32 current_byte: 0,
33 bits_in_current: 0,
34 }
35 }
36
37 /// Create a new forward bitstream writer with a capacity hint.
38 pub fn with_capacity(byte_capacity: usize) -> Self {
39 Self {
40 output: Vec::with_capacity(byte_capacity),
41 current_byte: 0,
42 bits_in_current: 0,
43 }
44 }
45
46 /// Write `num_bits` bits from `value` (LSB first, up to 25 bits).
47 ///
48 /// The lowest `num_bits` bits of `value` are written to the stream.
49 /// Bits are packed into bytes starting from the least significant bit.
50 ///
51 /// # Panics
52 ///
53 /// Panics if `num_bits` exceeds 25.
54 pub fn write_bits(&mut self, value: u32, num_bits: u8) {
55 debug_assert!(
56 num_bits <= 25,
57 "ForwardBitWriter supports up to 25 bits per call"
58 );
59
60 if num_bits == 0 {
61 return;
62 }
63
64 // Mask off any extraneous high bits from value.
65 let mask = if num_bits >= 32 {
66 u32::MAX
67 } else {
68 (1u32 << num_bits) - 1
69 };
70 let masked_value = value & mask;
71
72 // Pack bits into current_byte, flushing full bytes as we go.
73 let mut remaining_bits = num_bits;
74 let mut bits_to_write = masked_value;
75
76 while remaining_bits > 0 {
77 let space_in_current = 8 - self.bits_in_current;
78 let take = remaining_bits.min(space_in_current);
79
80 // Extract the lowest `take` bits from bits_to_write.
81 let take_mask = if take >= 32 {
82 u32::MAX
83 } else {
84 (1u32 << take) - 1
85 };
86 let chunk = (bits_to_write & take_mask) as u8;
87
88 // Place them at the correct position in current_byte.
89 self.current_byte |= chunk << self.bits_in_current;
90 self.bits_in_current += take;
91
92 // Advance past the bits we consumed.
93 bits_to_write >>= take;
94 remaining_bits -= take;
95
96 // If the byte is full, flush it.
97 if self.bits_in_current == 8 {
98 self.output.push(self.current_byte);
99 self.current_byte = 0;
100 self.bits_in_current = 0;
101 }
102 }
103 }
104
105 /// Write a single bit (0 or 1).
106 pub fn write_bit(&mut self, bit: bool) {
107 self.write_bits(if bit { 1 } else { 0 }, 1);
108 }
109
110 /// Flush remaining bits, padding with zeros to the next byte boundary.
111 ///
112 /// Consumes the writer and returns the accumulated output bytes.
113 /// If there are any pending bits that do not fill a complete byte,
114 /// they are padded with zeros in the high bits.
115 pub fn finish(mut self) -> Vec<u8> {
116 if self.bits_in_current > 0 {
117 // Pad remaining bits with zeros (already zero from initialization).
118 self.output.push(self.current_byte);
119 }
120 self.output
121 }
122
123 /// Current bit position (total number of bits written so far).
124 pub fn bit_position(&self) -> usize {
125 self.output.len() * 8 + self.bits_in_current as usize
126 }
127
128 /// Current byte length of the output (not counting partial byte).
129 pub fn byte_len(&self) -> usize {
130 self.output.len()
131 }
132
133 /// Whether no bits have been written yet.
134 pub fn is_empty(&self) -> bool {
135 self.output.is_empty() && self.bits_in_current == 0
136 }
137
138 /// Get a reference to the bytes written so far (not including partial byte).
139 pub fn as_bytes(&self) -> &[u8] {
140 &self.output
141 }
142}
143
144impl Default for ForwardBitWriter {
145 fn default() -> Self {
146 Self::new()
147 }
148}
149
150/// Backward bitstream writer for FSE sequence encoding.
151///
152/// Produces a byte array compatible with the `FseBitReader`:
153/// - The last byte contains a sentinel (highest set bit) and the first data bits.
154/// - Preceding bytes contain subsequent data bits, with byte at index N-2
155/// being read after the sentinel byte, N-3 after that, etc.
156///
157/// The encoder writes bits in the same order the decoder reads them
158/// (first written = first decoded).
159///
160/// Internally, bits are accumulated into a `Vec<u8>` from MSB of the highest
161/// byte down to LSB of byte 0. At `finish()`, a sentinel is added and the
162/// output is ready for the decoder.
163pub struct BackwardBitWriter {
164 /// All data bits collected in a flat bit vector. We track them from the
165 /// "first written" end so that we can serialize in the order the reader
166 /// expects.
167 data_bits: Vec<u8>,
168 /// Total number of data bits written.
169 total_bits: usize,
170}
171
172impl BackwardBitWriter {
173 /// Create a new backward bitstream writer.
174 pub fn new() -> Self {
175 Self {
176 data_bits: Vec::new(),
177 total_bits: 0,
178 }
179 }
180
181 /// Create a new backward bitstream writer with a capacity hint.
182 pub fn with_capacity(byte_capacity: usize) -> Self {
183 Self {
184 data_bits: Vec::with_capacity(byte_capacity * 8),
185 total_bits: 0,
186 }
187 }
188
189 /// Write `num_bits` bits from `value` into the backward stream.
190 ///
191 /// The lowest `num_bits` bits of `value` are appended. The first call's
192 /// bits will be the first bits the decoder reads.
193 pub fn write_bits(&mut self, value: u64, num_bits: u8) {
194 if num_bits == 0 {
195 return;
196 }
197
198 // Store individual bits (LSB of value first).
199 for i in 0..num_bits {
200 let bit = ((value >> i) & 1) as u8;
201 self.data_bits.push(bit);
202 }
203 self.total_bits += num_bits as usize;
204 }
205
206 /// Write a single bit (0 or 1).
207 pub fn write_bit(&mut self, bit: bool) {
208 self.data_bits.push(if bit { 1 } else { 0 });
209 self.total_bits += 1;
210 }
211
212 /// Finalize the backward bitstream.
213 ///
214 /// Produces a byte array where:
215 /// - The last byte contains the sentinel and the first data bits.
216 /// - Preceding bytes (read from index N-2 down to 0) contain later data bits.
217 ///
218 /// The `FseBitReader` loads the sentinel byte's data bits first (into the
219 /// accumulator's LSB), then loads byte N-2, N-3, ..., 0 into successively
220 /// higher accumulator positions.
221 ///
222 /// Returns the finalized byte vector. If no bits were written, returns `[0x01]`.
223 pub fn finish(self) -> Vec<u8> {
224 if self.data_bits.is_empty() {
225 return vec![0x01];
226 }
227
228 // The FseBitReader reads:
229 // 1. Sentinel byte (last byte): data bits below sentinel loaded first (LSB of accumulator)
230 // 2. Byte at index N-2: loaded into bits above sentinel data
231 // 3. Byte at index N-3: loaded above that
232 // ...
233 // N. Byte at index 0: loaded into highest positions
234 //
235 // So the first data bits go into the sentinel byte, next 8 bits into byte N-2,
236 // next 8 bits into byte N-3, etc.
237 //
238 // Build the output in reverse: start with byte 0, then byte 1, ..., then sentinel.
239
240 let n = self.data_bits.len();
241
242 // Figure out how many bits go into the sentinel byte.
243 // The sentinel byte can hold up to 7 data bits (bits 0-6, sentinel at bit 7 max).
244 // If total bits mod 8 == 0, sentinel gets 0 data bits (sentinel-only byte).
245 // If total bits mod 8 == k (1..7), sentinel gets k data bits.
246 // Actually, we need the total bits to decompose into: sentinel_bits + full_bytes * 8.
247 // sentinel_bits can be 0..7. If 0, we need an extra sentinel-only byte.
248
249 // Pack the data bits into bytes. The reader reads:
250 // sentinel_data (first S data bits, S=0..7), then
251 // byte N-2 (next 8 bits), byte N-3 (next 8), ..., byte 0 (last 8 bits).
252 //
253 // So byte 0 has the LAST 8 data bits, byte 1 has the second-to-last 8, etc.
254
255 let sentinel_data_bits = n % 8;
256 let full_bytes = n / 8;
257
258 // Build from byte 0 (which has the last 8 data bits) to the sentinel byte.
259 let mut output = Vec::with_capacity(full_bytes + 1);
260
261 // Byte 0 has data bits at indices [n - 8, n - 1] (the last 8 data bits).
262 // Byte 1 has data bits at indices [n - 16, n - 9].
263 // ...
264 // Byte k has data bits at indices [n - 8*(k+1), n - 8*k - 1].
265 //
266 // If sentinel_data_bits > 0, the sentinel covers indices [0, sentinel_data_bits-1].
267 // The remaining full_bytes cover indices [sentinel_data_bits, n-1].
268
269 // Build full bytes: byte 0 = last 8, byte 1 = second-to-last 8, etc.
270 for byte_idx in 0..full_bytes {
271 // This byte covers data_bits starting at offset:
272 // sentinel_data_bits + (full_bytes - 1 - byte_idx) * 8
273 let start = sentinel_data_bits + (full_bytes - 1 - byte_idx) * 8;
274 let mut byte_val = 0u8;
275 for bit in 0..8 {
276 if self.data_bits[start + bit] != 0 {
277 byte_val |= 1 << bit;
278 }
279 }
280 output.push(byte_val);
281 }
282
283 // Build sentinel byte: first sentinel_data_bits of data_bits.
284 let mut sentinel_byte = 0u8;
285 for bit in 0..sentinel_data_bits {
286 if self.data_bits[bit] != 0 {
287 sentinel_byte |= 1 << bit;
288 }
289 }
290 sentinel_byte |= 1 << sentinel_data_bits; // Sentinel bit
291 output.push(sentinel_byte);
292
293 output
294 }
295
296 /// Number of data bits written so far (excludes sentinel).
297 pub fn len(&self) -> usize {
298 self.total_bits
299 }
300
301 /// Whether no bits have been written yet.
302 pub fn is_empty(&self) -> bool {
303 self.total_bits == 0
304 }
305}
306
307impl Default for BackwardBitWriter {
308 fn default() -> Self {
309 Self::new()
310 }
311}
312
313#[cfg(test)]
314mod tests {
315 use super::*;
316
317 #[test]
318 fn test_forward_empty() {
319 let writer = ForwardBitWriter::new();
320 assert!(writer.is_empty());
321 assert_eq!(writer.bit_position(), 0);
322 let output = writer.finish();
323 assert!(output.is_empty());
324 }
325
326 #[test]
327 fn test_forward_single_byte() {
328 let mut writer = ForwardBitWriter::new();
329 writer.write_bits(0xAB, 8);
330 assert_eq!(writer.bit_position(), 8);
331 let output = writer.finish();
332 assert_eq!(output, vec![0xAB]);
333 }
334
335 #[test]
336 fn test_forward_partial_byte() {
337 let mut writer = ForwardBitWriter::new();
338 // Write 3 bits: binary 101 = 5
339 writer.write_bits(5, 3);
340 assert_eq!(writer.bit_position(), 3);
341 let output = writer.finish();
342 // Should be padded: 0b00000_101 = 0x05
343 assert_eq!(output, vec![0x05]);
344 }
345
346 #[test]
347 fn test_forward_multi_byte() {
348 let mut writer = ForwardBitWriter::new();
349 // Write 12 bits: 0xABC & 0xFFF = 0xABC
350 // LSB first: low 8 bits = 0xBC, then high 4 bits = 0x0A
351 writer.write_bits(0xABC, 12);
352 let output = writer.finish();
353 assert_eq!(output, vec![0xBC, 0x0A]);
354 }
355
356 #[test]
357 fn test_forward_cross_byte_boundary() {
358 let mut writer = ForwardBitWriter::new();
359 writer.write_bits(0x07, 3); // bits: 111
360 writer.write_bits(0x1F, 5); // bits: 11111
361 // Combined: 11111_111 = 0xFF
362 let output = writer.finish();
363 assert_eq!(output, vec![0xFF]);
364 }
365
366 #[test]
367 fn test_forward_multiple_writes() {
368 let mut writer = ForwardBitWriter::new();
369 writer.write_bits(1, 1); // bit 0: 1
370 writer.write_bits(0, 1); // bit 1: 0
371 writer.write_bits(1, 1); // bit 2: 1
372 writer.write_bits(0, 1); // bit 3: 0
373 writer.write_bits(1, 1); // bit 4: 1
374 writer.write_bits(0, 1); // bit 5: 0
375 writer.write_bits(1, 1); // bit 6: 1
376 writer.write_bits(0, 1); // bit 7: 0
377 // Binary: 01010101 = 0x55
378 let output = writer.finish();
379 assert_eq!(output, vec![0x55]);
380 }
381
382 #[test]
383 fn test_forward_write_bit() {
384 let mut writer = ForwardBitWriter::new();
385 for _ in 0..8 {
386 writer.write_bit(true);
387 }
388 let output = writer.finish();
389 assert_eq!(output, vec![0xFF]);
390 }
391
392 #[test]
393 fn test_forward_zero_bits() {
394 let mut writer = ForwardBitWriter::new();
395 writer.write_bits(0xFF, 0); // Should write nothing
396 assert!(writer.is_empty());
397 let output = writer.finish();
398 assert!(output.is_empty());
399 }
400
401 #[test]
402 fn test_forward_25_bits() {
403 let mut writer = ForwardBitWriter::new();
404 let val = (1u32 << 25) - 1; // 25 bits all ones
405 writer.write_bits(val, 25);
406 assert_eq!(writer.bit_position(), 25);
407 let output = writer.finish();
408 // 25 bits = 3 full bytes (24 bits) + 1 partial byte (1 bit)
409 assert_eq!(output.len(), 4);
410 assert_eq!(output[0], 0xFF);
411 assert_eq!(output[1], 0xFF);
412 assert_eq!(output[2], 0xFF);
413 assert_eq!(output[3], 0x01); // 1 bit set, padded
414 }
415
416 #[test]
417 fn test_backward_empty() {
418 let writer = BackwardBitWriter::new();
419 assert!(writer.is_empty());
420 assert_eq!(writer.len(), 0);
421 let output = writer.finish();
422 // Sentinel-only byte.
423 assert_eq!(output, vec![0x01]);
424 }
425
426 #[test]
427 fn test_backward_single_bit() {
428 let mut writer = BackwardBitWriter::new();
429 writer.write_bit(true);
430 let output = writer.finish();
431 // 1 data bit = 1, sentinel_data_bits = 1 mod 8 = 1.
432 // sentinel = 1 | (1 << 1) = 0x03
433 assert_eq!(output, vec![0x03]);
434 }
435
436 #[test]
437 fn test_backward_single_byte_data() {
438 let mut writer = BackwardBitWriter::new();
439 // Write 8 bits of data: 0xAB
440 writer.write_bits(0xAB, 8);
441 let output = writer.finish();
442 // 8 data bits: sentinel gets 0 data bits (8 mod 8 = 0).
443 // 1 full byte = 0xAB at index 0, sentinel 0x01 at index 1.
444 assert_eq!(output, vec![0xAB, 0x01]);
445 }
446
447 #[test]
448 fn test_backward_partial_bits() {
449 let mut writer = BackwardBitWriter::new();
450 // Write 5 bits: 0b10110 = 22
451 writer.write_bits(22, 5);
452 let output = writer.finish();
453 // 5 data bits, 0 full bytes, sentinel gets all 5 bits.
454 // sentinel = 22 | (1 << 5) = 0x36
455 assert_eq!(output, vec![0x36]);
456 }
457
458 #[test]
459 fn test_backward_multi_byte() {
460 let mut writer = BackwardBitWriter::new();
461 writer.write_bits(0xFF, 8);
462 writer.write_bits(0xAA, 8);
463 let output = writer.finish();
464 // 16 data bits, 2 full bytes, sentinel gets 0 data bits.
465 // Byte 0 = last 8 data bits (0xAA), byte 1 = first 8 data bits (0xFF),
466 // sentinel = 0x01.
467 // Wait: data_bits in write order = [0xFF bits, 0xAA bits].
468 // sentinel_data_bits = 16 % 8 = 0, full_bytes = 2.
469 // byte_idx=0: start = 0 + (2-1-0)*8 = 8, data_bits[8..15] = 0xAA bits
470 // byte_idx=1: start = 0 + (2-1-1)*8 = 0, data_bits[0..7] = 0xFF bits
471 assert_eq!(output, vec![0xAA, 0xFF, 0x01]);
472 }
473
474 #[test]
475 fn test_backward_len() {
476 let mut writer = BackwardBitWriter::new();
477 writer.write_bits(0, 3);
478 assert_eq!(writer.len(), 3);
479 writer.write_bits(0, 10);
480 assert_eq!(writer.len(), 13);
481 }
482
483 #[test]
484 fn test_backward_zero_bits() {
485 let mut writer = BackwardBitWriter::new();
486 writer.write_bits(0xFF, 0); // Should write nothing
487 assert!(writer.is_empty());
488 }
489
490 #[test]
491 fn test_forward_with_capacity() {
492 let writer = ForwardBitWriter::with_capacity(128);
493 assert!(writer.is_empty());
494 let output = writer.finish();
495 assert!(output.is_empty());
496 }
497
498 #[test]
499 fn test_backward_with_capacity() {
500 let writer = BackwardBitWriter::with_capacity(128);
501 assert!(writer.is_empty());
502 }
503}