Skip to main content

commonware_storage/bitmap/
authenticated.rs

1//! An authenticated bitmap.
2//!
3//! The authenticated bitmap is an in-memory data structure that does not persist its contents other
4//! than the data corresponding to its "pruned" section, allowing full restoration by "replaying"
5//! all retained elements.
6//!
7//! Authentication is provided by a Merkle tree that is maintained over the bitmap, with each leaf
8//! covering a chunk of N bytes. This Merkle tree isn't balanced, but instead mimics the structure
9//! of an MMR with an equivalent number of leaves. This structure reduces overhead of updating the
10//! most recently added elements, and (more importantly) simplifies aligning the bitmap with an MMR
11//! over elements whose activity state is reflected by the bitmap.
12
13use crate::{
14    merkle::{
15        batch::MIN_TO_PARALLELIZE,
16        hasher::Hasher,
17        mmr::{
18            self,
19            mem::{Config, Mmr},
20            verification, Error, Location, Position, Proof,
21        },
22        storage::Storage,
23        Family as _,
24    },
25    metadata::{Config as MConfig, Metadata},
26    Context,
27};
28use commonware_codec::DecodeExt;
29use commonware_cryptography::Digest;
30use commonware_parallel::ThreadPool;
31use commonware_utils::{
32    bitmap::{BitMap as UtilsBitMap, Prunable as PrunableBitMap},
33    sequence::prefixed_u64::U64,
34};
35use rayon::prelude::*;
36use std::collections::HashSet;
37use tracing::{debug, error, warn};
38
39/// Returns a root digest that incorporates bits not yet part of the MMR because they
40/// belong to the last (unfilled) chunk.
41pub(crate) fn partial_chunk_root<H: Hasher<mmr::Family>, const N: usize>(
42    hasher: &H,
43    mmr_root: &H::Digest,
44    next_bit: u64,
45    last_chunk_digest: &H::Digest,
46) -> H::Digest {
47    assert!(next_bit > 0);
48    assert!(next_bit < UtilsBitMap::<N>::CHUNK_SIZE_BITS);
49    let next_bit = next_bit.to_be_bytes();
50    hasher.hash([
51        mmr_root.as_ref(),
52        next_bit.as_slice(),
53        last_chunk_digest.as_ref(),
54    ])
55}
56
57mod private {
58    pub trait Sealed {}
59}
60
61/// Trait for valid [BitMap] type states.
62pub trait State<D: Digest>: private::Sealed + Sized + Send + Sync {}
63
64/// Merkleized state: the bitmap has been merkleized and the root is cached.
65pub struct Merkleized<D: Digest> {
66    /// The cached root of the bitmap.
67    root: D,
68}
69
70impl<D: Digest> private::Sealed for Merkleized<D> {}
71impl<D: Digest> State<D> for Merkleized<D> {}
72
73/// Unmerkleized state: the bitmap has pending changes not yet merkleized.
74pub struct Unmerkleized {
75    /// Bitmap chunks that have been changed but whose changes are not yet reflected in the
76    /// root digest.
77    ///
78    /// Each dirty chunk is identified by its absolute index, including pruned chunks.
79    ///
80    /// Invariant: Indices are always in the range [pruned_chunks, authenticated_len).
81    dirty_chunks: HashSet<usize>,
82}
83
84impl private::Sealed for Unmerkleized {}
85impl<D: Digest> State<D> for Unmerkleized {}
86
87/// A merkleized bitmap whose root digest has been computed and cached.
88pub type MerkleizedBitMap<E, D, const N: usize> = BitMap<E, D, N, Merkleized<D>>;
89
90/// An unmerkleized bitmap whose root digest has not been computed.
91pub type UnmerkleizedBitMap<E, D, const N: usize> = BitMap<E, D, N, Unmerkleized>;
92
93/// A bitmap supporting inclusion proofs through Merkelization.
94///
95/// Merkelization of the bitmap is performed over chunks of N bytes. If the goal is to minimize
96/// proof sizes, choose an N that is equal to the size or double the size of the hasher's digest.
97///
98/// # Type States
99///
100/// The bitmap uses the type-state pattern to enforce at compile-time whether the bitmap has
101/// pending updates that must be merkleized before computing proofs. [MerkleizedBitMap] represents
102/// a bitmapwhose root digest has been computed and cached. [UnmerkleizedBitMap] represents a
103/// bitmap with pending updates. An unmerkleized bitmap can be converted into a merkleized bitmap
104/// by calling [UnmerkleizedBitMap::merkleize].
105///
106/// # Warning
107///
108/// Even though we use u64 identifiers for bits, on 32-bit machines, the maximum addressable bit is
109/// limited to (u32::MAX * N * 8).
110pub struct BitMap<E: Context, D: Digest, const N: usize, S: State<D> = Merkleized<D>> {
111    /// The underlying bitmap.
112    bitmap: PrunableBitMap<N>,
113
114    /// Invariant: Chunks in range [0, authenticated_len) are in `mmr`.
115    /// This is an absolute index that includes pruned chunks.
116    authenticated_len: usize,
117
118    /// A Merkle tree with each leaf representing an N*8 bit "chunk" of the bitmap.
119    ///
120    /// After calling `merkleize` all chunks are guaranteed to be included in the Merkle tree. The
121    /// last chunk of the bitmap is never part of the tree.
122    ///
123    /// Because leaf elements can be updated when bits in the bitmap are flipped, this tree, while
124    /// based on an MMR structure, is not an MMR but a Merkle tree. The MMR structure results in
125    /// reduced update overhead for elements being appended or updated near the tip compared to a
126    /// more typical balanced Merkle tree.
127    mmr: Mmr<D>,
128
129    /// The thread pool to use for parallelization.
130    pool: Option<ThreadPool>,
131
132    /// Merkleization-dependent state.
133    state: S,
134
135    /// Metadata for persisting pruned state.
136    metadata: Metadata<E, U64, Vec<u8>>,
137}
138
139/// Prefix used for the metadata key identifying node digests.
140const NODE_PREFIX: u8 = 0;
141
142/// Prefix used for the metadata key identifying the pruned_chunks value.
143const PRUNED_CHUNKS_PREFIX: u8 = 1;
144
145impl<E: Context, D: Digest, const N: usize, S: State<D>> BitMap<E, D, N, S> {
146    /// The size of a chunk in bits.
147    pub const CHUNK_SIZE_BITS: u64 = PrunableBitMap::<N>::CHUNK_SIZE_BITS;
148
149    /// Return the size of the bitmap in bits.
150    #[inline]
151    pub fn size(&self) -> Position {
152        self.mmr.size()
153    }
154
155    /// Return the number of bits currently stored in the bitmap, irrespective of any pruning.
156    #[inline]
157    pub const fn len(&self) -> u64 {
158        self.bitmap.len()
159    }
160
161    /// Returns true if the bitmap is empty.
162    #[inline]
163    pub const fn is_empty(&self) -> bool {
164        self.len() == 0
165    }
166
167    /// Return the number of bits that have been pruned from this bitmap.
168    #[inline]
169    pub const fn pruned_bits(&self) -> u64 {
170        self.bitmap.pruned_bits()
171    }
172
173    /// Returns the number of complete chunks (excludes partial chunk at end, if any).
174    /// The returned index is absolute and includes pruned chunks.
175    #[inline]
176    fn complete_chunks(&self) -> usize {
177        let chunks_len = self.bitmap.chunks_len();
178        if self.bitmap.is_chunk_aligned() {
179            chunks_len
180        } else {
181            // Last chunk is partial
182            chunks_len.checked_sub(1).unwrap()
183        }
184    }
185
186    /// Return the last chunk of the bitmap and its size in bits. The size can be 0 (meaning the
187    /// last chunk is empty).
188    #[inline]
189    pub fn last_chunk(&self) -> (&[u8; N], u64) {
190        self.bitmap.last_chunk()
191    }
192
193    /// Returns the bitmap chunk containing the specified bit.
194    ///
195    /// # Warning
196    ///
197    /// Panics if the bit doesn't exist or has been pruned.
198    #[inline]
199    pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
200        self.bitmap.get_chunk_containing(bit)
201    }
202
203    /// Get the value of a bit.
204    ///
205    /// # Warning
206    ///
207    /// Panics if the bit doesn't exist or has been pruned.
208    #[inline]
209    pub fn get_bit(&self, bit: u64) -> bool {
210        self.bitmap.get_bit(bit)
211    }
212
213    /// Get the value of a bit from its chunk.
214    /// `bit` is an index into the entire bitmap, not just the chunk.
215    #[inline]
216    pub const fn get_bit_from_chunk(chunk: &[u8; N], bit: u64) -> bool {
217        PrunableBitMap::<N>::get_bit_from_chunk(chunk, bit)
218    }
219
220    /// Verify whether `proof` proves that the `chunk` containing the given bit belongs to the
221    /// bitmap corresponding to `root`.
222    pub fn verify_bit_inclusion(
223        hasher: &impl Hasher<mmr::Family, Digest = D>,
224        proof: &Proof<D>,
225        chunk: &[u8; N],
226        bit: u64,
227        root: &D,
228    ) -> bool {
229        let bit_len = *proof.leaves;
230        if bit >= bit_len {
231            debug!(bit_len, bit, "tried to verify non-existent bit");
232            return false;
233        }
234
235        // The chunk index should always be < MAX_LEAVES.
236        let chunked_leaves = Location::new(PrunableBitMap::<N>::to_chunk_index(bit_len) as u64);
237        let mut mmr_proof = Proof {
238            leaves: chunked_leaves,
239            digests: proof.digests.clone(),
240        };
241
242        let loc = Location::new(PrunableBitMap::<N>::to_chunk_index(bit) as u64);
243        if bit_len.is_multiple_of(Self::CHUNK_SIZE_BITS) {
244            return mmr_proof.verify_element_inclusion(hasher, chunk, loc, root);
245        }
246
247        if proof.digests.is_empty() {
248            debug!("proof has no digests");
249            return false;
250        }
251        let last_digest = mmr_proof.digests.pop().unwrap();
252
253        if chunked_leaves == loc {
254            // The proof is over a bit in the partial chunk. In this case the proof's only digest
255            // should be the MMR's root, otherwise it is invalid. Since we've popped off the last
256            // digest already, there should be no remaining digests.
257            if !mmr_proof.digests.is_empty() {
258                debug!(
259                    digests = mmr_proof.digests.len() + 1,
260                    "proof over partial chunk should have exactly 1 digest"
261                );
262                return false;
263            }
264            let last_chunk_digest = hasher.digest(chunk);
265            let next_bit = bit_len % Self::CHUNK_SIZE_BITS;
266            let reconstructed_root =
267                partial_chunk_root::<_, N>(hasher, &last_digest, next_bit, &last_chunk_digest);
268            return reconstructed_root == *root;
269        };
270
271        // For the case where the proof is over a bit in a full chunk, `last_digest` contains the
272        // digest of that chunk.
273        let mmr_root = match mmr_proof.reconstruct_root(hasher, &[chunk], loc) {
274            Ok(root) => root,
275            Err(error) => {
276                debug!(error = ?error, "invalid proof input");
277                return false;
278            }
279        };
280
281        let next_bit = bit_len % Self::CHUNK_SIZE_BITS;
282        let reconstructed_root =
283            partial_chunk_root::<_, N>(hasher, &mmr_root, next_bit, &last_digest);
284
285        reconstructed_root == *root
286    }
287}
288
289impl<E: Context, D: Digest, const N: usize> MerkleizedBitMap<E, D, N> {
290    /// Initialize a bitmap from the metadata in the given partition. If the partition is empty,
291    /// returns an empty bitmap. Otherwise restores the pruned state (the caller must replay
292    /// retained elements to restore its full state).
293    ///
294    /// Returns an error if the bitmap could not be restored, e.g. because of data corruption or
295    /// underlying storage error.
296    pub async fn init(
297        context: E,
298        partition: &str,
299        pool: Option<ThreadPool>,
300        hasher: &impl Hasher<mmr::Family, Digest = D>,
301    ) -> Result<Self, Error> {
302        let metadata_cfg = MConfig {
303            partition: partition.into(),
304            codec_config: ((0..).into(), ()),
305        };
306        let metadata =
307            Metadata::<_, U64, Vec<u8>>::init(context.with_label("metadata"), metadata_cfg).await?;
308
309        let key: U64 = U64::new(PRUNED_CHUNKS_PREFIX, 0);
310        let pruned_chunks = match metadata.get(&key) {
311            Some(bytes) => u64::from_be_bytes(bytes.as_slice().try_into().map_err(|_| {
312                error!("pruned chunks value not a valid u64");
313                Error::DataCorrupted("pruned chunks value not a valid u64")
314            })?),
315            None => {
316                warn!("bitmap metadata does not contain pruned chunks, initializing as empty");
317                0
318            }
319        } as usize;
320        if pruned_chunks == 0 {
321            let mmr = Mmr::new(hasher);
322            let cached_root = *mmr.root();
323            return Ok(Self {
324                bitmap: PrunableBitMap::new(),
325                authenticated_len: 0,
326                mmr,
327                pool,
328                metadata,
329                state: Merkleized { root: cached_root },
330            });
331        }
332        let pruned_loc = Location::new(pruned_chunks as u64);
333        if !pruned_loc.is_valid() {
334            return Err(Error::DataCorrupted("pruned chunks exceeds MAX_LEAVES"));
335        }
336
337        let mut pinned_nodes = Vec::new();
338        for (index, pos) in mmr::Family::nodes_to_pin(pruned_loc).enumerate() {
339            let Some(bytes) = metadata.get(&U64::new(NODE_PREFIX, index as u64)) else {
340                error!(?pruned_loc, ?pos, "missing pinned node");
341                return Err(Error::MissingNode(pos));
342            };
343            let digest = D::decode(bytes.as_ref());
344            let Ok(digest) = digest else {
345                error!(?pruned_loc, ?pos, "could not convert node bytes to digest");
346                return Err(Error::MissingNode(pos));
347            };
348            pinned_nodes.push(digest);
349        }
350
351        let mmr = Mmr::init(
352            Config {
353                nodes: Vec::new(),
354                pruning_boundary: Location::new(pruned_chunks as u64),
355                pinned_nodes,
356            },
357            hasher,
358        )?;
359
360        let bitmap = PrunableBitMap::new_with_pruned_chunks(pruned_chunks)
361            .expect("pruned_chunks should never overflow");
362        let cached_root = *mmr.root();
363        Ok(Self {
364            bitmap,
365            // Pruned chunks are already authenticated in the MMR
366            authenticated_len: pruned_chunks,
367            mmr,
368            pool,
369            metadata,
370            state: Merkleized { root: cached_root },
371        })
372    }
373
374    pub fn get_node(&self, position: Position) -> Option<D> {
375        self.mmr.get_node(position)
376    }
377
378    /// Write the information necessary to restore the bitmap in its fully pruned state at its last
379    /// pruning boundary. Restoring the entire bitmap state is then possible by replaying the
380    /// retained elements.
381    pub async fn write_pruned(&mut self) -> Result<(), Error> {
382        self.metadata.clear();
383
384        // Write the number of pruned chunks.
385        let key = U64::new(PRUNED_CHUNKS_PREFIX, 0);
386        self.metadata
387            .put(key, self.bitmap.pruned_chunks().to_be_bytes().to_vec());
388
389        // Write the pinned nodes.
390        let pruned_loc = Location::new(self.bitmap.pruned_chunks() as u64);
391        assert!(
392            pruned_loc.is_valid(),
393            "expected valid location from pruned_chunks"
394        );
395        for (i, digest) in mmr::Family::nodes_to_pin(pruned_loc).enumerate() {
396            let digest = self.mmr.get_node_unchecked(digest);
397            let key = U64::new(NODE_PREFIX, i as u64);
398            self.metadata.put(key, digest.to_vec());
399        }
400
401        self.metadata.sync().await.map_err(Error::Metadata)
402    }
403
404    /// Destroy the bitmap metadata from disk.
405    pub async fn destroy(self) -> Result<(), Error> {
406        self.metadata.destroy().await.map_err(Error::Metadata)
407    }
408
409    /// Prune all complete chunks before the chunk containing the given bit.
410    ///
411    /// The chunk containing `bit` and all subsequent chunks are retained. All chunks
412    /// before it are pruned from the bitmap and the underlying MMR.
413    ///
414    /// If `bit` equals the bitmap length, this prunes all complete chunks while retaining
415    /// the empty trailing chunk, preparing the bitmap for appending new data.
416    pub fn prune_to_bit(&mut self, bit: u64) -> Result<(), Error> {
417        let chunk = PrunableBitMap::<N>::to_chunk_index(bit);
418        if chunk < self.bitmap.pruned_chunks() {
419            return Ok(());
420        }
421
422        // Prune inner bitmap
423        self.bitmap.prune_to_bit(bit);
424
425        // Update authenticated length
426        self.authenticated_len = self.complete_chunks();
427
428        self.mmr.prune(Location::new(chunk as u64))?;
429        Ok(())
430    }
431
432    /// Return the cached root digest against which inclusion proofs can be verified.
433    ///
434    /// # Format
435    ///
436    /// The root digest is simply that of the underlying MMR whenever the bit count falls on a chunk
437    /// boundary. Otherwise, the root is computed as follows in order to capture the bits that are
438    /// not yet part of the MMR:
439    ///
440    /// hash(mmr_root || next_bit as u64 be_bytes || last_chunk_digest)
441    ///
442    /// The root is computed during merkleization and cached, so this method is cheap to call.
443    pub const fn root(&self) -> D {
444        self.state.root
445    }
446
447    /// Return an inclusion proof for the specified bit, along with the chunk of the bitmap
448    /// containing that bit. The proof can be used to prove any bit in the chunk.
449    ///
450    /// The bitmap proof stores the number of bits in the bitmap within the `size` field of the
451    /// proof instead of MMR size since the underlying MMR's size does not reflect the number of
452    /// bits in any partial chunk. The underlying MMR size can be derived from the number of
453    /// bits as `leaf_num_to_pos(proof.size / BitMap<_, N>::CHUNK_SIZE_BITS)`.
454    ///
455    /// # Errors
456    ///
457    /// Returns [Error::BitOutOfBounds] if `bit` is out of bounds.
458    pub async fn proof(
459        &self,
460        hasher: &impl Hasher<mmr::Family, Digest = D>,
461        bit: u64,
462    ) -> Result<(Proof<D>, [u8; N]), Error> {
463        if bit >= self.len() {
464            return Err(Error::BitOutOfBounds(bit, self.len()));
465        }
466
467        let chunk = *self.get_chunk_containing(bit);
468        let chunk_loc = Location::from(PrunableBitMap::<N>::to_chunk_index(bit));
469        let (last_chunk, next_bit) = self.bitmap.last_chunk();
470
471        if chunk_loc == self.mmr.leaves() {
472            assert!(next_bit > 0);
473            // Proof is over a bit in the partial chunk. In this case only a single digest is
474            // required in the proof: the mmr's root.
475            return Ok((
476                Proof {
477                    leaves: Location::new(self.len()),
478                    digests: vec![*self.mmr.root()],
479                },
480                chunk,
481            ));
482        }
483
484        let range = chunk_loc..chunk_loc + 1;
485        let mut proof = verification::range_proof(hasher, &self.mmr, range).await?;
486        proof.leaves = Location::new(self.len());
487        if next_bit == Self::CHUNK_SIZE_BITS {
488            // Bitmap is chunk aligned.
489            return Ok((proof, chunk));
490        }
491
492        // Since the bitmap wasn't chunk aligned, we'll need to include the digest of the last chunk
493        // in the proof to be able to re-derive the root.
494        let last_chunk_digest = hasher.digest(last_chunk);
495        proof.digests.push(last_chunk_digest);
496
497        Ok((proof, chunk))
498    }
499
500    /// Convert this merkleized bitmap into an unmerkleized bitmap without making any changes to it.
501    pub fn into_dirty(self) -> UnmerkleizedBitMap<E, D, N> {
502        UnmerkleizedBitMap {
503            bitmap: self.bitmap,
504            authenticated_len: self.authenticated_len,
505            mmr: self.mmr,
506            pool: self.pool,
507            state: Unmerkleized {
508                dirty_chunks: HashSet::new(),
509            },
510            metadata: self.metadata,
511        }
512    }
513}
514
515impl<E: Context, D: Digest, const N: usize> UnmerkleizedBitMap<E, D, N> {
516    /// Add a single bit to the end of the bitmap.
517    ///
518    /// # Warning
519    ///
520    /// The update will not affect the root until `merkleize` is called.
521    pub fn push(&mut self, bit: bool) {
522        self.bitmap.push(bit);
523    }
524
525    /// Set the value of the given bit.
526    ///
527    /// # Warning
528    ///
529    /// The update will not impact the root until `merkleize` is called.
530    pub fn set_bit(&mut self, bit: u64, value: bool) {
531        // Apply the change to the inner bitmap
532        self.bitmap.set_bit(bit, value);
533
534        // If the updated chunk is already in the MMR, mark it as dirty.
535        let chunk = PrunableBitMap::<N>::to_chunk_index(bit);
536        if chunk < self.authenticated_len {
537            self.state.dirty_chunks.insert(chunk);
538        }
539    }
540
541    /// The chunks that have been modified or added since the last call to `merkleize`.
542    pub fn dirty_chunks(&self) -> Vec<Location> {
543        let mut chunks: Vec<Location> = self
544            .state
545            .dirty_chunks
546            .iter()
547            .map(|&chunk| Location::new(chunk as u64))
548            .collect();
549
550        // Include complete chunks that haven't been authenticated yet
551        for i in self.authenticated_len..self.complete_chunks() {
552            chunks.push(Location::new(i as u64));
553        }
554
555        chunks
556    }
557
558    /// Merkleize all updates not yet reflected in the bitmap's root.
559    pub fn merkleize(
560        mut self,
561        hasher: &impl Hasher<mmr::Family, Digest = D>,
562    ) -> Result<MerkleizedBitMap<E, D, N>, Error> {
563        // Add newly pushed complete chunks to the batch.
564        let mut batch = self.mmr.new_batch().with_pool(self.pool.clone());
565        let start = self.authenticated_len;
566        let end = self.complete_chunks();
567        for i in start..end {
568            batch = batch.add(hasher, self.bitmap.get_chunk(i));
569        }
570        self.authenticated_len = end;
571
572        // Pre-hash dirty chunks into digests and update in the batch.
573        let dirty: Vec<(Location, D)> = {
574            let updates: Vec<(Location, &[u8; N])> = self
575                .state
576                .dirty_chunks
577                .iter()
578                .map(|&chunk| {
579                    let loc = Location::new(chunk as u64);
580                    (loc, self.bitmap.get_chunk(chunk))
581                })
582                .collect();
583
584            match self.pool.as_ref() {
585                Some(pool) if updates.len() >= MIN_TO_PARALLELIZE => pool.install(|| {
586                    updates
587                        .par_iter()
588                        .map_init(
589                            || hasher.clone(),
590                            |h, &(loc, chunk)| {
591                                let pos = Position::try_from(loc).unwrap();
592                                (loc, h.leaf_digest(pos, chunk.as_ref()))
593                            },
594                        )
595                        .collect()
596                }),
597                _ => updates
598                    .iter()
599                    .map(|&(loc, chunk)| {
600                        let pos = Position::try_from(loc).unwrap();
601                        (loc, hasher.leaf_digest(pos, chunk.as_ref()))
602                    })
603                    .collect(),
604            }
605        };
606        batch = batch.update_leaf_batched(&dirty)?;
607
608        // Merkleize and apply.
609        let batch = batch.merkleize(&self.mmr, hasher);
610        self.mmr.apply_batch(&batch)?;
611
612        // Compute the bitmap root.
613        let mmr_root = *self.mmr.root();
614        let cached_root = if self.bitmap.is_chunk_aligned() {
615            mmr_root
616        } else {
617            let (last_chunk, next_bit) = self.bitmap.last_chunk();
618            let last_chunk_digest = hasher.digest(last_chunk);
619            partial_chunk_root::<_, N>(hasher, &mmr_root, next_bit, &last_chunk_digest)
620        };
621
622        Ok(MerkleizedBitMap {
623            bitmap: self.bitmap,
624            authenticated_len: self.authenticated_len,
625            mmr: self.mmr,
626            pool: self.pool,
627            metadata: self.metadata,
628            state: Merkleized { root: cached_root },
629        })
630    }
631}
632
633impl<E: Context, D: Digest, const N: usize> Storage<mmr::Family> for MerkleizedBitMap<E, D, N> {
634    type Digest = D;
635
636    async fn size(&self) -> Position {
637        self.size()
638    }
639
640    async fn get_node(&self, position: Position) -> Result<Option<D>, Error> {
641        Ok(self.get_node(position))
642    }
643}
644
645#[cfg(test)]
646mod tests {
647    use super::*;
648    use commonware_codec::FixedSize;
649    use commonware_cryptography::{sha256, Hasher, Sha256};
650    use commonware_macros::test_traced;
651    use commonware_runtime::{deterministic, Metrics, Runner as _};
652    use mmr::StandardHasher;
653
654    const SHA256_SIZE: usize = sha256::Digest::SIZE;
655
656    type TestContext = deterministic::Context;
657    type TestMerkleizedBitMap<const N: usize> = MerkleizedBitMap<TestContext, sha256::Digest, N>;
658
659    impl<E: Context, D: Digest, const N: usize> UnmerkleizedBitMap<E, D, N> {
660        // Add a byte's worth of bits to the bitmap.
661        //
662        // # Warning
663        //
664        // - The update will not impact the root until `merkleize` is called.
665        //
666        // - Assumes self.next_bit is currently byte aligned, and panics otherwise.
667        fn push_byte(&mut self, byte: u8) {
668            self.bitmap.push_byte(byte);
669        }
670
671        /// Add a chunk of bits to the bitmap.
672        ///
673        /// # Warning
674        ///
675        /// - The update will not impact the root until `merkleize` is called.
676        ///
677        /// - Panics if self.next_bit is not chunk aligned.
678        fn push_chunk(&mut self, chunk: &[u8; N]) {
679            self.bitmap.push_chunk(chunk);
680        }
681    }
682
683    fn test_chunk<const N: usize>(s: &[u8]) -> [u8; N] {
684        assert_eq!(N % 32, 0);
685        let mut vec: Vec<u8> = Vec::new();
686        for _ in 0..N / 32 {
687            vec.extend(Sha256::hash(s).iter());
688        }
689
690        vec.try_into().unwrap()
691    }
692
693    #[test_traced]
694    fn test_bitmap_verify_empty_proof() {
695        let executor = deterministic::Runner::default();
696        executor.start(|_context| async move {
697            let hasher = StandardHasher::<Sha256>::new();
698            let proof = Proof {
699                leaves: Location::new(100),
700                digests: Vec::new(),
701            };
702            assert!(
703                !TestMerkleizedBitMap::<SHA256_SIZE>::verify_bit_inclusion(
704                    &hasher,
705                    &proof,
706                    &[0u8; SHA256_SIZE],
707                    0,
708                    &Sha256::fill(0x00),
709                ),
710                "proof without digests shouldn't verify or panic"
711            );
712        });
713    }
714
715    #[test_traced]
716    fn test_bitmap_empty_then_one() {
717        let executor = deterministic::Runner::default();
718        executor.start(|context| async move {
719            let hasher = StandardHasher::<Sha256>::new();
720            let mut bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
721                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
722                    .await
723                    .unwrap();
724            assert_eq!(bitmap.len(), 0);
725            assert_eq!(bitmap.bitmap.pruned_chunks(), 0);
726            bitmap.prune_to_bit(0).unwrap();
727            assert_eq!(bitmap.bitmap.pruned_chunks(), 0);
728
729            // Add a single bit
730            let root = bitmap.root();
731            let mut dirty = bitmap.into_dirty();
732            dirty.push(true);
733            bitmap = dirty.merkleize(&hasher).unwrap();
734            // Root should change
735            let new_root = bitmap.root();
736            assert_ne!(root, new_root);
737            let root = new_root;
738            bitmap.prune_to_bit(1).unwrap();
739            assert_eq!(bitmap.len(), 1);
740            assert_ne!(bitmap.last_chunk().0, &[0u8; SHA256_SIZE]);
741            assert_eq!(bitmap.last_chunk().1, 1);
742            // Pruning should be a no-op since we're not beyond a chunk boundary.
743            assert_eq!(bitmap.bitmap.pruned_chunks(), 0);
744            assert_eq!(root, bitmap.root());
745
746            // Fill up a full chunk
747            let mut dirty = bitmap.into_dirty();
748            for i in 0..(TestMerkleizedBitMap::<SHA256_SIZE>::CHUNK_SIZE_BITS - 1) {
749                dirty.push(i % 2 != 0);
750            }
751            bitmap = dirty.merkleize(&hasher).unwrap();
752            assert_eq!(bitmap.len(), 256);
753            assert_ne!(root, bitmap.root());
754            let root = bitmap.root();
755
756            // Chunk should be provable.
757            let (proof, chunk) = bitmap.proof(&hasher, 0).await.unwrap();
758            assert!(
759                TestMerkleizedBitMap::<SHA256_SIZE>::verify_bit_inclusion(
760                    &hasher, &proof, &chunk, 255, &root
761                ),
762                "failed to prove bit in only chunk"
763            );
764            // bit outside range should not verify
765            assert!(
766                !TestMerkleizedBitMap::<SHA256_SIZE>::verify_bit_inclusion(
767                    &hasher, &proof, &chunk, 256, &root
768                ),
769                "should not be able to prove bit outside of chunk"
770            );
771
772            // Now pruning all bits should matter.
773            bitmap.prune_to_bit(256).unwrap();
774            assert_eq!(bitmap.len(), 256);
775            assert_eq!(bitmap.bitmap.pruned_chunks(), 1);
776            assert_eq!(bitmap.bitmap.pruned_bits(), 256);
777            assert_eq!(root, bitmap.root());
778
779            // Pruning to an earlier point should be a no-op.
780            bitmap.prune_to_bit(10).unwrap();
781            assert_eq!(root, bitmap.root());
782        });
783    }
784
785    #[test_traced]
786    fn test_bitmap_building() {
787        // Build the same bitmap with 2 chunks worth of bits in multiple ways and make sure they are
788        // equivalent based on their roots.
789        let executor = deterministic::Runner::default();
790        executor.start(|context| async move {
791            let test_chunk = test_chunk(b"test");
792            let hasher: StandardHasher<Sha256> = StandardHasher::new();
793
794            // Add each bit one at a time after the first chunk.
795            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
796                TestMerkleizedBitMap::init(context.with_label("bitmap1"), "test1", None, &hasher)
797                    .await
798                    .unwrap();
799            let mut dirty = bitmap.into_dirty();
800            dirty.push_chunk(&test_chunk);
801            for b in test_chunk {
802                for j in 0..8 {
803                    let mask = 1 << j;
804                    let bit = (b & mask) != 0;
805                    dirty.push(bit);
806                }
807            }
808            assert_eq!(dirty.len(), 256 * 2);
809
810            let bitmap = dirty.merkleize(&hasher).unwrap();
811            let root = bitmap.root();
812            let inner_root = *bitmap.mmr.root();
813            assert_eq!(root, inner_root);
814
815            {
816                // Repeat the above MMR build only using push_chunk instead, and make
817                // sure root digests match.
818                let bitmap: TestMerkleizedBitMap<SHA256_SIZE> = TestMerkleizedBitMap::init(
819                    context.with_label("bitmap2"),
820                    "test2",
821                    None,
822                    &hasher,
823                )
824                .await
825                .unwrap();
826                let mut dirty = bitmap.into_dirty();
827                dirty.push_chunk(&test_chunk);
828                dirty.push_chunk(&test_chunk);
829                let bitmap = dirty.merkleize(&hasher).unwrap();
830                let same_root = bitmap.root();
831                assert_eq!(root, same_root);
832            }
833            {
834                // Repeat build again using push_byte this time.
835                let bitmap: TestMerkleizedBitMap<SHA256_SIZE> = TestMerkleizedBitMap::init(
836                    context.with_label("bitmap3"),
837                    "test3",
838                    None,
839                    &hasher,
840                )
841                .await
842                .unwrap();
843                let mut dirty = bitmap.into_dirty();
844                dirty.push_chunk(&test_chunk);
845                for b in test_chunk {
846                    dirty.push_byte(b);
847                }
848                let bitmap = dirty.merkleize(&hasher).unwrap();
849                let same_root = bitmap.root();
850                assert_eq!(root, same_root);
851            }
852        });
853    }
854
855    #[test_traced]
856    #[should_panic(expected = "cannot add chunk")]
857    fn test_bitmap_build_chunked_panic() {
858        let executor = deterministic::Runner::default();
859        executor.start(|context| async move {
860            let hasher: StandardHasher<Sha256> = StandardHasher::new();
861            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
862                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
863                    .await
864                    .unwrap();
865            let mut dirty = bitmap.into_dirty();
866            dirty.push_chunk(&test_chunk(b"test"));
867            dirty.push(true);
868            dirty.push_chunk(&test_chunk(b"panic"));
869        });
870    }
871
872    #[test_traced]
873    #[should_panic(expected = "cannot add byte")]
874    fn test_bitmap_build_byte_panic() {
875        let executor = deterministic::Runner::default();
876        executor.start(|context| async move {
877            let hasher: StandardHasher<Sha256> = StandardHasher::new();
878            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
879                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
880                    .await
881                    .unwrap();
882            let mut dirty = bitmap.into_dirty();
883            dirty.push_chunk(&test_chunk(b"test"));
884            dirty.push(true);
885            dirty.push_byte(0x01);
886        });
887    }
888
889    #[test_traced]
890    #[should_panic(expected = "out of bounds")]
891    fn test_bitmap_get_out_of_bounds_bit_panic() {
892        let executor = deterministic::Runner::default();
893        executor.start(|context| async move {
894            let hasher: StandardHasher<Sha256> = StandardHasher::new();
895            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
896                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
897                    .await
898                    .unwrap();
899            let mut dirty = bitmap.into_dirty();
900            dirty.push_chunk(&test_chunk(b"test"));
901            dirty.get_bit(256);
902        });
903    }
904
905    #[test_traced]
906    #[should_panic(expected = "pruned")]
907    fn test_bitmap_get_pruned_bit_panic() {
908        let executor = deterministic::Runner::default();
909        executor.start(|context| async move {
910            let hasher: StandardHasher<Sha256> = StandardHasher::new();
911            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
912                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
913                    .await
914                    .unwrap();
915            let mut dirty = bitmap.into_dirty();
916            dirty.push_chunk(&test_chunk(b"test"));
917            dirty.push_chunk(&test_chunk(b"test2"));
918            let mut bitmap = dirty.merkleize(&hasher).unwrap();
919
920            bitmap.prune_to_bit(256).unwrap();
921            bitmap.get_bit(255);
922        });
923    }
924
925    #[test_traced]
926    fn test_bitmap_root_boundaries() {
927        let executor = deterministic::Runner::default();
928        executor.start(|context| async move {
929            // Build a starting test MMR with two chunks worth of bits.
930            let hasher = StandardHasher::<Sha256>::new();
931            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
932                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
933                    .await
934                    .unwrap();
935            let mut dirty = bitmap.into_dirty();
936            dirty.push_chunk(&test_chunk(b"test"));
937            dirty.push_chunk(&test_chunk(b"test2"));
938            let mut bitmap = dirty.merkleize(&hasher).unwrap();
939
940            let root = bitmap.root();
941
942            // Confirm that root changes if we add a 1 bit, even though we won't fill a chunk.
943            let mut dirty = bitmap.into_dirty();
944            dirty.push(true);
945            bitmap = dirty.merkleize(&hasher).unwrap();
946            let new_root = bitmap.root();
947            assert_ne!(root, new_root);
948            assert_eq!(bitmap.mmr.size(), 3); // shouldn't include the trailing bits
949
950            // Add 0 bits to fill up entire chunk.
951            for _ in 0..(SHA256_SIZE * 8 - 1) {
952                let mut dirty = bitmap.into_dirty();
953                dirty.push(false);
954                bitmap = dirty.merkleize(&hasher).unwrap();
955                let newer_root = bitmap.root();
956                // root will change when adding 0s within the same chunk
957                assert_ne!(new_root, newer_root);
958            }
959            assert_eq!(bitmap.mmr.size(), 4); // chunk we filled should have been added to mmr
960
961            // Confirm the root changes when we add the next 0 bit since it's part of a new chunk.
962            let mut dirty = bitmap.into_dirty();
963            dirty.push(false);
964            assert_eq!(dirty.len(), 256 * 3 + 1);
965            bitmap = dirty.merkleize(&hasher).unwrap();
966            let newer_root = bitmap.root();
967            assert_ne!(new_root, newer_root);
968
969            // Confirm pruning everything doesn't affect the root.
970            bitmap.prune_to_bit(bitmap.len()).unwrap();
971            assert_eq!(bitmap.bitmap.pruned_chunks(), 3);
972            assert_eq!(bitmap.len(), 256 * 3 + 1);
973            assert_eq!(newer_root, bitmap.root());
974        });
975    }
976
977    #[test_traced]
978    fn test_bitmap_get_set_bits() {
979        let executor = deterministic::Runner::default();
980        executor.start(|context| async move {
981            // Build a test MMR with a few chunks worth of bits.
982            let hasher = StandardHasher::<Sha256>::new();
983            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
984                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
985                    .await
986                    .unwrap();
987            let mut dirty = bitmap.into_dirty();
988            dirty.push_chunk(&test_chunk(b"test"));
989            dirty.push_chunk(&test_chunk(b"test2"));
990            dirty.push_chunk(&test_chunk(b"test3"));
991            dirty.push_chunk(&test_chunk(b"test4"));
992            // Add a few extra bits to exercise not being on a chunk or byte boundary.
993            dirty.push_byte(0xF1);
994            dirty.push(true);
995            dirty.push(false);
996            dirty.push(true);
997
998            let mut bitmap = dirty.merkleize(&hasher).unwrap();
999            let root = bitmap.root();
1000
1001            // Flip each bit and confirm the root changes, then flip it back to confirm it is safely
1002            // restored.
1003            for bit_pos in (0..bitmap.len()).rev() {
1004                let bit = bitmap.get_bit(bit_pos);
1005                let mut dirty = bitmap.into_dirty();
1006                dirty.set_bit(bit_pos, !bit);
1007                bitmap = dirty.merkleize(&hasher).unwrap();
1008                let new_root = bitmap.root();
1009                assert_ne!(root, new_root, "failed at bit {bit_pos}");
1010                // flip it back
1011                let mut dirty = bitmap.into_dirty();
1012                dirty.set_bit(bit_pos, bit);
1013                bitmap = dirty.merkleize(&hasher).unwrap();
1014                let new_root = bitmap.root();
1015                assert_eq!(root, new_root);
1016            }
1017
1018            // Repeat the test after pruning.
1019            let start_bit = (SHA256_SIZE * 8 * 2) as u64;
1020            bitmap.prune_to_bit(start_bit).unwrap();
1021            for bit_pos in (start_bit..bitmap.len()).rev() {
1022                let bit = bitmap.get_bit(bit_pos);
1023                let mut dirty = bitmap.into_dirty();
1024                dirty.set_bit(bit_pos, !bit);
1025                bitmap = dirty.merkleize(&hasher).unwrap();
1026                let new_root = bitmap.root();
1027                assert_ne!(root, new_root, "failed at bit {bit_pos}");
1028                // flip it back
1029                let mut dirty = bitmap.into_dirty();
1030                dirty.set_bit(bit_pos, bit);
1031                bitmap = dirty.merkleize(&hasher).unwrap();
1032                let new_root = bitmap.root();
1033                assert_eq!(root, new_root);
1034            }
1035        });
1036    }
1037
1038    fn flip_bit<const N: usize>(bit: u64, chunk: &[u8; N]) -> [u8; N] {
1039        let byte = PrunableBitMap::<N>::chunk_byte_offset(bit);
1040        let mask = PrunableBitMap::<N>::chunk_byte_bitmask(bit);
1041        let mut tmp = chunk.to_vec();
1042        tmp[byte] ^= mask;
1043        tmp.try_into().unwrap()
1044    }
1045
1046    #[test_traced]
1047    fn test_bitmap_mmr_proof_verification() {
1048        test_bitmap_mmr_proof_verification_n::<32>();
1049        test_bitmap_mmr_proof_verification_n::<64>();
1050    }
1051
1052    fn test_bitmap_mmr_proof_verification_n<const N: usize>() {
1053        let executor = deterministic::Runner::default();
1054        executor.start(|context| async move {
1055            // Build a bitmap with 10 chunks worth of bits.
1056            let hasher = StandardHasher::<Sha256>::new();
1057            let bitmap: MerkleizedBitMap<TestContext, sha256::Digest, N> =
1058                MerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
1059                    .await
1060                    .unwrap();
1061            let mut dirty = bitmap.into_dirty();
1062            for i in 0u32..10 {
1063                dirty.push_chunk(&test_chunk(format!("test{i}").as_bytes()));
1064            }
1065            // Add a few extra bits to exercise not being on a chunk or byte boundary.
1066            dirty.push_byte(0xA6);
1067            dirty.push(true);
1068            dirty.push(false);
1069            dirty.push(true);
1070            dirty.push(true);
1071            dirty.push(false);
1072
1073            let mut bitmap = dirty.merkleize(&hasher).unwrap();
1074            let root = bitmap.root();
1075
1076            // Make sure every bit is provable, even after pruning in intervals of 251 bits (251 is
1077            // the largest prime that is less than the size of one 32-byte chunk in bits).
1078            for prune_to_bit in (0..bitmap.len()).step_by(251) {
1079                assert_eq!(bitmap.root(), root);
1080                bitmap.prune_to_bit(prune_to_bit).unwrap();
1081                for i in prune_to_bit..bitmap.len() {
1082                    let (proof, chunk) = bitmap.proof(&hasher, i).await.unwrap();
1083
1084                    // Proof should verify for the original chunk containing the bit.
1085                    assert!(
1086                        MerkleizedBitMap::<TestContext, _, N>::verify_bit_inclusion(
1087                            &hasher, &proof, &chunk, i, &root
1088                        ),
1089                        "failed to prove bit {i}",
1090                    );
1091
1092                    // Flip the bit in the chunk and make sure the proof fails.
1093                    let corrupted = flip_bit(i, &chunk);
1094                    assert!(
1095                        !MerkleizedBitMap::<TestContext, _, N>::verify_bit_inclusion(
1096                            &hasher, &proof, &corrupted, i, &root
1097                        ),
1098                        "proving bit {i} after flipping should have failed",
1099                    );
1100                }
1101            }
1102        })
1103    }
1104
1105    #[test_traced]
1106    fn test_bitmap_persistence() {
1107        const PARTITION: &str = "bitmap-test";
1108        const FULL_CHUNK_COUNT: usize = 100;
1109
1110        let executor = deterministic::Runner::default();
1111        executor.start(|context| async move {
1112            let hasher = StandardHasher::<Sha256>::new();
1113            // Initializing from an empty partition should result in an empty bitmap.
1114            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
1115                TestMerkleizedBitMap::init(context.with_label("initial"), PARTITION, None, &hasher)
1116                    .await
1117                    .unwrap();
1118            assert_eq!(bitmap.len(), 0);
1119
1120            // Add a non-trivial amount of data.
1121            let mut dirty = bitmap.into_dirty();
1122            for i in 0..FULL_CHUNK_COUNT {
1123                dirty.push_chunk(&test_chunk(format!("test{i}").as_bytes()));
1124            }
1125            let mut bitmap = dirty.merkleize(&hasher).unwrap();
1126            let chunk_aligned_root = bitmap.root();
1127
1128            // Add a few extra bits beyond the last chunk boundary.
1129            let mut dirty = bitmap.into_dirty();
1130            dirty.push_byte(0xA6);
1131            dirty.push(true);
1132            dirty.push(false);
1133            dirty.push(true);
1134            bitmap = dirty.merkleize(&hasher).unwrap();
1135            let root = bitmap.root();
1136
1137            // prune 10 chunks at a time and make sure replay will restore the bitmap every time.
1138            for i in (10..=FULL_CHUNK_COUNT).step_by(10) {
1139                bitmap
1140                    .prune_to_bit(
1141                        (i * TestMerkleizedBitMap::<SHA256_SIZE>::CHUNK_SIZE_BITS as usize) as u64,
1142                    )
1143                    .unwrap();
1144                bitmap.write_pruned().await.unwrap();
1145                bitmap = TestMerkleizedBitMap::init(
1146                    context.with_label(&format!("restore_{i}")),
1147                    PARTITION,
1148                    None,
1149                    &hasher,
1150                )
1151                .await
1152                .unwrap();
1153                let _ = bitmap.root();
1154
1155                // Replay missing chunks.
1156                let mut dirty = bitmap.into_dirty();
1157                for j in i..FULL_CHUNK_COUNT {
1158                    dirty.push_chunk(&test_chunk(format!("test{j}").as_bytes()));
1159                }
1160                assert_eq!(dirty.bitmap.pruned_chunks(), i);
1161                assert_eq!(dirty.len(), FULL_CHUNK_COUNT as u64 * 256);
1162                bitmap = dirty.merkleize(&hasher).unwrap();
1163                assert_eq!(bitmap.root(), chunk_aligned_root);
1164
1165                // Replay missing partial chunk.
1166                let mut dirty = bitmap.into_dirty();
1167                dirty.push_byte(0xA6);
1168                dirty.push(true);
1169                dirty.push(false);
1170                dirty.push(true);
1171                bitmap = dirty.merkleize(&hasher).unwrap();
1172                assert_eq!(bitmap.root(), root);
1173            }
1174        });
1175    }
1176
1177    #[test_traced]
1178    fn test_bitmap_proof_out_of_bounds() {
1179        let executor = deterministic::Runner::default();
1180        executor.start(|context| async move {
1181            let hasher = StandardHasher::<Sha256>::new();
1182            let bitmap: TestMerkleizedBitMap<SHA256_SIZE> =
1183                TestMerkleizedBitMap::init(context.with_label("bitmap"), "test", None, &hasher)
1184                    .await
1185                    .unwrap();
1186            let mut dirty = bitmap.into_dirty();
1187            dirty.push_chunk(&test_chunk(b"test"));
1188            let bitmap = dirty.merkleize(&hasher).unwrap();
1189
1190            // Proof for bit_offset >= bit_count should fail
1191            let result = bitmap.proof(&hasher, 256).await;
1192            assert!(matches!(result, Err(Error::BitOutOfBounds(offset, size))
1193                    if offset == 256 && size == 256));
1194
1195            let result = bitmap.proof(&hasher, 1000).await;
1196            assert!(matches!(result, Err(Error::BitOutOfBounds(offset, size))
1197                    if offset == 1000 && size == 256));
1198
1199            // Valid proof should work
1200            assert!(bitmap.proof(&hasher, 0).await.is_ok());
1201            assert!(bitmap.proof(&hasher, 255).await.is_ok());
1202        });
1203    }
1204}