1use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11 #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13 PrunedChunksOverflow,
14}
15
16#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24 bitmap: BitMap<N>,
26
27 pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36 pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39 pub fn new() -> Self {
43 Self {
44 bitmap: BitMap::new(),
45 pruned_chunks: 0,
46 }
47 }
48
49 pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56 let pruned_chunks_u64 = pruned_chunks as u64;
58 pruned_chunks_u64
59 .checked_mul(Self::CHUNK_SIZE_BITS)
60 .ok_or(Error::PrunedChunksOverflow)?;
61
62 Ok(Self {
63 bitmap: BitMap::new(),
64 pruned_chunks,
65 })
66 }
67
68 #[inline]
72 pub fn len(&self) -> u64 {
73 let pruned_bits = (self.pruned_chunks as u64)
74 .checked_mul(Self::CHUNK_SIZE_BITS)
75 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77 pruned_bits
78 .checked_add(self.bitmap.len())
79 .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80 }
81
82 #[inline]
84 pub fn is_empty(&self) -> bool {
85 self.len() == 0
86 }
87
88 #[inline]
90 pub fn is_chunk_aligned(&self) -> bool {
91 self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92 }
93
94 #[inline]
96 pub fn chunks_len(&self) -> usize {
97 self.bitmap.chunks_len()
98 }
99
100 #[inline]
102 pub fn pruned_chunks(&self) -> usize {
103 self.pruned_chunks
104 }
105
106 #[inline]
108 pub fn pruned_bits(&self) -> u64 {
109 (self.pruned_chunks as u64)
110 .checked_mul(Self::CHUNK_SIZE_BITS)
111 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
112 }
113
114 #[inline]
122 pub fn get_bit(&self, bit: u64) -> bool {
123 let chunk_num = Self::unpruned_chunk(bit);
124 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
125
126 self.bitmap.get(bit - self.pruned_bits())
128 }
129
130 #[inline]
136 pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
137 let chunk_num = Self::unpruned_chunk(bit);
138 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
139
140 self.bitmap.get_chunk_containing(bit - self.pruned_bits())
142 }
143
144 #[inline]
147 pub fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
148 BitMap::<N>::get_from_chunk(chunk, bit)
149 }
150
151 #[inline]
153 pub fn last_chunk(&self) -> (&[u8; N], u64) {
154 self.bitmap.last_chunk()
155 }
156
157 pub fn set_bit(&mut self, bit: u64, value: bool) {
165 let chunk_num = Self::unpruned_chunk(bit);
166 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
167
168 self.bitmap.set(bit - self.pruned_bits(), value);
170 }
171
172 pub fn push(&mut self, bit: bool) {
174 self.bitmap.push(bit);
175 }
176
177 pub fn pop(&mut self) -> bool {
183 self.bitmap.pop()
184 }
185
186 pub fn push_byte(&mut self, byte: u8) {
192 self.bitmap.push_byte(byte);
193 }
194
195 pub fn push_chunk(&mut self, chunk: &[u8; N]) {
201 self.bitmap.push_chunk(chunk);
202 }
203
204 pub fn pop_chunk(&mut self) -> [u8; N] {
210 self.bitmap.pop_chunk()
211 }
212
213 pub fn prune_to_bit(&mut self, bit: u64) {
227 assert!(
228 bit <= self.len(),
229 "bit {} out of bounds (len: {})",
230 bit,
231 self.len()
232 );
233
234 let chunk = Self::unpruned_chunk(bit);
235 if chunk < self.pruned_chunks {
236 return;
237 }
238
239 let chunks_to_prune = chunk - self.pruned_chunks;
240 self.bitmap.prune_chunks(chunks_to_prune);
241 self.pruned_chunks = chunk;
242 }
243
244 #[inline]
248 pub fn chunk_byte_bitmask(bit: u64) -> u8 {
249 BitMap::<N>::chunk_byte_bitmask(bit)
250 }
251
252 #[inline]
254 pub fn chunk_byte_offset(bit: u64) -> usize {
255 BitMap::<N>::chunk_byte_offset(bit)
256 }
257
258 #[inline]
266 pub fn pruned_chunk(&self, bit: u64) -> usize {
267 assert!(bit < self.len(), "out of bounds: {bit}");
268 let chunk = Self::unpruned_chunk(bit);
269 assert!(chunk >= self.pruned_chunks, "bit pruned: {bit}");
270
271 chunk - self.pruned_chunks
272 }
273
274 #[inline]
281 pub fn unpruned_chunk(bit: u64) -> usize {
282 BitMap::<N>::chunk(bit)
283 }
284
285 #[inline]
288 pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
289 self.bitmap.get_chunk(chunk)
290 }
291
292 pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
298 assert!(
299 chunk_index >= self.pruned_chunks,
300 "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
301 self.pruned_chunks
302 );
303 let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
304 self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
305 }
306
307 pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
318 assert!(
319 chunks.len() <= self.pruned_chunks,
320 "cannot unprune {} chunks (only {} pruned)",
321 chunks.len(),
322 self.pruned_chunks
323 );
324
325 for chunk in chunks.iter() {
326 self.bitmap.prepend_chunk(chunk);
327 }
328
329 self.pruned_chunks -= chunks.len();
330 }
331}
332
333impl<const N: usize> Default for Prunable<N> {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339impl<const N: usize> Write for Prunable<N> {
340 fn write(&self, buf: &mut impl BufMut) {
341 (self.pruned_chunks as u64).write(buf);
342 self.bitmap.write(buf);
343 }
344}
345
346impl<const N: usize> Read for Prunable<N> {
347 type Cfg = u64;
349
350 fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
351 let pruned_chunks_u64 = u64::read(buf)?;
352
353 let pruned_bits =
355 pruned_chunks_u64
356 .checked_mul(Self::CHUNK_SIZE_BITS)
357 .ok_or(CodecError::Invalid(
358 "Prunable",
359 "pruned_chunks would overflow when computing pruned_bits",
360 ))?;
361
362 let pruned_chunks = usize::try_from(pruned_chunks_u64)
363 .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
364
365 let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
366
367 pruned_bits
369 .checked_add(bitmap.len())
370 .ok_or(CodecError::Invalid(
371 "Prunable",
372 "total bitmap length (pruned + unpruned) would overflow u64",
373 ))?;
374
375 Ok(Self {
376 bitmap,
377 pruned_chunks,
378 })
379 }
380}
381
382impl<const N: usize> EncodeSize for Prunable<N> {
383 fn encode_size(&self) -> usize {
384 (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
385 }
386}
387
388#[cfg(test)]
389mod tests {
390 use super::*;
391 use crate::hex;
392 use bytes::BytesMut;
393 use commonware_codec::Encode;
394
395 #[test]
396 fn test_new() {
397 let prunable: Prunable<32> = Prunable::new();
398 assert_eq!(prunable.len(), 0);
399 assert_eq!(prunable.pruned_bits(), 0);
400 assert_eq!(prunable.pruned_chunks(), 0);
401 assert!(prunable.is_empty());
402 assert_eq!(prunable.chunks_len(), 0); }
404
405 #[test]
406 fn test_new_with_pruned_chunks() {
407 let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
408 assert_eq!(prunable.len(), 16);
409 assert_eq!(prunable.pruned_bits(), 16);
410 assert_eq!(prunable.pruned_chunks(), 1);
411 assert_eq!(prunable.chunks_len(), 0);
412 }
413
414 #[test]
415 fn test_new_with_pruned_chunks_overflow() {
416 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
418 let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
419
420 assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
421 }
422
423 #[test]
424 fn test_push_and_get_bits() {
425 let mut prunable: Prunable<4> = Prunable::new();
426
427 prunable.push(true);
429 prunable.push(false);
430 prunable.push(true);
431
432 assert_eq!(prunable.len(), 3);
433 assert!(!prunable.is_empty());
434 assert!(prunable.get_bit(0));
435 assert!(!prunable.get_bit(1));
436 assert!(prunable.get_bit(2));
437 }
438
439 #[test]
440 fn test_push_byte() {
441 let mut prunable: Prunable<4> = Prunable::new();
442
443 prunable.push_byte(0xFF);
445 assert_eq!(prunable.len(), 8);
446
447 for i in 0..8 {
449 assert!(prunable.get_bit(i as u64));
450 }
451
452 prunable.push_byte(0x00);
453 assert_eq!(prunable.len(), 16);
454
455 for i in 8..16 {
457 assert!(!prunable.get_bit(i as u64));
458 }
459 }
460
461 #[test]
462 fn test_push_chunk() {
463 let mut prunable: Prunable<4> = Prunable::new();
464 let chunk = hex!("0xAABBCCDD");
465
466 prunable.push_chunk(&chunk);
467 assert_eq!(prunable.len(), 32); let retrieved_chunk = prunable.get_chunk_containing(0);
470 assert_eq!(retrieved_chunk, &chunk);
471 }
472
473 #[test]
474 fn test_set_bit() {
475 let mut prunable: Prunable<4> = Prunable::new();
476
477 prunable.push(false);
479 prunable.push(false);
480 prunable.push(false);
481
482 assert!(!prunable.get_bit(1));
483
484 prunable.set_bit(1, true);
486 assert!(prunable.get_bit(1));
487
488 prunable.set_bit(1, false);
490 assert!(!prunable.get_bit(1));
491 }
492
493 #[test]
494 fn test_pruning_basic() {
495 let mut prunable: Prunable<4> = Prunable::new();
496
497 let chunk1 = hex!("0x01020304");
499 let chunk2 = hex!("0x05060708");
500 let chunk3 = hex!("0x090A0B0C");
501
502 prunable.push_chunk(&chunk1);
503 prunable.push_chunk(&chunk2);
504 prunable.push_chunk(&chunk3);
505
506 assert_eq!(prunable.len(), 96); assert_eq!(prunable.pruned_chunks(), 0);
508
509 prunable.prune_to_bit(32);
511 assert_eq!(prunable.pruned_chunks(), 1);
512 assert_eq!(prunable.pruned_bits(), 32);
513 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(32), &chunk2);
517 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
518
519 prunable.prune_to_bit(64);
521 assert_eq!(prunable.pruned_chunks(), 2);
522 assert_eq!(prunable.pruned_bits(), 64);
523 assert_eq!(prunable.len(), 96);
524
525 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
527 }
528
529 #[test]
530 #[should_panic(expected = "bit pruned")]
531 fn test_get_pruned_bit_panics() {
532 let mut prunable: Prunable<4> = Prunable::new();
533
534 prunable.push_chunk(&[1, 2, 3, 4]);
536 prunable.push_chunk(&[5, 6, 7, 8]);
537
538 prunable.prune_to_bit(32);
540
541 prunable.get_bit(0);
543 }
544
545 #[test]
546 #[should_panic(expected = "bit pruned")]
547 fn test_get_pruned_chunk_panics() {
548 let mut prunable: Prunable<4> = Prunable::new();
549
550 prunable.push_chunk(&[1, 2, 3, 4]);
552 prunable.push_chunk(&[5, 6, 7, 8]);
553
554 prunable.prune_to_bit(32);
556
557 prunable.get_chunk_containing(0);
559 }
560
561 #[test]
562 #[should_panic(expected = "bit pruned")]
563 fn test_set_pruned_bit_panics() {
564 let mut prunable: Prunable<4> = Prunable::new();
565
566 prunable.push_chunk(&[1, 2, 3, 4]);
568 prunable.push_chunk(&[5, 6, 7, 8]);
569
570 prunable.prune_to_bit(32);
572
573 prunable.set_bit(0, true);
575 }
576
577 #[test]
578 #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
579 fn test_prune_to_bit_out_of_bounds() {
580 let mut prunable: Prunable<1> = Prunable::new();
581
582 prunable.push_byte(1);
584 prunable.push_byte(2);
585 prunable.push_byte(3);
586
587 prunable.prune_to_bit(25);
589 }
590
591 #[test]
592 fn test_pruning_with_partial_chunk() {
593 let mut prunable: Prunable<4> = Prunable::new();
594
595 prunable.push_chunk(&[0xFF; 4]);
597 prunable.push_chunk(&[0xAA; 4]);
598 prunable.push(true);
599 prunable.push(false);
600 prunable.push(true);
601
602 assert_eq!(prunable.len(), 67); prunable.prune_to_bit(32);
606 assert_eq!(prunable.pruned_chunks(), 1);
607 assert_eq!(prunable.len(), 67);
608
609 assert!(prunable.get_bit(64));
611 assert!(!prunable.get_bit(65));
612 assert!(prunable.get_bit(66));
613 }
614
615 #[test]
616 fn test_prune_idempotent() {
617 let mut prunable: Prunable<4> = Prunable::new();
618
619 prunable.push_chunk(&[1, 2, 3, 4]);
621 prunable.push_chunk(&[5, 6, 7, 8]);
622
623 prunable.prune_to_bit(32);
625 assert_eq!(prunable.pruned_chunks(), 1);
626
627 prunable.prune_to_bit(32);
629 assert_eq!(prunable.pruned_chunks(), 1);
630
631 prunable.prune_to_bit(16);
632 assert_eq!(prunable.pruned_chunks(), 1);
633 }
634
635 #[test]
636 fn test_push_after_pruning() {
637 let mut prunable: Prunable<4> = Prunable::new();
638
639 prunable.push_chunk(&[1, 2, 3, 4]);
641 prunable.push_chunk(&[5, 6, 7, 8]);
642
643 prunable.prune_to_bit(32);
645 assert_eq!(prunable.len(), 64);
646 assert_eq!(prunable.pruned_chunks(), 1);
647
648 prunable.push_chunk(&[9, 10, 11, 12]);
650 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
654 }
655
656 #[test]
657 fn test_chunk_calculations() {
658 assert_eq!(Prunable::<4>::unpruned_chunk(0), 0);
660 assert_eq!(Prunable::<4>::unpruned_chunk(31), 0);
661 assert_eq!(Prunable::<4>::unpruned_chunk(32), 1);
662 assert_eq!(Prunable::<4>::unpruned_chunk(63), 1);
663 assert_eq!(Prunable::<4>::unpruned_chunk(64), 2);
664
665 assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
667 assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
668 assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
669 assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
670 assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
674 assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
675 assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
676 assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); }
678
679 #[test]
680 fn test_pruned_chunk() {
681 let mut prunable: Prunable<4> = Prunable::new();
682
683 for i in 0..3 {
685 let chunk = [
686 (i * 4) as u8,
687 (i * 4 + 1) as u8,
688 (i * 4 + 2) as u8,
689 (i * 4 + 3) as u8,
690 ];
691 prunable.push_chunk(&chunk);
692 }
693
694 assert_eq!(prunable.pruned_chunk(0), 0);
696 assert_eq!(prunable.pruned_chunk(32), 1);
697 assert_eq!(prunable.pruned_chunk(64), 2);
698
699 prunable.prune_to_bit(32);
701 assert_eq!(prunable.pruned_chunk(32), 0); assert_eq!(prunable.pruned_chunk(64), 1); }
704
705 #[test]
706 fn test_last_chunk_with_pruning() {
707 let mut prunable: Prunable<4> = Prunable::new();
708
709 prunable.push_chunk(&[1, 2, 3, 4]);
711 prunable.push_chunk(&[5, 6, 7, 8]);
712 prunable.push(true);
713 prunable.push(false);
714
715 let (_, next_bit) = prunable.last_chunk();
716 assert_eq!(next_bit, 2);
717
718 let chunk_data = *prunable.last_chunk().0;
720
721 prunable.prune_to_bit(32);
723 let (chunk2, next_bit2) = prunable.last_chunk();
724 assert_eq!(next_bit2, 2);
725 assert_eq!(&chunk_data, chunk2);
726 }
727
728 #[test]
729 fn test_different_chunk_sizes() {
730 let mut p8: Prunable<8> = Prunable::new();
732 let mut p16: Prunable<16> = Prunable::new();
733 let mut p32: Prunable<32> = Prunable::new();
734
735 for i in 0..10 {
737 p8.push(i % 2 == 0);
738 p16.push(i % 2 == 0);
739 p32.push(i % 2 == 0);
740 }
741
742 assert_eq!(p8.len(), 10);
744 assert_eq!(p16.len(), 10);
745 assert_eq!(p32.len(), 10);
746
747 for i in 0..10 {
749 let expected = i % 2 == 0;
750 if expected {
751 assert!(p8.get_bit(i));
752 assert!(p16.get_bit(i));
753 assert!(p32.get_bit(i));
754 } else {
755 assert!(!p8.get_bit(i));
756 assert!(!p16.get_bit(i));
757 assert!(!p32.get_bit(i));
758 }
759 }
760 }
761
762 #[test]
763 fn test_get_bit_from_chunk() {
764 let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
765
766 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
768 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
769 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
770 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
771
772 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
774 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
775 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
776 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
777 }
778
779 #[test]
780 fn test_get_chunk() {
781 let mut prunable: Prunable<4> = Prunable::new();
782 let chunk1 = hex!("0x11223344");
783 let chunk2 = hex!("0x55667788");
784 let chunk3 = hex!("0x99AABBCC");
785
786 prunable.push_chunk(&chunk1);
787 prunable.push_chunk(&chunk2);
788 prunable.push_chunk(&chunk3);
789
790 assert_eq!(prunable.get_chunk(0), &chunk1);
792 assert_eq!(prunable.get_chunk(1), &chunk2);
793 assert_eq!(prunable.get_chunk(2), &chunk3);
794
795 prunable.prune_to_bit(32);
797 assert_eq!(prunable.get_chunk(0), &chunk2);
798 assert_eq!(prunable.get_chunk(1), &chunk3);
799 }
800
801 #[test]
802 fn test_pop() {
803 let mut prunable: Prunable<4> = Prunable::new();
804
805 prunable.push(true);
806 prunable.push(false);
807 prunable.push(true);
808 assert_eq!(prunable.len(), 3);
809
810 assert!(prunable.pop());
811 assert_eq!(prunable.len(), 2);
812
813 assert!(!prunable.pop());
814 assert_eq!(prunable.len(), 1);
815
816 assert!(prunable.pop());
817 assert_eq!(prunable.len(), 0);
818 assert!(prunable.is_empty());
819
820 for i in 0..100 {
821 prunable.push(i % 3 == 0);
822 }
823 assert_eq!(prunable.len(), 100);
824
825 for i in (0..100).rev() {
826 let expected = i % 3 == 0;
827 assert_eq!(prunable.pop(), expected);
828 assert_eq!(prunable.len(), i);
829 }
830
831 assert!(prunable.is_empty());
832 }
833
834 #[test]
835 fn test_pop_chunk() {
836 let mut prunable: Prunable<4> = Prunable::new();
837 const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
838
839 let chunk1 = hex!("0xAABBCCDD");
841 prunable.push_chunk(&chunk1);
842 assert_eq!(prunable.len(), CHUNK_SIZE);
843 let popped = prunable.pop_chunk();
844 assert_eq!(popped, chunk1);
845 assert_eq!(prunable.len(), 0);
846 assert!(prunable.is_empty());
847
848 let chunk2 = hex!("0x11223344");
850 let chunk3 = hex!("0x55667788");
851 let chunk4 = hex!("0x99AABBCC");
852
853 prunable.push_chunk(&chunk2);
854 prunable.push_chunk(&chunk3);
855 prunable.push_chunk(&chunk4);
856 assert_eq!(prunable.len(), CHUNK_SIZE * 3);
857
858 assert_eq!(prunable.pop_chunk(), chunk4);
859 assert_eq!(prunable.len(), CHUNK_SIZE * 2);
860
861 assert_eq!(prunable.pop_chunk(), chunk3);
862 assert_eq!(prunable.len(), CHUNK_SIZE);
863
864 assert_eq!(prunable.pop_chunk(), chunk2);
865 assert_eq!(prunable.len(), 0);
866
867 prunable = Prunable::new();
869 let first_chunk = hex!("0xAABBCCDD");
870 let second_chunk = hex!("0x11223344");
871 prunable.push_chunk(&first_chunk);
872 prunable.push_chunk(&second_chunk);
873
874 assert_eq!(prunable.pop_chunk(), second_chunk);
876 assert_eq!(prunable.len(), CHUNK_SIZE);
877
878 for i in 0..CHUNK_SIZE {
879 let byte_idx = (i / 8) as usize;
880 let bit_idx = i % 8;
881 let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
882 assert_eq!(prunable.get_bit(i), expected);
883 }
884
885 assert_eq!(prunable.pop_chunk(), first_chunk);
886 assert_eq!(prunable.len(), 0);
887 }
888
889 #[test]
890 #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
891 fn test_pop_chunk_not_aligned() {
892 let mut prunable: Prunable<4> = Prunable::new();
893
894 prunable.push_chunk(&[0xFF; 4]);
896 prunable.push(true);
897
898 prunable.pop_chunk();
900 }
901
902 #[test]
903 #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
904 fn test_pop_chunk_insufficient_bits() {
905 let mut prunable: Prunable<4> = Prunable::new();
906
907 prunable.push(true);
909 prunable.push(false);
910
911 prunable.pop_chunk();
913 }
914
915 #[test]
916 fn test_write_read_empty() {
917 let original: Prunable<4> = Prunable::new();
918 let encoded = original.encode();
919
920 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
921 assert_eq!(decoded.len(), original.len());
922 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
923 assert!(decoded.is_empty());
924 }
925
926 #[test]
927 fn test_write_read_non_empty() {
928 let mut original: Prunable<4> = Prunable::new();
929 original.push_chunk(&hex!("0xAABBCCDD"));
930 original.push_chunk(&hex!("0x11223344"));
931 original.push(true);
932 original.push(false);
933 original.push(true);
934
935 let encoded = original.encode();
936 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
937
938 assert_eq!(decoded.len(), original.len());
939 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
940 assert_eq!(decoded.len(), 67);
941
942 for i in 0..original.len() {
944 assert_eq!(decoded.get_bit(i), original.get_bit(i));
945 }
946 }
947
948 #[test]
949 fn test_write_read_with_pruning() {
950 let mut original: Prunable<4> = Prunable::new();
951 original.push_chunk(&hex!("0x01020304"));
952 original.push_chunk(&hex!("0x05060708"));
953 original.push_chunk(&hex!("0x090A0B0C"));
954
955 original.prune_to_bit(32);
957 assert_eq!(original.pruned_chunks(), 1);
958 assert_eq!(original.len(), 96);
959
960 let encoded = original.encode();
961 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
962
963 assert_eq!(decoded.len(), original.len());
964 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
965 assert_eq!(decoded.pruned_chunks(), 1);
966 assert_eq!(decoded.len(), 96);
967
968 assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
970 assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
971 }
972
973 #[test]
974 fn test_write_read_with_pruning_2() {
975 let mut original: Prunable<4> = Prunable::new();
976
977 for i in 0..5 {
979 let chunk = [
980 (i * 4) as u8,
981 (i * 4 + 1) as u8,
982 (i * 4 + 2) as u8,
983 (i * 4 + 3) as u8,
984 ];
985 original.push_chunk(&chunk);
986 }
987
988 original.prune_to_bit(96); assert_eq!(original.pruned_chunks(), 3);
991 assert_eq!(original.len(), 160);
992
993 let encoded = original.encode();
994 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
995
996 assert_eq!(decoded.len(), original.len());
997 assert_eq!(decoded.pruned_chunks(), 3);
998
999 for i in 96..original.len() {
1001 assert_eq!(decoded.get_bit(i), original.get_bit(i));
1002 }
1003 }
1004
1005 #[test]
1006 fn test_encode_size_matches() {
1007 let mut prunable: Prunable<4> = Prunable::new();
1008 prunable.push_chunk(&[1, 2, 3, 4]);
1009 prunable.push_chunk(&[5, 6, 7, 8]);
1010 prunable.push(true);
1011
1012 let size = prunable.encode_size();
1013 let encoded = prunable.encode();
1014
1015 assert_eq!(size, encoded.len());
1016 }
1017
1018 #[test]
1019 fn test_encode_size_with_pruning() {
1020 let mut prunable: Prunable<4> = Prunable::new();
1021 prunable.push_chunk(&[1, 2, 3, 4]);
1022 prunable.push_chunk(&[5, 6, 7, 8]);
1023 prunable.push_chunk(&[9, 10, 11, 12]);
1024
1025 prunable.prune_to_bit(32);
1026
1027 let size = prunable.encode_size();
1028 let encoded = prunable.encode();
1029
1030 assert_eq!(size, encoded.len());
1031 }
1032
1033 #[test]
1034 fn test_read_max_len_validation() {
1035 let mut original: Prunable<4> = Prunable::new();
1036 for _ in 0..10 {
1037 original.push(true);
1038 }
1039
1040 let encoded = original.encode();
1041
1042 assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1044
1045 let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1047 assert!(result.is_err());
1048 }
1049
1050 #[test]
1051 fn test_codec_roundtrip_different_chunk_sizes() {
1052 let mut p8: Prunable<8> = Prunable::new();
1054 let mut p16: Prunable<16> = Prunable::new();
1055 let mut p32: Prunable<32> = Prunable::new();
1056
1057 for i in 0..100 {
1058 let bit = i % 3 == 0;
1059 p8.push(bit);
1060 p16.push(bit);
1061 p32.push(bit);
1062 }
1063
1064 let encoded8 = p8.encode();
1066 let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1067 assert_eq!(decoded8.len(), p8.len());
1068
1069 let encoded16 = p16.encode();
1070 let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1071 assert_eq!(decoded16.len(), p16.len());
1072
1073 let encoded32 = p32.encode();
1074 let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1075 assert_eq!(decoded32.len(), p32.len());
1076 }
1077
1078 #[test]
1079 fn test_read_pruned_chunks_overflow() {
1080 let mut buf = BytesMut::new();
1081
1082 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1084 overflowing_pruned_chunks.write(&mut buf);
1085
1086 0u64.write(&mut buf); let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1091 match result {
1092 Err(CodecError::Invalid(type_name, msg)) => {
1093 assert_eq!(type_name, "Prunable");
1094 assert_eq!(
1095 msg,
1096 "pruned_chunks would overflow when computing pruned_bits"
1097 );
1098 }
1099 Ok(_) => panic!("Expected error but got Ok"),
1100 Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1101 }
1102 }
1103
1104 #[test]
1105 fn test_read_total_length_overflow() {
1106 let mut buf = BytesMut::new();
1107
1108 let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1110 let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1111
1112 let remaining_space = u64::MAX - pruned_bits;
1114 let bitmap_len = remaining_space + 1; max_safe_pruned_chunks.write(&mut buf);
1118 bitmap_len.write(&mut buf);
1119
1120 let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1122 for _ in 0..(num_chunks * 4) {
1123 0u8.write(&mut buf);
1124 }
1125
1126 let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1128 match result {
1129 Err(CodecError::Invalid(type_name, msg)) => {
1130 assert_eq!(type_name, "Prunable");
1131 assert_eq!(
1132 msg,
1133 "total bitmap length (pruned + unpruned) would overflow u64"
1134 );
1135 }
1136 Ok(_) => panic!("Expected error but got Ok"),
1137 Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1138 }
1139 }
1140
1141 #[test]
1142 fn test_is_chunk_aligned() {
1143 let prunable: Prunable<4> = Prunable::new();
1145 assert!(prunable.is_chunk_aligned());
1146
1147 let mut prunable: Prunable<4> = Prunable::new();
1149 for i in 1..=32 {
1150 prunable.push(i % 2 == 0);
1151 if i == 32 {
1152 assert!(prunable.is_chunk_aligned()); } else {
1154 assert!(!prunable.is_chunk_aligned()); }
1156 }
1157
1158 for i in 33..=64 {
1160 prunable.push(i % 2 == 0);
1161 if i == 64 {
1162 assert!(prunable.is_chunk_aligned()); } else {
1164 assert!(!prunable.is_chunk_aligned()); }
1166 }
1167
1168 let mut prunable: Prunable<4> = Prunable::new();
1170 assert!(prunable.is_chunk_aligned());
1171 prunable.push_chunk(&[1, 2, 3, 4]);
1172 assert!(prunable.is_chunk_aligned()); prunable.push_chunk(&[5, 6, 7, 8]);
1174 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1176 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1180 prunable.push_chunk(&[1, 2, 3, 4]);
1181 prunable.push_chunk(&[5, 6, 7, 8]);
1182 prunable.push_chunk(&[9, 10, 11, 12]);
1183 assert!(prunable.is_chunk_aligned()); prunable.prune_to_bit(32);
1187 assert!(prunable.is_chunk_aligned());
1188 assert_eq!(prunable.len(), 96);
1189
1190 prunable.push(true);
1192 prunable.push(false);
1193 assert!(!prunable.is_chunk_aligned()); prunable.prune_to_bit(64);
1197 assert!(!prunable.is_chunk_aligned()); let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1201 assert!(prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1204 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1206 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1210 for _ in 0..4 {
1211 prunable.push_byte(0xFF);
1212 }
1213 assert!(prunable.is_chunk_aligned()); prunable.pop();
1217 assert!(!prunable.is_chunk_aligned()); for _ in 0..31 {
1221 prunable.pop();
1222 }
1223 assert!(prunable.is_chunk_aligned()); }
1225}