Skip to main content

commonware_utils/bitmap/
prunable.rs

1//! A prunable wrapper around BitMap that tracks pruned chunks.
2
3use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8/// Errors that can occur when working with a prunable bitmap.
9#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11    /// The provided pruned_chunks value would overflow.
12    #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13    PrunedChunksOverflow,
14}
15
16/// A prunable bitmap that stores data in chunks of N bytes.
17///
18/// # Panics
19///
20/// Operations panic if `bit / CHUNK_SIZE_BITS > usize::MAX`. On 32-bit systems
21/// with N=32, this occurs at bit >= 1,099,511,627,776.
22#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24    /// The underlying BitMap storing the actual bits.
25    bitmap: BitMap<N>,
26
27    /// The number of bitmap chunks that have been pruned.
28    ///
29    /// # Invariant
30    ///
31    /// Must satisfy: `pruned_chunks as u64 * CHUNK_SIZE_BITS + bitmap.len() <= u64::MAX`
32    pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36    /// The size of a chunk in bits.
37    pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39    /* Constructors */
40
41    /// Create a new empty prunable bitmap.
42    pub const fn new() -> Self {
43        Self {
44            bitmap: BitMap::new(),
45            pruned_chunks: 0,
46        }
47    }
48
49    /// Create a new empty prunable bitmap with the given number of pruned chunks.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if `pruned_chunks` violates the invariant that
54    /// `pruned_chunks as u64 * CHUNK_SIZE_BITS` must not overflow u64.
55    pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56        // Validate the invariant: pruned_chunks * CHUNK_SIZE_BITS must fit in u64
57        let pruned_chunks_u64 = pruned_chunks as u64;
58        pruned_chunks_u64
59            .checked_mul(Self::CHUNK_SIZE_BITS)
60            .ok_or(Error::PrunedChunksOverflow)?;
61
62        Ok(Self {
63            bitmap: BitMap::new(),
64            pruned_chunks,
65        })
66    }
67
68    /* Length */
69
70    /// Return the number of bits in the bitmap, irrespective of any pruning.
71    #[inline]
72    pub const fn len(&self) -> u64 {
73        let pruned_bits = (self.pruned_chunks as u64)
74            .checked_mul(Self::CHUNK_SIZE_BITS)
75            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77        pruned_bits
78            .checked_add(self.bitmap.len())
79            .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80    }
81
82    /// Return true if the bitmap is empty.
83    #[inline]
84    pub const fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    /// Returns true if the bitmap length is aligned to a chunk boundary.
89    #[inline]
90    pub const fn is_chunk_aligned(&self) -> bool {
91        self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92    }
93
94    /// Return the number of chunks, including pruned chunks.
95    #[inline]
96    pub fn chunks_len(&self) -> usize {
97        self.pruned_chunks + self.bitmap.chunks_len()
98    }
99
100    /// Return the number of pruned chunks.
101    #[inline]
102    pub const fn pruned_chunks(&self) -> usize {
103        self.pruned_chunks
104    }
105
106    /// Return the number of complete (fully filled) chunks in the bitmap, accounting
107    /// for pruning. A chunk is complete when all CHUNK_SIZE_BITS bits have been written.
108    #[inline]
109    pub fn complete_chunks(&self) -> usize {
110        self.pruned_chunks
111            + self
112                .bitmap
113                .chunks_len()
114                .saturating_sub(if self.is_chunk_aligned() { 0 } else { 1 })
115    }
116
117    /// Return the number of pruned bits.
118    #[inline]
119    pub const fn pruned_bits(&self) -> u64 {
120        (self.pruned_chunks as u64)
121            .checked_mul(Self::CHUNK_SIZE_BITS)
122            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
123    }
124
125    /* Getters */
126
127    /// Get the value of a bit.
128    ///
129    /// # Warning
130    ///
131    /// Panics if the bit doesn't exist or has been pruned.
132    #[inline]
133    pub fn get_bit(&self, bit: u64) -> bool {
134        let chunk_num = Self::to_chunk_index(bit);
135        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
136
137        // Adjust bit to account for pruning
138        self.bitmap.get(bit - self.pruned_bits())
139    }
140
141    /// Returns the bitmap chunk containing the specified bit.
142    ///
143    /// # Warning
144    ///
145    /// Panics if the bit doesn't exist or has been pruned.
146    #[inline]
147    pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
148        let chunk_num = Self::to_chunk_index(bit);
149        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
150
151        // Adjust bit to account for pruning
152        self.bitmap.get_chunk_containing(bit - self.pruned_bits())
153    }
154
155    /// Get the value of a bit from its chunk.
156    /// `bit` is an index into the entire bitmap, not just the chunk.
157    #[inline]
158    pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
159        BitMap::<N>::get_bit_from_chunk(chunk, bit)
160    }
161
162    /// Return the last chunk of the bitmap and its size in bits.
163    #[inline]
164    pub fn last_chunk(&self) -> (&[u8; N], u64) {
165        self.bitmap.last_chunk()
166    }
167
168    /* Setters */
169
170    /// Set the value of the given bit.
171    ///
172    /// # Warning
173    ///
174    /// Panics if the bit doesn't exist or has been pruned.
175    pub fn set_bit(&mut self, bit: u64, value: bool) {
176        let chunk_num = Self::to_chunk_index(bit);
177        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
178
179        // Adjust bit to account for pruning
180        self.bitmap.set(bit - self.pruned_bits(), value);
181    }
182
183    /// Add a single bit to the end of the bitmap.
184    pub fn push(&mut self, bit: bool) {
185        self.bitmap.push(bit);
186    }
187
188    /// Remove and return the last bit from the bitmap.
189    ///
190    /// # Warning
191    ///
192    /// Panics if the bitmap is empty.
193    pub fn pop(&mut self) -> bool {
194        self.bitmap.pop()
195    }
196
197    /// Add a byte to the bitmap.
198    ///
199    /// # Warning
200    ///
201    /// Panics if self.next_bit is not byte aligned.
202    pub fn push_byte(&mut self, byte: u8) {
203        self.bitmap.push_byte(byte);
204    }
205
206    /// Add a chunk of bits to the bitmap.
207    ///
208    /// # Warning
209    ///
210    /// Panics if self.next_bit is not chunk aligned.
211    pub fn push_chunk(&mut self, chunk: &[u8; N]) {
212        self.bitmap.push_chunk(chunk);
213    }
214
215    /// Remove and return the last complete chunk from the bitmap.
216    ///
217    /// # Warning
218    ///
219    /// Panics if the bitmap has fewer than `CHUNK_SIZE_BITS` bits or if not chunk-aligned.
220    pub fn pop_chunk(&mut self) -> [u8; N] {
221        self.bitmap.pop_chunk()
222    }
223
224    /* Pruning */
225
226    /// Prune all complete chunks before the chunk containing the given bit.
227    ///
228    /// The chunk containing `bit` and all subsequent chunks are retained. All chunks
229    /// before it are pruned.
230    ///
231    /// If `bit` equals the bitmap length, this prunes all complete chunks while retaining
232    /// the empty trailing chunk, preparing the bitmap for appending new data.
233    ///
234    /// # Warning
235    ///
236    /// Panics if `bit` is greater than the bitmap length.
237    pub fn prune_to_bit(&mut self, bit: u64) {
238        assert!(
239            bit <= self.len(),
240            "bit {} out of bounds (len: {})",
241            bit,
242            self.len()
243        );
244
245        let chunk = Self::to_chunk_index(bit);
246        if chunk < self.pruned_chunks {
247            return;
248        }
249
250        let chunks_to_prune = chunk - self.pruned_chunks;
251        self.bitmap.prune_chunks(chunks_to_prune);
252        self.pruned_chunks = chunk;
253    }
254
255    /* Indexing Helpers */
256
257    /// Convert a bit into a bitmask for the byte containing that bit.
258    #[inline]
259    pub const fn chunk_byte_bitmask(bit: u64) -> u8 {
260        BitMap::<N>::chunk_byte_bitmask(bit)
261    }
262
263    /// Convert a bit into the index of the byte within a chunk containing the bit.
264    #[inline]
265    pub const fn chunk_byte_offset(bit: u64) -> usize {
266        BitMap::<N>::chunk_byte_offset(bit)
267    }
268
269    /// Convert a bit into the index of the chunk it belongs to.
270    ///
271    /// # Panics
272    ///
273    /// Panics if `bit / CHUNK_SIZE_BITS > usize::MAX`.
274    #[inline]
275    pub fn to_chunk_index(bit: u64) -> usize {
276        BitMap::<N>::to_chunk_index(bit)
277    }
278
279    /// Get a reference to a chunk by its absolute index (includes pruned chunks).
280    ///
281    /// # Panics
282    ///
283    /// Panics if the chunk has been pruned or is out of bounds.
284    #[inline]
285    pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
286        assert!(
287            chunk >= self.pruned_chunks,
288            "chunk {chunk} is pruned (pruned_chunks: {})",
289            self.pruned_chunks
290        );
291        self.bitmap.get_chunk(chunk - self.pruned_chunks)
292    }
293
294    /// Overwrite a chunk's data by its absolute chunk index.
295    ///
296    /// # Panics
297    ///
298    /// Panics if the chunk is pruned or out of bounds.
299    pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
300        assert!(
301            chunk_index >= self.pruned_chunks,
302            "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
303            self.pruned_chunks
304        );
305        let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
306        self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
307    }
308
309    /// Unprune chunks by prepending them back to the front of the bitmap.
310    ///
311    /// The caller must provide chunks in **reverse** order: to restore chunks with
312    /// indices [0, 1, 2], pass them as [2, 1, 0]. This is necessary because each chunk
313    /// is prepended to the front, so the last chunk provided becomes the first chunk
314    /// in the bitmap.
315    ///
316    /// # Panics
317    ///
318    /// Panics if chunks.len() > self.pruned_chunks.
319    pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
320        assert!(
321            chunks.len() <= self.pruned_chunks,
322            "cannot unprune {} chunks (only {} pruned)",
323            chunks.len(),
324            self.pruned_chunks
325        );
326
327        for chunk in chunks.iter() {
328            self.bitmap.prepend_chunk(chunk);
329        }
330
331        self.pruned_chunks -= chunks.len();
332    }
333}
334
335impl<const N: usize> Default for Prunable<N> {
336    fn default() -> Self {
337        Self::new()
338    }
339}
340
341impl<const N: usize> Write for Prunable<N> {
342    fn write(&self, buf: &mut impl BufMut) {
343        (self.pruned_chunks as u64).write(buf);
344        self.bitmap.write(buf);
345    }
346}
347
348impl<const N: usize> Read for Prunable<N> {
349    // Max length for the unpruned portion of the bitmap.
350    type Cfg = u64;
351
352    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
353        let pruned_chunks_u64 = u64::read(buf)?;
354
355        // Validate that pruned_chunks * CHUNK_SIZE_BITS doesn't overflow u64
356        let pruned_bits =
357            pruned_chunks_u64
358                .checked_mul(Self::CHUNK_SIZE_BITS)
359                .ok_or(CodecError::Invalid(
360                    "Prunable",
361                    "pruned_chunks would overflow when computing pruned_bits",
362                ))?;
363
364        let pruned_chunks = usize::try_from(pruned_chunks_u64)
365            .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
366
367        let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
368
369        // Validate that total length (pruned_bits + bitmap.len()) doesn't overflow u64
370        pruned_bits
371            .checked_add(bitmap.len())
372            .ok_or(CodecError::Invalid(
373                "Prunable",
374                "total bitmap length (pruned + unpruned) would overflow u64",
375            ))?;
376
377        Ok(Self {
378            bitmap,
379            pruned_chunks,
380        })
381    }
382}
383
384impl<const N: usize> EncodeSize for Prunable<N> {
385    fn encode_size(&self) -> usize {
386        (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
387    }
388}
389
390#[cfg(feature = "arbitrary")]
391impl<const N: usize> arbitrary::Arbitrary<'_> for Prunable<N> {
392    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
393        let mut bitmap = Self {
394            bitmap: BitMap::<N>::arbitrary(u)?,
395            pruned_chunks: 0,
396        };
397        let prune_to = u.int_in_range(0..=bitmap.len())?;
398        bitmap.prune_to_bit(prune_to);
399        Ok(bitmap)
400    }
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use crate::hex;
407    use bytes::BytesMut;
408    use commonware_codec::Encode;
409
410    #[test]
411    fn test_new() {
412        let prunable: Prunable<32> = Prunable::new();
413        assert_eq!(prunable.len(), 0);
414        assert_eq!(prunable.pruned_bits(), 0);
415        assert_eq!(prunable.pruned_chunks(), 0);
416        assert!(prunable.is_empty());
417        assert_eq!(prunable.chunks_len(), 0); // No chunks when empty
418    }
419
420    #[test]
421    fn test_new_with_pruned_chunks() {
422        let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
423        assert_eq!(prunable.len(), 16);
424        assert_eq!(prunable.pruned_bits(), 16);
425        assert_eq!(prunable.pruned_chunks(), 1);
426        assert_eq!(prunable.chunks_len(), 1);
427    }
428
429    #[test]
430    fn test_new_with_pruned_chunks_overflow() {
431        // Try to create a Prunable with pruned_chunks that would overflow
432        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
433        let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
434
435        assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
436    }
437
438    #[test]
439    fn test_push_and_get_bits() {
440        let mut prunable: Prunable<4> = Prunable::new();
441
442        // Add some bits
443        prunable.push(true);
444        prunable.push(false);
445        prunable.push(true);
446
447        assert_eq!(prunable.len(), 3);
448        assert!(!prunable.is_empty());
449        assert!(prunable.get_bit(0));
450        assert!(!prunable.get_bit(1));
451        assert!(prunable.get_bit(2));
452    }
453
454    #[test]
455    fn test_push_byte() {
456        let mut prunable: Prunable<4> = Prunable::new();
457
458        // Add a byte
459        prunable.push_byte(0xFF);
460        assert_eq!(prunable.len(), 8);
461
462        // All bits should be set
463        for i in 0..8 {
464            assert!(prunable.get_bit(i as u64));
465        }
466
467        prunable.push_byte(0x00);
468        assert_eq!(prunable.len(), 16);
469
470        // Next 8 bits should be clear
471        for i in 8..16 {
472            assert!(!prunable.get_bit(i as u64));
473        }
474    }
475
476    #[test]
477    fn test_push_chunk() {
478        let mut prunable: Prunable<4> = Prunable::new();
479        let chunk = hex!("0xAABBCCDD");
480
481        prunable.push_chunk(&chunk);
482        assert_eq!(prunable.len(), 32); // 4 bytes * 8 bits
483
484        let retrieved_chunk = prunable.get_chunk_containing(0);
485        assert_eq!(retrieved_chunk, &chunk);
486    }
487
488    #[test]
489    fn test_set_bit() {
490        let mut prunable: Prunable<4> = Prunable::new();
491
492        // Add some bits
493        prunable.push(false);
494        prunable.push(false);
495        prunable.push(false);
496
497        assert!(!prunable.get_bit(1));
498
499        // Set a bit
500        prunable.set_bit(1, true);
501        assert!(prunable.get_bit(1));
502
503        // Set it back
504        prunable.set_bit(1, false);
505        assert!(!prunable.get_bit(1));
506    }
507
508    #[test]
509    fn test_pruning_basic() {
510        let mut prunable: Prunable<4> = Prunable::new();
511
512        // Add multiple chunks (4 bytes each)
513        let chunk1 = hex!("0x01020304");
514        let chunk2 = hex!("0x05060708");
515        let chunk3 = hex!("0x090A0B0C");
516
517        prunable.push_chunk(&chunk1);
518        prunable.push_chunk(&chunk2);
519        prunable.push_chunk(&chunk3);
520
521        assert_eq!(prunable.len(), 96); // 3 chunks * 32 bits
522        assert_eq!(prunable.pruned_chunks(), 0);
523
524        // Prune to second chunk (bit 32 is start of second chunk)
525        prunable.prune_to_bit(32);
526        assert_eq!(prunable.pruned_chunks(), 1);
527        assert_eq!(prunable.pruned_bits(), 32);
528        assert_eq!(prunable.len(), 96); // Total count unchanged
529
530        // Can still access non-pruned bits
531        assert_eq!(prunable.get_chunk_containing(32), &chunk2);
532        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
533
534        // Prune to third chunk
535        prunable.prune_to_bit(64);
536        assert_eq!(prunable.pruned_chunks(), 2);
537        assert_eq!(prunable.pruned_bits(), 64);
538        assert_eq!(prunable.len(), 96);
539
540        // Can still access the third chunk
541        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
542    }
543
544    #[test]
545    #[should_panic(expected = "bit pruned")]
546    fn test_get_pruned_bit_panics() {
547        let mut prunable: Prunable<4> = Prunable::new();
548
549        // Add two chunks
550        prunable.push_chunk(&[1, 2, 3, 4]);
551        prunable.push_chunk(&[5, 6, 7, 8]);
552
553        // Prune first chunk
554        prunable.prune_to_bit(32);
555
556        // Try to access pruned bit - should panic
557        prunable.get_bit(0);
558    }
559
560    #[test]
561    #[should_panic(expected = "bit pruned")]
562    fn test_get_pruned_chunk_panics() {
563        let mut prunable: Prunable<4> = Prunable::new();
564
565        // Add two chunks
566        prunable.push_chunk(&[1, 2, 3, 4]);
567        prunable.push_chunk(&[5, 6, 7, 8]);
568
569        // Prune first chunk
570        prunable.prune_to_bit(32);
571
572        // Try to access pruned chunk - should panic
573        prunable.get_chunk_containing(0);
574    }
575
576    #[test]
577    #[should_panic(expected = "bit pruned")]
578    fn test_set_pruned_bit_panics() {
579        let mut prunable: Prunable<4> = Prunable::new();
580
581        // Add two chunks
582        prunable.push_chunk(&[1, 2, 3, 4]);
583        prunable.push_chunk(&[5, 6, 7, 8]);
584
585        // Prune first chunk
586        prunable.prune_to_bit(32);
587
588        // Try to set pruned bit - should panic
589        prunable.set_bit(0, true);
590    }
591
592    #[test]
593    #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
594    fn test_prune_to_bit_out_of_bounds() {
595        let mut prunable: Prunable<1> = Prunable::new();
596
597        // Add 3 bytes (24 bits total)
598        prunable.push_byte(1);
599        prunable.push_byte(2);
600        prunable.push_byte(3);
601
602        // Try to prune to a bit beyond the bitmap
603        prunable.prune_to_bit(25);
604    }
605
606    #[test]
607    fn test_pruning_with_partial_chunk() {
608        let mut prunable: Prunable<4> = Prunable::new();
609
610        // Add two full chunks and some partial bits
611        prunable.push_chunk(&[0xFF; 4]);
612        prunable.push_chunk(&[0xAA; 4]);
613        prunable.push(true);
614        prunable.push(false);
615        prunable.push(true);
616
617        assert_eq!(prunable.len(), 67); // 64 + 3 bits
618
619        // Prune to second chunk
620        prunable.prune_to_bit(32);
621        assert_eq!(prunable.pruned_chunks(), 1);
622        assert_eq!(prunable.len(), 67);
623
624        // Can still access the partial bits
625        assert!(prunable.get_bit(64));
626        assert!(!prunable.get_bit(65));
627        assert!(prunable.get_bit(66));
628    }
629
630    #[test]
631    fn test_prune_idempotent() {
632        let mut prunable: Prunable<4> = Prunable::new();
633
634        // Add chunks
635        prunable.push_chunk(&[1, 2, 3, 4]);
636        prunable.push_chunk(&[5, 6, 7, 8]);
637
638        // Prune to bit 32
639        prunable.prune_to_bit(32);
640        assert_eq!(prunable.pruned_chunks(), 1);
641
642        // Pruning to same or earlier point should be no-op
643        prunable.prune_to_bit(32);
644        assert_eq!(prunable.pruned_chunks(), 1);
645
646        prunable.prune_to_bit(16);
647        assert_eq!(prunable.pruned_chunks(), 1);
648    }
649
650    #[test]
651    fn test_push_after_pruning() {
652        let mut prunable: Prunable<4> = Prunable::new();
653
654        // Add initial chunks
655        prunable.push_chunk(&[1, 2, 3, 4]);
656        prunable.push_chunk(&[5, 6, 7, 8]);
657
658        // Prune first chunk
659        prunable.prune_to_bit(32);
660        assert_eq!(prunable.len(), 64);
661        assert_eq!(prunable.pruned_chunks(), 1);
662
663        // Add more data
664        prunable.push_chunk(&[9, 10, 11, 12]);
665        assert_eq!(prunable.len(), 96); // 32 pruned + 64 active
666
667        // New chunk should be accessible
668        assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
669    }
670
671    #[test]
672    fn test_chunk_calculations() {
673        // Test chunk_num calculation
674        assert_eq!(Prunable::<4>::to_chunk_index(0), 0);
675        assert_eq!(Prunable::<4>::to_chunk_index(31), 0);
676        assert_eq!(Prunable::<4>::to_chunk_index(32), 1);
677        assert_eq!(Prunable::<4>::to_chunk_index(63), 1);
678        assert_eq!(Prunable::<4>::to_chunk_index(64), 2);
679
680        // Test chunk_byte_offset
681        assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
682        assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
683        assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
684        assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
685        assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); // Wraps to next chunk
686
687        // Test chunk_byte_bitmask
688        assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
689        assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
690        assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
691        assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); // Next byte
692    }
693
694    #[test]
695    fn test_last_chunk_with_pruning() {
696        let mut prunable: Prunable<4> = Prunable::new();
697
698        // Add chunks
699        prunable.push_chunk(&[1, 2, 3, 4]);
700        prunable.push_chunk(&[5, 6, 7, 8]);
701        prunable.push(true);
702        prunable.push(false);
703
704        let (_, next_bit) = prunable.last_chunk();
705        assert_eq!(next_bit, 2);
706
707        // Store the chunk data for comparison
708        let chunk_data = *prunable.last_chunk().0;
709
710        // Pruning shouldn't affect last_chunk
711        prunable.prune_to_bit(32);
712        let (chunk2, next_bit2) = prunable.last_chunk();
713        assert_eq!(next_bit2, 2);
714        assert_eq!(&chunk_data, chunk2);
715    }
716
717    #[test]
718    fn test_different_chunk_sizes() {
719        // Test with different chunk sizes
720        let mut p8: Prunable<8> = Prunable::new();
721        let mut p16: Prunable<16> = Prunable::new();
722        let mut p32: Prunable<32> = Prunable::new();
723
724        // Add same pattern to each
725        for i in 0..10 {
726            p8.push(i % 2 == 0);
727            p16.push(i % 2 == 0);
728            p32.push(i % 2 == 0);
729        }
730
731        // All should have same bit count
732        assert_eq!(p8.len(), 10);
733        assert_eq!(p16.len(), 10);
734        assert_eq!(p32.len(), 10);
735
736        // All should have same bit values
737        for i in 0..10 {
738            let expected = i % 2 == 0;
739            if expected {
740                assert!(p8.get_bit(i));
741                assert!(p16.get_bit(i));
742                assert!(p32.get_bit(i));
743            } else {
744                assert!(!p8.get_bit(i));
745                assert!(!p16.get_bit(i));
746                assert!(!p32.get_bit(i));
747            }
748        }
749    }
750
751    #[test]
752    fn test_get_bit_from_chunk() {
753        let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
754
755        // Test first byte
756        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
757        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
758        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
759        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
760
761        // Test second byte
762        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
763        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
764        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
765        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
766    }
767
768    #[test]
769    fn test_get_chunk() {
770        let mut prunable: Prunable<4> = Prunable::new();
771        let chunk1 = hex!("0x11223344");
772        let chunk2 = hex!("0x55667788");
773        let chunk3 = hex!("0x99AABBCC");
774
775        prunable.push_chunk(&chunk1);
776        prunable.push_chunk(&chunk2);
777        prunable.push_chunk(&chunk3);
778
779        // Before pruning
780        assert_eq!(prunable.get_chunk(0), &chunk1);
781        assert_eq!(prunable.get_chunk(1), &chunk2);
782        assert_eq!(prunable.get_chunk(2), &chunk3);
783
784        // After pruning
785        prunable.prune_to_bit(32);
786        assert_eq!(prunable.get_chunk(1), &chunk2);
787        assert_eq!(prunable.get_chunk(2), &chunk3);
788    }
789
790    #[test]
791    fn test_pop() {
792        let mut prunable: Prunable<4> = Prunable::new();
793
794        prunable.push(true);
795        prunable.push(false);
796        prunable.push(true);
797        assert_eq!(prunable.len(), 3);
798
799        assert!(prunable.pop());
800        assert_eq!(prunable.len(), 2);
801
802        assert!(!prunable.pop());
803        assert_eq!(prunable.len(), 1);
804
805        assert!(prunable.pop());
806        assert_eq!(prunable.len(), 0);
807        assert!(prunable.is_empty());
808
809        for i in 0..100 {
810            prunable.push(i % 3 == 0);
811        }
812        assert_eq!(prunable.len(), 100);
813
814        for i in (0..100).rev() {
815            let expected = i % 3 == 0;
816            assert_eq!(prunable.pop(), expected);
817            assert_eq!(prunable.len(), i);
818        }
819
820        assert!(prunable.is_empty());
821    }
822
823    #[test]
824    fn test_pop_chunk() {
825        let mut prunable: Prunable<4> = Prunable::new();
826        const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
827
828        // Test 1: Pop a single chunk and verify it returns the correct data
829        let chunk1 = hex!("0xAABBCCDD");
830        prunable.push_chunk(&chunk1);
831        assert_eq!(prunable.len(), CHUNK_SIZE);
832        let popped = prunable.pop_chunk();
833        assert_eq!(popped, chunk1);
834        assert_eq!(prunable.len(), 0);
835        assert!(prunable.is_empty());
836
837        // Test 2: Pop multiple chunks in reverse order
838        let chunk2 = hex!("0x11223344");
839        let chunk3 = hex!("0x55667788");
840        let chunk4 = hex!("0x99AABBCC");
841
842        prunable.push_chunk(&chunk2);
843        prunable.push_chunk(&chunk3);
844        prunable.push_chunk(&chunk4);
845        assert_eq!(prunable.len(), CHUNK_SIZE * 3);
846
847        assert_eq!(prunable.pop_chunk(), chunk4);
848        assert_eq!(prunable.len(), CHUNK_SIZE * 2);
849
850        assert_eq!(prunable.pop_chunk(), chunk3);
851        assert_eq!(prunable.len(), CHUNK_SIZE);
852
853        assert_eq!(prunable.pop_chunk(), chunk2);
854        assert_eq!(prunable.len(), 0);
855
856        // Test 3: Verify data integrity when popping chunks
857        prunable = Prunable::new();
858        let first_chunk = hex!("0xAABBCCDD");
859        let second_chunk = hex!("0x11223344");
860        prunable.push_chunk(&first_chunk);
861        prunable.push_chunk(&second_chunk);
862
863        // Pop the second chunk, verify it and that first chunk is intact
864        assert_eq!(prunable.pop_chunk(), second_chunk);
865        assert_eq!(prunable.len(), CHUNK_SIZE);
866
867        for i in 0..CHUNK_SIZE {
868            let byte_idx = (i / 8) as usize;
869            let bit_idx = i % 8;
870            let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
871            assert_eq!(prunable.get_bit(i), expected);
872        }
873
874        assert_eq!(prunable.pop_chunk(), first_chunk);
875        assert_eq!(prunable.len(), 0);
876    }
877
878    #[test]
879    #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
880    fn test_pop_chunk_not_aligned() {
881        let mut prunable: Prunable<4> = Prunable::new();
882
883        // Push a full chunk plus one bit
884        prunable.push_chunk(&[0xFF; 4]);
885        prunable.push(true);
886
887        // Should panic because not chunk-aligned
888        prunable.pop_chunk();
889    }
890
891    #[test]
892    #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
893    fn test_pop_chunk_insufficient_bits() {
894        let mut prunable: Prunable<4> = Prunable::new();
895
896        // Push only a few bits (less than a full chunk)
897        prunable.push(true);
898        prunable.push(false);
899
900        // Should panic because we don't have a full chunk to pop
901        prunable.pop_chunk();
902    }
903
904    #[test]
905    fn test_write_read_empty() {
906        let original: Prunable<4> = Prunable::new();
907        let encoded = original.encode();
908
909        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
910        assert_eq!(decoded.len(), original.len());
911        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
912        assert!(decoded.is_empty());
913    }
914
915    #[test]
916    fn test_write_read_non_empty() {
917        let mut original: Prunable<4> = Prunable::new();
918        original.push_chunk(&hex!("0xAABBCCDD"));
919        original.push_chunk(&hex!("0x11223344"));
920        original.push(true);
921        original.push(false);
922        original.push(true);
923
924        let encoded = original.encode();
925        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
926
927        assert_eq!(decoded.len(), original.len());
928        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
929        assert_eq!(decoded.len(), 67);
930
931        // Verify all bits match
932        for i in 0..original.len() {
933            assert_eq!(decoded.get_bit(i), original.get_bit(i));
934        }
935    }
936
937    #[test]
938    fn test_write_read_with_pruning() {
939        let mut original: Prunable<4> = Prunable::new();
940        original.push_chunk(&hex!("0x01020304"));
941        original.push_chunk(&hex!("0x05060708"));
942        original.push_chunk(&hex!("0x090A0B0C"));
943
944        // Prune first chunk
945        original.prune_to_bit(32);
946        assert_eq!(original.pruned_chunks(), 1);
947        assert_eq!(original.len(), 96);
948
949        let encoded = original.encode();
950        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
951
952        assert_eq!(decoded.len(), original.len());
953        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
954        assert_eq!(decoded.pruned_chunks(), 1);
955        assert_eq!(decoded.len(), 96);
956
957        // Verify remaining chunks match
958        assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
959        assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
960    }
961
962    #[test]
963    fn test_write_read_with_pruning_2() {
964        let mut original: Prunable<4> = Prunable::new();
965
966        // Add several chunks
967        for i in 0..5 {
968            let chunk = [
969                (i * 4) as u8,
970                (i * 4 + 1) as u8,
971                (i * 4 + 2) as u8,
972                (i * 4 + 3) as u8,
973            ];
974            original.push_chunk(&chunk);
975        }
976
977        // Keep only last two chunks
978        original.prune_to_bit(96); // Prune first 3 chunks
979        assert_eq!(original.pruned_chunks(), 3);
980        assert_eq!(original.len(), 160);
981
982        let encoded = original.encode();
983        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
984
985        assert_eq!(decoded.len(), original.len());
986        assert_eq!(decoded.pruned_chunks(), 3);
987
988        // Verify remaining accessible bits match
989        for i in 96..original.len() {
990            assert_eq!(decoded.get_bit(i), original.get_bit(i));
991        }
992    }
993
994    #[test]
995    fn test_encode_size_matches() {
996        let mut prunable: Prunable<4> = Prunable::new();
997        prunable.push_chunk(&[1, 2, 3, 4]);
998        prunable.push_chunk(&[5, 6, 7, 8]);
999        prunable.push(true);
1000
1001        let size = prunable.encode_size();
1002        let encoded = prunable.encode();
1003
1004        assert_eq!(size, encoded.len());
1005    }
1006
1007    #[test]
1008    fn test_encode_size_with_pruning() {
1009        let mut prunable: Prunable<4> = Prunable::new();
1010        prunable.push_chunk(&[1, 2, 3, 4]);
1011        prunable.push_chunk(&[5, 6, 7, 8]);
1012        prunable.push_chunk(&[9, 10, 11, 12]);
1013
1014        prunable.prune_to_bit(32);
1015
1016        let size = prunable.encode_size();
1017        let encoded = prunable.encode();
1018
1019        assert_eq!(size, encoded.len());
1020    }
1021
1022    #[test]
1023    fn test_read_max_len_validation() {
1024        let mut original: Prunable<4> = Prunable::new();
1025        for _ in 0..10 {
1026            original.push(true);
1027        }
1028
1029        let encoded = original.encode();
1030
1031        // Should succeed with sufficient max_len
1032        assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1033
1034        // Should fail with insufficient max_len
1035        let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1036        assert!(result.is_err());
1037    }
1038
1039    #[test]
1040    fn test_codec_roundtrip_different_chunk_sizes() {
1041        // Test with different chunk sizes
1042        let mut p8: Prunable<8> = Prunable::new();
1043        let mut p16: Prunable<16> = Prunable::new();
1044        let mut p32: Prunable<32> = Prunable::new();
1045
1046        for i in 0..100 {
1047            let bit = i % 3 == 0;
1048            p8.push(bit);
1049            p16.push(bit);
1050            p32.push(bit);
1051        }
1052
1053        // Roundtrip each
1054        let encoded8 = p8.encode();
1055        let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1056        assert_eq!(decoded8.len(), p8.len());
1057
1058        let encoded16 = p16.encode();
1059        let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1060        assert_eq!(decoded16.len(), p16.len());
1061
1062        let encoded32 = p32.encode();
1063        let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1064        assert_eq!(decoded32.len(), p32.len());
1065    }
1066
1067    #[test]
1068    fn test_read_pruned_chunks_overflow() {
1069        let mut buf = BytesMut::new();
1070
1071        // Write a pruned_chunks value that would overflow when multiplied by CHUNK_SIZE_BITS
1072        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1073        overflowing_pruned_chunks.write(&mut buf);
1074
1075        // Write a valid bitmap (empty)
1076        0u64.write(&mut buf); // len = 0
1077
1078        // Try to read - should fail with overflow error
1079        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1080        match result {
1081            Err(CodecError::Invalid(type_name, msg)) => {
1082                assert_eq!(type_name, "Prunable");
1083                assert_eq!(
1084                    msg,
1085                    "pruned_chunks would overflow when computing pruned_bits"
1086                );
1087            }
1088            Ok(_) => panic!("Expected error but got Ok"),
1089            Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1090        }
1091    }
1092
1093    #[test]
1094    fn test_read_total_length_overflow() {
1095        let mut buf = BytesMut::new();
1096
1097        // Make pruned_bits as large as possible without overflowing
1098        let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1099        let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1100
1101        // Make bitmap_len large enough that adding it overflows
1102        let remaining_space = u64::MAX - pruned_bits;
1103        let bitmap_len = remaining_space + 1; // Go over by 1 to trigger overflow
1104
1105        // Write the serialized data
1106        max_safe_pruned_chunks.write(&mut buf);
1107        bitmap_len.write(&mut buf);
1108
1109        // Write bitmap chunk data
1110        let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1111        for _ in 0..(num_chunks * 4) {
1112            0u8.write(&mut buf);
1113        }
1114
1115        // Try to read - should fail because pruned_bits + bitmap_len overflows u64
1116        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1117        match result {
1118            Err(CodecError::Invalid(type_name, msg)) => {
1119                assert_eq!(type_name, "Prunable");
1120                assert_eq!(
1121                    msg,
1122                    "total bitmap length (pruned + unpruned) would overflow u64"
1123                );
1124            }
1125            Ok(_) => panic!("Expected error but got Ok"),
1126            Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1127        }
1128    }
1129
1130    #[test]
1131    fn test_is_chunk_aligned() {
1132        // Empty bitmap is chunk aligned
1133        let prunable: Prunable<4> = Prunable::new();
1134        assert!(prunable.is_chunk_aligned());
1135
1136        // Add bits one at a time and check alignment
1137        let mut prunable: Prunable<4> = Prunable::new();
1138        for i in 1..=32 {
1139            prunable.push(i % 2 == 0);
1140            if i == 32 {
1141                assert!(prunable.is_chunk_aligned()); // Exactly one chunk
1142            } else {
1143                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1144            }
1145        }
1146
1147        // Add another full chunk
1148        for i in 33..=64 {
1149            prunable.push(i % 2 == 0);
1150            if i == 64 {
1151                assert!(prunable.is_chunk_aligned()); // Exactly two chunks
1152            } else {
1153                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1154            }
1155        }
1156
1157        // Test with push_chunk
1158        let mut prunable: Prunable<4> = Prunable::new();
1159        assert!(prunable.is_chunk_aligned());
1160        prunable.push_chunk(&[1, 2, 3, 4]);
1161        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1162        prunable.push_chunk(&[5, 6, 7, 8]);
1163        assert!(prunable.is_chunk_aligned()); // 64 bits = 2 chunks
1164        prunable.push(true);
1165        assert!(!prunable.is_chunk_aligned()); // 65 bits = partial chunk
1166
1167        // Test alignment with pruning
1168        let mut prunable: Prunable<4> = Prunable::new();
1169        prunable.push_chunk(&[1, 2, 3, 4]);
1170        prunable.push_chunk(&[5, 6, 7, 8]);
1171        prunable.push_chunk(&[9, 10, 11, 12]);
1172        assert!(prunable.is_chunk_aligned()); // 96 bits = 3 chunks
1173
1174        // Prune first chunk - still aligned (64 bits remaining)
1175        prunable.prune_to_bit(32);
1176        assert!(prunable.is_chunk_aligned());
1177        assert_eq!(prunable.len(), 96);
1178
1179        // Add a partial chunk
1180        prunable.push(true);
1181        prunable.push(false);
1182        assert!(!prunable.is_chunk_aligned()); // 98 bits total
1183
1184        // Prune to align again
1185        prunable.prune_to_bit(64);
1186        assert!(!prunable.is_chunk_aligned()); // 98 bits total (34 bits remaining)
1187
1188        // Test with new_with_pruned_chunks
1189        let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1190        assert!(prunable.is_chunk_aligned()); // 64 bits pruned, 0 bits in bitmap
1191
1192        let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1193        assert!(prunable.is_chunk_aligned()); // 32 bits pruned, 0 bits in bitmap
1194        prunable.push(true);
1195        assert!(!prunable.is_chunk_aligned()); // 33 bits total
1196
1197        // Test with push_byte
1198        let mut prunable: Prunable<4> = Prunable::new();
1199        for _ in 0..4 {
1200            prunable.push_byte(0xFF);
1201        }
1202        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1203
1204        // Test after pop
1205        prunable.pop();
1206        assert!(!prunable.is_chunk_aligned()); // 31 bits
1207
1208        // Pop back to alignment
1209        for _ in 0..31 {
1210            prunable.pop();
1211        }
1212        assert!(prunable.is_chunk_aligned()); // 0 bits
1213    }
1214
1215    #[cfg(feature = "arbitrary")]
1216    mod conformance {
1217        use super::*;
1218        use commonware_codec::conformance::CodecConformance;
1219
1220        commonware_conformance::conformance_tests! {
1221            CodecConformance<Prunable<16>>,
1222        }
1223    }
1224}