1use crate::error::{Error, Result};
11
12pub const MAX_BITS_PER_CALL: usize = 56;
15
16#[derive(Debug, Clone)]
36pub struct BitWriter {
37 storage: Vec<u8>,
39 bits_written: usize,
41}
42
43impl Default for BitWriter {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl BitWriter {
50 pub fn new() -> Self {
52 Self {
53 storage: Vec::new(),
54 bits_written: 0,
55 }
56 }
57
58 pub fn with_capacity(capacity_bytes: usize) -> Self {
64 Self {
65 storage: Vec::with_capacity(capacity_bytes),
66 bits_written: 0,
67 }
68 }
69
70 #[inline]
72 pub fn bits_written(&self) -> usize {
73 self.bits_written
74 }
75
76 #[inline]
78 pub fn bytes_written(&self) -> usize {
79 self.bits_written.div_ceil(8)
80 }
81
82 #[inline]
84 pub fn is_byte_aligned(&self) -> bool {
85 self.bits_written.is_multiple_of(8)
86 }
87
88 #[inline]
90 pub fn bits_to_byte_boundary(&self) -> usize {
91 if self.bits_written.is_multiple_of(8) {
92 0
93 } else {
94 8 - (self.bits_written % 8)
95 }
96 }
97
98 fn ensure_capacity(&mut self, additional_bits: usize) -> Result<()> {
100 let total_bits = self.bits_written + additional_bits;
101 let required_bytes = total_bits.div_ceil(8) + 8; if self.storage.len() < required_bytes {
104 self.storage
105 .try_reserve(required_bytes - self.storage.len())?;
106 self.storage.resize(required_bytes, 0);
107 }
108 Ok(())
109 }
110
111 #[inline]
125 pub fn write(&mut self, n_bits: usize, bits: u64) -> Result<()> {
126 if n_bits > MAX_BITS_PER_CALL {
127 return Err(Error::TooManyBitsPerCall(n_bits));
128 }
129
130 if n_bits == 0 {
131 return Ok(());
132 }
133
134 debug_assert!(
135 bits >> n_bits == 0 || n_bits == 64,
136 "bits {bits:#x} has more than {n_bits} bits"
137 );
138
139 self.ensure_capacity(n_bits)?;
140
141 let byte_offset = self.bits_written / 8;
142 let bits_in_first_byte = self.bits_written % 8;
143
144 let shifted_bits = bits << bits_in_first_byte;
146
147 let p = &mut self.storage[byte_offset..];
150
151 let mut current = u64::from_le_bytes(p[..8].try_into().unwrap());
155 current |= shifted_bits;
156 p[..8].copy_from_slice(¤t.to_le_bytes());
157
158 self.bits_written += n_bits;
159 Ok(())
160 }
161
162 pub fn zero_pad_to_byte(&mut self) {
166 let remainder = self.bits_to_byte_boundary();
167 if remainder > 0 {
168 let _ = self.write(remainder, 0);
170 }
171 debug_assert!(self.is_byte_aligned());
172 }
173
174 pub fn append_bytes(&mut self, data: &[u8]) -> Result<()> {
182 if !self.is_byte_aligned() {
183 return Err(Error::NotByteAligned(self.bits_written));
184 }
185
186 if data.is_empty() {
187 return Ok(());
188 }
189
190 let byte_offset = self.bits_written / 8;
191 let new_len = byte_offset + data.len() + 8; if self.storage.len() < new_len {
194 self.storage.try_reserve(new_len - self.storage.len())?;
195 self.storage.resize(new_len, 0);
196 }
197
198 self.storage[byte_offset..byte_offset + data.len()].copy_from_slice(data);
199 self.bits_written += data.len() * 8;
200
201 if byte_offset + data.len() < self.storage.len() {
203 self.storage[byte_offset + data.len()] = 0;
204 }
205
206 Ok(())
207 }
208
209 pub fn append_byte_aligned(&mut self, other: &BitWriter) -> Result<()> {
217 if !self.is_byte_aligned() {
218 return Err(Error::NotByteAligned(self.bits_written));
219 }
220 if !other.is_byte_aligned() {
221 return Err(Error::NotByteAligned(other.bits_written));
222 }
223
224 let other_bytes = other.bytes_written();
225 self.append_bytes(&other.storage[..other_bytes])
226 }
227
228 pub fn append_unaligned(&mut self, other: &BitWriter) -> Result<()> {
232 let full_bytes = other.bits_written / 8;
233 let remaining_bits = other.bits_written % 8;
234
235 for &byte in &other.storage[..full_bytes] {
236 self.write(8, byte as u64)?;
237 }
238
239 if remaining_bits > 0 {
240 let mask = (1u64 << remaining_bits) - 1;
241 let last_bits = other.storage[full_bytes] as u64 & mask;
242 self.write(remaining_bits, last_bits)?;
243 }
244
245 Ok(())
246 }
247
248 pub fn as_bytes(&self) -> &[u8] {
256 assert!(
257 self.is_byte_aligned(),
258 "BitWriter must be byte-aligned to get bytes"
259 );
260 &self.storage[..self.bytes_written()]
261 }
262
263 pub fn peek_bytes(&self) -> &[u8] {
269 let bytes = self.bits_written.div_ceil(8);
270 &self.storage[..bytes.min(self.storage.len())]
271 }
272
273 pub fn finish(mut self) -> Vec<u8> {
281 assert!(
282 self.is_byte_aligned(),
283 "BitWriter must be byte-aligned to finish"
284 );
285 self.storage.truncate(self.bytes_written());
286 self.storage
287 }
288
289 pub fn finish_with_padding(mut self) -> Vec<u8> {
293 self.zero_pad_to_byte();
294 self.storage.truncate(self.bytes_written());
295 self.storage
296 }
297}
298
299impl BitWriter {
301 #[inline]
303 pub fn write_bit(&mut self, bit: bool) -> Result<()> {
304 self.write(1, bit as u64)
305 }
306
307 #[inline]
309 pub fn write_u8(&mut self, value: u8) -> Result<()> {
310 self.write(8, value as u64)
311 }
312
313 #[inline]
315 pub fn write_u16(&mut self, value: u16) -> Result<()> {
316 self.write(16, value as u64)
317 }
318
319 #[inline]
321 pub fn write_u32(&mut self, value: u32) -> Result<()> {
322 self.write(32, value as u64)
323 }
324
325 pub fn write_u32_coder(
339 &mut self,
340 value: u32,
341 d0: u32,
342 d1: u32,
343 d2: u32,
344 d3: u32,
345 u_bits: usize,
346 ) -> Result<()> {
347 if value == d0 {
348 self.write(2, 0)?;
349 } else if value == d1 {
350 self.write(2, 1)?;
351 } else if value == d2 {
352 self.write(2, 2)?;
353 } else {
354 debug_assert!(value >= d3, "value {value} < d3 {d3}");
355 debug_assert!(
356 (value - d3) < (1 << u_bits),
357 "value {value} - d3 {d3} doesn't fit in {u_bits} bits"
358 );
359 self.write(2, 3)?;
360 self.write(u_bits, (value - d3) as u64)?;
361 }
362 Ok(())
363 }
364
365 pub fn write_enum_default(&mut self, value: u32) -> Result<()> {
372 if value == 0 {
373 self.write(2, 0)?;
374 } else if value == 1 {
375 self.write(2, 1)?;
376 } else if value < 18 {
377 self.write(2, 2)?;
378 self.write(4, (value - 2) as u64)?;
379 } else {
380 debug_assert!(
381 value < 82,
382 "value {value} too large for default enum encoding"
383 );
384 self.write(2, 3)?;
385 self.write(6, (value - 18) as u64)?;
386 }
387 Ok(())
388 }
389
390 pub fn write_u64_coder(&mut self, value: u64) -> Result<()> {
400 if value == 0 {
401 self.write(2, 0)?;
402 } else if value <= 16 {
403 self.write(2, 1)?;
404 self.write(4, value - 1)?;
405 } else if value <= 272 {
406 self.write(2, 2)?;
407 self.write(8, value - 17)?;
408 } else if value <= 4368 {
409 self.write(2, 3)?;
410 self.write(12, value - 273)?;
411 } else {
412 self.write(2, 3)?;
414 let low = (value - 273) & 0xFFF;
415 let high = (value - 273) >> 12;
416 self.write(12, low | 0x1000)?; self.write(32, high)?;
418 }
419 Ok(())
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_write_simple() {
429 let mut writer = BitWriter::new();
430 writer.write(8, 0x12).unwrap();
431 writer.write(8, 0x34).unwrap();
432
433 let bytes = writer.finish();
434 assert_eq!(bytes, vec![0x12, 0x34]);
435 }
436
437 #[test]
438 fn test_write_partial_bytes() {
439 let mut writer = BitWriter::new();
440 writer.write(4, 0x2).unwrap(); writer.write(4, 0x1).unwrap(); let bytes = writer.finish();
445 assert_eq!(bytes, vec![0x12]);
446 }
447
448 #[test]
449 fn test_write_across_bytes() {
450 let mut writer = BitWriter::new();
451 writer.write(4, 0x2).unwrap();
452 writer.write(8, 0x34).unwrap();
453 writer.write(4, 0x1).unwrap();
454
455 let bytes = writer.finish();
456 assert_eq!(bytes, vec![0x42, 0x13]);
460 }
461
462 #[test]
463 fn test_zero_pad() {
464 let mut writer = BitWriter::new();
465 writer.write(5, 0x15).unwrap();
466 assert!(!writer.is_byte_aligned());
467 assert_eq!(writer.bits_to_byte_boundary(), 3);
468
469 writer.zero_pad_to_byte();
470 assert!(writer.is_byte_aligned());
471
472 let bytes = writer.finish();
473 assert_eq!(bytes, vec![0x15]); }
475
476 #[test]
477 fn test_append_bytes() {
478 let mut writer = BitWriter::new();
479 writer.write(8, 0x12).unwrap();
480 writer.append_bytes(&[0x34, 0x56]).unwrap();
481
482 let bytes = writer.finish();
483 assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
484 }
485
486 #[test]
487 fn test_append_bytes_unaligned_fails() {
488 let mut writer = BitWriter::new();
489 writer.write(4, 0x2).unwrap();
490
491 let result = writer.append_bytes(&[0x34]);
492 assert!(result.is_err());
493 }
494
495 #[test]
496 fn test_write_too_many_bits() {
497 let mut writer = BitWriter::new();
498 let result = writer.write(57, 0);
499 assert!(matches!(result, Err(Error::TooManyBitsPerCall(57))));
500 }
501
502 #[test]
503 fn test_bits_written() {
504 let mut writer = BitWriter::new();
505 assert_eq!(writer.bits_written(), 0);
506
507 writer.write(5, 0).unwrap();
508 assert_eq!(writer.bits_written(), 5);
509
510 writer.write(11, 0).unwrap();
511 assert_eq!(writer.bits_written(), 16);
512 }
513
514 #[test]
515 fn test_append_byte_aligned() {
516 let mut writer1 = BitWriter::new();
517 writer1.write(8, 0x12).unwrap();
518
519 let mut writer2 = BitWriter::new();
520 writer2.write(16, 0x5634).unwrap();
521
522 writer1.append_byte_aligned(&writer2).unwrap();
523
524 let bytes = writer1.finish();
525 assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
526 }
527
528 #[test]
529 fn test_append_unaligned() {
530 let mut writer1 = BitWriter::new();
531 writer1.write(4, 0x2).unwrap();
532
533 let mut writer2 = BitWriter::new();
534 writer2.write(8, 0x34).unwrap();
535
536 writer1.append_unaligned(&writer2).unwrap();
537 writer1.zero_pad_to_byte();
538
539 let bytes = writer1.finish();
540 assert_eq!(bytes, vec![0x42, 0x03]);
544 }
545
546 #[test]
547 fn test_finish_with_padding() {
548 let mut writer = BitWriter::new();
549 writer.write(5, 0x15).unwrap();
550
551 let bytes = writer.finish_with_padding();
552 assert_eq!(bytes, vec![0x15]);
553 }
554
555 #[test]
556 fn test_u32_coder() {
557 let mut writer = BitWriter::new();
559 writer.write_u32_coder(0, 0, 1, 2, 3, 8).unwrap();
560 writer.zero_pad_to_byte();
561 assert_eq!(writer.as_bytes(), &[0b00]); let mut writer = BitWriter::new();
564 writer.write_u32_coder(1, 0, 1, 2, 3, 8).unwrap();
565 writer.zero_pad_to_byte();
566 assert_eq!(writer.as_bytes(), &[0b01]); let mut writer = BitWriter::new();
569 writer.write_u32_coder(2, 0, 1, 2, 3, 8).unwrap();
570 writer.zero_pad_to_byte();
571 assert_eq!(writer.as_bytes(), &[0b10]); let mut writer = BitWriter::new();
575 writer.write_u32_coder(10, 0, 1, 2, 3, 8).unwrap(); writer.zero_pad_to_byte();
577 assert_eq!(writer.as_bytes(), &[0x1F, 0x00]);
584 }
585}