commonware_utils/bitmap/historical/
bitmap.rs

1use super::{batch::Batch, Error};
2use crate::bitmap::{historical::BatchGuard, Prunable};
3#[cfg(not(feature = "std"))]
4use alloc::{collections::BTreeMap, vec::Vec};
5#[cfg(feature = "std")]
6use std::collections::BTreeMap;
7
8/// Type of change to a chunk.
9#[derive(Clone, Debug)]
10pub(super) enum ChunkDiff<const N: usize> {
11    /// Chunk was modified (contains old value before the change).
12    Modified([u8; N]),
13    /// Chunk was removed from the right side (contains old value before removal).
14    Removed([u8; N]),
15    /// Chunk was added (did not exist before).
16    Added,
17    /// Chunk was pruned from the left side (contains old value before pruning).
18    Pruned([u8; N]),
19}
20
21/// A reverse diff that describes the state before a commit.
22#[derive(Clone, Debug)]
23pub(super) struct CommitDiff<const N: usize> {
24    /// Total length in bits before this commit.
25    pub(super) len: u64,
26    /// Number of pruned chunks before this commit.
27    pub(super) pruned_chunks: usize,
28    /// Chunk-level changes.
29    pub(super) chunk_diffs: BTreeMap<usize, ChunkDiff<N>>,
30}
31
32/// A historical bitmap that maintains one actual bitmap plus diffs for history and batching.
33///
34/// Commit numbers must be strictly monotonically increasing and < u64::MAX.
35pub struct BitMap<const N: usize> {
36    /// The current/HEAD state - the one and only full bitmap.
37    pub(super) current: Prunable<N>,
38
39    /// Historical commits: commit_number -> reverse diff from that commit.
40    pub(super) commits: BTreeMap<u64, CommitDiff<N>>,
41
42    /// Active batch (if any).
43    pub(super) active_batch: Option<Batch<N>>,
44}
45
46impl<const N: usize> BitMap<N> {
47    /// Create a new empty historical bitmap.
48    pub fn new() -> Self {
49        Self {
50            current: Prunable::new(),
51            commits: BTreeMap::new(),
52            active_batch: None,
53        }
54    }
55
56    /// Create a new historical bitmap with the given number of pruned chunks.
57    pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
58        Ok(Self {
59            current: Prunable::new_with_pruned_chunks(pruned_chunks)?,
60            commits: BTreeMap::new(),
61            active_batch: None,
62        })
63    }
64
65    /// Start a new batch for making mutations.
66    ///
67    /// The returned [BatchGuard] must be either committed or dropped. All mutations
68    /// are applied to the guard's diff layer and do not affect the current bitmap
69    /// until commit.
70    ///
71    /// # Examples
72    ///
73    /// ```
74    /// # use commonware_utils::bitmap::historical::BitMap;
75    /// let mut bitmap: BitMap<4> = BitMap::new();
76    ///
77    /// let mut batch = bitmap.start_batch();
78    /// batch.push(true);
79    /// batch.push(false);
80    /// batch.commit(1).unwrap();
81    ///
82    /// assert_eq!(bitmap.len(), 2);
83    /// ```
84    ///
85    /// # Panics
86    ///
87    /// Panics if a batch is already active.
88    pub fn start_batch(&mut self) -> BatchGuard<'_, N> {
89        assert!(
90            self.active_batch.is_none(),
91            "cannot start batch: batch already active"
92        );
93
94        let batch = Batch {
95            base_len: self.current.len(),
96            base_pruned_chunks: self.current.pruned_chunks(),
97            projected_len: self.current.len(),
98            projected_pruned_chunks: self.current.pruned_chunks(),
99            modified_bits: BTreeMap::new(),
100            appended_bits: Vec::new(),
101            chunks_to_prune: BTreeMap::new(),
102        };
103
104        self.active_batch = Some(batch);
105
106        BatchGuard {
107            bitmap: self,
108            committed: false,
109        }
110    }
111
112    /// Execute a closure with a batch and commit it at the given commit number.
113    ///
114    /// # Errors
115    ///
116    /// Returns [Error::NonMonotonicCommit] if the commit number is not
117    /// greater than the previous commit.
118    ///
119    /// Returns [Error::ReservedCommitNumber] if the commit number is `u64::MAX`.
120    ///
121    /// # Panics
122    ///
123    /// Panics if a batch is already active.
124    pub fn with_batch<F>(&mut self, commit_number: u64, f: F) -> Result<(), Error>
125    where
126        F: FnOnce(&mut BatchGuard<'_, N>),
127    {
128        let mut guard = self.start_batch();
129        f(&mut guard);
130        guard.commit(commit_number)
131    }
132
133    /// Get the bitmap state as it existed at a specific commit.
134    ///
135    /// Returns `None` if the commit does not exist or if `commit_number` is `u64::MAX`
136    /// (which is reserved and cannot be used as a commit number).
137    ///
138    /// This reconstructs the historical state by applying reverse diffs backward from
139    /// the current state. Each commit's reverse diff describes the state before that
140    /// commit, so we "undo" commits one by one until we reach the target.
141    ///
142    /// # Examples
143    ///
144    /// ```
145    /// # use commonware_utils::bitmap::historical::BitMap;
146    /// let mut bitmap: BitMap<4> = BitMap::new();
147    ///
148    /// bitmap.with_batch(1, |batch| {
149    ///     batch.push(true);
150    ///     batch.push(false);
151    /// }).unwrap();
152    ///
153    /// bitmap.with_batch(2, |batch| {
154    ///     batch.set_bit(0, false);
155    ///     batch.push(true);
156    /// }).unwrap();
157    ///
158    /// // Get state as it was at commit 1
159    /// let state_at_1 = bitmap.get_at_commit(1).unwrap();
160    /// assert_eq!(state_at_1.len(), 2);
161    /// assert!(state_at_1.get_bit(0));
162    /// assert!(!state_at_1.get_bit(1));
163    ///
164    /// // Current state is different
165    /// assert_eq!(bitmap.len(), 3);
166    /// assert!(!bitmap.get_bit(0));
167    /// ```
168    pub fn get_at_commit(&self, commit_number: u64) -> Option<Prunable<N>> {
169        // Check if the commit exists and is valid
170        if commit_number == u64::MAX || !self.commits.contains_key(&commit_number) {
171            return None;
172        }
173
174        // Start with current state
175        let mut state = self.current.clone();
176
177        // Apply reverse diffs from newest down to target (exclusive)
178        // Each reverse diff at commit N describes the state before commit N
179        // Addition can't overflow because commit_number < u64::MAX
180        for (_commit, diff) in self.commits.range(commit_number + 1..).rev() {
181            self.apply_reverse_diff(&mut state, diff);
182        }
183
184        Some(state)
185    }
186
187    /// Push bits to extend the bitmap to target length.
188    fn push_to_length(&self, state: &mut Prunable<N>, target_len: u64) {
189        while state.len() < target_len {
190            let remaining = target_len - state.len();
191            let next_bit = state.len() % Prunable::<N>::CHUNK_SIZE_BITS;
192
193            // If we're at a chunk boundary and need at least a full chunk, push an entire chunk
194            if next_bit == 0 && remaining >= Prunable::<N>::CHUNK_SIZE_BITS {
195                state.push_chunk(&[0u8; N]);
196            } else {
197                // Otherwise push individual bits
198                state.push(false);
199            }
200        }
201    }
202
203    /// Pop bits to shrink the bitmap to target length.
204    /// Optimized to pop entire chunks when possible.
205    fn pop_to_length(&self, state: &mut Prunable<N>, target_len: u64) {
206        while state.len() > target_len {
207            let excess = state.len() - target_len;
208            let next_bit = state.len() % Prunable::<N>::CHUNK_SIZE_BITS;
209
210            // If at chunk boundary and we need to remove at least a full chunk, pop entire chunk
211            if next_bit == 0 && excess >= Prunable::<N>::CHUNK_SIZE_BITS {
212                state.pop_chunk();
213            } else {
214                // Otherwise pop individual bits
215                state.pop();
216            }
217        }
218    }
219
220    /// Apply a reverse diff to transform newer_state into the previous state (in-place).
221    ///
222    /// Algorithm:
223    /// 1. Restore pruned chunks by prepending them back (unprune)
224    /// 2. Adjust bitmap structure to target length (extend/shrink as needed)
225    /// 3. Update chunk data for Modified and Removed chunks
226    /// 4. Set next_bit to match target length exactly
227    fn apply_reverse_diff(&self, newer_state: &mut Prunable<N>, diff: &CommitDiff<N>) {
228        let target_len = diff.len;
229        let target_pruned = diff.pruned_chunks;
230        let newer_pruned = newer_state.pruned_chunks();
231
232        // Phase 1: Restore pruned chunks
233        assert!(
234            target_pruned <= newer_pruned,
235            "invariant violation: target_pruned ({target_pruned}) > newer_pruned ({newer_pruned})"
236        );
237        let mut chunks_to_unprune = Vec::with_capacity(newer_pruned - target_pruned);
238        for chunk_index in (target_pruned..newer_pruned).rev() {
239            let Some(ChunkDiff::Pruned(chunk)) = diff.chunk_diffs.get(&chunk_index) else {
240                panic!("chunk {chunk_index} should be Pruned in diff");
241            };
242            chunks_to_unprune.push(*chunk);
243        }
244        newer_state.unprune_chunks(&chunks_to_unprune);
245
246        // Phase 2: Adjust bitmap structure to target length
247        if newer_state.len() < target_len {
248            self.push_to_length(newer_state, target_len);
249        } else if newer_state.len() > target_len {
250            self.pop_to_length(newer_state, target_len);
251        }
252
253        // Phase 3: Update chunk data
254        for (&chunk_index, change) in diff
255            .chunk_diffs
256            .iter()
257            .filter(|(chunk_index, _)| **chunk_index >= newer_pruned)
258        {
259            match change {
260                ChunkDiff::Modified(old_data) | ChunkDiff::Removed(old_data) => {
261                    // Both cases: chunk exists in target, just update its data
262                    newer_state.set_chunk_by_index(chunk_index, old_data);
263                }
264                ChunkDiff::Added => {
265                    // Chunk didn't exist in target - already handled by pop_to_length.
266                    // We can break here because there are no more modifications to apply.
267                    // Added can only occur after all Modified. If we encounter Added, we know
268                    // there are no Removed. (diff.chunk_diffs can't have both Added and Removed.)
269                    break;
270                }
271                ChunkDiff::Pruned(_) => {
272                    panic!("pruned chunk found at unexpected index {chunk_index}")
273                }
274            }
275        }
276
277        assert_eq!(newer_state.pruned_chunks(), target_pruned);
278        assert_eq!(newer_state.len(), target_len);
279    }
280
281    /// Check if a commit exists.
282    pub fn commit_exists(&self, commit_number: u64) -> bool {
283        self.commits.contains_key(&commit_number)
284    }
285
286    /// Get an iterator over all commit numbers in ascending order.
287    pub fn commits(&self) -> impl Iterator<Item = u64> + '_ {
288        self.commits.keys().copied()
289    }
290
291    /// Get the latest commit number, if any commits exist.
292    pub fn latest_commit(&self) -> Option<u64> {
293        self.commits.keys().next_back().copied()
294    }
295
296    /// Get the earliest commit number, if any commits exist.
297    pub fn earliest_commit(&self) -> Option<u64> {
298        self.commits.keys().next().copied()
299    }
300
301    /// Get a reference to the current bitmap state.
302    pub fn current(&self) -> &Prunable<N> {
303        &self.current
304    }
305
306    /// Number of bits in the current bitmap.
307    #[inline]
308    pub fn len(&self) -> u64 {
309        self.current.len()
310    }
311
312    /// Returns true if the current bitmap is empty.
313    #[inline]
314    pub fn is_empty(&self) -> bool {
315        self.current.is_empty()
316    }
317
318    /// Get the value of a bit in the current bitmap.
319    #[inline]
320    pub fn get_bit(&self, bit: u64) -> bool {
321        self.current.get_bit(bit)
322    }
323
324    /// Get the chunk containing a bit in the current bitmap.
325    #[inline]
326    pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
327        self.current.get_chunk_containing(bit)
328    }
329
330    /// Number of pruned chunks in the current bitmap.
331    #[inline]
332    pub fn pruned_chunks(&self) -> usize {
333        self.current.pruned_chunks()
334    }
335
336    /// Remove all commits with numbers below the commit number.
337    ///
338    /// Returns the number of commits removed.
339    pub fn prune_commits_before(&mut self, commit_number: u64) -> usize {
340        let count = self.commits.len();
341        self.commits = self.commits.split_off(&commit_number);
342        count - self.commits.len()
343    }
344
345    /// Clear all historical commits.
346    pub fn clear_history(&mut self) {
347        self.commits.clear();
348    }
349
350    /// Apply a batch's changes to the current bitmap.
351    pub(super) fn apply_batch_to_current(&mut self, batch: &Batch<N>) {
352        // Step 1: Shrink to length before appends (handles net pops)
353        let target_len_before_appends = batch.projected_len - batch.appended_bits.len() as u64;
354
355        while self.current.len() > target_len_before_appends {
356            self.current.pop();
357        }
358
359        // Step 2: Grow by appending new bits
360        for &bit in &batch.appended_bits {
361            self.current.push(bit);
362        }
363        assert_eq!(self.current.len(), batch.projected_len);
364
365        // Step 3: Modify existing base bits (not appended bits)
366        for (&bit, &value) in &batch.modified_bits {
367            self.current.set_bit(bit, value);
368        }
369
370        // Step 4: Prune chunks from the beginning
371        if batch.projected_pruned_chunks > batch.base_pruned_chunks {
372            let prune_to_bit =
373                batch.projected_pruned_chunks as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
374            self.current.prune_to_bit(prune_to_bit);
375        }
376    }
377
378    /// Build a reverse diff from a batch.
379    pub(super) fn build_reverse_diff(&self, batch: &Batch<N>) -> CommitDiff<N> {
380        let mut changes = BTreeMap::new();
381        self.capture_modified_chunks(batch, &mut changes);
382        self.capture_appended_chunks(batch, &mut changes);
383        self.capture_popped_chunks(batch, &mut changes);
384        self.capture_pruned_chunks(batch, &mut changes);
385        CommitDiff {
386            len: batch.base_len,
387            pruned_chunks: batch.base_pruned_chunks,
388            chunk_diffs: changes,
389        }
390    }
391
392    /// Capture chunks affected by bit modifications.
393    ///
394    /// For each chunk containing modified bits, we store its original value so we can
395    /// restore it when reconstructing historical states.
396    fn capture_modified_chunks(
397        &self,
398        batch: &Batch<N>,
399        changes: &mut BTreeMap<usize, ChunkDiff<N>>,
400    ) {
401        for &bit in batch.modified_bits.keys() {
402            let chunk_idx = Prunable::<N>::unpruned_chunk(bit);
403            changes.entry(chunk_idx).or_insert_with(|| {
404                // `modified_bits` only contains bits from the base region that existed
405                // at batch creation. Since current hasn't changed yet (we're still
406                // building the diff), the chunk MUST exist.
407                let old_chunk = self
408                    .get_chunk(chunk_idx)
409                    .expect("chunk must exist for modified bit");
410                ChunkDiff::Modified(old_chunk)
411            });
412        }
413    }
414
415    /// Capture chunks affected by appended bits.
416    ///
417    /// When bits are appended, they may:
418    /// - Extend an existing partial chunk (mark as Modified with old data)
419    /// - Create entirely new chunks (mark as Added)
420    fn capture_appended_chunks(
421        &self,
422        batch: &Batch<N>,
423        changes: &mut BTreeMap<usize, ChunkDiff<N>>,
424    ) {
425        if batch.appended_bits.is_empty() {
426            return;
427        }
428
429        // Calculate which chunks will be affected by appends.
430        // Note: append_start_bit accounts for any net pops before the pushes.
431        let append_start_bit = batch.projected_len - batch.appended_bits.len() as u64;
432        let start_chunk = Prunable::<N>::unpruned_chunk(append_start_bit);
433        let end_chunk = Prunable::<N>::unpruned_chunk(batch.projected_len.saturating_sub(1));
434
435        for chunk_idx in start_chunk..=end_chunk {
436            // Use or_insert_with so we don't overwrite chunks already captured
437            // by capture_modified_chunks (which runs first and takes precedence).
438            changes.entry(chunk_idx).or_insert_with(|| {
439                if let Some(old_chunk) = self.get_chunk(chunk_idx) {
440                    // Chunk existed before: store its old data
441                    ChunkDiff::Modified(old_chunk)
442                } else {
443                    // Chunk is brand new: mark as Added
444                    ChunkDiff::Added
445                }
446            });
447        }
448    }
449
450    /// Capture chunks affected by pop operations.
451    ///
452    /// When bits are popped (projected_len < base_len), we need to capture the original
453    /// data of chunks that will be truncated or fully removed. This allows reconstruction
454    /// to restore the bits that were popped.
455    fn capture_popped_chunks(&self, batch: &Batch<N>, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
456        if batch.projected_len >= batch.base_len || batch.base_len == 0 {
457            return; // No net pops
458        }
459
460        // Identify the range of chunks affected by length reduction.
461        let old_last_chunk = Prunable::<N>::unpruned_chunk(batch.base_len - 1);
462        let new_last_chunk = if batch.projected_len > 0 {
463            Prunable::<N>::unpruned_chunk(batch.projected_len - 1)
464        } else {
465            0
466        };
467
468        // Capture all chunks between the new and old endpoints.
469        // Skip chunks that were already pruned before this batch started.
470        for chunk_idx in new_last_chunk..=old_last_chunk {
471            if chunk_idx < batch.base_pruned_chunks {
472                // This chunk was already pruned before the batch, skip it
473                continue;
474            }
475
476            changes.entry(chunk_idx).or_insert_with(|| {
477                let old_chunk = self
478                    .get_chunk(chunk_idx)
479                    .expect("chunk must exist in base bitmap for popped bits");
480
481                // Determine if this chunk is partially kept or completely removed
482                let chunk_start_bit = chunk_idx as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
483
484                if batch.projected_len > chunk_start_bit {
485                    // Chunk spans the new length boundary → partially kept (Modified)
486                    ChunkDiff::Modified(old_chunk)
487                } else {
488                    // Chunk is completely beyond the new length → fully removed (Removed)
489                    ChunkDiff::Removed(old_chunk)
490                }
491            });
492        }
493    }
494
495    /// Capture chunks that will be pruned.
496    ///
497    /// The batch's `prune_to_bit` method already captured the old chunk data,
498    /// so we simply copy it into the reverse diff.
499    fn capture_pruned_chunks(&self, batch: &Batch<N>, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
500        for (&chunk_idx, &chunk_data) in &batch.chunks_to_prune {
501            changes.insert(chunk_idx, ChunkDiff::Pruned(chunk_data));
502        }
503    }
504
505    /// Get chunk data from current state if it exists.
506    ///
507    /// Returns `Some(chunk_data)` if the chunk exists in the current bitmap,
508    /// or `None` if it's out of bounds or pruned.
509    fn get_chunk(&self, chunk_idx: usize) -> Option<[u8; N]> {
510        let current_pruned = self.current.pruned_chunks();
511        if chunk_idx >= current_pruned {
512            let bitmap_idx = chunk_idx - current_pruned;
513            if bitmap_idx < self.current.chunks_len() {
514                return Some(*self.current.get_chunk(bitmap_idx));
515            }
516        }
517        None
518    }
519}
520
521impl<const N: usize> Default for BitMap<N> {
522    fn default() -> Self {
523        Self::new()
524    }
525}