commonware_utils/bitmap/
prunable.rs

1//! A prunable wrapper around BitMap that tracks pruned chunks.
2
3use super::BitMap;
4use bytes::{Buf, BufMut};
5use commonware_codec::{EncodeSize, Error as CodecError, Read, ReadExt, Write};
6use thiserror::Error;
7
8/// Errors that can occur when working with a prunable bitmap.
9#[derive(Debug, Error, Clone, PartialEq, Eq)]
10pub enum Error {
11    /// The provided pruned_chunks value would overflow.
12    #[error("pruned_chunks * CHUNK_SIZE_BITS overflows u64")]
13    PrunedChunksOverflow,
14}
15
16/// A prunable bitmap that stores data in chunks of N bytes.
17///
18/// # Panics
19///
20/// Operations panic if `bit / CHUNK_SIZE_BITS > usize::MAX`. On 32-bit systems
21/// with N=32, this occurs at bit >= 1,099,511,627,776.
22#[derive(Clone, Debug)]
23pub struct Prunable<const N: usize> {
24    /// The underlying BitMap storing the actual bits.
25    bitmap: BitMap<N>,
26
27    /// The number of bitmap chunks that have been pruned.
28    ///
29    /// # Invariant
30    ///
31    /// Must satisfy: `pruned_chunks as u64 * CHUNK_SIZE_BITS + bitmap.len() <= u64::MAX`
32    pruned_chunks: usize,
33}
34
35impl<const N: usize> Prunable<N> {
36    /// The size of a chunk in bits.
37    pub const CHUNK_SIZE_BITS: u64 = BitMap::<N>::CHUNK_SIZE_BITS;
38
39    /* Constructors */
40
41    /// Create a new empty prunable bitmap.
42    pub fn new() -> Self {
43        Self {
44            bitmap: BitMap::new(),
45            pruned_chunks: 0,
46        }
47    }
48
49    /// Create a new empty prunable bitmap with the given number of pruned chunks.
50    ///
51    /// # Errors
52    ///
53    /// Returns an error if `pruned_chunks` violates the invariant that
54    /// `pruned_chunks as u64 * CHUNK_SIZE_BITS` must not overflow u64.
55    pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
56        // Validate the invariant: pruned_chunks * CHUNK_SIZE_BITS must fit in u64
57        let pruned_chunks_u64 = pruned_chunks as u64;
58        pruned_chunks_u64
59            .checked_mul(Self::CHUNK_SIZE_BITS)
60            .ok_or(Error::PrunedChunksOverflow)?;
61
62        Ok(Self {
63            bitmap: BitMap::new(),
64            pruned_chunks,
65        })
66    }
67
68    /* Length */
69
70    /// Return the number of bits in the bitmap, irrespective of any pruning.
71    #[inline]
72    pub fn len(&self) -> u64 {
73        let pruned_bits = (self.pruned_chunks as u64)
74            .checked_mul(Self::CHUNK_SIZE_BITS)
75            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64");
76
77        pruned_bits
78            .checked_add(self.bitmap.len())
79            .expect("invariant violated: pruned_bits + bitmap.len() overflows u64")
80    }
81
82    /// Return true if the bitmap is empty.
83    #[inline]
84    pub fn is_empty(&self) -> bool {
85        self.len() == 0
86    }
87
88    /// Returns true if the bitmap length is aligned to a chunk boundary.
89    #[inline]
90    pub fn is_chunk_aligned(&self) -> bool {
91        self.len().is_multiple_of(Self::CHUNK_SIZE_BITS)
92    }
93
94    /// Return the number of unpruned chunks in the bitmap.
95    #[inline]
96    pub fn chunks_len(&self) -> usize {
97        self.bitmap.chunks_len()
98    }
99
100    /// Return the number of pruned chunks.
101    #[inline]
102    pub fn pruned_chunks(&self) -> usize {
103        self.pruned_chunks
104    }
105
106    /// Return the number of pruned bits.
107    #[inline]
108    pub fn pruned_bits(&self) -> u64 {
109        (self.pruned_chunks as u64)
110            .checked_mul(Self::CHUNK_SIZE_BITS)
111            .expect("invariant violated: pruned_chunks * CHUNK_SIZE_BITS overflows u64")
112    }
113
114    /* Getters */
115
116    /// Get the value of a bit.
117    ///
118    /// # Warning
119    ///
120    /// Panics if the bit doesn't exist or has been pruned.
121    #[inline]
122    pub fn get_bit(&self, bit: u64) -> bool {
123        let chunk_num = Self::unpruned_chunk(bit);
124        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
125
126        // Adjust bit to account for pruning
127        self.bitmap.get(bit - self.pruned_bits())
128    }
129
130    /// Returns the bitmap chunk containing the specified bit.
131    ///
132    /// # Warning
133    ///
134    /// Panics if the bit doesn't exist or has been pruned.
135    #[inline]
136    pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
137        let chunk_num = Self::unpruned_chunk(bit);
138        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
139
140        // Adjust bit to account for pruning
141        self.bitmap.get_chunk_containing(bit - self.pruned_bits())
142    }
143
144    /// Get the value of a bit from its chunk.
145    /// `bit` is an index into the entire bitmap, not just the chunk.
146    #[inline]
147    pub fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
148        BitMap::<N>::get_from_chunk(chunk, bit)
149    }
150
151    /// Return the last chunk of the bitmap and its size in bits.
152    #[inline]
153    pub fn last_chunk(&self) -> (&[u8; N], u64) {
154        self.bitmap.last_chunk()
155    }
156
157    /* Setters */
158
159    /// Set the value of the given bit.
160    ///
161    /// # Warning
162    ///
163    /// Panics if the bit doesn't exist or has been pruned.
164    pub fn set_bit(&mut self, bit: u64, value: bool) {
165        let chunk_num = Self::unpruned_chunk(bit);
166        assert!(chunk_num >= self.pruned_chunks, "bit pruned: {bit}");
167
168        // Adjust bit to account for pruning
169        self.bitmap.set(bit - self.pruned_bits(), value);
170    }
171
172    /// Add a single bit to the end of the bitmap.
173    pub fn push(&mut self, bit: bool) {
174        self.bitmap.push(bit);
175    }
176
177    /// Remove and return the last bit from the bitmap.
178    ///
179    /// # Warning
180    ///
181    /// Panics if the bitmap is empty.
182    pub fn pop(&mut self) -> bool {
183        self.bitmap.pop()
184    }
185
186    /// Add a byte to the bitmap.
187    ///
188    /// # Warning
189    ///
190    /// Panics if self.next_bit is not byte aligned.
191    pub fn push_byte(&mut self, byte: u8) {
192        self.bitmap.push_byte(byte);
193    }
194
195    /// Add a chunk of bits to the bitmap.
196    ///
197    /// # Warning
198    ///
199    /// Panics if self.next_bit is not chunk aligned.
200    pub fn push_chunk(&mut self, chunk: &[u8; N]) {
201        self.bitmap.push_chunk(chunk);
202    }
203
204    /// Remove and return the last complete chunk from the bitmap.
205    ///
206    /// # Warning
207    ///
208    /// Panics if the bitmap has fewer than `CHUNK_SIZE_BITS` bits or if not chunk-aligned.
209    pub fn pop_chunk(&mut self) -> [u8; N] {
210        self.bitmap.pop_chunk()
211    }
212
213    /* Pruning */
214
215    /// Prune all complete chunks before the chunk containing the given bit.
216    ///
217    /// The chunk containing `bit` and all subsequent chunks are retained. All chunks
218    /// before it are pruned.
219    ///
220    /// If `bit` equals the bitmap length, this prunes all complete chunks while retaining
221    /// the empty trailing chunk, preparing the bitmap for appending new data.
222    ///
223    /// # Warning
224    ///
225    /// Panics if `bit` is greater than the bitmap length.
226    pub fn prune_to_bit(&mut self, bit: u64) {
227        assert!(
228            bit <= self.len(),
229            "bit {} out of bounds (len: {})",
230            bit,
231            self.len()
232        );
233
234        let chunk = Self::unpruned_chunk(bit);
235        if chunk < self.pruned_chunks {
236            return;
237        }
238
239        let chunks_to_prune = chunk - self.pruned_chunks;
240        self.bitmap.prune_chunks(chunks_to_prune);
241        self.pruned_chunks = chunk;
242    }
243
244    /* Indexing Helpers */
245
246    /// Convert a bit into a bitmask for the byte containing that bit.
247    #[inline]
248    pub fn chunk_byte_bitmask(bit: u64) -> u8 {
249        BitMap::<N>::chunk_byte_bitmask(bit)
250    }
251
252    /// Convert a bit into the index of the byte within a chunk containing the bit.
253    #[inline]
254    pub fn chunk_byte_offset(bit: u64) -> usize {
255        BitMap::<N>::chunk_byte_offset(bit)
256    }
257
258    /// Convert a bit into the index of the chunk it belongs to within the bitmap,
259    /// taking pruned chunks into account. That is, the returned value is a valid index into
260    /// the inner bitmap.
261    ///
262    /// # Warning
263    ///
264    /// Panics if the bit doesn't exist or has been pruned.
265    #[inline]
266    pub fn pruned_chunk(&self, bit: u64) -> usize {
267        assert!(bit < self.len(), "out of bounds: {bit}");
268        let chunk = Self::unpruned_chunk(bit);
269        assert!(chunk >= self.pruned_chunks, "bit pruned: {bit}");
270
271        chunk - self.pruned_chunks
272    }
273
274    /// Convert a bit into the number of the chunk it belongs to,
275    /// ignoring any pruning.
276    ///
277    /// # Panics
278    ///
279    /// Panics if `bit / CHUNK_SIZE_BITS > usize::MAX`.
280    #[inline]
281    pub fn unpruned_chunk(bit: u64) -> usize {
282        BitMap::<N>::chunk(bit)
283    }
284
285    /// Get a reference to a chunk by its index in the current bitmap
286    /// Note this is an index into the chunks, not a bit.
287    #[inline]
288    pub fn get_chunk(&self, chunk: usize) -> &[u8; N] {
289        self.bitmap.get_chunk(chunk)
290    }
291
292    /// Overwrite a chunk's data by its raw (unpruned) chunk index.
293    ///
294    /// # Panics
295    ///
296    /// Panics if the chunk is pruned or out of bounds.
297    pub(super) fn set_chunk_by_index(&mut self, chunk_index: usize, chunk_data: &[u8; N]) {
298        assert!(
299            chunk_index >= self.pruned_chunks,
300            "cannot set pruned chunk {chunk_index} (pruned_chunks: {})",
301            self.pruned_chunks
302        );
303        let bitmap_chunk_idx = chunk_index - self.pruned_chunks;
304        self.bitmap.set_chunk_by_index(bitmap_chunk_idx, chunk_data);
305    }
306
307    /// Unprune chunks by prepending them back to the front of the bitmap.
308    ///
309    /// The caller must provide chunks in **reverse** order: to restore chunks with
310    /// indices [0, 1, 2], pass them as [2, 1, 0]. This is necessary because each chunk
311    /// is prepended to the front, so the last chunk provided becomes the first chunk
312    /// in the bitmap.
313    ///
314    /// # Panics
315    ///
316    /// Panics if chunks.len() > self.pruned_chunks.
317    pub(super) fn unprune_chunks(&mut self, chunks: &[[u8; N]]) {
318        assert!(
319            chunks.len() <= self.pruned_chunks,
320            "cannot unprune {} chunks (only {} pruned)",
321            chunks.len(),
322            self.pruned_chunks
323        );
324
325        for chunk in chunks.iter() {
326            self.bitmap.prepend_chunk(chunk);
327        }
328
329        self.pruned_chunks -= chunks.len();
330    }
331}
332
333impl<const N: usize> Default for Prunable<N> {
334    fn default() -> Self {
335        Self::new()
336    }
337}
338
339impl<const N: usize> Write for Prunable<N> {
340    fn write(&self, buf: &mut impl BufMut) {
341        (self.pruned_chunks as u64).write(buf);
342        self.bitmap.write(buf);
343    }
344}
345
346impl<const N: usize> Read for Prunable<N> {
347    // Max length for the unpruned portion of the bitmap.
348    type Cfg = u64;
349
350    fn read_cfg(buf: &mut impl Buf, max_len: &Self::Cfg) -> Result<Self, CodecError> {
351        let pruned_chunks_u64 = u64::read(buf)?;
352
353        // Validate that pruned_chunks * CHUNK_SIZE_BITS doesn't overflow u64
354        let pruned_bits =
355            pruned_chunks_u64
356                .checked_mul(Self::CHUNK_SIZE_BITS)
357                .ok_or(CodecError::Invalid(
358                    "Prunable",
359                    "pruned_chunks would overflow when computing pruned_bits",
360                ))?;
361
362        let pruned_chunks = usize::try_from(pruned_chunks_u64)
363            .map_err(|_| CodecError::Invalid("Prunable", "pruned_chunks doesn't fit in usize"))?;
364
365        let bitmap = BitMap::<N>::read_cfg(buf, max_len)?;
366
367        // Validate that total length (pruned_bits + bitmap.len()) doesn't overflow u64
368        pruned_bits
369            .checked_add(bitmap.len())
370            .ok_or(CodecError::Invalid(
371                "Prunable",
372                "total bitmap length (pruned + unpruned) would overflow u64",
373            ))?;
374
375        Ok(Self {
376            bitmap,
377            pruned_chunks,
378        })
379    }
380}
381
382impl<const N: usize> EncodeSize for Prunable<N> {
383    fn encode_size(&self) -> usize {
384        (self.pruned_chunks as u64).encode_size() + self.bitmap.encode_size()
385    }
386}
387
388#[cfg(test)]
389mod tests {
390    use super::*;
391    use crate::hex;
392    use bytes::BytesMut;
393    use commonware_codec::Encode;
394
395    #[test]
396    fn test_new() {
397        let prunable: Prunable<32> = Prunable::new();
398        assert_eq!(prunable.len(), 0);
399        assert_eq!(prunable.pruned_bits(), 0);
400        assert_eq!(prunable.pruned_chunks(), 0);
401        assert!(prunable.is_empty());
402        assert_eq!(prunable.chunks_len(), 0); // No chunks when empty
403    }
404
405    #[test]
406    fn test_new_with_pruned_chunks() {
407        let prunable: Prunable<2> = Prunable::new_with_pruned_chunks(1).unwrap();
408        assert_eq!(prunable.len(), 16);
409        assert_eq!(prunable.pruned_bits(), 16);
410        assert_eq!(prunable.pruned_chunks(), 1);
411        assert_eq!(prunable.chunks_len(), 0);
412    }
413
414    #[test]
415    fn test_new_with_pruned_chunks_overflow() {
416        // Try to create a Prunable with pruned_chunks that would overflow
417        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) as usize + 1;
418        let result = Prunable::<4>::new_with_pruned_chunks(overflowing_pruned_chunks);
419
420        assert!(matches!(result, Err(Error::PrunedChunksOverflow)));
421    }
422
423    #[test]
424    fn test_push_and_get_bits() {
425        let mut prunable: Prunable<4> = Prunable::new();
426
427        // Add some bits
428        prunable.push(true);
429        prunable.push(false);
430        prunable.push(true);
431
432        assert_eq!(prunable.len(), 3);
433        assert!(!prunable.is_empty());
434        assert!(prunable.get_bit(0));
435        assert!(!prunable.get_bit(1));
436        assert!(prunable.get_bit(2));
437    }
438
439    #[test]
440    fn test_push_byte() {
441        let mut prunable: Prunable<4> = Prunable::new();
442
443        // Add a byte
444        prunable.push_byte(0xFF);
445        assert_eq!(prunable.len(), 8);
446
447        // All bits should be set
448        for i in 0..8 {
449            assert!(prunable.get_bit(i as u64));
450        }
451
452        prunable.push_byte(0x00);
453        assert_eq!(prunable.len(), 16);
454
455        // Next 8 bits should be clear
456        for i in 8..16 {
457            assert!(!prunable.get_bit(i as u64));
458        }
459    }
460
461    #[test]
462    fn test_push_chunk() {
463        let mut prunable: Prunable<4> = Prunable::new();
464        let chunk = hex!("0xAABBCCDD");
465
466        prunable.push_chunk(&chunk);
467        assert_eq!(prunable.len(), 32); // 4 bytes * 8 bits
468
469        let retrieved_chunk = prunable.get_chunk_containing(0);
470        assert_eq!(retrieved_chunk, &chunk);
471    }
472
473    #[test]
474    fn test_set_bit() {
475        let mut prunable: Prunable<4> = Prunable::new();
476
477        // Add some bits
478        prunable.push(false);
479        prunable.push(false);
480        prunable.push(false);
481
482        assert!(!prunable.get_bit(1));
483
484        // Set a bit
485        prunable.set_bit(1, true);
486        assert!(prunable.get_bit(1));
487
488        // Set it back
489        prunable.set_bit(1, false);
490        assert!(!prunable.get_bit(1));
491    }
492
493    #[test]
494    fn test_pruning_basic() {
495        let mut prunable: Prunable<4> = Prunable::new();
496
497        // Add multiple chunks (4 bytes each)
498        let chunk1 = hex!("0x01020304");
499        let chunk2 = hex!("0x05060708");
500        let chunk3 = hex!("0x090A0B0C");
501
502        prunable.push_chunk(&chunk1);
503        prunable.push_chunk(&chunk2);
504        prunable.push_chunk(&chunk3);
505
506        assert_eq!(prunable.len(), 96); // 3 chunks * 32 bits
507        assert_eq!(prunable.pruned_chunks(), 0);
508
509        // Prune to second chunk (bit 32 is start of second chunk)
510        prunable.prune_to_bit(32);
511        assert_eq!(prunable.pruned_chunks(), 1);
512        assert_eq!(prunable.pruned_bits(), 32);
513        assert_eq!(prunable.len(), 96); // Total count unchanged
514
515        // Can still access non-pruned bits
516        assert_eq!(prunable.get_chunk_containing(32), &chunk2);
517        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
518
519        // Prune to third chunk
520        prunable.prune_to_bit(64);
521        assert_eq!(prunable.pruned_chunks(), 2);
522        assert_eq!(prunable.pruned_bits(), 64);
523        assert_eq!(prunable.len(), 96);
524
525        // Can still access the third chunk
526        assert_eq!(prunable.get_chunk_containing(64), &chunk3);
527    }
528
529    #[test]
530    #[should_panic(expected = "bit pruned")]
531    fn test_get_pruned_bit_panics() {
532        let mut prunable: Prunable<4> = Prunable::new();
533
534        // Add two chunks
535        prunable.push_chunk(&[1, 2, 3, 4]);
536        prunable.push_chunk(&[5, 6, 7, 8]);
537
538        // Prune first chunk
539        prunable.prune_to_bit(32);
540
541        // Try to access pruned bit - should panic
542        prunable.get_bit(0);
543    }
544
545    #[test]
546    #[should_panic(expected = "bit pruned")]
547    fn test_get_pruned_chunk_panics() {
548        let mut prunable: Prunable<4> = Prunable::new();
549
550        // Add two chunks
551        prunable.push_chunk(&[1, 2, 3, 4]);
552        prunable.push_chunk(&[5, 6, 7, 8]);
553
554        // Prune first chunk
555        prunable.prune_to_bit(32);
556
557        // Try to access pruned chunk - should panic
558        prunable.get_chunk_containing(0);
559    }
560
561    #[test]
562    #[should_panic(expected = "bit pruned")]
563    fn test_set_pruned_bit_panics() {
564        let mut prunable: Prunable<4> = Prunable::new();
565
566        // Add two chunks
567        prunable.push_chunk(&[1, 2, 3, 4]);
568        prunable.push_chunk(&[5, 6, 7, 8]);
569
570        // Prune first chunk
571        prunable.prune_to_bit(32);
572
573        // Try to set pruned bit - should panic
574        prunable.set_bit(0, true);
575    }
576
577    #[test]
578    #[should_panic(expected = "bit 25 out of bounds (len: 24)")]
579    fn test_prune_to_bit_out_of_bounds() {
580        let mut prunable: Prunable<1> = Prunable::new();
581
582        // Add 3 bytes (24 bits total)
583        prunable.push_byte(1);
584        prunable.push_byte(2);
585        prunable.push_byte(3);
586
587        // Try to prune to a bit beyond the bitmap
588        prunable.prune_to_bit(25);
589    }
590
591    #[test]
592    fn test_pruning_with_partial_chunk() {
593        let mut prunable: Prunable<4> = Prunable::new();
594
595        // Add two full chunks and some partial bits
596        prunable.push_chunk(&[0xFF; 4]);
597        prunable.push_chunk(&[0xAA; 4]);
598        prunable.push(true);
599        prunable.push(false);
600        prunable.push(true);
601
602        assert_eq!(prunable.len(), 67); // 64 + 3 bits
603
604        // Prune to second chunk
605        prunable.prune_to_bit(32);
606        assert_eq!(prunable.pruned_chunks(), 1);
607        assert_eq!(prunable.len(), 67);
608
609        // Can still access the partial bits
610        assert!(prunable.get_bit(64));
611        assert!(!prunable.get_bit(65));
612        assert!(prunable.get_bit(66));
613    }
614
615    #[test]
616    fn test_prune_idempotent() {
617        let mut prunable: Prunable<4> = Prunable::new();
618
619        // Add chunks
620        prunable.push_chunk(&[1, 2, 3, 4]);
621        prunable.push_chunk(&[5, 6, 7, 8]);
622
623        // Prune to bit 32
624        prunable.prune_to_bit(32);
625        assert_eq!(prunable.pruned_chunks(), 1);
626
627        // Pruning to same or earlier point should be no-op
628        prunable.prune_to_bit(32);
629        assert_eq!(prunable.pruned_chunks(), 1);
630
631        prunable.prune_to_bit(16);
632        assert_eq!(prunable.pruned_chunks(), 1);
633    }
634
635    #[test]
636    fn test_push_after_pruning() {
637        let mut prunable: Prunable<4> = Prunable::new();
638
639        // Add initial chunks
640        prunable.push_chunk(&[1, 2, 3, 4]);
641        prunable.push_chunk(&[5, 6, 7, 8]);
642
643        // Prune first chunk
644        prunable.prune_to_bit(32);
645        assert_eq!(prunable.len(), 64);
646        assert_eq!(prunable.pruned_chunks(), 1);
647
648        // Add more data
649        prunable.push_chunk(&[9, 10, 11, 12]);
650        assert_eq!(prunable.len(), 96); // 32 pruned + 64 active
651
652        // New chunk should be accessible
653        assert_eq!(prunable.get_chunk_containing(64), &[9, 10, 11, 12]);
654    }
655
656    #[test]
657    fn test_chunk_calculations() {
658        // Test chunk_num calculation
659        assert_eq!(Prunable::<4>::unpruned_chunk(0), 0);
660        assert_eq!(Prunable::<4>::unpruned_chunk(31), 0);
661        assert_eq!(Prunable::<4>::unpruned_chunk(32), 1);
662        assert_eq!(Prunable::<4>::unpruned_chunk(63), 1);
663        assert_eq!(Prunable::<4>::unpruned_chunk(64), 2);
664
665        // Test chunk_byte_offset
666        assert_eq!(Prunable::<4>::chunk_byte_offset(0), 0);
667        assert_eq!(Prunable::<4>::chunk_byte_offset(8), 1);
668        assert_eq!(Prunable::<4>::chunk_byte_offset(16), 2);
669        assert_eq!(Prunable::<4>::chunk_byte_offset(24), 3);
670        assert_eq!(Prunable::<4>::chunk_byte_offset(32), 0); // Wraps to next chunk
671
672        // Test chunk_byte_bitmask
673        assert_eq!(Prunable::<4>::chunk_byte_bitmask(0), 0b00000001);
674        assert_eq!(Prunable::<4>::chunk_byte_bitmask(1), 0b00000010);
675        assert_eq!(Prunable::<4>::chunk_byte_bitmask(7), 0b10000000);
676        assert_eq!(Prunable::<4>::chunk_byte_bitmask(8), 0b00000001); // Next byte
677    }
678
679    #[test]
680    fn test_pruned_chunk() {
681        let mut prunable: Prunable<4> = Prunable::new();
682
683        // Add three chunks
684        for i in 0..3 {
685            let chunk = [
686                (i * 4) as u8,
687                (i * 4 + 1) as u8,
688                (i * 4 + 2) as u8,
689                (i * 4 + 3) as u8,
690            ];
691            prunable.push_chunk(&chunk);
692        }
693
694        // Before pruning
695        assert_eq!(prunable.pruned_chunk(0), 0);
696        assert_eq!(prunable.pruned_chunk(32), 1);
697        assert_eq!(prunable.pruned_chunk(64), 2);
698
699        // After pruning first chunk
700        prunable.prune_to_bit(32);
701        assert_eq!(prunable.pruned_chunk(32), 0); // Now at index 0
702        assert_eq!(prunable.pruned_chunk(64), 1); // Now at index 1
703    }
704
705    #[test]
706    fn test_last_chunk_with_pruning() {
707        let mut prunable: Prunable<4> = Prunable::new();
708
709        // Add chunks
710        prunable.push_chunk(&[1, 2, 3, 4]);
711        prunable.push_chunk(&[5, 6, 7, 8]);
712        prunable.push(true);
713        prunable.push(false);
714
715        let (_, next_bit) = prunable.last_chunk();
716        assert_eq!(next_bit, 2);
717
718        // Store the chunk data for comparison
719        let chunk_data = *prunable.last_chunk().0;
720
721        // Pruning shouldn't affect last_chunk
722        prunable.prune_to_bit(32);
723        let (chunk2, next_bit2) = prunable.last_chunk();
724        assert_eq!(next_bit2, 2);
725        assert_eq!(&chunk_data, chunk2);
726    }
727
728    #[test]
729    fn test_different_chunk_sizes() {
730        // Test with different chunk sizes
731        let mut p8: Prunable<8> = Prunable::new();
732        let mut p16: Prunable<16> = Prunable::new();
733        let mut p32: Prunable<32> = Prunable::new();
734
735        // Add same pattern to each
736        for i in 0..10 {
737            p8.push(i % 2 == 0);
738            p16.push(i % 2 == 0);
739            p32.push(i % 2 == 0);
740        }
741
742        // All should have same bit count
743        assert_eq!(p8.len(), 10);
744        assert_eq!(p16.len(), 10);
745        assert_eq!(p32.len(), 10);
746
747        // All should have same bit values
748        for i in 0..10 {
749            let expected = i % 2 == 0;
750            if expected {
751                assert!(p8.get_bit(i));
752                assert!(p16.get_bit(i));
753                assert!(p32.get_bit(i));
754            } else {
755                assert!(!p8.get_bit(i));
756                assert!(!p16.get_bit(i));
757                assert!(!p32.get_bit(i));
758            }
759        }
760    }
761
762    #[test]
763    fn test_get_bit_from_chunk() {
764        let chunk: [u8; 4] = [0b10101010, 0b11001100, 0b11110000, 0b00001111];
765
766        // Test first byte
767        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 0));
768        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 1));
769        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 2));
770        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 3));
771
772        // Test second byte
773        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 8));
774        assert!(!Prunable::<4>::get_bit_from_chunk(&chunk, 9));
775        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 10));
776        assert!(Prunable::<4>::get_bit_from_chunk(&chunk, 11));
777    }
778
779    #[test]
780    fn test_get_chunk() {
781        let mut prunable: Prunable<4> = Prunable::new();
782        let chunk1 = hex!("0x11223344");
783        let chunk2 = hex!("0x55667788");
784        let chunk3 = hex!("0x99AABBCC");
785
786        prunable.push_chunk(&chunk1);
787        prunable.push_chunk(&chunk2);
788        prunable.push_chunk(&chunk3);
789
790        // Before pruning
791        assert_eq!(prunable.get_chunk(0), &chunk1);
792        assert_eq!(prunable.get_chunk(1), &chunk2);
793        assert_eq!(prunable.get_chunk(2), &chunk3);
794
795        // After pruning
796        prunable.prune_to_bit(32);
797        assert_eq!(prunable.get_chunk(0), &chunk2);
798        assert_eq!(prunable.get_chunk(1), &chunk3);
799    }
800
801    #[test]
802    fn test_pop() {
803        let mut prunable: Prunable<4> = Prunable::new();
804
805        prunable.push(true);
806        prunable.push(false);
807        prunable.push(true);
808        assert_eq!(prunable.len(), 3);
809
810        assert!(prunable.pop());
811        assert_eq!(prunable.len(), 2);
812
813        assert!(!prunable.pop());
814        assert_eq!(prunable.len(), 1);
815
816        assert!(prunable.pop());
817        assert_eq!(prunable.len(), 0);
818        assert!(prunable.is_empty());
819
820        for i in 0..100 {
821            prunable.push(i % 3 == 0);
822        }
823        assert_eq!(prunable.len(), 100);
824
825        for i in (0..100).rev() {
826            let expected = i % 3 == 0;
827            assert_eq!(prunable.pop(), expected);
828            assert_eq!(prunable.len(), i);
829        }
830
831        assert!(prunable.is_empty());
832    }
833
834    #[test]
835    fn test_pop_chunk() {
836        let mut prunable: Prunable<4> = Prunable::new();
837        const CHUNK_SIZE: u64 = Prunable::<4>::CHUNK_SIZE_BITS;
838
839        // Test 1: Pop a single chunk and verify it returns the correct data
840        let chunk1 = hex!("0xAABBCCDD");
841        prunable.push_chunk(&chunk1);
842        assert_eq!(prunable.len(), CHUNK_SIZE);
843        let popped = prunable.pop_chunk();
844        assert_eq!(popped, chunk1);
845        assert_eq!(prunable.len(), 0);
846        assert!(prunable.is_empty());
847
848        // Test 2: Pop multiple chunks in reverse order
849        let chunk2 = hex!("0x11223344");
850        let chunk3 = hex!("0x55667788");
851        let chunk4 = hex!("0x99AABBCC");
852
853        prunable.push_chunk(&chunk2);
854        prunable.push_chunk(&chunk3);
855        prunable.push_chunk(&chunk4);
856        assert_eq!(prunable.len(), CHUNK_SIZE * 3);
857
858        assert_eq!(prunable.pop_chunk(), chunk4);
859        assert_eq!(prunable.len(), CHUNK_SIZE * 2);
860
861        assert_eq!(prunable.pop_chunk(), chunk3);
862        assert_eq!(prunable.len(), CHUNK_SIZE);
863
864        assert_eq!(prunable.pop_chunk(), chunk2);
865        assert_eq!(prunable.len(), 0);
866
867        // Test 3: Verify data integrity when popping chunks
868        prunable = Prunable::new();
869        let first_chunk = hex!("0xAABBCCDD");
870        let second_chunk = hex!("0x11223344");
871        prunable.push_chunk(&first_chunk);
872        prunable.push_chunk(&second_chunk);
873
874        // Pop the second chunk, verify it and that first chunk is intact
875        assert_eq!(prunable.pop_chunk(), second_chunk);
876        assert_eq!(prunable.len(), CHUNK_SIZE);
877
878        for i in 0..CHUNK_SIZE {
879            let byte_idx = (i / 8) as usize;
880            let bit_idx = i % 8;
881            let expected = (first_chunk[byte_idx] >> bit_idx) & 1 == 1;
882            assert_eq!(prunable.get_bit(i), expected);
883        }
884
885        assert_eq!(prunable.pop_chunk(), first_chunk);
886        assert_eq!(prunable.len(), 0);
887    }
888
889    #[test]
890    #[should_panic(expected = "cannot pop chunk when not chunk aligned")]
891    fn test_pop_chunk_not_aligned() {
892        let mut prunable: Prunable<4> = Prunable::new();
893
894        // Push a full chunk plus one bit
895        prunable.push_chunk(&[0xFF; 4]);
896        prunable.push(true);
897
898        // Should panic because not chunk-aligned
899        prunable.pop_chunk();
900    }
901
902    #[test]
903    #[should_panic(expected = "cannot pop chunk: bitmap has fewer than CHUNK_SIZE_BITS bits")]
904    fn test_pop_chunk_insufficient_bits() {
905        let mut prunable: Prunable<4> = Prunable::new();
906
907        // Push only a few bits (less than a full chunk)
908        prunable.push(true);
909        prunable.push(false);
910
911        // Should panic because we don't have a full chunk to pop
912        prunable.pop_chunk();
913    }
914
915    #[test]
916    fn test_write_read_empty() {
917        let original: Prunable<4> = Prunable::new();
918        let encoded = original.encode();
919
920        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
921        assert_eq!(decoded.len(), original.len());
922        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
923        assert!(decoded.is_empty());
924    }
925
926    #[test]
927    fn test_write_read_non_empty() {
928        let mut original: Prunable<4> = Prunable::new();
929        original.push_chunk(&hex!("0xAABBCCDD"));
930        original.push_chunk(&hex!("0x11223344"));
931        original.push(true);
932        original.push(false);
933        original.push(true);
934
935        let encoded = original.encode();
936        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
937
938        assert_eq!(decoded.len(), original.len());
939        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
940        assert_eq!(decoded.len(), 67);
941
942        // Verify all bits match
943        for i in 0..original.len() {
944            assert_eq!(decoded.get_bit(i), original.get_bit(i));
945        }
946    }
947
948    #[test]
949    fn test_write_read_with_pruning() {
950        let mut original: Prunable<4> = Prunable::new();
951        original.push_chunk(&hex!("0x01020304"));
952        original.push_chunk(&hex!("0x05060708"));
953        original.push_chunk(&hex!("0x090A0B0C"));
954
955        // Prune first chunk
956        original.prune_to_bit(32);
957        assert_eq!(original.pruned_chunks(), 1);
958        assert_eq!(original.len(), 96);
959
960        let encoded = original.encode();
961        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
962
963        assert_eq!(decoded.len(), original.len());
964        assert_eq!(decoded.pruned_chunks(), original.pruned_chunks());
965        assert_eq!(decoded.pruned_chunks(), 1);
966        assert_eq!(decoded.len(), 96);
967
968        // Verify remaining chunks match
969        assert_eq!(decoded.get_chunk_containing(32), &hex!("0x05060708"));
970        assert_eq!(decoded.get_chunk_containing(64), &hex!("0x090A0B0C"));
971    }
972
973    #[test]
974    fn test_write_read_with_pruning_2() {
975        let mut original: Prunable<4> = Prunable::new();
976
977        // Add several chunks
978        for i in 0..5 {
979            let chunk = [
980                (i * 4) as u8,
981                (i * 4 + 1) as u8,
982                (i * 4 + 2) as u8,
983                (i * 4 + 3) as u8,
984            ];
985            original.push_chunk(&chunk);
986        }
987
988        // Keep only last two chunks
989        original.prune_to_bit(96); // Prune first 3 chunks
990        assert_eq!(original.pruned_chunks(), 3);
991        assert_eq!(original.len(), 160);
992
993        let encoded = original.encode();
994        let decoded = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &u64::MAX).unwrap();
995
996        assert_eq!(decoded.len(), original.len());
997        assert_eq!(decoded.pruned_chunks(), 3);
998
999        // Verify remaining accessible bits match
1000        for i in 96..original.len() {
1001            assert_eq!(decoded.get_bit(i), original.get_bit(i));
1002        }
1003    }
1004
1005    #[test]
1006    fn test_encode_size_matches() {
1007        let mut prunable: Prunable<4> = Prunable::new();
1008        prunable.push_chunk(&[1, 2, 3, 4]);
1009        prunable.push_chunk(&[5, 6, 7, 8]);
1010        prunable.push(true);
1011
1012        let size = prunable.encode_size();
1013        let encoded = prunable.encode();
1014
1015        assert_eq!(size, encoded.len());
1016    }
1017
1018    #[test]
1019    fn test_encode_size_with_pruning() {
1020        let mut prunable: Prunable<4> = Prunable::new();
1021        prunable.push_chunk(&[1, 2, 3, 4]);
1022        prunable.push_chunk(&[5, 6, 7, 8]);
1023        prunable.push_chunk(&[9, 10, 11, 12]);
1024
1025        prunable.prune_to_bit(32);
1026
1027        let size = prunable.encode_size();
1028        let encoded = prunable.encode();
1029
1030        assert_eq!(size, encoded.len());
1031    }
1032
1033    #[test]
1034    fn test_read_max_len_validation() {
1035        let mut original: Prunable<4> = Prunable::new();
1036        for _ in 0..10 {
1037            original.push(true);
1038        }
1039
1040        let encoded = original.encode();
1041
1042        // Should succeed with sufficient max_len
1043        assert!(Prunable::<4>::read_cfg(&mut encoded.as_ref(), &100).is_ok());
1044
1045        // Should fail with insufficient max_len
1046        let result = Prunable::<4>::read_cfg(&mut encoded.as_ref(), &5);
1047        assert!(result.is_err());
1048    }
1049
1050    #[test]
1051    fn test_codec_roundtrip_different_chunk_sizes() {
1052        // Test with different chunk sizes
1053        let mut p8: Prunable<8> = Prunable::new();
1054        let mut p16: Prunable<16> = Prunable::new();
1055        let mut p32: Prunable<32> = Prunable::new();
1056
1057        for i in 0..100 {
1058            let bit = i % 3 == 0;
1059            p8.push(bit);
1060            p16.push(bit);
1061            p32.push(bit);
1062        }
1063
1064        // Roundtrip each
1065        let encoded8 = p8.encode();
1066        let decoded8 = Prunable::<8>::read_cfg(&mut encoded8.as_ref(), &u64::MAX).unwrap();
1067        assert_eq!(decoded8.len(), p8.len());
1068
1069        let encoded16 = p16.encode();
1070        let decoded16 = Prunable::<16>::read_cfg(&mut encoded16.as_ref(), &u64::MAX).unwrap();
1071        assert_eq!(decoded16.len(), p16.len());
1072
1073        let encoded32 = p32.encode();
1074        let decoded32 = Prunable::<32>::read_cfg(&mut encoded32.as_ref(), &u64::MAX).unwrap();
1075        assert_eq!(decoded32.len(), p32.len());
1076    }
1077
1078    #[test]
1079    fn test_read_pruned_chunks_overflow() {
1080        let mut buf = BytesMut::new();
1081
1082        // Write a pruned_chunks value that would overflow when multiplied by CHUNK_SIZE_BITS
1083        let overflowing_pruned_chunks = (u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS) + 1;
1084        overflowing_pruned_chunks.write(&mut buf);
1085
1086        // Write a valid bitmap (empty)
1087        0u64.write(&mut buf); // len = 0
1088
1089        // Try to read - should fail with overflow error
1090        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1091        match result {
1092            Err(CodecError::Invalid(type_name, msg)) => {
1093                assert_eq!(type_name, "Prunable");
1094                assert_eq!(
1095                    msg,
1096                    "pruned_chunks would overflow when computing pruned_bits"
1097                );
1098            }
1099            Ok(_) => panic!("Expected error but got Ok"),
1100            Err(e) => panic!("Expected Invalid error for pruned_bits overflow, got: {e:?}"),
1101        }
1102    }
1103
1104    #[test]
1105    fn test_read_total_length_overflow() {
1106        let mut buf = BytesMut::new();
1107
1108        // Make pruned_bits as large as possible without overflowing
1109        let max_safe_pruned_chunks = u64::MAX / Prunable::<4>::CHUNK_SIZE_BITS;
1110        let pruned_bits = max_safe_pruned_chunks * Prunable::<4>::CHUNK_SIZE_BITS;
1111
1112        // Make bitmap_len large enough that adding it overflows
1113        let remaining_space = u64::MAX - pruned_bits;
1114        let bitmap_len = remaining_space + 1; // Go over by 1 to trigger overflow
1115
1116        // Write the serialized data
1117        max_safe_pruned_chunks.write(&mut buf);
1118        bitmap_len.write(&mut buf);
1119
1120        // Write bitmap chunk data
1121        let num_chunks = bitmap_len.div_ceil(Prunable::<4>::CHUNK_SIZE_BITS);
1122        for _ in 0..(num_chunks * 4) {
1123            0u8.write(&mut buf);
1124        }
1125
1126        // Try to read - should fail because pruned_bits + bitmap_len overflows u64
1127        let result = Prunable::<4>::read_cfg(&mut buf.as_ref(), &u64::MAX);
1128        match result {
1129            Err(CodecError::Invalid(type_name, msg)) => {
1130                assert_eq!(type_name, "Prunable");
1131                assert_eq!(
1132                    msg,
1133                    "total bitmap length (pruned + unpruned) would overflow u64"
1134                );
1135            }
1136            Ok(_) => panic!("Expected error but got Ok"),
1137            Err(e) => panic!("Expected Invalid error for total length overflow, got: {e:?}"),
1138        }
1139    }
1140
1141    #[test]
1142    fn test_is_chunk_aligned() {
1143        // Empty bitmap is chunk aligned
1144        let prunable: Prunable<4> = Prunable::new();
1145        assert!(prunable.is_chunk_aligned());
1146
1147        // Add bits one at a time and check alignment
1148        let mut prunable: Prunable<4> = Prunable::new();
1149        for i in 1..=32 {
1150            prunable.push(i % 2 == 0);
1151            if i == 32 {
1152                assert!(prunable.is_chunk_aligned()); // Exactly one chunk
1153            } else {
1154                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1155            }
1156        }
1157
1158        // Add another full chunk
1159        for i in 33..=64 {
1160            prunable.push(i % 2 == 0);
1161            if i == 64 {
1162                assert!(prunable.is_chunk_aligned()); // Exactly two chunks
1163            } else {
1164                assert!(!prunable.is_chunk_aligned()); // Partial chunk
1165            }
1166        }
1167
1168        // Test with push_chunk
1169        let mut prunable: Prunable<4> = Prunable::new();
1170        assert!(prunable.is_chunk_aligned());
1171        prunable.push_chunk(&[1, 2, 3, 4]);
1172        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1173        prunable.push_chunk(&[5, 6, 7, 8]);
1174        assert!(prunable.is_chunk_aligned()); // 64 bits = 2 chunks
1175        prunable.push(true);
1176        assert!(!prunable.is_chunk_aligned()); // 65 bits = partial chunk
1177
1178        // Test alignment with pruning
1179        let mut prunable: Prunable<4> = Prunable::new();
1180        prunable.push_chunk(&[1, 2, 3, 4]);
1181        prunable.push_chunk(&[5, 6, 7, 8]);
1182        prunable.push_chunk(&[9, 10, 11, 12]);
1183        assert!(prunable.is_chunk_aligned()); // 96 bits = 3 chunks
1184
1185        // Prune first chunk - still aligned (64 bits remaining)
1186        prunable.prune_to_bit(32);
1187        assert!(prunable.is_chunk_aligned());
1188        assert_eq!(prunable.len(), 96);
1189
1190        // Add a partial chunk
1191        prunable.push(true);
1192        prunable.push(false);
1193        assert!(!prunable.is_chunk_aligned()); // 98 bits total
1194
1195        // Prune to align again
1196        prunable.prune_to_bit(64);
1197        assert!(!prunable.is_chunk_aligned()); // 98 bits total (34 bits remaining)
1198
1199        // Test with new_with_pruned_chunks
1200        let prunable: Prunable<4> = Prunable::new_with_pruned_chunks(2).unwrap();
1201        assert!(prunable.is_chunk_aligned()); // 64 bits pruned, 0 bits in bitmap
1202
1203        let mut prunable: Prunable<4> = Prunable::new_with_pruned_chunks(1).unwrap();
1204        assert!(prunable.is_chunk_aligned()); // 32 bits pruned, 0 bits in bitmap
1205        prunable.push(true);
1206        assert!(!prunable.is_chunk_aligned()); // 33 bits total
1207
1208        // Test with push_byte
1209        let mut prunable: Prunable<4> = Prunable::new();
1210        for _ in 0..4 {
1211            prunable.push_byte(0xFF);
1212        }
1213        assert!(prunable.is_chunk_aligned()); // 32 bits = 1 chunk
1214
1215        // Test after pop
1216        prunable.pop();
1217        assert!(!prunable.is_chunk_aligned()); // 31 bits
1218
1219        // Pop back to alignment
1220        for _ in 0..31 {
1221            prunable.pop();
1222        }
1223        assert!(prunable.is_chunk_aligned()); // 0 bits
1224    }
1225}