1use crate::error::{Error, Result};
11
12pub const MAX_BITS_PER_CALL: usize = 56;
15
16#[derive(Debug, Clone)]
36pub struct BitWriter {
37 storage: Vec<u8>,
39 bits_written: usize,
41}
42
43impl Default for BitWriter {
44 fn default() -> Self {
45 Self::new()
46 }
47}
48
49impl BitWriter {
50 pub fn new() -> Self {
52 Self {
53 storage: Vec::new(),
54 bits_written: 0,
55 }
56 }
57
58 pub fn with_capacity(capacity_bytes: usize) -> Self {
64 Self {
65 storage: Vec::with_capacity(capacity_bytes),
66 bits_written: 0,
67 }
68 }
69
70 #[inline]
72 pub fn bits_written(&self) -> usize {
73 self.bits_written
74 }
75
76 #[inline]
78 pub fn bytes_written(&self) -> usize {
79 self.bits_written.div_ceil(8)
80 }
81
82 #[inline]
84 pub fn is_byte_aligned(&self) -> bool {
85 self.bits_written.is_multiple_of(8)
86 }
87
88 #[inline]
90 pub fn bits_to_byte_boundary(&self) -> usize {
91 if self.bits_written.is_multiple_of(8) {
92 0
93 } else {
94 8 - (self.bits_written % 8)
95 }
96 }
97
98 fn ensure_capacity(&mut self, additional_bits: usize) -> Result<()> {
100 let total_bits = self.bits_written + additional_bits;
101 let required_bytes = total_bits.div_ceil(8) + 8; if self.storage.len() < required_bytes {
104 self.storage
105 .try_reserve(required_bytes - self.storage.len())?;
106 self.storage.resize(required_bytes, 0);
107 }
108 Ok(())
109 }
110
111 #[inline]
125 pub fn write(&mut self, n_bits: usize, bits: u64) -> Result<()> {
126 if n_bits > MAX_BITS_PER_CALL {
127 return Err(Error::TooManyBitsPerCall(n_bits));
128 }
129
130 if n_bits == 0 {
131 return Ok(());
132 }
133
134 debug_assert!(
135 bits >> n_bits == 0 || n_bits == 64,
136 "bits {bits:#x} has more than {n_bits} bits"
137 );
138
139 self.ensure_capacity(n_bits)?;
140
141 let byte_offset = self.bits_written / 8;
142 let bits_in_first_byte = self.bits_written % 8;
143
144 let shifted_bits = bits << bits_in_first_byte;
146
147 let p = &mut self.storage[byte_offset..];
150
151 let mut current = u64::from_le_bytes(p[..8].try_into().unwrap());
155 current |= shifted_bits;
156 p[..8].copy_from_slice(¤t.to_le_bytes());
157
158 self.bits_written += n_bits;
159 Ok(())
160 }
161
162 pub fn zero_pad_to_byte(&mut self) {
166 let remainder = self.bits_to_byte_boundary();
167 if remainder > 0 {
168 let _ = self.write(remainder, 0);
170 }
171 debug_assert!(self.is_byte_aligned());
172 }
173
174 pub fn append_bytes(&mut self, data: &[u8]) -> Result<()> {
182 if !self.is_byte_aligned() {
183 return Err(Error::NotByteAligned(self.bits_written));
184 }
185
186 if data.is_empty() {
187 return Ok(());
188 }
189
190 let byte_offset = self.bits_written / 8;
191 let new_len = byte_offset + data.len() + 8; if self.storage.len() < new_len {
194 self.storage.try_reserve(new_len - self.storage.len())?;
195 self.storage.resize(new_len, 0);
196 }
197
198 self.storage[byte_offset..byte_offset + data.len()].copy_from_slice(data);
199 self.bits_written += data.len() * 8;
200
201 if byte_offset + data.len() < self.storage.len() {
203 self.storage[byte_offset + data.len()] = 0;
204 }
205
206 Ok(())
207 }
208
209 pub fn append_byte_aligned(&mut self, other: &BitWriter) -> Result<()> {
217 if !self.is_byte_aligned() {
218 return Err(Error::NotByteAligned(self.bits_written));
219 }
220 if !other.is_byte_aligned() {
221 return Err(Error::NotByteAligned(other.bits_written));
222 }
223
224 let other_bytes = other.bytes_written();
225 self.append_bytes(&other.storage[..other_bytes])
226 }
227
228 pub fn append_unaligned(&mut self, other: &BitWriter) -> Result<()> {
232 let full_bytes = other.bits_written / 8;
233 let remaining_bits = other.bits_written % 8;
234
235 for &byte in &other.storage[..full_bytes] {
236 self.write(8, byte as u64)?;
237 }
238
239 if remaining_bits > 0 {
240 let mask = (1u64 << remaining_bits) - 1;
241 let last_bits = other.storage[full_bytes] as u64 & mask;
242 self.write(remaining_bits, last_bits)?;
243 }
244
245 Ok(())
246 }
247
248 pub fn as_bytes(&self) -> &[u8] {
256 assert!(
257 self.is_byte_aligned(),
258 "BitWriter must be byte-aligned to get bytes"
259 );
260 &self.storage[..self.bytes_written()]
261 }
262
263 pub fn peek_bytes(&self) -> &[u8] {
269 let bytes = self.bits_written.div_ceil(8);
270 &self.storage[..bytes.min(self.storage.len())]
271 }
272
273 pub fn finish(mut self) -> Vec<u8> {
281 assert!(
282 self.is_byte_aligned(),
283 "BitWriter must be byte-aligned to finish"
284 );
285 self.storage.truncate(self.bytes_written());
286 self.storage
287 }
288
289 pub fn finish_with_padding(mut self) -> Vec<u8> {
293 self.zero_pad_to_byte();
294 self.storage.truncate(self.bytes_written());
295 self.storage
296 }
297}
298
299impl BitWriter {
301 #[inline]
303 pub fn write_bit(&mut self, bit: bool) -> Result<()> {
304 self.write(1, bit as u64)
305 }
306
307 #[inline]
309 pub fn write_u8(&mut self, value: u8) -> Result<()> {
310 self.write(8, value as u64)
311 }
312
313 #[inline]
315 pub fn write_u16(&mut self, value: u16) -> Result<()> {
316 self.write(16, value as u64)
317 }
318
319 #[inline]
321 pub fn write_u32(&mut self, value: u32) -> Result<()> {
322 self.write(32, value as u64)
323 }
324
325 pub fn write_u32_coder(
339 &mut self,
340 value: u32,
341 d0: u32,
342 d1: u32,
343 d2: u32,
344 d3: u32,
345 u_bits: usize,
346 ) -> Result<()> {
347 if value == d0 {
348 self.write(2, 0)?;
349 } else if value == d1 {
350 self.write(2, 1)?;
351 } else if value == d2 {
352 self.write(2, 2)?;
353 } else {
354 debug_assert!(value >= d3, "value {value} < d3 {d3}");
355 debug_assert!(
356 (value - d3) < (1 << u_bits),
357 "value {value} - d3 {d3} doesn't fit in {u_bits} bits"
358 );
359 self.write(2, 3)?;
360 self.write(u_bits, (value - d3) as u64)?;
361 }
362 Ok(())
363 }
364
365 pub fn write_enum_default(&mut self, value: u32) -> Result<()> {
372 if value == 0 {
373 self.write(2, 0)?;
374 } else if value == 1 {
375 self.write(2, 1)?;
376 } else if value < 18 {
377 self.write(2, 2)?;
378 self.write(4, (value - 2) as u64)?;
379 } else {
380 debug_assert!(
381 value < 82,
382 "value {value} too large for default enum encoding"
383 );
384 self.write(2, 3)?;
385 self.write(6, (value - 18) as u64)?;
386 }
387 Ok(())
388 }
389
390 pub fn write_u64_coder(&mut self, value: u64) -> Result<()> {
400 if value == 0 {
401 self.write(2, 0)?;
402 } else if value <= 16 {
403 self.write(2, 1)?;
404 self.write(4, value - 1)?;
405 } else if value <= 272 {
406 self.write(2, 2)?;
407 self.write(8, value - 17)?;
408 } else {
409 self.write(2, 3)?;
411 let mut remaining = value;
412 self.write(12, remaining & 0xFFF)?;
413 remaining >>= 12;
414 let mut shift = 12;
415 while remaining > 0 && shift < 60 {
416 self.write(1, 1)?; self.write(8, remaining & 0xFF)?;
418 remaining >>= 8;
419 shift += 8;
420 }
421 if remaining > 0 {
422 self.write(1, 1)?;
424 self.write(4, remaining & 0xF)?;
425 } else {
426 self.write(1, 0)?; }
428 }
429 Ok(())
430 }
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 #[test]
438 fn test_write_simple() {
439 let mut writer = BitWriter::new();
440 writer.write(8, 0x12).unwrap();
441 writer.write(8, 0x34).unwrap();
442
443 let bytes = writer.finish();
444 assert_eq!(bytes, vec![0x12, 0x34]);
445 }
446
447 #[test]
448 fn test_write_partial_bytes() {
449 let mut writer = BitWriter::new();
450 writer.write(4, 0x2).unwrap(); writer.write(4, 0x1).unwrap(); let bytes = writer.finish();
455 assert_eq!(bytes, vec![0x12]);
456 }
457
458 #[test]
459 fn test_write_across_bytes() {
460 let mut writer = BitWriter::new();
461 writer.write(4, 0x2).unwrap();
462 writer.write(8, 0x34).unwrap();
463 writer.write(4, 0x1).unwrap();
464
465 let bytes = writer.finish();
466 assert_eq!(bytes, vec![0x42, 0x13]);
470 }
471
472 #[test]
473 fn test_zero_pad() {
474 let mut writer = BitWriter::new();
475 writer.write(5, 0x15).unwrap();
476 assert!(!writer.is_byte_aligned());
477 assert_eq!(writer.bits_to_byte_boundary(), 3);
478
479 writer.zero_pad_to_byte();
480 assert!(writer.is_byte_aligned());
481
482 let bytes = writer.finish();
483 assert_eq!(bytes, vec![0x15]); }
485
486 #[test]
487 fn test_append_bytes() {
488 let mut writer = BitWriter::new();
489 writer.write(8, 0x12).unwrap();
490 writer.append_bytes(&[0x34, 0x56]).unwrap();
491
492 let bytes = writer.finish();
493 assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
494 }
495
496 #[test]
497 fn test_append_bytes_unaligned_fails() {
498 let mut writer = BitWriter::new();
499 writer.write(4, 0x2).unwrap();
500
501 let result = writer.append_bytes(&[0x34]);
502 assert!(result.is_err());
503 }
504
505 #[test]
506 fn test_write_too_many_bits() {
507 let mut writer = BitWriter::new();
508 let result = writer.write(57, 0);
509 assert!(matches!(result, Err(Error::TooManyBitsPerCall(57))));
510 }
511
512 #[test]
513 fn test_bits_written() {
514 let mut writer = BitWriter::new();
515 assert_eq!(writer.bits_written(), 0);
516
517 writer.write(5, 0).unwrap();
518 assert_eq!(writer.bits_written(), 5);
519
520 writer.write(11, 0).unwrap();
521 assert_eq!(writer.bits_written(), 16);
522 }
523
524 #[test]
525 fn test_append_byte_aligned() {
526 let mut writer1 = BitWriter::new();
527 writer1.write(8, 0x12).unwrap();
528
529 let mut writer2 = BitWriter::new();
530 writer2.write(16, 0x5634).unwrap();
531
532 writer1.append_byte_aligned(&writer2).unwrap();
533
534 let bytes = writer1.finish();
535 assert_eq!(bytes, vec![0x12, 0x34, 0x56]);
536 }
537
538 #[test]
539 fn test_append_unaligned() {
540 let mut writer1 = BitWriter::new();
541 writer1.write(4, 0x2).unwrap();
542
543 let mut writer2 = BitWriter::new();
544 writer2.write(8, 0x34).unwrap();
545
546 writer1.append_unaligned(&writer2).unwrap();
547 writer1.zero_pad_to_byte();
548
549 let bytes = writer1.finish();
550 assert_eq!(bytes, vec![0x42, 0x03]);
554 }
555
556 #[test]
557 fn test_finish_with_padding() {
558 let mut writer = BitWriter::new();
559 writer.write(5, 0x15).unwrap();
560
561 let bytes = writer.finish_with_padding();
562 assert_eq!(bytes, vec![0x15]);
563 }
564
565 #[test]
566 fn test_u32_coder() {
567 let mut writer = BitWriter::new();
569 writer.write_u32_coder(0, 0, 1, 2, 3, 8).unwrap();
570 writer.zero_pad_to_byte();
571 assert_eq!(writer.as_bytes(), &[0b00]); let mut writer = BitWriter::new();
574 writer.write_u32_coder(1, 0, 1, 2, 3, 8).unwrap();
575 writer.zero_pad_to_byte();
576 assert_eq!(writer.as_bytes(), &[0b01]); let mut writer = BitWriter::new();
579 writer.write_u32_coder(2, 0, 1, 2, 3, 8).unwrap();
580 writer.zero_pad_to_byte();
581 assert_eq!(writer.as_bytes(), &[0b10]); let mut writer = BitWriter::new();
585 writer.write_u32_coder(10, 0, 1, 2, 3, 8).unwrap(); writer.zero_pad_to_byte();
587 assert_eq!(writer.as_bytes(), &[0x1F, 0x00]);
594 }
595
596 fn u64_encode(value: u64) -> (usize, Vec<u8>) {
598 let mut writer = BitWriter::new();
599 writer.write_u64_coder(value).unwrap();
600 let bits = writer.bits_written();
601 writer.zero_pad_to_byte();
602 (bits, writer.finish())
603 }
604
605 fn u64_decode(data: &[u8]) -> u64 {
608 let mut pos = 0usize; let read_bits = |data: &[u8], pos: &mut usize, n: usize| -> u64 {
610 let mut val = 0u64;
611 for i in 0..n {
612 let byte_idx = (*pos + i) / 8;
613 let bit_idx = (*pos + i) % 8;
614 if byte_idx < data.len() && (data[byte_idx] >> bit_idx) & 1 == 1 {
615 val |= 1u64 << i;
616 }
617 }
618 *pos += n;
619 val
620 };
621
622 let selector = read_bits(data, &mut pos, 2);
623 match selector {
624 0 => 0,
625 1 => 1 + read_bits(data, &mut pos, 4),
626 2 => 17 + read_bits(data, &mut pos, 8),
627 3 => {
628 let mut value = read_bits(data, &mut pos, 12);
629 let mut shift = 12u32;
630 while shift < 60 {
631 if read_bits(data, &mut pos, 1) == 0 {
632 break; }
634 value |= read_bits(data, &mut pos, 8) << shift;
635 shift += 8;
636 }
637 if shift == 60 && read_bits(data, &mut pos, 1) == 1 {
638 value |= read_bits(data, &mut pos, 4) << shift;
639 }
640 value
641 }
642 _ => unreachable!(),
643 }
644 }
645
646 #[test]
647 fn test_u64_coder_small_values() {
648 let (bits, _) = u64_encode(0);
650 assert_eq!(bits, 2);
651 assert_eq!(u64_decode(&u64_encode(0).1), 0);
652
653 assert_eq!(u64_decode(&u64_encode(1).1), 1);
655 assert_eq!(u64_decode(&u64_encode(15).1), 15);
656 assert_eq!(u64_decode(&u64_encode(16).1), 16);
657 let (bits, _) = u64_encode(1);
658 assert_eq!(bits, 6); assert_eq!(u64_decode(&u64_encode(17).1), 17);
662 assert_eq!(u64_decode(&u64_encode(271).1), 271);
663 assert_eq!(u64_decode(&u64_encode(272).1), 272);
664 let (bits, _) = u64_encode(17);
665 assert_eq!(bits, 10); }
667
668 #[test]
669 fn test_u64_coder_selector3_varint() {
670 let (bits, _) = u64_encode(273);
672 assert_eq!(bits, 15); assert_eq!(u64_decode(&u64_encode(273).1), 273);
674
675 assert_eq!(u64_decode(&u64_encode(4096).1), 4096);
677 let (bits, _) = u64_encode(4096);
678 assert_eq!(bits, 24); assert_eq!(u64_decode(&u64_encode(1 << 16).1), 1 << 16);
682
683 assert_eq!(u64_decode(&u64_encode(1 << 28).1), 1 << 28);
685
686 assert_eq!(
688 u64_decode(&u64_encode((1u64 << 32) - 1).1),
689 (1u64 << 32) - 1
690 );
691
692 assert_eq!(u64_decode(&u64_encode(1u64 << 32).1), 1u64 << 32);
694
695 assert_eq!(u64_decode(&u64_encode(1u64 << 63).1), 1u64 << 63);
697 }
698
699 #[test]
700 fn test_u64_coder_roundtrip_exhaustive() {
701 let test_values: &[u64] = &[
703 0,
704 1,
705 15,
706 16,
707 17,
708 271,
709 272,
710 273,
711 4096,
712 1 << 16,
713 1 << 28,
714 (1u64 << 32) - 1,
715 1u64 << 32,
716 1u64 << 63,
717 ];
718 for &v in test_values {
719 let encoded = u64_encode(v).1;
720 let decoded = u64_decode(&encoded);
721 assert_eq!(
722 decoded, v,
723 "U64 roundtrip failed for value {v}: encoded {encoded:?}, decoded {decoded}"
724 );
725 }
726 }
727}