1use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11 #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13 PrunedChunksOverflow,
14}
15
16#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24 bitmap: BitMap<N>,
26
27 pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36 pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39 pub const fn new() -> Self {
43 Self {
44 bitmap: BitMap::new(),
45 pruned_chunks: 0,
46 }
47 }
48
49 pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56 let pruned_chunks_u64 = pruned_chunks as u64;
58 pruned_chunks_u64
59 .checked_mul(Self::CHUNK_SIZE_BITS)
60 .ok_or(Error::PrunedChunksOverflow)?;
61
62 Ok(Self {
63 bitmap: BitMap::new(),
64 pruned_chunks,
65 })
66 }
67
68 #[inline]
72 pub const fn len(&self) -> u64 {
73 let pruned_bits = (self.pruned_chunks as u64)
74 .checked_mul(Self::CHUNK_SIZE_BITS)
75 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77 pruned_bits
78 .checked_add(self.bitmap.len())
79 .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80 }
81
82 #[inline]
84 pub const fn is_empty(&self) -> bool {
85 self.len() == 0
86 }
87
88 #[inline]
90 pub const fn is_chunk_aligned(&self) -> bool {
91 self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92 }
93
94 #[inline]
96 pub fn chunks_len(&self) -> usize {
97 self.pruned_chunks + self.bitmap.chunks_len()
98 }
99
100 #[inline]
102 pub const fn pruned_chunks(&self) -> usize {
103 self.pruned_chunks
104 }
105
106 #[inline]
109 pub fn complete_chunks(&self) -> usize {
110 self.pruned_chunks
111 + self
112 .bitmap
113 .chunks_len()
114 .saturating_sub(if self.is_chunk_aligned() { 0 } else { 1 })
115 }
116
117 #[inline]
119 pub const fn pruned_bits(&self) -> u64 {
120 (self.pruned_chunks as u64)
121 .checked_mul(Self::CHUNK_SIZE_BITS)
122 .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
123 }
124
125 #[inline]
133 pub fn get_bit(&self, bit: u64) -> bool {
134 let chunk_num = Self::to_chunk_index(bit);
135 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
136
137 self.bitmap.get(bit - self.pruned_bits())
139 }
140
141 #[inline]
147 pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
148 let chunk_num = Self::to_chunk_index(bit);
149 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
150
151 self.bitmap.get_chunk_containing(bit - self.pruned_bits())
153 }
154
155 #[inline]
158 pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
159 BitMap::<N>::get_bit_from_chunk(chunk, bit)
160 }
161
162 #[inline]
164 pub fn last_chunk(&self) -> (&[u8; N], u64) {
165 self.bitmap.last_chunk()
166 }
167
168 pub fn set_bit(&mut self, bit: u64, value: bool) {
176 let chunk_num = Self::to_chunk_index(bit);
177 assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
178
179 self.bitmap.set(bit - self.pruned_bits(), value);
181 }
182
183 pub fn push(&mut self, bit: bool) {
185 self.bitmap.push(bit);
186 }
187
188 pub fn pop(&mut self) -> bool {
194 self.bitmap.pop()
195 }
196
197 pub fn push_byte(&mut self, byte: u8) {
203 self.bitmap.push_byte(byte);
204 }
205
206 pub fn push_chunk(&mut self, chunk: &[u8; N]) {
212 self.bitmap.push_chunk(chunk);
213 }
214
215 pub fn pop_chunk(&mut self) -> [u8; N] {
221 self.bitmap.pop_chunk()
222 }
223
224 pub fn prune_to_bit(&mut self, bit: u64) {
238 assert!(
239 bit <= self.len(),
240 "bit {} out of bounds (len: {})",
241 bit,
242 self.len()
243 );
244
245 let chunk = Self::to_chunk_index(bit);
246 if chunk < self.pruned_chunks {
247 return;
248 }
249
250 let chunks_to_prune = chunk - self.pruned_chunks;
251 self.bitmap.prune_chunks(chunks_to_prune);
252 self.pruned_chunks = chunk;
253 }
254
255 #[inline]
259 pub const fn chunk_byte_bitmask(bit: u64) -> u8 {
260 BitMap::<N>::chunk_byte_bitmask(bit)
261 }
262
263 #[inline]
265 pub const fn chunk_byte_offset(bit: u64) -> usize {
266 BitMap::<N>::chunk_byte_offset(bit)
267 }
268
269 #[inline]
275 pub fn to_chunk_index(bit: u64) -> usize {
276 BitMap::<N>::to_chunk_index(bit)
277 }
278
279 #[inline]
285 pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
286 assert!(
287 chunk >= self.pruned_chunks,
288 "chunk {chunk} is pruned (pruned_chunks: {})",
289 self.pruned_chunks
290 );
291 self.bitmap.get_chunk(chunk - self.pruned_chunks)
292 }
293
294 pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
300 assert!(
301 chunk_index >= self.pruned_chunks,
302 "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
303 self.pruned_chunks
304 );
305 let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
306 self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
307 }
308
309 pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
320 assert!(
321 chunks.len() <= self.pruned_chunks,
322 "cannot unprune {} chunks (only {} pruned)",
323 chunks.len(),
324 self.pruned_chunks
325 );
326
327 for chunk in chunks.iter() {
328 self.bitmap.prepend_chunk(chunk);
329 }
330
331 self.pruned_chunks -= chunks.len();
332 }
333}
334
335impl<const N: usize> Default for Prunable<N> {
336 fn default() -> Self {
337 Self::new()
338 }
339}
340
341impl<const N: usize> Write for Prunable<N> {
342 fn write(&self, buf: &mut impl BufMut) {
343 (self.pruned_chunks as u64).write(buf);
344 self.bitmap.write(buf);
345 }
346}
347
348impl<const N: usize> Read for Prunable<N> {
349 type Cfg = u64;
351
352 fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
353 let pruned_chunks_u64 = u64::read(buf)?;
354
355 let pruned_bits =
357 pruned_chunks_u64
358 .checked_mul(Self::CHUNK_SIZE_BITS)
359 .ok_or(CodecError::Invalid(
360 "Prunable",
361 "pruned_chunks would overflow when computing pruned_bits",
362 ))?;
363
364 let pruned_chunks = usize::try_from(pruned_chunks_u64)
365 .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
366
367 let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
368
369 pruned_bits
371 .checked_add(bitmap.len())
372 .ok_or(CodecError::Invalid(
373 "Prunable",
374 "total bitmap length (pruned + unpruned) would overflow u64",
375 ))?;
376
377 Ok(Self {
378 bitmap,
379 pruned_chunks,
380 })
381 }
382}
383
384impl<const N: usize> EncodeSize for Prunable<N> {
385 fn encode_size(&self) -> usize {
386 (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
387 }
388}
389
390#[cfg(feature = "arbitrary")]
391impl<const N: usize> arbitrary::Arbitrary<'_> for Prunable<N> {
392 fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
393 let mut bitmap = Self {
394 bitmap: BitMap::<N>::arbitrary(u)?,
395 pruned_chunks: 0,
396 };
397 let prune_to = u.int_in_range(0..=bitmap.len())?;
398 bitmap.prune_to_bit(prune_to);
399 Ok(bitmap)
400 }
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use crate::hex;
407 use bytes::BytesMut;
408 use commonware_codec::Encode;
409
410 #[test]
411 fn test_new() {
412 let prunable: Prunable<32> = Prunable::new();
413 assert_eq!(prunable.len(), 0);
414 assert_eq!(prunable.pruned_bits(), 0);
415 assert_eq!(prunable.pruned_chunks(), 0);
416 assert!(prunable.is_empty());
417 assert_eq!(prunable.chunks_len(), 0); }
419
420 #[test]
421 fn test_new_with_pruned_chunks() {
422 let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
423 assert_eq!(prunable.len(), 16);
424 assert_eq!(prunable.pruned_bits(), 16);
425 assert_eq!(prunable.pruned_chunks(), 1);
426 assert_eq!(prunable.chunks_len(), 1);
427 }
428
429 #[test]
430 fn test_new_with_pruned_chunks_overflow() {
431 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
433 let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
434
435 assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
436 }
437
438 #[test]
439 fn test_push_and_get_bits() {
440 let mut prunable: Prunable<4> = Prunable::new();
441
442 prunable.push(true);
444 prunable.push(false);
445 prunable.push(true);
446
447 assert_eq!(prunable.len(), 3);
448 assert!(!prunable.is_empty());
449 assert!(prunable.get_bit(0));
450 assert!(!prunable.get_bit(1));
451 assert!(prunable.get_bit(2));
452 }
453
454 #[test]
455 fn test_push_byte() {
456 let mut prunable: Prunable<4> = Prunable::new();
457
458 prunable.push_byte(0xFF);
460 assert_eq!(prunable.len(), 8);
461
462 for i in 0..8 {
464 assert!(prunable.get_bit(i as u64));
465 }
466
467 prunable.push_byte(0x00);
468 assert_eq!(prunable.len(), 16);
469
470 for i in 8..16 {
472 assert!(!prunable.get_bit(i as u64));
473 }
474 }
475
476 #[test]
477 fn test_push_chunk() {
478 let mut prunable: Prunable<4> = Prunable::new();
479 let chunk = hex!("0xAABBCCDD");
480
481 prunable.push_chunk(&chunk);
482 assert_eq!(prunable.len(), 32); let retrieved_chunk = prunable.get_chunk_containing(0);
485 assert_eq!(retrieved_chunk, &chunk);
486 }
487
488 #[test]
489 fn test_set_bit() {
490 let mut prunable: Prunable<4> = Prunable::new();
491
492 prunable.push(false);
494 prunable.push(false);
495 prunable.push(false);
496
497 assert!(!prunable.get_bit(1));
498
499 prunable.set_bit(1, true);
501 assert!(prunable.get_bit(1));
502
503 prunable.set_bit(1, false);
505 assert!(!prunable.get_bit(1));
506 }
507
508 #[test]
509 fn test_pruning_basic() {
510 let mut prunable: Prunable<4> = Prunable::new();
511
512 let chunk1 = hex!("0x01020304");
514 let chunk2 = hex!("0x05060708");
515 let chunk3 = hex!("0x090A0B0C");
516
517 prunable.push_chunk(&chunk1);
518 prunable.push_chunk(&chunk2);
519 prunable.push_chunk(&chunk3);
520
521 assert_eq!(prunable.len(), 96); assert_eq!(prunable.pruned_chunks(), 0);
523
524 prunable.prune_to_bit(32);
526 assert_eq!(prunable.pruned_chunks(), 1);
527 assert_eq!(prunable.pruned_bits(), 32);
528 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(32), &chunk2);
532 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
533
534 prunable.prune_to_bit(64);
536 assert_eq!(prunable.pruned_chunks(), 2);
537 assert_eq!(prunable.pruned_bits(), 64);
538 assert_eq!(prunable.len(), 96);
539
540 assert_eq!(prunable.get_chunk_containing(64), &chunk3);
542 }
543
544 #[test]
545 #[should_panic(expected = "bit pruned")]
546 fn test_get_pruned_bit_panics() {
547 let mut prunable: Prunable<4> = Prunable::new();
548
549 prunable.push_chunk(&[1, 2, 3, 4]);
551 prunable.push_chunk(&[5, 6, 7, 8]);
552
553 prunable.prune_to_bit(32);
555
556 prunable.get_bit(0);
558 }
559
560 #[test]
561 #[should_panic(expected = "bit pruned")]
562 fn test_get_pruned_chunk_panics() {
563 let mut prunable: Prunable<4> = Prunable::new();
564
565 prunable.push_chunk(&[1, 2, 3, 4]);
567 prunable.push_chunk(&[5, 6, 7, 8]);
568
569 prunable.prune_to_bit(32);
571
572 prunable.get_chunk_containing(0);
574 }
575
576 #[test]
577 #[should_panic(expected = "bit pruned")]
578 fn test_set_pruned_bit_panics() {
579 let mut prunable: Prunable<4> = Prunable::new();
580
581 prunable.push_chunk(&[1, 2, 3, 4]);
583 prunable.push_chunk(&[5, 6, 7, 8]);
584
585 prunable.prune_to_bit(32);
587
588 prunable.set_bit(0, true);
590 }
591
592 #[test]
593 #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
594 fn test_prune_to_bit_out_of_bounds() {
595 let mut prunable: Prunable<1> = Prunable::new();
596
597 prunable.push_byte(1);
599 prunable.push_byte(2);
600 prunable.push_byte(3);
601
602 prunable.prune_to_bit(25);
604 }
605
606 #[test]
607 fn test_pruning_with_partial_chunk() {
608 let mut prunable: Prunable<4> = Prunable::new();
609
610 prunable.push_chunk(&[0xFF; 4]);
612 prunable.push_chunk(&[0xAA; 4]);
613 prunable.push(true);
614 prunable.push(false);
615 prunable.push(true);
616
617 assert_eq!(prunable.len(), 67); prunable.prune_to_bit(32);
621 assert_eq!(prunable.pruned_chunks(), 1);
622 assert_eq!(prunable.len(), 67);
623
624 assert!(prunable.get_bit(64));
626 assert!(!prunable.get_bit(65));
627 assert!(prunable.get_bit(66));
628 }
629
630 #[test]
631 fn test_prune_idempotent() {
632 let mut prunable: Prunable<4> = Prunable::new();
633
634 prunable.push_chunk(&[1, 2, 3, 4]);
636 prunable.push_chunk(&[5, 6, 7, 8]);
637
638 prunable.prune_to_bit(32);
640 assert_eq!(prunable.pruned_chunks(), 1);
641
642 prunable.prune_to_bit(32);
644 assert_eq!(prunable.pruned_chunks(), 1);
645
646 prunable.prune_to_bit(16);
647 assert_eq!(prunable.pruned_chunks(), 1);
648 }
649
650 #[test]
651 fn test_push_after_pruning() {
652 let mut prunable: Prunable<4> = Prunable::new();
653
654 prunable.push_chunk(&[1, 2, 3, 4]);
656 prunable.push_chunk(&[5, 6, 7, 8]);
657
658 prunable.prune_to_bit(32);
660 assert_eq!(prunable.len(), 64);
661 assert_eq!(prunable.pruned_chunks(), 1);
662
663 prunable.push_chunk(&[9, 10, 11, 12]);
665 assert_eq!(prunable.len(), 96); assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
669 }
670
671 #[test]
672 fn test_chunk_calculations() {
673 assert_eq!(Prunable::<4>::to_chunk_index(0), 0);
675 assert_eq!(Prunable::<4>::to_chunk_index(31), 0);
676 assert_eq!(Prunable::<4>::to_chunk_index(32), 1);
677 assert_eq!(Prunable::<4>::to_chunk_index(63), 1);
678 assert_eq!(Prunable::<4>::to_chunk_index(64), 2);
679
680 assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
682 assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
683 assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
684 assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
685 assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
689 assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
690 assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
691 assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); }
693
694 #[test]
695 fn test_last_chunk_with_pruning() {
696 let mut prunable: Prunable<4> = Prunable::new();
697
698 prunable.push_chunk(&[1, 2, 3, 4]);
700 prunable.push_chunk(&[5, 6, 7, 8]);
701 prunable.push(true);
702 prunable.push(false);
703
704 let (_, next_bit) = prunable.last_chunk();
705 assert_eq!(next_bit, 2);
706
707 let chunk_data = *prunable.last_chunk().0;
709
710 prunable.prune_to_bit(32);
712 let (chunk2, next_bit2) = prunable.last_chunk();
713 assert_eq!(next_bit2, 2);
714 assert_eq!(&chunk_data, chunk2);
715 }
716
717 #[test]
718 fn test_different_chunk_sizes() {
719 let mut p8: Prunable<8> = Prunable::new();
721 let mut p16: Prunable<16> = Prunable::new();
722 let mut p32: Prunable<32> = Prunable::new();
723
724 for i in 0..10 {
726 p8.push(i % 2 == 0);
727 p16.push(i % 2 == 0);
728 p32.push(i % 2 == 0);
729 }
730
731 assert_eq!(p8.len(), 10);
733 assert_eq!(p16.len(), 10);
734 assert_eq!(p32.len(), 10);
735
736 for i in 0..10 {
738 let expected = i % 2 == 0;
739 if expected {
740 assert!(p8.get_bit(i));
741 assert!(p16.get_bit(i));
742 assert!(p32.get_bit(i));
743 } else {
744 assert!(!p8.get_bit(i));
745 assert!(!p16.get_bit(i));
746 assert!(!p32.get_bit(i));
747 }
748 }
749 }
750
751 #[test]
752 fn test_get_bit_from_chunk() {
753 let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
754
755 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
757 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
758 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
759 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
760
761 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
763 assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
764 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
765 assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
766 }
767
768 #[test]
769 fn test_get_chunk() {
770 let mut prunable: Prunable<4> = Prunable::new();
771 let chunk1 = hex!("0x11223344");
772 let chunk2 = hex!("0x55667788");
773 let chunk3 = hex!("0x99AABBCC");
774
775 prunable.push_chunk(&chunk1);
776 prunable.push_chunk(&chunk2);
777 prunable.push_chunk(&chunk3);
778
779 assert_eq!(prunable.get_chunk(0), &chunk1);
781 assert_eq!(prunable.get_chunk(1), &chunk2);
782 assert_eq!(prunable.get_chunk(2), &chunk3);
783
784 prunable.prune_to_bit(32);
786 assert_eq!(prunable.get_chunk(1), &chunk2);
787 assert_eq!(prunable.get_chunk(2), &chunk3);
788 }
789
790 #[test]
791 fn test_pop() {
792 let mut prunable: Prunable<4> = Prunable::new();
793
794 prunable.push(true);
795 prunable.push(false);
796 prunable.push(true);
797 assert_eq!(prunable.len(), 3);
798
799 assert!(prunable.pop());
800 assert_eq!(prunable.len(), 2);
801
802 assert!(!prunable.pop());
803 assert_eq!(prunable.len(), 1);
804
805 assert!(prunable.pop());
806 assert_eq!(prunable.len(), 0);
807 assert!(prunable.is_empty());
808
809 for i in 0..100 {
810 prunable.push(i % 3 == 0);
811 }
812 assert_eq!(prunable.len(), 100);
813
814 for i in (0..100).rev() {
815 let expected = i % 3 == 0;
816 assert_eq!(prunable.pop(), expected);
817 assert_eq!(prunable.len(), i);
818 }
819
820 assert!(prunable.is_empty());
821 }
822
823 #[test]
824 fn test_pop_chunk() {
825 let mut prunable: Prunable<4> = Prunable::new();
826 const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
827
828 let chunk1 = hex!("0xAABBCCDD");
830 prunable.push_chunk(&chunk1);
831 assert_eq!(prunable.len(), CHUNK_SIZE);
832 let popped = prunable.pop_chunk();
833 assert_eq!(popped, chunk1);
834 assert_eq!(prunable.len(), 0);
835 assert!(prunable.is_empty());
836
837 let chunk2 = hex!("0x11223344");
839 let chunk3 = hex!("0x55667788");
840 let chunk4 = hex!("0x99AABBCC");
841
842 prunable.push_chunk(&chunk2);
843 prunable.push_chunk(&chunk3);
844 prunable.push_chunk(&chunk4);
845 assert_eq!(prunable.len(), CHUNK_SIZE * 3);
846
847 assert_eq!(prunable.pop_chunk(), chunk4);
848 assert_eq!(prunable.len(), CHUNK_SIZE * 2);
849
850 assert_eq!(prunable.pop_chunk(), chunk3);
851 assert_eq!(prunable.len(), CHUNK_SIZE);
852
853 assert_eq!(prunable.pop_chunk(), chunk2);
854 assert_eq!(prunable.len(), 0);
855
856 prunable = Prunable::new();
858 let first_chunk = hex!("0xAABBCCDD");
859 let second_chunk = hex!("0x11223344");
860 prunable.push_chunk(&first_chunk);
861 prunable.push_chunk(&second_chunk);
862
863 assert_eq!(prunable.pop_chunk(), second_chunk);
865 assert_eq!(prunable.len(), CHUNK_SIZE);
866
867 for i in 0..CHUNK_SIZE {
868 let byte_idx = (i / 8) as usize;
869 let bit_idx = i % 8;
870 let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
871 assert_eq!(prunable.get_bit(i), expected);
872 }
873
874 assert_eq!(prunable.pop_chunk(), first_chunk);
875 assert_eq!(prunable.len(), 0);
876 }
877
878 #[test]
879 #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
880 fn test_pop_chunk_not_aligned() {
881 let mut prunable: Prunable<4> = Prunable::new();
882
883 prunable.push_chunk(&[0xFF; 4]);
885 prunable.push(true);
886
887 prunable.pop_chunk();
889 }
890
891 #[test]
892 #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
893 fn test_pop_chunk_insufficient_bits() {
894 let mut prunable: Prunable<4> = Prunable::new();
895
896 prunable.push(true);
898 prunable.push(false);
899
900 prunable.pop_chunk();
902 }
903
904 #[test]
905 fn test_write_read_empty() {
906 let original: Prunable<4> = Prunable::new();
907 let encoded = original.encode();
908
909 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
910 assert_eq!(decoded.len(), original.len());
911 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
912 assert!(decoded.is_empty());
913 }
914
915 #[test]
916 fn test_write_read_non_empty() {
917 let mut original: Prunable<4> = Prunable::new();
918 original.push_chunk(&hex!("0xAABBCCDD"));
919 original.push_chunk(&hex!("0x11223344"));
920 original.push(true);
921 original.push(false);
922 original.push(true);
923
924 let encoded = original.encode();
925 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
926
927 assert_eq!(decoded.len(), original.len());
928 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
929 assert_eq!(decoded.len(), 67);
930
931 for i in 0..original.len() {
933 assert_eq!(decoded.get_bit(i), original.get_bit(i));
934 }
935 }
936
937 #[test]
938 fn test_write_read_with_pruning() {
939 let mut original: Prunable<4> = Prunable::new();
940 original.push_chunk(&hex!("0x01020304"));
941 original.push_chunk(&hex!("0x05060708"));
942 original.push_chunk(&hex!("0x090A0B0C"));
943
944 original.prune_to_bit(32);
946 assert_eq!(original.pruned_chunks(), 1);
947 assert_eq!(original.len(), 96);
948
949 let encoded = original.encode();
950 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
951
952 assert_eq!(decoded.len(), original.len());
953 assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
954 assert_eq!(decoded.pruned_chunks(), 1);
955 assert_eq!(decoded.len(), 96);
956
957 assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
959 assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
960 }
961
962 #[test]
963 fn test_write_read_with_pruning_2() {
964 let mut original: Prunable<4> = Prunable::new();
965
966 for i in 0..5 {
968 let chunk = [
969 (i * 4) as u8,
970 (i * 4 + 1) as u8,
971 (i * 4 + 2) as u8,
972 (i * 4 + 3) as u8,
973 ];
974 original.push_chunk(&chunk);
975 }
976
977 original.prune_to_bit(96); assert_eq!(original.pruned_chunks(), 3);
980 assert_eq!(original.len(), 160);
981
982 let encoded = original.encode();
983 let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
984
985 assert_eq!(decoded.len(), original.len());
986 assert_eq!(decoded.pruned_chunks(), 3);
987
988 for i in 96..original.len() {
990 assert_eq!(decoded.get_bit(i), original.get_bit(i));
991 }
992 }
993
994 #[test]
995 fn test_encode_size_matches() {
996 let mut prunable: Prunable<4> = Prunable::new();
997 prunable.push_chunk(&[1, 2, 3, 4]);
998 prunable.push_chunk(&[5, 6, 7, 8]);
999 prunable.push(true);
1000
1001 let size = prunable.encode_size();
1002 let encoded = prunable.encode();
1003
1004 assert_eq!(size, encoded.len());
1005 }
1006
1007 #[test]
1008 fn test_encode_size_with_pruning() {
1009 let mut prunable: Prunable<4> = Prunable::new();
1010 prunable.push_chunk(&[1, 2, 3, 4]);
1011 prunable.push_chunk(&[5, 6, 7, 8]);
1012 prunable.push_chunk(&[9, 10, 11, 12]);
1013
1014 prunable.prune_to_bit(32);
1015
1016 let size = prunable.encode_size();
1017 let encoded = prunable.encode();
1018
1019 assert_eq!(size, encoded.len());
1020 }
1021
1022 #[test]
1023 fn test_read_max_len_validation() {
1024 let mut original: Prunable<4> = Prunable::new();
1025 for _ in 0..10 {
1026 original.push(true);
1027 }
1028
1029 let encoded = original.encode();
1030
1031 assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1033
1034 let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1036 assert!(result.is_err());
1037 }
1038
1039 #[test]
1040 fn test_codec_roundtrip_different_chunk_sizes() {
1041 let mut p8: Prunable<8> = Prunable::new();
1043 let mut p16: Prunable<16> = Prunable::new();
1044 let mut p32: Prunable<32> = Prunable::new();
1045
1046 for i in 0..100 {
1047 let bit = i % 3 == 0;
1048 p8.push(bit);
1049 p16.push(bit);
1050 p32.push(bit);
1051 }
1052
1053 let encoded8 = p8.encode();
1055 let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1056 assert_eq!(decoded8.len(), p8.len());
1057
1058 let encoded16 = p16.encode();
1059 let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1060 assert_eq!(decoded16.len(), p16.len());
1061
1062 let encoded32 = p32.encode();
1063 let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1064 assert_eq!(decoded32.len(), p32.len());
1065 }
1066
1067 #[test]
1068 fn test_read_pruned_chunks_overflow() {
1069 let mut buf = BytesMut::new();
1070
1071 let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1073 overflowing_pruned_chunks.write(&mut buf);
1074
1075 0u64.write(&mut buf); let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1080 match result {
1081 Err(CodecError::Invalid(type_name, msg)) => {
1082 assert_eq!(type_name, "Prunable");
1083 assert_eq!(
1084 msg,
1085 "pruned_chunks would overflow when computing pruned_bits"
1086 );
1087 }
1088 Ok(_) => panic!("Expected error but got Ok"),
1089 Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1090 }
1091 }
1092
1093 #[test]
1094 fn test_read_total_length_overflow() {
1095 let mut buf = BytesMut::new();
1096
1097 let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1099 let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1100
1101 let remaining_space = u64::MAX - pruned_bits;
1103 let bitmap_len = remaining_space + 1; max_safe_pruned_chunks.write(&mut buf);
1107 bitmap_len.write(&mut buf);
1108
1109 let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1111 for _ in 0..(num_chunks * 4) {
1112 0u8.write(&mut buf);
1113 }
1114
1115 let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1117 match result {
1118 Err(CodecError::Invalid(type_name, msg)) => {
1119 assert_eq!(type_name, "Prunable");
1120 assert_eq!(
1121 msg,
1122 "total bitmap length (pruned + unpruned) would overflow u64"
1123 );
1124 }
1125 Ok(_) => panic!("Expected error but got Ok"),
1126 Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1127 }
1128 }
1129
1130 #[test]
1131 fn test_is_chunk_aligned() {
1132 let prunable: Prunable<4> = Prunable::new();
1134 assert!(prunable.is_chunk_aligned());
1135
1136 let mut prunable: Prunable<4> = Prunable::new();
1138 for i in 1..=32 {
1139 prunable.push(i % 2 == 0);
1140 if i == 32 {
1141 assert!(prunable.is_chunk_aligned()); } else {
1143 assert!(!prunable.is_chunk_aligned()); }
1145 }
1146
1147 for i in 33..=64 {
1149 prunable.push(i % 2 == 0);
1150 if i == 64 {
1151 assert!(prunable.is_chunk_aligned()); } else {
1153 assert!(!prunable.is_chunk_aligned()); }
1155 }
1156
1157 let mut prunable: Prunable<4> = Prunable::new();
1159 assert!(prunable.is_chunk_aligned());
1160 prunable.push_chunk(&[1, 2, 3, 4]);
1161 assert!(prunable.is_chunk_aligned()); prunable.push_chunk(&[5, 6, 7, 8]);
1163 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1165 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1169 prunable.push_chunk(&[1, 2, 3, 4]);
1170 prunable.push_chunk(&[5, 6, 7, 8]);
1171 prunable.push_chunk(&[9, 10, 11, 12]);
1172 assert!(prunable.is_chunk_aligned()); prunable.prune_to_bit(32);
1176 assert!(prunable.is_chunk_aligned());
1177 assert_eq!(prunable.len(), 96);
1178
1179 prunable.push(true);
1181 prunable.push(false);
1182 assert!(!prunable.is_chunk_aligned()); prunable.prune_to_bit(64);
1186 assert!(!prunable.is_chunk_aligned()); let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1190 assert!(prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1193 assert!(prunable.is_chunk_aligned()); prunable.push(true);
1195 assert!(!prunable.is_chunk_aligned()); let mut prunable: Prunable<4> = Prunable::new();
1199 for _ in 0..4 {
1200 prunable.push_byte(0xFF);
1201 }
1202 assert!(prunable.is_chunk_aligned()); prunable.pop();
1206 assert!(!prunable.is_chunk_aligned()); for _ in 0..31 {
1210 prunable.pop();
1211 }
1212 assert!(prunable.is_chunk_aligned()); }
1214
1215 #[cfg(feature = "arbitrary")]
1216 mod conformance {
1217 use super::*;
1218 use commonware_codec::conformance::CodecConformance;
1219
1220 commonware_conformance::conformance_tests! {
1221 CodecConformance<Prunable<16>>,
1222 }
1223 }
1224}