1use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11 #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13 PrunedChunksOverflow,
14}
15
16#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24 bitmap: BitMap<N>,
26
27 pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36 pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39 pub const fn new() -> Self {
43 Self {
44 bitmap: BitMap::new(),
45 pruned_chunks: 0,
46 }
47 }
48
49 pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56 let pruned_chunks_u64 = pruned_chunks as u64;
58 pruned_chunks_u64
59 .checked_mul(Self::CHUNK_SIZE_BITS)
60 .ok_or(Error::PrunedChunksOverflow)?;
61
62 Ok(Self {
63 bitmap: BitMap::new(),
64 pruned_chunks,
65 })
66 }
67
68 #[inline]
72 pub const fn len(&self) -> u64 {
73 let pruned_bits = (self.pruned_chunks as u64)
74 .checked_mul(Self::CHUNK_SIZE_BITS)
75 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77 pruned_bits
78 .checked_add(self.bitmap.len())
79 .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80 }
81
82 #[inline]
84 pub const fn is_empty(&self) -> bool {
85 self.len() == 0
86 }
87
88 #[inline]
90 pub const fn is_chunk_aligned(&self) -> bool {
91 self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92 }
93
94 #[inline]
96 pub fn chunks_len(&self) -> usize {
97 self.bitmap.chunks_len()
98 }
99
100 #[inline]
102 pub const fn pruned_chunks(&self) -> usize {
103 self.pruned_chunks
104 }
105
106 #[inline]
108 pub const fn pruned_bits(&self) -> u64 {
109 (self.pruned_chunks as u64)
110 .checked_mul(Self::CHUNK_SIZE_BITS)
111 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
112 }
113
114 #[inline]
122 pub fn get_bit(&self, bit: u64) -> bool {
123 let chunk_num = Self::unpruned_chunk(bit);
124 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
125
126 self.bitmap.get(bit - self.pruned_bits())
128 }
129
130 #[inline]
136 pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
137 let chunk_num = Self::unpruned_chunk(bit);
138 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
139
140 self.bitmap.get_chunk_containing(bit - self.pruned_bits())
142 }
143
144 #[inline]
147 pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
148 BitMap::<N>::get_from_chunk(chunk, bit)
149 }
150
151 #[inline]
153 pub fn last_chunk(&self) -> (&[u8; N], u64) {
154 self.bitmap.last_chunk()
155 }
156
157 pub fn set_bit(&mut self, bit: u64, value: bool) {
165 let chunk_num = Self::unpruned_chunk(bit);
166 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
167
168 self.bitmap.set(bit - self.pruned_bits(), value);
170 }
171
172 pub fn push(&mut self, bit: bool) {
174 self.bitmap.push(bit);
175 }
176
177 pub fn pop(&mut self) -> bool {
183 self.bitmap.pop()
184 }
185
186 pub fn push_byte(&mut self, byte: u8) {
192 self.bitmap.push_byte(byte);
193 }
194
195 pub fn push_chunk(&mut self, chunk: &[u8; N]) {
201 self.bitmap.push_chunk(chunk);
202 }
203
204 pub fn pop_chunk(&mut self) -> [u8; N] {
210 self.bitmap.pop_chunk()
211 }
212
213 pub fn prune_to_bit(&mut self, bit: u64) {
227 assert!(
228 bit <= self.len(),
229 "bit {} out of bounds (len: {})",
230 bit,
231 self.len()
232 );
233
234 let chunk = Self::unpruned_chunk(bit);
235 if chunk < self.pruned_chunks {
236 return;
237 }
238
239 let chunks_to_prune = chunk - self.pruned_chunks;
240 self.bitmap.prune_chunks(chunks_to_prune);
241 self.pruned_chunks = chunk;
242 }
243
244 #[inline]
248 pub const fn chunk_byte_bitmask(bit: u64) -> u8 {
249 BitMap::<N>::chunk_byte_bitmask(bit)
250 }
251
252 #[inline]
254 pub const fn chunk_byte_offset(bit: u64) -> usize {
255 BitMap::<N>::chunk_byte_offset(bit)
256 }
257
258 #[inline]
266 pub fn pruned_chunk(&self, bit: u64) -> usize {
267 assert!(bit < self.len(), "out of bounds: {bit}");
268 let chunk = Self::unpruned_chunk(bit);
269 assert!(chunk >= self.pruned_chunks, "bit pruned: {bit}");
270
271 chunk - self.pruned_chunks
272 }
273
274 #[inline]
281 pub fn unpruned_chunk(bit: u64) -> usize {
282 BitMap::<N>::chunk(bit)
283 }
284
285 #[inline]
288 pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
289 self.bitmap.get_chunk(chunk)
290 }
291
292 pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
298 assert!(
299 chunk_index >= self.pruned_chunks,
300 "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
301 self.pruned_chunks
302 );
303 let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
304 self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
305 }
306
307 pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
318 assert!(
319 chunks.len() <= self.pruned_chunks,
320 "cannot unprune {} chunks (only {} pruned)",
321 chunks.len(),
322 self.pruned_chunks
323 );
324
325 for chunk in chunks.iter() {
326 self.bitmap.prepend_chunk(chunk);
327 }
328
329 self.pruned_chunks -= chunks.len();
330 }
331}
332
333impl<const N: usize> Default for Prunable<N> {
334 fn default() -> Self {
335 Self::new()
336 }
337}
338
339impl<const N: usize> Write for Prunable<N> {
340 fn write(&self, buf: &mut impl BufMut) {
341 (self.pruned_chunks as u64).write(buf);
342 self.bitmap.write(buf);
343 }
344}
345
346impl<const N: usize> Read for Prunable<N> {
347 type Cfg = u64;
349
350 fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
351 let pruned_chunks_u64 = u64::read(buf)?;
352
353 let pruned_bits =
355 pruned_chunks_u64
356 .checked_mul(Self::CHUNK_SIZE_BITS)
357 .ok_or(CodecError::Invalid(
358 "Prunable",
359 "pruned_chunks would overflow when computing pruned_bits",
360 ))?;
361
362 let pruned_chunks = usize::try_from(pruned_chunks_u64)
363 .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
364
365 let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
366
367 pruned_bits
369 .checked_add(bitmap.len())
370 .ok_or(CodecError::Invalid(
371 "Prunable",
372 "total bitmap length (pruned + unpruned) would overflow u64",
373 ))?;
374
375 Ok(Self {
376 bitmap,
377 pruned_chunks,
378 })
379 }
380}
381
382impl<const N: usize> EncodeSize for Prunable<N> {
383 fn encode_size(&self) -> usize {
384 (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
385 }
386}
387
388#[cfg(feature = "arbitrary")]
389impl<const N: usize> arbitrary::Arbitrary<'_> for Prunable<N> {
390 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
391 let mut bitmap = Self {
392 bitmap: BitMap::<N>::arbitrary(u)?,
393 pruned_chunks: 0,
394 };
395 let prune_to = u.int_in_range(0..=bitmap.len())?;
396 bitmap.prune_to_bit(prune_to);
397 Ok(bitmap)
398 }
399}
400
401#[cfg(test)]
402mod tests {
403 use super::*;
404 use crate::hex;
405 use bytes::BytesMut;
406 use commonware_codec::Encode;
407
408 #[test]
409 fn test_new() {
410 let prunable: Prunable<32> = Prunable::new();
411 assert_eq!(prunable.len(), 0);
412 assert_eq!(prunable.pruned_bits(), 0);
413 assert_eq!(prunable.pruned_chunks(), 0);
414 assert!(prunable.is_empty());
415 assert_eq!(prunable.chunks_len(), 0); }
417
418 #[test]
419 fn test_new_with_pruned_chunks() {
420 let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
421 assert_eq!(prunable.len(), 16);
422 assert_eq!(prunable.pruned_bits(), 16);
423 assert_eq!(prunable.pruned_chunks(), 1);
424 assert_eq!(prunable.chunks_len(), 0);
425 }
426
427 #[test]
428 fn test_new_with_pruned_chunks_overflow() {
429 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
431 let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
432
433 assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
434 }
435
436 #[test]
437 fn test_push_and_get_bits() {
438 let mut prunable: Prunable<4> = Prunable::new();
439
440 prunable.push(true);
442 prunable.push(false);
443 prunable.push(true);
444
445 assert_eq!(prunable.len(), 3);
446 assert!(!prunable.is_empty());
447 assert!(prunable.get_bit(0));
448 assert!(!prunable.get_bit(1));
449 assert!(prunable.get_bit(2));
450 }
451
452 #[test]
453 fn test_push_byte() {
454 let mut prunable: Prunable<4> = Prunable::new();
455
456 prunable.push_byte(0xFF);
458 assert_eq!(prunable.len(), 8);
459
460 for i in 0..8 {
462 assert!(prunable.get_bit(i as u64));
463 }
464
465 prunable.push_byte(0x00);
466 assert_eq!(prunable.len(), 16);
467
468 for i in 8..16 {
470 assert!(!prunable.get_bit(i as u64));
471 }
472 }
473
474 #[test]
475 fn test_push_chunk() {
476 let mut prunable: Prunable<4> = Prunable::new();
477 let chunk = hex!("0xAABBCCDD");
478
479 prunable.push_chunk(&chunk);
480 assert_eq!(prunable.len(), 32); let retrieved_chunk = prunable.get_chunk_containing(0);
483 assert_eq!(retrieved_chunk, &chunk);
484 }
485
486 #[test]
487 fn test_set_bit() {
488 let mut prunable: Prunable<4> = Prunable::new();
489
490 prunable.push(false);
492 prunable.push(false);
493 prunable.push(false);
494
495 assert!(!prunable.get_bit(1));
496
497 prunable.set_bit(1, true);
499 assert!(prunable.get_bit(1));
500
501 prunable.set_bit(1, false);
503 assert!(!prunable.get_bit(1));
504 }
505
506 #[test]
507 fn test_pruning_basic() {
508 let mut prunable: Prunable<4> = Prunable::new();
509
510 let chunk1 = hex!("0x01020304");
512 let chunk2 = hex!("0x05060708");
513 let chunk3 = hex!("0x090A0B0C");
514
515 prunable.push_chunk(&chunk1);
516 prunable.push_chunk(&chunk2);
517 prunable.push_chunk(&chunk3);
518
519 assert_eq!(prunable.len(), 96); assert_eq!(prunable.pruned_chunks(), 0);
521
522 prunable.prune_to_bit(32);
524 assert_eq!(prunable.pruned_chunks(), 1);
525 assert_eq!(prunable.pruned_bits(), 32);
526 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(32), &chunk2);
530 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
531
532 prunable.prune_to_bit(64);
534 assert_eq!(prunable.pruned_chunks(), 2);
535 assert_eq!(prunable.pruned_bits(), 64);
536 assert_eq!(prunable.len(), 96);
537
538 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
540 }
541
542 #[test]
543 #[should_panic(expected = "bit pruned")]
544 fn test_get_pruned_bit_panics() {
545 let mut prunable: Prunable<4> = Prunable::new();
546
547 prunable.push_chunk(&[1, 2, 3, 4]);
549 prunable.push_chunk(&[5, 6, 7, 8]);
550
551 prunable.prune_to_bit(32);
553
554 prunable.get_bit(0);
556 }
557
558 #[test]
559 #[should_panic(expected = "bit pruned")]
560 fn test_get_pruned_chunk_panics() {
561 let mut prunable: Prunable<4> = Prunable::new();
562
563 prunable.push_chunk(&[1, 2, 3, 4]);
565 prunable.push_chunk(&[5, 6, 7, 8]);
566
567 prunable.prune_to_bit(32);
569
570 prunable.get_chunk_containing(0);
572 }
573
574 #[test]
575 #[should_panic(expected = "bit pruned")]
576 fn test_set_pruned_bit_panics() {
577 let mut prunable: Prunable<4> = Prunable::new();
578
579 prunable.push_chunk(&[1, 2, 3, 4]);
581 prunable.push_chunk(&[5, 6, 7, 8]);
582
583 prunable.prune_to_bit(32);
585
586 prunable.set_bit(0, true);
588 }
589
590 #[test]
591 #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
592 fn test_prune_to_bit_out_of_bounds() {
593 let mut prunable: Prunable<1> = Prunable::new();
594
595 prunable.push_byte(1);
597 prunable.push_byte(2);
598 prunable.push_byte(3);
599
600 prunable.prune_to_bit(25);
602 }
603
604 #[test]
605 fn test_pruning_with_partial_chunk() {
606 let mut prunable: Prunable<4> = Prunable::new();
607
608 prunable.push_chunk(&[0xFF; 4]);
610 prunable.push_chunk(&[0xAA; 4]);
611 prunable.push(true);
612 prunable.push(false);
613 prunable.push(true);
614
615 assert_eq!(prunable.len(), 67); prunable.prune_to_bit(32);
619 assert_eq!(prunable.pruned_chunks(), 1);
620 assert_eq!(prunable.len(), 67);
621
622 assert!(prunable.get_bit(64));
624 assert!(!prunable.get_bit(65));
625 assert!(prunable.get_bit(66));
626 }
627
628 #[test]
629 fn test_prune_idempotent() {
630 let mut prunable: Prunable<4> = Prunable::new();
631
632 prunable.push_chunk(&[1, 2, 3, 4]);
634 prunable.push_chunk(&[5, 6, 7, 8]);
635
636 prunable.prune_to_bit(32);
638 assert_eq!(prunable.pruned_chunks(), 1);
639
640 prunable.prune_to_bit(32);
642 assert_eq!(prunable.pruned_chunks(), 1);
643
644 prunable.prune_to_bit(16);
645 assert_eq!(prunable.pruned_chunks(), 1);
646 }
647
648 #[test]
649 fn test_push_after_pruning() {
650 let mut prunable: Prunable<4> = Prunable::new();
651
652 prunable.push_chunk(&[1, 2, 3, 4]);
654 prunable.push_chunk(&[5, 6, 7, 8]);
655
656 prunable.prune_to_bit(32);
658 assert_eq!(prunable.len(), 64);
659 assert_eq!(prunable.pruned_chunks(), 1);
660
661 prunable.push_chunk(&[9, 10, 11, 12]);
663 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
667 }
668
669 #[test]
670 fn test_chunk_calculations() {
671 assert_eq!(Prunable::<4>::unpruned_chunk(0), 0);
673 assert_eq!(Prunable::<4>::unpruned_chunk(31), 0);
674 assert_eq!(Prunable::<4>::unpruned_chunk(32), 1);
675 assert_eq!(Prunable::<4>::unpruned_chunk(63), 1);
676 assert_eq!(Prunable::<4>::unpruned_chunk(64), 2);
677
678 assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
680 assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
681 assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
682 assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
683 assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
687 assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
688 assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
689 assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); }
691
692 #[test]
693 fn test_pruned_chunk() {
694 let mut prunable: Prunable<4> = Prunable::new();
695
696 for i in 0..3 {
698 let chunk = [
699 (i * 4) as u8,
700 (i * 4 + 1) as u8,
701 (i * 4 + 2) as u8,
702 (i * 4 + 3) as u8,
703 ];
704 prunable.push_chunk(&chunk);
705 }
706
707 assert_eq!(prunable.pruned_chunk(0), 0);
709 assert_eq!(prunable.pruned_chunk(32), 1);
710 assert_eq!(prunable.pruned_chunk(64), 2);
711
712 prunable.prune_to_bit(32);
714 assert_eq!(prunable.pruned_chunk(32), 0); assert_eq!(prunable.pruned_chunk(64), 1); }
717
718 #[test]
719 fn test_last_chunk_with_pruning() {
720 let mut prunable: Prunable<4> = Prunable::new();
721
722 prunable.push_chunk(&[1, 2, 3, 4]);
724 prunable.push_chunk(&[5, 6, 7, 8]);
725 prunable.push(true);
726 prunable.push(false);
727
728 let (_, next_bit) = prunable.last_chunk();
729 assert_eq!(next_bit, 2);
730
731 let chunk_data = *prunable.last_chunk().0;
733
734 prunable.prune_to_bit(32);
736 let (chunk2, next_bit2) = prunable.last_chunk();
737 assert_eq!(next_bit2, 2);
738 assert_eq!(&chunk_data, chunk2);
739 }
740
741 #[test]
742 fn test_different_chunk_sizes() {
743 let mut p8: Prunable<8> = Prunable::new();
745 let mut p16: Prunable<16> = Prunable::new();
746 let mut p32: Prunable<32> = Prunable::new();
747
748 for i in 0..10 {
750 p8.push(i % 2 == 0);
751 p16.push(i % 2 == 0);
752 p32.push(i % 2 == 0);
753 }
754
755 assert_eq!(p8.len(), 10);
757 assert_eq!(p16.len(), 10);
758 assert_eq!(p32.len(), 10);
759
760 for i in 0..10 {
762 let expected = i % 2 == 0;
763 if expected {
764 assert!(p8.get_bit(i));
765 assert!(p16.get_bit(i));
766 assert!(p32.get_bit(i));
767 } else {
768 assert!(!p8.get_bit(i));
769 assert!(!p16.get_bit(i));
770 assert!(!p32.get_bit(i));
771 }
772 }
773 }
774
775 #[test]
776 fn test_get_bit_from_chunk() {
777 let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
778
779 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
781 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
782 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
783 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
784
785 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
787 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
788 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
789 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
790 }
791
792 #[test]
793 fn test_get_chunk() {
794 let mut prunable: Prunable<4> = Prunable::new();
795 let chunk1 = hex!("0x11223344");
796 let chunk2 = hex!("0x55667788");
797 let chunk3 = hex!("0x99AABBCC");
798
799 prunable.push_chunk(&chunk1);
800 prunable.push_chunk(&chunk2);
801 prunable.push_chunk(&chunk3);
802
803 assert_eq!(prunable.get_chunk(0), &chunk1);
805 assert_eq!(prunable.get_chunk(1), &chunk2);
806 assert_eq!(prunable.get_chunk(2), &chunk3);
807
808 prunable.prune_to_bit(32);
810 assert_eq!(prunable.get_chunk(0), &chunk2);
811 assert_eq!(prunable.get_chunk(1), &chunk3);
812 }
813
814 #[test]
815 fn test_pop() {
816 let mut prunable: Prunable<4> = Prunable::new();
817
818 prunable.push(true);
819 prunable.push(false);
820 prunable.push(true);
821 assert_eq!(prunable.len(), 3);
822
823 assert!(prunable.pop());
824 assert_eq!(prunable.len(), 2);
825
826 assert!(!prunable.pop());
827 assert_eq!(prunable.len(), 1);
828
829 assert!(prunable.pop());
830 assert_eq!(prunable.len(), 0);
831 assert!(prunable.is_empty());
832
833 for i in 0..100 {
834 prunable.push(i % 3 == 0);
835 }
836 assert_eq!(prunable.len(), 100);
837
838 for i in (0..100).rev() {
839 let expected = i % 3 == 0;
840 assert_eq!(prunable.pop(), expected);
841 assert_eq!(prunable.len(), i);
842 }
843
844 assert!(prunable.is_empty());
845 }
846
847 #[test]
848 fn test_pop_chunk() {
849 let mut prunable: Prunable<4> = Prunable::new();
850 const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
851
852 let chunk1 = hex!("0xAABBCCDD");
854 prunable.push_chunk(&chunk1);
855 assert_eq!(prunable.len(), CHUNK_SIZE);
856 let popped = prunable.pop_chunk();
857 assert_eq!(popped, chunk1);
858 assert_eq!(prunable.len(), 0);
859 assert!(prunable.is_empty());
860
861 let chunk2 = hex!("0x11223344");
863 let chunk3 = hex!("0x55667788");
864 let chunk4 = hex!("0x99AABBCC");
865
866 prunable.push_chunk(&chunk2);
867 prunable.push_chunk(&chunk3);
868 prunable.push_chunk(&chunk4);
869 assert_eq!(prunable.len(), CHUNK_SIZE * 3);
870
871 assert_eq!(prunable.pop_chunk(), chunk4);
872 assert_eq!(prunable.len(), CHUNK_SIZE * 2);
873
874 assert_eq!(prunable.pop_chunk(), chunk3);
875 assert_eq!(prunable.len(), CHUNK_SIZE);
876
877 assert_eq!(prunable.pop_chunk(), chunk2);
878 assert_eq!(prunable.len(), 0);
879
880 prunable = Prunable::new();
882 let first_chunk = hex!("0xAABBCCDD");
883 let second_chunk = hex!("0x11223344");
884 prunable.push_chunk(&first_chunk);
885 prunable.push_chunk(&second_chunk);
886
887 assert_eq!(prunable.pop_chunk(), second_chunk);
889 assert_eq!(prunable.len(), CHUNK_SIZE);
890
891 for i in 0..CHUNK_SIZE {
892 let byte_idx = (i / 8) as usize;
893 let bit_idx = i % 8;
894 let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
895 assert_eq!(prunable.get_bit(i), expected);
896 }
897
898 assert_eq!(prunable.pop_chunk(), first_chunk);
899 assert_eq!(prunable.len(), 0);
900 }
901
902 #[test]
903 #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
904 fn test_pop_chunk_not_aligned() {
905 let mut prunable: Prunable<4> = Prunable::new();
906
907 prunable.push_chunk(&[0xFF; 4]);
909 prunable.push(true);
910
911 prunable.pop_chunk();
913 }
914
915 #[test]
916 #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
917 fn test_pop_chunk_insufficient_bits() {
918 let mut prunable: Prunable<4> = Prunable::new();
919
920 prunable.push(true);
922 prunable.push(false);
923
924 prunable.pop_chunk();
926 }
927
928 #[test]
929 fn test_write_read_empty() {
930 let original: Prunable<4> = Prunable::new();
931 let encoded = original.encode();
932
933 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
934 assert_eq!(decoded.len(), original.len());
935 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
936 assert!(decoded.is_empty());
937 }
938
939 #[test]
940 fn test_write_read_non_empty() {
941 let mut original: Prunable<4> = Prunable::new();
942 original.push_chunk(&hex!("0xAABBCCDD"));
943 original.push_chunk(&hex!("0x11223344"));
944 original.push(true);
945 original.push(false);
946 original.push(true);
947
948 let encoded = original.encode();
949 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
950
951 assert_eq!(decoded.len(), original.len());
952 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
953 assert_eq!(decoded.len(), 67);
954
955 for i in 0..original.len() {
957 assert_eq!(decoded.get_bit(i), original.get_bit(i));
958 }
959 }
960
961 #[test]
962 fn test_write_read_with_pruning() {
963 let mut original: Prunable<4> = Prunable::new();
964 original.push_chunk(&hex!("0x01020304"));
965 original.push_chunk(&hex!("0x05060708"));
966 original.push_chunk(&hex!("0x090A0B0C"));
967
968 original.prune_to_bit(32);
970 assert_eq!(original.pruned_chunks(), 1);
971 assert_eq!(original.len(), 96);
972
973 let encoded = original.encode();
974 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
975
976 assert_eq!(decoded.len(), original.len());
977 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
978 assert_eq!(decoded.pruned_chunks(), 1);
979 assert_eq!(decoded.len(), 96);
980
981 assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
983 assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
984 }
985
986 #[test]
987 fn test_write_read_with_pruning_2() {
988 let mut original: Prunable<4> = Prunable::new();
989
990 for i in 0..5 {
992 let chunk = [
993 (i * 4) as u8,
994 (i * 4 + 1) as u8,
995 (i * 4 + 2) as u8,
996 (i * 4 + 3) as u8,
997 ];
998 original.push_chunk(&chunk);
999 }
1000
1001 original.prune_to_bit(96); assert_eq!(original.pruned_chunks(), 3);
1004 assert_eq!(original.len(), 160);
1005
1006 let encoded = original.encode();
1007 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
1008
1009 assert_eq!(decoded.len(), original.len());
1010 assert_eq!(decoded.pruned_chunks(), 3);
1011
1012 for i in 96..original.len() {
1014 assert_eq!(decoded.get_bit(i), original.get_bit(i));
1015 }
1016 }
1017
1018 #[test]
1019 fn test_encode_size_matches() {
1020 let mut prunable: Prunable<4> = Prunable::new();
1021 prunable.push_chunk(&[1, 2, 3, 4]);
1022 prunable.push_chunk(&[5, 6, 7, 8]);
1023 prunable.push(true);
1024
1025 let size = prunable.encode_size();
1026 let encoded = prunable.encode();
1027
1028 assert_eq!(size, encoded.len());
1029 }
1030
1031 #[test]
1032 fn test_encode_size_with_pruning() {
1033 let mut prunable: Prunable<4> = Prunable::new();
1034 prunable.push_chunk(&[1, 2, 3, 4]);
1035 prunable.push_chunk(&[5, 6, 7, 8]);
1036 prunable.push_chunk(&[9, 10, 11, 12]);
1037
1038 prunable.prune_to_bit(32);
1039
1040 let size = prunable.encode_size();
1041 let encoded = prunable.encode();
1042
1043 assert_eq!(size, encoded.len());
1044 }
1045
1046 #[test]
1047 fn test_read_max_len_validation() {
1048 let mut original: Prunable<4> = Prunable::new();
1049 for _ in 0..10 {
1050 original.push(true);
1051 }
1052
1053 let encoded = original.encode();
1054
1055 assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1057
1058 let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1060 assert!(result.is_err());
1061 }
1062
1063 #[test]
1064 fn test_codec_roundtrip_different_chunk_sizes() {
1065 let mut p8: Prunable<8> = Prunable::new();
1067 let mut p16: Prunable<16> = Prunable::new();
1068 let mut p32: Prunable<32> = Prunable::new();
1069
1070 for i in 0..100 {
1071 let bit = i % 3 == 0;
1072 p8.push(bit);
1073 p16.push(bit);
1074 p32.push(bit);
1075 }
1076
1077 let encoded8 = p8.encode();
1079 let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1080 assert_eq!(decoded8.len(), p8.len());
1081
1082 let encoded16 = p16.encode();
1083 let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1084 assert_eq!(decoded16.len(), p16.len());
1085
1086 let encoded32 = p32.encode();
1087 let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1088 assert_eq!(decoded32.len(), p32.len());
1089 }
1090
1091 #[test]
1092 fn test_read_pruned_chunks_overflow() {
1093 let mut buf = BytesMut::new();
1094
1095 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1097 overflowing_pruned_chunks.write(&mut buf);
1098
1099 0u64.write(&mut buf); let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1104 match result {
1105 Err(CodecError::Invalid(type_name, msg)) => {
1106 assert_eq!(type_name, "Prunable");
1107 assert_eq!(
1108 msg,
1109 "pruned_chunks would overflow when computing pruned_bits"
1110 );
1111 }
1112 Ok(_) => panic!("Expected error but got Ok"),
1113 Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1114 }
1115 }
1116
1117 #[test]
1118 fn test_read_total_length_overflow() {
1119 let mut buf = BytesMut::new();
1120
1121 let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1123 let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1124
1125 let remaining_space = u64::MAX - pruned_bits;
1127 let bitmap_len = remaining_space + 1; max_safe_pruned_chunks.write(&mut buf);
1131 bitmap_len.write(&mut buf);
1132
1133 let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1135 for _ in 0..(num_chunks * 4) {
1136 0u8.write(&mut buf);
1137 }
1138
1139 let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1141 match result {
1142 Err(CodecError::Invalid(type_name, msg)) => {
1143 assert_eq!(type_name, "Prunable");
1144 assert_eq!(
1145 msg,
1146 "total bitmap length (pruned + unpruned) would overflow u64"
1147 );
1148 }
1149 Ok(_) => panic!("Expected error but got Ok"),
1150 Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1151 }
1152 }
1153
1154 #[test]
1155 fn test_is_chunk_aligned() {
1156 let prunable: Prunable<4> = Prunable::new();
1158 assert!(prunable.is_chunk_aligned());
1159
1160 let mut prunable: Prunable<4> = Prunable::new();
1162 for i in 1..=32 {
1163 prunable.push(i % 2 == 0);
1164 if i == 32 {
1165 assert!(prunable.is_chunk_aligned()); } else {
1167 assert!(!prunable.is_chunk_aligned()); }
1169 }
1170
1171 for i in 33..=64 {
1173 prunable.push(i % 2 == 0);
1174 if i == 64 {
1175 assert!(prunable.is_chunk_aligned()); } else {
1177 assert!(!prunable.is_chunk_aligned()); }
1179 }
1180
1181 let mut prunable: Prunable<4> = Prunable::new();
1183 assert!(prunable.is_chunk_aligned());
1184 prunable.push_chunk(&[1, 2, 3, 4]);
1185 assert!(prunable.is_chunk_aligned()); prunable.push_chunk(&[5, 6, 7, 8]);
1187 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1189 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1193 prunable.push_chunk(&[1, 2, 3, 4]);
1194 prunable.push_chunk(&[5, 6, 7, 8]);
1195 prunable.push_chunk(&[9, 10, 11, 12]);
1196 assert!(prunable.is_chunk_aligned()); prunable.prune_to_bit(32);
1200 assert!(prunable.is_chunk_aligned());
1201 assert_eq!(prunable.len(), 96);
1202
1203 prunable.push(true);
1205 prunable.push(false);
1206 assert!(!prunable.is_chunk_aligned()); prunable.prune_to_bit(64);
1210 assert!(!prunable.is_chunk_aligned()); let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1214 assert!(prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1217 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1219 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1223 for _ in 0..4 {
1224 prunable.push_byte(0xFF);
1225 }
1226 assert!(prunable.is_chunk_aligned()); prunable.pop();
1230 assert!(!prunable.is_chunk_aligned()); for _ in 0..31 {
1234 prunable.pop();
1235 }
1236 assert!(prunable.is_chunk_aligned()); }
1238
1239 #[cfg(feature = "arbitrary")]
1240 mod conformance {
1241 use super::*;
1242 use commonware_codec::conformance::CodecConformance;
1243
1244 commonware_conformance::conformance_tests! {
1245 CodecConformance<Prunable<16>>,
1246 }
1247 }
1248}