commonware_utils/bitmap/
prunable.rs

1//! A prunable wrapper around BitMap that tracks pruned chunks.
2
3use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8/// Errors that can occur when working with a prunable bitmap.
9#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11    /// The provided pruned_chunks value would overflow.
12    #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13    PrunedChunksOverflow,
14}
15
16/// A prunable bitmap that stores data in chunks of N bytes.
17///
18/// # Panics
19///
20/// Operations panic if `bit / CHUNK_SIZE_BITS > usize::MAX`. On 32-bit systems
21/// with N=32, this occurs at bit >= 1,099,511,627,776.
22#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24    /// The underlying BitMap storing the actual bits.
25    bitmap: BitMap<N>,
26
27    /// The number of bitmap chunks that have been pruned.
28    ///
29    /// # Invariant
30    ///
31    /// Must satisfy: `pruned_chunks as u64 * CHUNK_SIZE_BITS + bitmap.len() <= u64::MAX`
32    pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36    /// The size of a chunk in bits.
37    pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39    /* Constructors */
40
41    /// Create a new empty prunable bitmap.
42    pub const fn new() -> Self {
43        Self {
44            bitmap: BitMap::new(),
45            pruned_chunks: 0,
46        }
47    }
48
49    /// Create a new empty prunable bitmap with the given number of pruned chunks.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if `pruned_chunks` violates the invariant that
54    /// `pruned_chunks as u64 * CHUNK_SIZE_BITS` must not overflow u64.
55    pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56        // Validate the invariant: pruned_chunks * CHUNK_SIZE_BITS must fit in u64
57        let pruned_chunks_u64 = pruned_chunks as u64;
58        pruned_chunks_u64
59            .checked_mul(Self::CHUNK_SIZE_BITS)
60            .ok_or(Error::PrunedChunksOverflow)?;
61
62        Ok(Self {
63            bitmap: BitMap::new(),
64            pruned_chunks,
65        })
66    }
67
68    /* Length */
69
70    /// Return the number of bits in the bitmap, irrespective of any pruning.
71    #[inline]
72    pub const fn len(&self) -> u64 {
73        let pruned_bits = (self.pruned_chunks as u64)
74            .checked_mul(Self::CHUNK_SIZE_BITS)
75            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77        pruned_bits
78            .checked_add(self.bitmap.len())
79            .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80    }
81
82    /// Return true if the bitmap is empty.
83    #[inline]
84    pub const fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    /// Returns true if the bitmap length is aligned to a chunk boundary.
89    #[inline]
90    pub const fn is_chunk_aligned(&self) -> bool {
91        self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92    }
93
94    /// Return the number of unpruned chunks in the bitmap.
95    #[inline]
96    pub fn chunks_len(&self) -> usize {
97        self.bitmap.chunks_len()
98    }
99
100    /// Return the number of pruned chunks.
101    #[inline]
102    pub const fn pruned_chunks(&self) -> usize {
103        self.pruned_chunks
104    }
105
106    /// Return the number of pruned bits.
107    #[inline]
108    pub const fn pruned_bits(&self) -> u64 {
109        (self.pruned_chunks as u64)
110            .checked_mul(Self::CHUNK_SIZE_BITS)
111            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
112    }
113
114    /* Getters */
115
116    /// Get the value of a bit.
117    ///
118    /// # Warning
119    ///
120    /// Panics if the bit doesn't exist or has been pruned.
121    #[inline]
122    pub fn get_bit(&self, bit: u64) -> bool {
123        let chunk_num = Self::unpruned_chunk(bit);
124        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
125
126        // Adjust bit to account for pruning
127        self.bitmap.get(bit - self.pruned_bits())
128    }
129
130    /// Returns the bitmap chunk containing the specified bit.
131    ///
132    /// # Warning
133    ///
134    /// Panics if the bit doesn't exist or has been pruned.
135    #[inline]
136    pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
137        let chunk_num = Self::unpruned_chunk(bit);
138        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
139
140        // Adjust bit to account for pruning
141        self.bitmap.get_chunk_containing(bit - self.pruned_bits())
142    }
143
144    /// Get the value of a bit from its chunk.
145    /// `bit` is an index into the entire bitmap, not just the chunk.
146    #[inline]
147    pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
148        BitMap::<N>::get_from_chunk(chunk, bit)
149    }
150
151    /// Return the last chunk of the bitmap and its size in bits.
152    #[inline]
153    pub fn last_chunk(&self) -> (&[u8; N], u64) {
154        self.bitmap.last_chunk()
155    }
156
157    /* Setters */
158
159    /// Set the value of the given bit.
160    ///
161    /// # Warning
162    ///
163    /// Panics if the bit doesn't exist or has been pruned.
164    pub fn set_bit(&mut self, bit: u64, value: bool) {
165        let chunk_num = Self::unpruned_chunk(bit);
166        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
167
168        // Adjust bit to account for pruning
169        self.bitmap.set(bit - self.pruned_bits(), value);
170    }
171
172    /// Add a single bit to the end of the bitmap.
173    pub fn push(&mut self, bit: bool) {
174        self.bitmap.push(bit);
175    }
176
177    /// Remove and return the last bit from the bitmap.
178    ///
179    /// # Warning
180    ///
181    /// Panics if the bitmap is empty.
182    pub fn pop(&mut self) -> bool {
183        self.bitmap.pop()
184    }
185
186    /// Add a byte to the bitmap.
187    ///
188    /// # Warning
189    ///
190    /// Panics if self.next_bit is not byte aligned.
191    pub fn push_byte(&mut self, byte: u8) {
192        self.bitmap.push_byte(byte);
193    }
194
195    /// Add a chunk of bits to the bitmap.
196    ///
197    /// # Warning
198    ///
199    /// Panics if self.next_bit is not chunk aligned.
200    pub fn push_chunk(&mut self, chunk: &[u8; N]) {
201        self.bitmap.push_chunk(chunk);
202    }
203
204    /// Remove and return the last complete chunk from the bitmap.
205    ///
206    /// # Warning
207    ///
208    /// Panics if the bitmap has fewer than `CHUNK_SIZE_BITS` bits or if not chunk-aligned.
209    pub fn pop_chunk(&mut self) -> [u8; N] {
210        self.bitmap.pop_chunk()
211    }
212
213    /* Pruning */
214
215    /// Prune all complete chunks before the chunk containing the given bit.
216    ///
217    /// The chunk containing `bit` and all subsequent chunks are retained. All chunks
218    /// before it are pruned.
219    ///
220    /// If `bit` equals the bitmap length, this prunes all complete chunks while retaining
221    /// the empty trailing chunk, preparing the bitmap for appending new data.
222    ///
223    /// # Warning
224    ///
225    /// Panics if `bit` is greater than the bitmap length.
226    pub fn prune_to_bit(&mut self, bit: u64) {
227        assert!(
228            bit <= self.len(),
229            "bit {} out of bounds (len: {})",
230            bit,
231            self.len()
232        );
233
234        let chunk = Self::unpruned_chunk(bit);
235        if chunk < self.pruned_chunks {
236            return;
237        }
238
239        let chunks_to_prune = chunk - self.pruned_chunks;
240        self.bitmap.prune_chunks(chunks_to_prune);
241        self.pruned_chunks = chunk;
242    }
243
244    /* Indexing Helpers */
245
246    /// Convert a bit into a bitmask for the byte containing that bit.
247    #[inline]
248    pub const fn chunk_byte_bitmask(bit: u64) -> u8 {
249        BitMap::<N>::chunk_byte_bitmask(bit)
250    }
251
252    /// Convert a bit into the index of the byte within a chunk containing the bit.
253    #[inline]
254    pub const fn chunk_byte_offset(bit: u64) -> usize {
255        BitMap::<N>::chunk_byte_offset(bit)
256    }
257
258    /// Convert a bit into the index of the chunk it belongs to within the bitmap,
259    /// taking pruned chunks into account. That is, the returned value is a valid index into
260    /// the inner bitmap.
261    ///
262    /// # Warning
263    ///
264    /// Panics if the bit doesn't exist or has been pruned.
265    #[inline]
266    pub fn pruned_chunk(&self, bit: u64) -> usize {
267        assert!(bit < self.len(), "out of bounds: {bit}");
268        let chunk = Self::unpruned_chunk(bit);
269        assert!(chunk >= self.pruned_chunks, "bit pruned: {bit}");
270
271        chunk - self.pruned_chunks
272    }
273
274    /// Convert a bit into the number of the chunk it belongs to,
275    /// ignoring any pruning.
276    ///
277    /// # Panics
278    ///
279    /// Panics if `bit / CHUNK_SIZE_BITS > usize::MAX`.
280    #[inline]
281    pub fn unpruned_chunk(bit: u64) -> usize {
282        BitMap::<N>::chunk(bit)
283    }
284
285    /// Get a reference to a chunk by its index in the current bitmap
286    /// Note this is an index into the chunks, not a bit.
287    #[inline]
288    pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
289        self.bitmap.get_chunk(chunk)
290    }
291
292    /// Overwrite a chunk's data by its raw (unpruned) chunk index.
293    ///
294    /// # Panics
295    ///
296    /// Panics if the chunk is pruned or out of bounds.
297    pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
298        assert!(
299            chunk_index >= self.pruned_chunks,
300            "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
301            self.pruned_chunks
302        );
303        let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
304        self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
305    }
306
307    /// Unprune chunks by prepending them back to the front of the bitmap.
308    ///
309    /// The caller must provide chunks in **reverse** order: to restore chunks with
310    /// indices [0, 1, 2], pass them as [2, 1, 0]. This is necessary because each chunk
311    /// is prepended to the front, so the last chunk provided becomes the first chunk
312    /// in the bitmap.
313    ///
314    /// # Panics
315    ///
316    /// Panics if chunks.len() > self.pruned_chunks.
317    pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
318        assert!(
319            chunks.len() <= self.pruned_chunks,
320            "cannot unprune {} chunks (only {} pruned)",
321            chunks.len(),
322            self.pruned_chunks
323        );
324
325        for chunk in chunks.iter() {
326            self.bitmap.prepend_chunk(chunk);
327        }
328
329        self.pruned_chunks -= chunks.len();
330    }
331}
332
333impl<const N: usize> Default for Prunable<N> {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339impl<const N: usize> Write for Prunable<N> {
340    fn write(&self, buf: &mut impl BufMut) {
341        (self.pruned_chunks as u64).write(buf);
342        self.bitmap.write(buf);
343    }
344}
345
346impl<const N: usize> Read for Prunable<N> {
347    // Max length for the unpruned portion of the bitmap.
348    type Cfg = u64;
349
350    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
351        let pruned_chunks_u64 = u64::read(buf)?;
352
353        // Validate that pruned_chunks * CHUNK_SIZE_BITS doesn't overflow u64
354        let pruned_bits =
355            pruned_chunks_u64
356                .checked_mul(Self::CHUNK_SIZE_BITS)
357                .ok_or(CodecError::Invalid(
358                    "Prunable",
359                    "pruned_chunks would overflow when computing pruned_bits",
360                ))?;
361
362        let pruned_chunks = usize::try_from(pruned_chunks_u64)
363            .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
364
365        let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
366
367        // Validate that total length (pruned_bits + bitmap.len()) doesn't overflow u64
368        pruned_bits
369            .checked_add(bitmap.len())
370            .ok_or(CodecError::Invalid(
371                "Prunable",
372                "total bitmap length (pruned + unpruned) would overflow u64",
373            ))?;
374
375        Ok(Self {
376            bitmap,
377            pruned_chunks,
378        })
379    }
380}
381
382impl<const N: usize> EncodeSize for Prunable<N> {
383    fn encode_size(&self) -> usize {
384        (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
385    }
386}
387
388#[cfg(feature = "arbitrary")]
389impl<const N: usize> arbitrary::Arbitrary<'_> for Prunable<N> {
390    fn arbitrary(u: &mut arbitrary::Unstructured<'_>) -> arbitrary::Result<Self> {
391        let mut bitmap = Self {
392            bitmap: BitMap::<N>::arbitrary(u)?,
393            pruned_chunks: 0,
394        };
395        let prune_to = u.int_in_range(0..=bitmap.len())?;
396        bitmap.prune_to_bit(prune_to);
397        Ok(bitmap)
398    }
399}
400
401#[cfg(test)]
402mod tests {
403    use super::*;
404    use crate::hex;
405    use bytes::BytesMut;
406    use commonware_codec::Encode;
407
408    #[test]
409    fn test_new() {
410        let prunable: Prunable<32> = Prunable::new();
411        assert_eq!(prunable.len(), 0);
412        assert_eq!(prunable.pruned_bits(), 0);
413        assert_eq!(prunable.pruned_chunks(), 0);
414        assert!(prunable.is_empty());
415        assert_eq!(prunable.chunks_len(), 0); // No chunks when empty
416    }
417
418    #[test]
419    fn test_new_with_pruned_chunks() {
420        let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
421        assert_eq!(prunable.len(), 16);
422        assert_eq!(prunable.pruned_bits(), 16);
423        assert_eq!(prunable.pruned_chunks(), 1);
424        assert_eq!(prunable.chunks_len(), 0);
425    }
426
427    #[test]
428    fn test_new_with_pruned_chunks_overflow() {
429        // Try to create a Prunable with pruned_chunks that would overflow
430        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
431        let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
432
433        assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
434    }
435
436    #[test]
437    fn test_push_and_get_bits() {
438        let mut prunable: Prunable<4> = Prunable::new();
439
440        // Add some bits
441        prunable.push(true);
442        prunable.push(false);
443        prunable.push(true);
444
445        assert_eq!(prunable.len(), 3);
446        assert!(!prunable.is_empty());
447        assert!(prunable.get_bit(0));
448        assert!(!prunable.get_bit(1));
449        assert!(prunable.get_bit(2));
450    }
451
452    #[test]
453    fn test_push_byte() {
454        let mut prunable: Prunable<4> = Prunable::new();
455
456        // Add a byte
457        prunable.push_byte(0xFF);
458        assert_eq!(prunable.len(), 8);
459
460        // All bits should be set
461        for i in 0..8 {
462            assert!(prunable.get_bit(i as u64));
463        }
464
465        prunable.push_byte(0x00);
466        assert_eq!(prunable.len(), 16);
467
468        // Next 8 bits should be clear
469        for i in 8..16 {
470            assert!(!prunable.get_bit(i as u64));
471        }
472    }
473
474    #[test]
475    fn test_push_chunk() {
476        let mut prunable: Prunable<4> = Prunable::new();
477        let chunk = hex!("0xAABBCCDD");
478
479        prunable.push_chunk(&chunk);
480        assert_eq!(prunable.len(), 32); // 4 bytes * 8 bits
481
482        let retrieved_chunk = prunable.get_chunk_containing(0);
483        assert_eq!(retrieved_chunk, &chunk);
484    }
485
486    #[test]
487    fn test_set_bit() {
488        let mut prunable: Prunable<4> = Prunable::new();
489
490        // Add some bits
491        prunable.push(false);
492        prunable.push(false);
493        prunable.push(false);
494
495        assert!(!prunable.get_bit(1));
496
497        // Set a bit
498        prunable.set_bit(1, true);
499        assert!(prunable.get_bit(1));
500
501        // Set it back
502        prunable.set_bit(1, false);
503        assert!(!prunable.get_bit(1));
504    }
505
506    #[test]
507    fn test_pruning_basic() {
508        let mut prunable: Prunable<4> = Prunable::new();
509
510        // Add multiple chunks (4 bytes each)
511        let chunk1 = hex!("0x01020304");
512        let chunk2 = hex!("0x05060708");
513        let chunk3 = hex!("0x090A0B0C");
514
515        prunable.push_chunk(&chunk1);
516        prunable.push_chunk(&chunk2);
517        prunable.push_chunk(&chunk3);
518
519        assert_eq!(prunable.len(), 96); // 3 chunks * 32 bits
520        assert_eq!(prunable.pruned_chunks(), 0);
521
522        // Prune to second chunk (bit 32 is start of second chunk)
523        prunable.prune_to_bit(32);
524        assert_eq!(prunable.pruned_chunks(), 1);
525        assert_eq!(prunable.pruned_bits(), 32);
526        assert_eq!(prunable.len(), 96); // Total count unchanged
527
528        // Can still access non-pruned bits
529        assert_eq!(prunable.get_chunk_containing(32), &chunk2);
530        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
531
532        // Prune to third chunk
533        prunable.prune_to_bit(64);
534        assert_eq!(prunable.pruned_chunks(), 2);
535        assert_eq!(prunable.pruned_bits(), 64);
536        assert_eq!(prunable.len(), 96);
537
538        // Can still access the third chunk
539        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
540    }
541
542    #[test]
543    #[should_panic(expected = "bit pruned")]
544    fn test_get_pruned_bit_panics() {
545        let mut prunable: Prunable<4> = Prunable::new();
546
547        // Add two chunks
548        prunable.push_chunk(&[1, 2, 3, 4]);
549        prunable.push_chunk(&[5, 6, 7, 8]);
550
551        // Prune first chunk
552        prunable.prune_to_bit(32);
553
554        // Try to access pruned bit - should panic
555        prunable.get_bit(0);
556    }
557
558    #[test]
559    #[should_panic(expected = "bit pruned")]
560    fn test_get_pruned_chunk_panics() {
561        let mut prunable: Prunable<4> = Prunable::new();
562
563        // Add two chunks
564        prunable.push_chunk(&[1, 2, 3, 4]);
565        prunable.push_chunk(&[5, 6, 7, 8]);
566
567        // Prune first chunk
568        prunable.prune_to_bit(32);
569
570        // Try to access pruned chunk - should panic
571        prunable.get_chunk_containing(0);
572    }
573
574    #[test]
575    #[should_panic(expected = "bit pruned")]
576    fn test_set_pruned_bit_panics() {
577        let mut prunable: Prunable<4> = Prunable::new();
578
579        // Add two chunks
580        prunable.push_chunk(&[1, 2, 3, 4]);
581        prunable.push_chunk(&[5, 6, 7, 8]);
582
583        // Prune first chunk
584        prunable.prune_to_bit(32);
585
586        // Try to set pruned bit - should panic
587        prunable.set_bit(0, true);
588    }
589
590    #[test]
591    #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
592    fn test_prune_to_bit_out_of_bounds() {
593        let mut prunable: Prunable<1> = Prunable::new();
594
595        // Add 3 bytes (24 bits total)
596        prunable.push_byte(1);
597        prunable.push_byte(2);
598        prunable.push_byte(3);
599
600        // Try to prune to a bit beyond the bitmap
601        prunable.prune_to_bit(25);
602    }
603
604    #[test]
605    fn test_pruning_with_partial_chunk() {
606        let mut prunable: Prunable<4> = Prunable::new();
607
608        // Add two full chunks and some partial bits
609        prunable.push_chunk(&[0xFF; 4]);
610        prunable.push_chunk(&[0xAA; 4]);
611        prunable.push(true);
612        prunable.push(false);
613        prunable.push(true);
614
615        assert_eq!(prunable.len(), 67); // 64 + 3 bits
616
617        // Prune to second chunk
618        prunable.prune_to_bit(32);
619        assert_eq!(prunable.pruned_chunks(), 1);
620        assert_eq!(prunable.len(), 67);
621
622        // Can still access the partial bits
623        assert!(prunable.get_bit(64));
624        assert!(!prunable.get_bit(65));
625        assert!(prunable.get_bit(66));
626    }
627
628    #[test]
629    fn test_prune_idempotent() {
630        let mut prunable: Prunable<4> = Prunable::new();
631
632        // Add chunks
633        prunable.push_chunk(&[1, 2, 3, 4]);
634        prunable.push_chunk(&[5, 6, 7, 8]);
635
636        // Prune to bit 32
637        prunable.prune_to_bit(32);
638        assert_eq!(prunable.pruned_chunks(), 1);
639
640        // Pruning to same or earlier point should be no-op
641        prunable.prune_to_bit(32);
642        assert_eq!(prunable.pruned_chunks(), 1);
643
644        prunable.prune_to_bit(16);
645        assert_eq!(prunable.pruned_chunks(), 1);
646    }
647
648    #[test]
649    fn test_push_after_pruning() {
650        let mut prunable: Prunable<4> = Prunable::new();
651
652        // Add initial chunks
653        prunable.push_chunk(&[1, 2, 3, 4]);
654        prunable.push_chunk(&[5, 6, 7, 8]);
655
656        // Prune first chunk
657        prunable.prune_to_bit(32);
658        assert_eq!(prunable.len(), 64);
659        assert_eq!(prunable.pruned_chunks(), 1);
660
661        // Add more data
662        prunable.push_chunk(&[9, 10, 11, 12]);
663        assert_eq!(prunable.len(), 96); // 32 pruned + 64 active
664
665        // New chunk should be accessible
666        assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
667    }
668
669    #[test]
670    fn test_chunk_calculations() {
671        // Test chunk_num calculation
672        assert_eq!(Prunable::<4>::unpruned_chunk(0), 0);
673        assert_eq!(Prunable::<4>::unpruned_chunk(31), 0);
674        assert_eq!(Prunable::<4>::unpruned_chunk(32), 1);
675        assert_eq!(Prunable::<4>::unpruned_chunk(63), 1);
676        assert_eq!(Prunable::<4>::unpruned_chunk(64), 2);
677
678        // Test chunk_byte_offset
679        assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
680        assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
681        assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
682        assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
683        assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); // Wraps to next chunk
684
685        // Test chunk_byte_bitmask
686        assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
687        assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
688        assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
689        assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); // Next byte
690    }
691
692    #[test]
693    fn test_pruned_chunk() {
694        let mut prunable: Prunable<4> = Prunable::new();
695
696        // Add three chunks
697        for i in 0..3 {
698            let chunk = [
699                (i * 4) as u8,
700                (i * 4 + 1) as u8,
701                (i * 4 + 2) as u8,
702                (i * 4 + 3) as u8,
703            ];
704            prunable.push_chunk(&chunk);
705        }
706
707        // Before pruning
708        assert_eq!(prunable.pruned_chunk(0), 0);
709        assert_eq!(prunable.pruned_chunk(32), 1);
710        assert_eq!(prunable.pruned_chunk(64), 2);
711
712        // After pruning first chunk
713        prunable.prune_to_bit(32);
714        assert_eq!(prunable.pruned_chunk(32), 0); // Now at index 0
715        assert_eq!(prunable.pruned_chunk(64), 1); // Now at index 1
716    }
717
718    #[test]
719    fn test_last_chunk_with_pruning() {
720        let mut prunable: Prunable<4> = Prunable::new();
721
722        // Add chunks
723        prunable.push_chunk(&[1, 2, 3, 4]);
724        prunable.push_chunk(&[5, 6, 7, 8]);
725        prunable.push(true);
726        prunable.push(false);
727
728        let (_, next_bit) = prunable.last_chunk();
729        assert_eq!(next_bit, 2);
730
731        // Store the chunk data for comparison
732        let chunk_data = *prunable.last_chunk().0;
733
734        // Pruning shouldn't affect last_chunk
735        prunable.prune_to_bit(32);
736        let (chunk2, next_bit2) = prunable.last_chunk();
737        assert_eq!(next_bit2, 2);
738        assert_eq!(&chunk_data, chunk2);
739    }
740
741    #[test]
742    fn test_different_chunk_sizes() {
743        // Test with different chunk sizes
744        let mut p8: Prunable<8> = Prunable::new();
745        let mut p16: Prunable<16> = Prunable::new();
746        let mut p32: Prunable<32> = Prunable::new();
747
748        // Add same pattern to each
749        for i in 0..10 {
750            p8.push(i % 2 == 0);
751            p16.push(i % 2 == 0);
752            p32.push(i % 2 == 0);
753        }
754
755        // All should have same bit count
756        assert_eq!(p8.len(), 10);
757        assert_eq!(p16.len(), 10);
758        assert_eq!(p32.len(), 10);
759
760        // All should have same bit values
761        for i in 0..10 {
762            let expected = i % 2 == 0;
763            if expected {
764                assert!(p8.get_bit(i));
765                assert!(p16.get_bit(i));
766                assert!(p32.get_bit(i));
767            } else {
768                assert!(!p8.get_bit(i));
769                assert!(!p16.get_bit(i));
770                assert!(!p32.get_bit(i));
771            }
772        }
773    }
774
775    #[test]
776    fn test_get_bit_from_chunk() {
777        let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
778
779        // Test first byte
780        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
781        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
782        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
783        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
784
785        // Test second byte
786        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
787        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
788        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
789        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
790    }
791
792    #[test]
793    fn test_get_chunk() {
794        let mut prunable: Prunable<4> = Prunable::new();
795        let chunk1 = hex!("0x11223344");
796        let chunk2 = hex!("0x55667788");
797        let chunk3 = hex!("0x99AABBCC");
798
799        prunable.push_chunk(&chunk1);
800        prunable.push_chunk(&chunk2);
801        prunable.push_chunk(&chunk3);
802
803        // Before pruning
804        assert_eq!(prunable.get_chunk(0), &chunk1);
805        assert_eq!(prunable.get_chunk(1), &chunk2);
806        assert_eq!(prunable.get_chunk(2), &chunk3);
807
808        // After pruning
809        prunable.prune_to_bit(32);
810        assert_eq!(prunable.get_chunk(0), &chunk2);
811        assert_eq!(prunable.get_chunk(1), &chunk3);
812    }
813
814    #[test]
815    fn test_pop() {
816        let mut prunable: Prunable<4> = Prunable::new();
817
818        prunable.push(true);
819        prunable.push(false);
820        prunable.push(true);
821        assert_eq!(prunable.len(), 3);
822
823        assert!(prunable.pop());
824        assert_eq!(prunable.len(), 2);
825
826        assert!(!prunable.pop());
827        assert_eq!(prunable.len(), 1);
828
829        assert!(prunable.pop());
830        assert_eq!(prunable.len(), 0);
831        assert!(prunable.is_empty());
832
833        for i in 0..100 {
834            prunable.push(i % 3 == 0);
835        }
836        assert_eq!(prunable.len(), 100);
837
838        for i in (0..100).rev() {
839            let expected = i % 3 == 0;
840            assert_eq!(prunable.pop(), expected);
841            assert_eq!(prunable.len(), i);
842        }
843
844        assert!(prunable.is_empty());
845    }
846
847    #[test]
848    fn test_pop_chunk() {
849        let mut prunable: Prunable<4> = Prunable::new();
850        const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
851
852        // Test 1: Pop a single chunk and verify it returns the correct data
853        let chunk1 = hex!("0xAABBCCDD");
854        prunable.push_chunk(&chunk1);
855        assert_eq!(prunable.len(), CHUNK_SIZE);
856        let popped = prunable.pop_chunk();
857        assert_eq!(popped, chunk1);
858        assert_eq!(prunable.len(), 0);
859        assert!(prunable.is_empty());
860
861        // Test 2: Pop multiple chunks in reverse order
862        let chunk2 = hex!("0x11223344");
863        let chunk3 = hex!("0x55667788");
864        let chunk4 = hex!("0x99AABBCC");
865
866        prunable.push_chunk(&chunk2);
867        prunable.push_chunk(&chunk3);
868        prunable.push_chunk(&chunk4);
869        assert_eq!(prunable.len(), CHUNK_SIZE * 3);
870
871        assert_eq!(prunable.pop_chunk(), chunk4);
872        assert_eq!(prunable.len(), CHUNK_SIZE * 2);
873
874        assert_eq!(prunable.pop_chunk(), chunk3);
875        assert_eq!(prunable.len(), CHUNK_SIZE);
876
877        assert_eq!(prunable.pop_chunk(), chunk2);
878        assert_eq!(prunable.len(), 0);
879
880        // Test 3: Verify data integrity when popping chunks
881        prunable = Prunable::new();
882        let first_chunk = hex!("0xAABBCCDD");
883        let second_chunk = hex!("0x11223344");
884        prunable.push_chunk(&first_chunk);
885        prunable.push_chunk(&second_chunk);
886
887        // Pop the second chunk, verify it and that first chunk is intact
888        assert_eq!(prunable.pop_chunk(), second_chunk);
889        assert_eq!(prunable.len(), CHUNK_SIZE);
890
891        for i in 0..CHUNK_SIZE {
892            let byte_idx = (i / 8) as usize;
893            let bit_idx = i % 8;
894            let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
895            assert_eq!(prunable.get_bit(i), expected);
896        }
897
898        assert_eq!(prunable.pop_chunk(), first_chunk);
899        assert_eq!(prunable.len(), 0);
900    }
901
902    #[test]
903    #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
904    fn test_pop_chunk_not_aligned() {
905        let mut prunable: Prunable<4> = Prunable::new();
906
907        // Push a full chunk plus one bit
908        prunable.push_chunk(&[0xFF; 4]);
909        prunable.push(true);
910
911        // Should panic because not chunk-aligned
912        prunable.pop_chunk();
913    }
914
915    #[test]
916    #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
917    fn test_pop_chunk_insufficient_bits() {
918        let mut prunable: Prunable<4> = Prunable::new();
919
920        // Push only a few bits (less than a full chunk)
921        prunable.push(true);
922        prunable.push(false);
923
924        // Should panic because we don't have a full chunk to pop
925        prunable.pop_chunk();
926    }
927
928    #[test]
929    fn test_write_read_empty() {
930        let original: Prunable<4> = Prunable::new();
931        let encoded = original.encode();
932
933        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
934        assert_eq!(decoded.len(), original.len());
935        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
936        assert!(decoded.is_empty());
937    }
938
939    #[test]
940    fn test_write_read_non_empty() {
941        let mut original: Prunable<4> = Prunable::new();
942        original.push_chunk(&hex!("0xAABBCCDD"));
943        original.push_chunk(&hex!("0x11223344"));
944        original.push(true);
945        original.push(false);
946        original.push(true);
947
948        let encoded = original.encode();
949        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
950
951        assert_eq!(decoded.len(), original.len());
952        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
953        assert_eq!(decoded.len(), 67);
954
955        // Verify all bits match
956        for i in 0..original.len() {
957            assert_eq!(decoded.get_bit(i), original.get_bit(i));
958        }
959    }
960
961    #[test]
962    fn test_write_read_with_pruning() {
963        let mut original: Prunable<4> = Prunable::new();
964        original.push_chunk(&hex!("0x01020304"));
965        original.push_chunk(&hex!("0x05060708"));
966        original.push_chunk(&hex!("0x090A0B0C"));
967
968        // Prune first chunk
969        original.prune_to_bit(32);
970        assert_eq!(original.pruned_chunks(), 1);
971        assert_eq!(original.len(), 96);
972
973        let encoded = original.encode();
974        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
975
976        assert_eq!(decoded.len(), original.len());
977        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
978        assert_eq!(decoded.pruned_chunks(), 1);
979        assert_eq!(decoded.len(), 96);
980
981        // Verify remaining chunks match
982        assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
983        assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
984    }
985
986    #[test]
987    fn test_write_read_with_pruning_2() {
988        let mut original: Prunable<4> = Prunable::new();
989
990        // Add several chunks
991        for i in 0..5 {
992            let chunk = [
993                (i * 4) as u8,
994                (i * 4 + 1) as u8,
995                (i * 4 + 2) as u8,
996                (i * 4 + 3) as u8,
997            ];
998            original.push_chunk(&chunk);
999        }
1000
1001        // Keep only last two chunks
1002        original.prune_to_bit(96); // Prune first 3 chunks
1003        assert_eq!(original.pruned_chunks(), 3);
1004        assert_eq!(original.len(), 160);
1005
1006        let encoded = original.encode();
1007        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
1008
1009        assert_eq!(decoded.len(), original.len());
1010        assert_eq!(decoded.pruned_chunks(), 3);
1011
1012        // Verify remaining accessible bits match
1013        for i in 96..original.len() {
1014            assert_eq!(decoded.get_bit(i), original.get_bit(i));
1015        }
1016    }
1017
1018    #[test]
1019    fn test_encode_size_matches() {
1020        let mut prunable: Prunable<4> = Prunable::new();
1021        prunable.push_chunk(&[1, 2, 3, 4]);
1022        prunable.push_chunk(&[5, 6, 7, 8]);
1023        prunable.push(true);
1024
1025        let size = prunable.encode_size();
1026        let encoded = prunable.encode();
1027
1028        assert_eq!(size, encoded.len());
1029    }
1030
1031    #[test]
1032    fn test_encode_size_with_pruning() {
1033        let mut prunable: Prunable<4> = Prunable::new();
1034        prunable.push_chunk(&[1, 2, 3, 4]);
1035        prunable.push_chunk(&[5, 6, 7, 8]);
1036        prunable.push_chunk(&[9, 10, 11, 12]);
1037
1038        prunable.prune_to_bit(32);
1039
1040        let size = prunable.encode_size();
1041        let encoded = prunable.encode();
1042
1043        assert_eq!(size, encoded.len());
1044    }
1045
1046    #[test]
1047    fn test_read_max_len_validation() {
1048        let mut original: Prunable<4> = Prunable::new();
1049        for _ in 0..10 {
1050            original.push(true);
1051        }
1052
1053        let encoded = original.encode();
1054
1055        // Should succeed with sufficient max_len
1056        assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1057
1058        // Should fail with insufficient max_len
1059        let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1060        assert!(result.is_err());
1061    }
1062
1063    #[test]
1064    fn test_codec_roundtrip_different_chunk_sizes() {
1065        // Test with different chunk sizes
1066        let mut p8: Prunable<8> = Prunable::new();
1067        let mut p16: Prunable<16> = Prunable::new();
1068        let mut p32: Prunable<32> = Prunable::new();
1069
1070        for i in 0..100 {
1071            let bit = i % 3 == 0;
1072            p8.push(bit);
1073            p16.push(bit);
1074            p32.push(bit);
1075        }
1076
1077        // Roundtrip each
1078        let encoded8 = p8.encode();
1079        let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1080        assert_eq!(decoded8.len(), p8.len());
1081
1082        let encoded16 = p16.encode();
1083        let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1084        assert_eq!(decoded16.len(), p16.len());
1085
1086        let encoded32 = p32.encode();
1087        let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1088        assert_eq!(decoded32.len(), p32.len());
1089    }
1090
1091    #[test]
1092    fn test_read_pruned_chunks_overflow() {
1093        let mut buf = BytesMut::new();
1094
1095        // Write a pruned_chunks value that would overflow when multiplied by CHUNK_SIZE_BITS
1096        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1097        overflowing_pruned_chunks.write(&mut buf);
1098
1099        // Write a valid bitmap (empty)
1100        0u64.write(&mut buf); // len = 0
1101
1102        // Try to read - should fail with overflow error
1103        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1104        match result {
1105            Err(CodecError::Invalid(type_name, msg)) => {
1106                assert_eq!(type_name, "Prunable");
1107                assert_eq!(
1108                    msg,
1109                    "pruned_chunks would overflow when computing pruned_bits"
1110                );
1111            }
1112            Ok(_) => panic!("Expected error but got Ok"),
1113            Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1114        }
1115    }
1116
1117    #[test]
1118    fn test_read_total_length_overflow() {
1119        let mut buf = BytesMut::new();
1120
1121        // Make pruned_bits as large as possible without overflowing
1122        let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1123        let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1124
1125        // Make bitmap_len large enough that adding it overflows
1126        let remaining_space = u64::MAX - pruned_bits;
1127        let bitmap_len = remaining_space + 1; // Go over by 1 to trigger overflow
1128
1129        // Write the serialized data
1130        max_safe_pruned_chunks.write(&mut buf);
1131        bitmap_len.write(&mut buf);
1132
1133        // Write bitmap chunk data
1134        let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1135        for _ in 0..(num_chunks * 4) {
1136            0u8.write(&mut buf);
1137        }
1138
1139        // Try to read - should fail because pruned_bits + bitmap_len overflows u64
1140        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1141        match result {
1142            Err(CodecError::Invalid(type_name, msg)) => {
1143                assert_eq!(type_name, "Prunable");
1144                assert_eq!(
1145                    msg,
1146                    "total bitmap length (pruned + unpruned) would overflow u64"
1147                );
1148            }
1149            Ok(_) => panic!("Expected error but got Ok"),
1150            Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1151        }
1152    }
1153
1154    #[test]
1155    fn test_is_chunk_aligned() {
1156        // Empty bitmap is chunk aligned
1157        let prunable: Prunable<4> = Prunable::new();
1158        assert!(prunable.is_chunk_aligned());
1159
1160        // Add bits one at a time and check alignment
1161        let mut prunable: Prunable<4> = Prunable::new();
1162        for i in 1..=32 {
1163            prunable.push(i % 2 == 0);
1164            if i == 32 {
1165                assert!(prunable.is_chunk_aligned()); // Exactly one chunk
1166            } else {
1167                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1168            }
1169        }
1170
1171        // Add another full chunk
1172        for i in 33..=64 {
1173            prunable.push(i % 2 == 0);
1174            if i == 64 {
1175                assert!(prunable.is_chunk_aligned()); // Exactly two chunks
1176            } else {
1177                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1178            }
1179        }
1180
1181        // Test with push_chunk
1182        let mut prunable: Prunable<4> = Prunable::new();
1183        assert!(prunable.is_chunk_aligned());
1184        prunable.push_chunk(&[1, 2, 3, 4]);
1185        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1186        prunable.push_chunk(&[5, 6, 7, 8]);
1187        assert!(prunable.is_chunk_aligned()); // 64 bits = 2 chunks
1188        prunable.push(true);
1189        assert!(!prunable.is_chunk_aligned()); // 65 bits = partial chunk
1190
1191        // Test alignment with pruning
1192        let mut prunable: Prunable<4> = Prunable::new();
1193        prunable.push_chunk(&[1, 2, 3, 4]);
1194        prunable.push_chunk(&[5, 6, 7, 8]);
1195        prunable.push_chunk(&[9, 10, 11, 12]);
1196        assert!(prunable.is_chunk_aligned()); // 96 bits = 3 chunks
1197
1198        // Prune first chunk - still aligned (64 bits remaining)
1199        prunable.prune_to_bit(32);
1200        assert!(prunable.is_chunk_aligned());
1201        assert_eq!(prunable.len(), 96);
1202
1203        // Add a partial chunk
1204        prunable.push(true);
1205        prunable.push(false);
1206        assert!(!prunable.is_chunk_aligned()); // 98 bits total
1207
1208        // Prune to align again
1209        prunable.prune_to_bit(64);
1210        assert!(!prunable.is_chunk_aligned()); // 98 bits total (34 bits remaining)
1211
1212        // Test with new_with_pruned_chunks
1213        let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1214        assert!(prunable.is_chunk_aligned()); // 64 bits pruned, 0 bits in bitmap
1215
1216        let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1217        assert!(prunable.is_chunk_aligned()); // 32 bits pruned, 0 bits in bitmap
1218        prunable.push(true);
1219        assert!(!prunable.is_chunk_aligned()); // 33 bits total
1220
1221        // Test with push_byte
1222        let mut prunable: Prunable<4> = Prunable::new();
1223        for _ in 0..4 {
1224            prunable.push_byte(0xFF);
1225        }
1226        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1227
1228        // Test after pop
1229        prunable.pop();
1230        assert!(!prunable.is_chunk_aligned()); // 31 bits
1231
1232        // Pop back to alignment
1233        for _ in 0..31 {
1234            prunable.pop();
1235        }
1236        assert!(prunable.is_chunk_aligned()); // 0 bits
1237    }
1238
1239    #[cfg(feature = "arbitrary")]
1240    mod conformance {
1241        use super::*;
1242        use commonware_codec::conformance::CodecConformance;
1243
1244        commonware_conformance::conformance_tests! {
1245            CodecConformance<Prunable<16>>,
1246        }
1247    }
1248}