1use crate::{Error, Result};
35
36#[derive(Clone, Copy)]
45pub struct BitReader<'a> {
46 data: &'a [u8],
47 byte_pos: usize,
49 acc: u64,
51 bits_in_acc: u32,
53}
54
55impl<'a> BitReader<'a> {
56 pub fn new(data: &'a [u8]) -> Self {
57 Self {
58 data,
59 byte_pos: 0,
60 acc: 0,
61 bits_in_acc: 0,
62 }
63 }
64
65 pub fn with_position(data: &'a [u8], byte_pos: usize) -> Self {
68 let byte_pos = byte_pos.min(data.len());
69 Self {
70 data,
71 byte_pos,
72 acc: 0,
73 bits_in_acc: 0,
74 }
75 }
76
77 pub fn bit_position(&self) -> u64 {
79 self.byte_pos as u64 * 8 - self.bits_in_acc as u64
80 }
81
82 pub fn byte_position(&self) -> usize {
84 (self.bit_position() / 8) as usize
85 }
86
87 pub fn bits_remaining(&self) -> u64 {
89 self.bits_in_acc as u64 + ((self.data.len() - self.byte_pos) as u64) * 8
90 }
91
92 pub fn is_byte_aligned(&self) -> bool {
94 self.bits_in_acc % 8 == 0
95 }
96
97 pub fn align_to_byte(&mut self) {
99 let drop = self.bits_in_acc % 8;
100 self.acc <<= drop;
101 self.bits_in_acc -= drop;
102 }
103
104 fn refill(&mut self) {
105 while self.bits_in_acc <= 56 && self.byte_pos < self.data.len() {
106 self.acc |= (self.data[self.byte_pos] as u64) << (56 - self.bits_in_acc);
107 self.bits_in_acc += 8;
108 self.byte_pos += 1;
109 }
110 }
111
112 pub fn read_u32(&mut self, n: u32) -> Result<u32> {
114 debug_assert!(n <= 32, "BitReader::read_u32 supports up to 32 bits");
115 if n == 0 {
116 return Ok(0);
117 }
118 if self.bits_in_acc < n {
119 self.refill();
120 if self.bits_in_acc < n {
121 return Err(Error::invalid("bitreader: out of bits"));
122 }
123 }
124 let v = (self.acc >> (64 - n)) as u32;
125 self.acc <<= n;
126 self.bits_in_acc -= n;
127 Ok(v)
128 }
129
130 pub fn read_u64(&mut self, n: u32) -> Result<u64> {
132 debug_assert!(n <= 64);
133 if n <= 32 {
134 return self.read_u32(n).map(|v| v as u64);
135 }
136 let hi = self.read_u32(n - 32)? as u64;
137 let lo = self.read_u32(32)? as u64;
138 Ok((hi << 32) | lo)
139 }
140
141 pub fn read_i32(&mut self, n: u32) -> Result<i32> {
143 if n == 0 {
144 return Ok(0);
145 }
146 let raw = self.read_u32(n)? as i32;
147 let shift = 32 - n;
148 Ok((raw << shift) >> shift)
149 }
150
151 pub fn read_bit(&mut self) -> Result<bool> {
153 Ok(self.read_u32(1)? != 0)
154 }
155
156 pub fn read_u1(&mut self) -> Result<u32> {
158 self.read_u32(1)
159 }
160
161 pub fn peek_u32(&mut self, n: u32) -> Result<u32> {
163 debug_assert!(n <= 32);
164 if n == 0 {
165 return Ok(0);
166 }
167 if self.bits_in_acc < n {
168 self.refill();
169 if self.bits_in_acc < n {
170 return Err(Error::invalid("bitreader: out of bits for peek"));
171 }
172 }
173 Ok((self.acc >> (64 - n)) as u32)
174 }
175
176 pub fn skip(&mut self, n: u32) -> Result<()> {
178 let mut left = n;
179 while left > 32 {
180 self.read_u32(32)?;
181 left -= 32;
182 }
183 self.read_u32(left)?;
184 Ok(())
185 }
186
187 pub fn consume(&mut self, n: u32) -> Result<()> {
189 self.skip(n)
190 }
191
192 pub fn read_unary(&mut self) -> Result<u32> {
197 let mut count = 0u32;
198 loop {
199 if self.bits_in_acc == 0 {
200 self.refill();
201 if self.bits_in_acc == 0 {
202 return Err(Error::invalid("bitreader: out of bits in unary code"));
203 }
204 }
205 let lz_total = self.acc.leading_zeros();
206 let lz_avail = lz_total.min(self.bits_in_acc);
207 count = count
208 .checked_add(lz_avail)
209 .ok_or_else(|| Error::invalid("bitreader: unary count overflow"))?;
210 if lz_avail >= 64 {
213 self.acc = 0;
214 } else {
215 self.acc <<= lz_avail;
216 }
217 self.bits_in_acc -= lz_avail;
218 if lz_avail < lz_total || self.bits_in_acc == 0 {
219 continue;
220 }
221 self.acc <<= 1;
223 self.bits_in_acc -= 1;
224 return Ok(count);
225 }
226 }
227
228 pub fn read_bytes(&mut self, n: usize) -> Result<Vec<u8>> {
230 if !self.is_byte_aligned() {
231 return Err(Error::invalid(
232 "bitreader: read_bytes requires byte alignment",
233 ));
234 }
235 self.align_to_byte();
236 let start = self.byte_pos - (self.bits_in_acc as usize / 8);
237 if start + n > self.data.len() {
242 return Err(Error::invalid("bitreader: read_bytes past end"));
243 }
244 let out = self.data[start..start + n].to_vec();
245 self.acc = 0;
248 self.bits_in_acc = 0;
249 self.byte_pos = start + n;
250 Ok(out)
251 }
252}
253
254pub struct BitWriter {
256 data: Vec<u8>,
257 acc: u64,
259 bits_in_acc: u32,
261}
262
263impl BitWriter {
264 pub fn new() -> Self {
265 Self {
266 data: Vec::new(),
267 acc: 0,
268 bits_in_acc: 0,
269 }
270 }
271
272 pub fn with_capacity(cap: usize) -> Self {
273 Self {
274 data: Vec::with_capacity(cap),
275 acc: 0,
276 bits_in_acc: 0,
277 }
278 }
279
280 pub fn bit_position(&self) -> u64 {
282 self.data.len() as u64 * 8 + self.bits_in_acc as u64
283 }
284
285 pub fn byte_len(&self) -> usize {
287 self.data.len()
288 }
289
290 pub fn is_byte_aligned(&self) -> bool {
292 self.bits_in_acc % 8 == 0
293 }
294
295 pub fn write_u32(&mut self, value: u32, n: u32) {
297 debug_assert!(n <= 32, "BitWriter::write_u32 supports up to 32 bits");
298 if n == 0 {
299 return;
300 }
301 let mask: u32 = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
302 let v = (value & mask) as u64;
303 let shift = 64 - self.bits_in_acc - n;
304 self.acc |= v << shift;
305 self.bits_in_acc += n;
306 while self.bits_in_acc >= 8 {
307 let byte = (self.acc >> 56) as u8;
308 self.data.push(byte);
309 self.acc <<= 8;
310 self.bits_in_acc -= 8;
311 }
312 }
313
314 pub fn write_bits(&mut self, value: u32, n: u32) {
317 self.write_u32(value, n)
318 }
319
320 pub fn write_u64(&mut self, value: u64, n: u32) {
322 debug_assert!(n <= 64);
323 if n <= 32 {
324 self.write_u32(value as u32, n);
325 } else {
326 self.write_u32((value >> 32) as u32, n - 32);
327 self.write_u32(value as u32, 32);
328 }
329 }
330
331 pub fn write_i32(&mut self, value: i32, n: u32) {
334 self.write_u32(value as u32, n);
335 }
336
337 pub fn write_bit(&mut self, bit: bool) {
338 self.write_u32(bit as u32, 1);
339 }
340
341 pub fn write_unary(&mut self, n: u32) {
344 let mut remaining = n;
345 while remaining >= 32 {
346 self.write_u32(0, 32);
347 remaining -= 32;
348 }
349 if remaining > 0 {
350 self.write_u32(0, remaining);
351 }
352 self.write_bit(true);
353 }
354
355 pub fn write_byte(&mut self, b: u8) {
356 self.write_u32(b as u32, 8);
357 }
358
359 pub fn write_bytes(&mut self, bytes: &[u8]) {
361 if self.is_byte_aligned() {
362 self.data.extend_from_slice(bytes);
364 } else {
365 for &b in bytes {
366 self.write_u32(b as u32, 8);
367 }
368 }
369 }
370
371 pub fn align_to_byte(&mut self) {
373 let pad = (8 - self.bits_in_acc % 8) % 8;
374 if pad > 0 {
375 self.write_u32(0, pad);
376 }
377 }
378
379 pub fn align_to_byte_zero(&mut self) {
383 self.align_to_byte()
384 }
385
386 pub fn bytes(&self) -> &[u8] {
388 &self.data
389 }
390
391 pub fn buffer(&self) -> &[u8] {
393 &self.data
394 }
395
396 pub fn finish(mut self) -> Vec<u8> {
398 if self.bits_in_acc > 0 {
399 let byte = (self.acc >> 56) as u8;
400 self.data.push(byte);
401 self.acc = 0;
402 self.bits_in_acc = 0;
403 }
404 self.data
405 }
406
407 pub fn into_bytes(self) -> Vec<u8> {
409 self.finish()
410 }
411}
412
413impl Default for BitWriter {
414 fn default() -> Self {
415 Self::new()
416 }
417}
418
419#[derive(Clone, Copy)]
425pub struct BitReaderLsb<'a> {
426 data: &'a [u8],
427 byte_pos: usize,
428 acc: u64,
430 bits_in_acc: u32,
431}
432
433impl<'a> BitReaderLsb<'a> {
434 pub fn new(data: &'a [u8]) -> Self {
435 Self {
436 data,
437 byte_pos: 0,
438 acc: 0,
439 bits_in_acc: 0,
440 }
441 }
442
443 pub fn bit_position(&self) -> u64 {
444 self.byte_pos as u64 * 8 - self.bits_in_acc as u64
445 }
446
447 pub fn is_byte_aligned(&self) -> bool {
448 self.bits_in_acc % 8 == 0
449 }
450
451 fn refill(&mut self) {
452 while self.bits_in_acc <= 56 && self.byte_pos < self.data.len() {
453 self.acc |= (self.data[self.byte_pos] as u64) << self.bits_in_acc;
454 self.bits_in_acc += 8;
455 self.byte_pos += 1;
456 }
457 }
458
459 pub fn read_u32(&mut self, n: u32) -> Result<u32> {
460 debug_assert!(n <= 32, "BitReaderLsb::read_u32 supports up to 32 bits");
461 if n == 0 {
462 return Ok(0);
463 }
464 if self.bits_in_acc < n {
465 self.refill();
466 if self.bits_in_acc < n {
467 return Err(Error::Eof);
468 }
469 }
470 let mask = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
471 let v = (self.acc as u32) & mask;
472 self.acc >>= n;
473 self.bits_in_acc -= n;
474 Ok(v)
475 }
476
477 pub fn read_u64(&mut self, n: u32) -> Result<u64> {
478 debug_assert!(n <= 64);
479 if n == 0 {
480 return Ok(0);
481 }
482 if n <= 32 {
483 return Ok(self.read_u32(n)? as u64);
484 }
485 let lo = self.read_u32(32)? as u64;
486 let hi = self.read_u32(n - 32)? as u64;
487 Ok(lo | (hi << 32))
488 }
489
490 pub fn read_i32(&mut self, n: u32) -> Result<i32> {
491 if n == 0 {
492 return Ok(0);
493 }
494 let raw = self.read_u32(n)? as i32;
495 let shift = 32 - n;
496 Ok((raw << shift) >> shift)
497 }
498
499 pub fn read_bit(&mut self) -> Result<bool> {
500 Ok(self.read_u32(1)? != 0)
501 }
502}
503
504pub struct BitWriterLsb {
506 data: Vec<u8>,
507 acc: u64,
509 bits_in_acc: u32,
510}
511
512impl BitWriterLsb {
513 pub fn new() -> Self {
514 Self {
515 data: Vec::new(),
516 acc: 0,
517 bits_in_acc: 0,
518 }
519 }
520
521 pub fn with_capacity(cap: usize) -> Self {
522 Self {
523 data: Vec::with_capacity(cap),
524 acc: 0,
525 bits_in_acc: 0,
526 }
527 }
528
529 pub fn bit_position(&self) -> u64 {
530 self.data.len() as u64 * 8 + self.bits_in_acc as u64
531 }
532
533 pub fn write_u32(&mut self, value: u32, n: u32) {
534 debug_assert!(n <= 32, "BitWriterLsb::write_u32 supports up to 32 bits");
535 if n == 0 {
536 return;
537 }
538 let mask: u32 = if n == 32 { u32::MAX } else { (1u32 << n) - 1 };
539 let v = value & mask;
540 self.acc |= (v as u64) << self.bits_in_acc;
541 self.bits_in_acc += n;
542 while self.bits_in_acc >= 8 {
543 self.data.push((self.acc & 0xFF) as u8);
544 self.acc >>= 8;
545 self.bits_in_acc -= 8;
546 }
547 }
548
549 pub fn write_u64(&mut self, value: u64, n: u32) {
550 debug_assert!(n <= 64);
551 if n <= 32 {
552 self.write_u32(value as u32, n);
553 } else {
554 self.write_u32(value as u32, 32);
555 self.write_u32((value >> 32) as u32, n - 32);
556 }
557 }
558
559 pub fn write_bit(&mut self, bit: bool) {
560 self.write_u32(bit as u32, 1);
561 }
562
563 pub fn align_to_byte(&mut self) {
565 let pad = (8 - self.bits_in_acc % 8) % 8;
566 self.write_u32(0, pad);
567 }
568
569 pub fn finish(mut self) -> Vec<u8> {
570 if self.bits_in_acc > 0 {
571 self.data.push((self.acc & 0xFF) as u8);
572 self.acc = 0;
573 self.bits_in_acc = 0;
574 }
575 self.data
576 }
577}
578
579impl Default for BitWriterLsb {
580 fn default() -> Self {
581 Self::new()
582 }
583}
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
592 fn msb_roundtrip_byte() {
593 let mut w = BitWriter::new();
594 for &b in &[1u32, 0, 1, 0, 0, 1, 0, 1] {
595 w.write_u32(b, 1);
596 }
597 assert_eq!(w.finish(), vec![0xA5]);
598 }
599
600 #[test]
601 fn msb_roundtrip_varied_widths() {
602 let mut bw = BitWriter::new();
603 let writes: Vec<(u32, u32)> = vec![
604 (0b1, 1),
605 (0b10101, 5),
606 (0b111100001111, 12),
607 (0xDEADBEEF, 32),
608 (0b001, 3),
609 (0xC, 4),
610 (0xABCD, 16),
611 (0x12345, 20),
612 (0, 8),
613 (0xFFFFFFFF, 32),
614 ];
615 for &(v, n) in &writes {
616 bw.write_u32(v, n);
617 }
618 let bytes = bw.finish();
619 let mut br = BitReader::new(&bytes);
620 for &(v, n) in &writes {
621 let got = br.read_u32(n).unwrap();
622 let mask = if n == 32 { u32::MAX } else { (1 << n) - 1 };
623 assert_eq!(got, v & mask, "mismatch for ({v:#x}, {n})");
624 }
625 }
626
627 #[test]
628 fn msb_signed_extension() {
629 let mut br = BitReader::new(&[0xFF]);
630 assert_eq!(br.read_i32(4).unwrap(), -1);
631 assert_eq!(br.read_i32(4).unwrap(), -1);
632 }
633
634 #[test]
635 fn msb_peek_skip() {
636 let mut br = BitReader::new(&[0xFF, 0x00]);
637 assert_eq!(br.peek_u32(12).unwrap(), 0xFF0);
638 br.skip(4).unwrap();
639 assert_eq!(br.read_u32(8).unwrap(), 0xF0);
640 }
641
642 #[test]
643 fn msb_alignment() {
644 let mut br = BitReader::new(&[0xFF, 0x55]);
645 br.read_u32(3).unwrap();
646 assert!(!br.is_byte_aligned());
647 br.align_to_byte();
648 assert!(br.is_byte_aligned());
649 assert_eq!(br.read_u32(8).unwrap(), 0x55);
650 }
651
652 #[test]
653 fn msb_write_bytes_fast_path() {
654 let mut w = BitWriter::new();
655 w.write_bytes(&[0x11, 0x22, 0x33]);
656 assert_eq!(w.finish(), vec![0x11, 0x22, 0x33]);
657 }
658
659 #[test]
660 fn msb_write_bytes_unaligned() {
661 let mut w = BitWriter::new();
662 w.write_u32(0b101, 3);
663 w.write_bytes(&[0xFF, 0x00]);
664 let out = w.finish();
666 assert_eq!(out.len(), 3);
667 }
668
669 #[test]
670 fn msb_write_bits_alias() {
671 let mut w = BitWriter::new();
672 w.write_bits(0xA, 4);
673 w.write_u32(0x5, 4);
674 assert_eq!(w.finish(), vec![0xA5]);
675 }
676
677 #[test]
678 fn msb_into_bytes_alias() {
679 let mut w = BitWriter::new();
680 w.write_u32(0xA5, 8);
681 assert_eq!(w.into_bytes(), vec![0xA5]);
682 }
683
684 #[test]
685 fn msb_read_u64_high_bits() {
686 let mut w = BitWriter::new();
688 w.write_u64(0x1234567890ABCDEF, 64);
689 let bytes = w.finish();
690 let mut r = BitReader::new(&bytes);
691 assert_eq!(r.read_u64(64).unwrap(), 0x1234567890ABCDEF);
692 }
693
694 #[test]
695 fn msb_unary_roundtrip() {
696 let mut w = BitWriter::new();
697 let counts: Vec<u32> = vec![0, 1, 7, 31, 32, 33, 64, 65, 100];
698 for &c in &counts {
699 w.write_unary(c);
700 }
701 let bytes = w.finish();
702 let mut r = BitReader::new(&bytes);
703 for &c in &counts {
704 assert_eq!(r.read_unary().unwrap(), c, "unary roundtrip for {c}");
705 }
706 }
707
708 #[test]
709 fn msb_read_bytes_aligned() {
710 let mut br = BitReader::new(&[0xAA, 0xBB, 0xCC, 0xDD]);
711 let _ = br.read_u32(8).unwrap();
712 let got = br.read_bytes(2).unwrap();
713 assert_eq!(got, vec![0xBB, 0xCC]);
714 assert_eq!(br.read_u32(8).unwrap(), 0xDD);
715 }
716
717 #[test]
720 fn lsb_roundtrip_byte() {
721 let mut w = BitWriterLsb::new();
722 for &b in &[1u32, 0, 1, 0, 0, 1, 0, 1] {
723 w.write_u32(b, 1);
724 }
725 assert_eq!(w.finish(), vec![0xA5]);
726 }
727
728 #[test]
729 fn lsb_multi_byte() {
730 let mut w = BitWriterLsb::new();
731 w.write_u32(0x3412, 16);
732 let bytes = w.finish();
733 assert_eq!(bytes, vec![0x12, 0x34]);
734 let mut r = BitReaderLsb::new(&bytes);
735 assert_eq!(r.read_u32(16).unwrap(), 0x3412);
736 }
737
738 #[test]
739 fn lsb_roundtrip_varied_widths() {
740 let mut bw = BitWriterLsb::new();
741 let writes: Vec<(u32, u32)> = vec![(5, 3), (0xABCD, 16), (0x1234567, 27), (1, 1)];
742 for &(v, n) in &writes {
743 bw.write_u32(v, n);
744 }
745 let bytes = bw.finish();
746 let mut r = BitReaderLsb::new(&bytes);
747 for &(v, n) in &writes {
748 assert_eq!(r.read_u32(n).unwrap(), v);
749 }
750 }
751}