use super::Error;
use crate::bitmap::Prunable;
#[cfg(not(feature = "std"))]
use alloc::{collections::BTreeMap, vec::Vec};
#[cfg(feature = "std")]
use std::collections::BTreeMap;
mod private {
pub trait Sealed {}
}
pub trait State: private::Sealed + Sized + Send + Sync {}
#[derive(Clone, Debug)]
pub struct Clean;
impl private::Sealed for Clean {}
impl State for Clean {}
#[derive(Clone, Debug)]
pub struct Dirty<const N: usize> {
base_len: u64,
base_pruned_chunks: usize,
projected_len: u64,
projected_pruned_chunks: usize,
modified_bits: BTreeMap<u64, bool>,
appended_bits: Vec<bool>,
chunks_to_prune: BTreeMap<usize, [u8; N]>,
}
impl<const N: usize> private::Sealed for Dirty<N> {}
impl<const N: usize> State for Dirty<N> {}
#[derive(Clone, Debug)]
pub(super) enum ChunkDiff<const N: usize> {
Modified([u8; N]),
Removed([u8; N]),
Added,
Pruned([u8; N]),
}
#[derive(Clone, Debug)]
pub(super) struct CommitDiff<const N: usize> {
pub(super) len: u64,
pub(super) pruned_chunks: usize,
pub(super) chunk_diffs: BTreeMap<usize, ChunkDiff<N>>,
}
#[derive(Clone, Debug)]
pub struct BitMap<const N: usize, S: State = Clean> {
current: Prunable<N>,
commits: BTreeMap<u64, CommitDiff<N>>,
state: S,
}
pub type CleanBitMap<const N: usize> = BitMap<N, Clean>;
pub type DirtyBitMap<const N: usize> = BitMap<N, Dirty<N>>;
impl<const N: usize> CleanBitMap<N> {
pub const fn new() -> Self {
Self {
current: Prunable::new(),
commits: BTreeMap::new(),
state: Clean,
}
}
pub fn new_with_pruned_chunks(pruned_chunks: usize) -> Result<Self, Error> {
Ok(Self {
current: Prunable::new_with_pruned_chunks(pruned_chunks)?,
commits: BTreeMap::new(),
state: Clean,
})
}
pub fn into_dirty(self) -> DirtyBitMap<N> {
DirtyBitMap {
state: Dirty {
base_len: self.current.len(),
base_pruned_chunks: self.current.pruned_chunks(),
projected_len: self.current.len(),
projected_pruned_chunks: self.current.pruned_chunks(),
modified_bits: BTreeMap::new(),
appended_bits: Vec::new(),
chunks_to_prune: BTreeMap::new(),
},
current: self.current,
commits: self.commits,
}
}
pub fn apply_batch<F>(self, commit_number: u64, f: F) -> Result<Self, Error>
where
F: FnOnce(&mut DirtyBitMap<N>),
{
let mut dirty = self.into_dirty();
f(&mut dirty);
dirty.commit(commit_number)
}
pub fn get_at_commit(&self, commit_number: u64) -> Option<Prunable<N>> {
if commit_number == u64::MAX || !self.commits.contains_key(&commit_number) {
return None;
}
let mut state = self.current.clone();
for (_commit, diff) in self.commits.range(commit_number + 1..).rev() {
self.apply_reverse_diff(&mut state, diff);
}
Some(state)
}
pub fn commit_exists(&self, commit_number: u64) -> bool {
self.commits.contains_key(&commit_number)
}
pub fn commits(&self) -> impl Iterator<Item = u64> + '_ {
self.commits.keys().copied()
}
pub fn latest_commit(&self) -> Option<u64> {
self.commits.keys().next_back().copied()
}
pub fn earliest_commit(&self) -> Option<u64> {
self.commits.keys().next().copied()
}
pub const fn current(&self) -> &Prunable<N> {
&self.current
}
#[inline]
pub const fn len(&self) -> u64 {
self.current.len()
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.current.is_empty()
}
#[inline]
pub fn get_bit(&self, bit: u64) -> bool {
self.current.get_bit(bit)
}
#[inline]
pub fn get_chunk_containing(&self, bit: u64) -> &[u8; N] {
self.current.get_chunk_containing(bit)
}
#[inline]
pub const fn pruned_chunks(&self) -> usize {
self.current.pruned_chunks()
}
pub fn prune_commits_before(&mut self, commit_number: u64) -> usize {
let count = self.commits.len();
self.commits = self.commits.split_off(&commit_number);
count - self.commits.len()
}
pub fn clear_history(&mut self) {
self.commits.clear();
}
fn push_to_length(&self, state: &mut Prunable<N>, target_len: u64) {
while state.len() < target_len {
let remaining = target_len - state.len();
let next_bit = state.len() % Prunable::<N>::CHUNK_SIZE_BITS;
if next_bit == 0 && remaining >= Prunable::<N>::CHUNK_SIZE_BITS {
state.push_chunk(&[0u8; N]);
} else {
state.push(false);
}
}
}
fn pop_to_length(&self, state: &mut Prunable<N>, target_len: u64) {
while state.len() > target_len {
let excess = state.len() - target_len;
let next_bit = state.len() % Prunable::<N>::CHUNK_SIZE_BITS;
if next_bit == 0 && excess >= Prunable::<N>::CHUNK_SIZE_BITS {
state.pop_chunk();
} else {
state.pop();
}
}
}
fn apply_reverse_diff(&self, newer_state: &mut Prunable<N>, diff: &CommitDiff<N>) {
let target_len = diff.len;
let target_pruned = diff.pruned_chunks;
let newer_pruned = newer_state.pruned_chunks();
assert!(
target_pruned <= newer_pruned,
"invariant violation: target_pruned ({target_pruned}) > newer_pruned ({newer_pruned})"
);
let mut chunks_to_unprune = Vec::with_capacity(newer_pruned - target_pruned);
for chunk_index in (target_pruned..newer_pruned).rev() {
let Some(ChunkDiff::Pruned(chunk)) = diff.chunk_diffs.get(&chunk_index) else {
panic!("chunk {chunk_index} should be Pruned in diff");
};
chunks_to_unprune.push(*chunk);
}
newer_state.unprune_chunks(&chunks_to_unprune);
if newer_state.len() < target_len {
self.push_to_length(newer_state, target_len);
} else if newer_state.len() > target_len {
self.pop_to_length(newer_state, target_len);
}
for (&chunk_index, change) in diff
.chunk_diffs
.iter()
.filter(|(chunk_index, _)| **chunk_index >= newer_pruned)
{
match change {
ChunkDiff::Modified(old_data) | ChunkDiff::Removed(old_data) => {
newer_state.set_chunk_by_index(chunk_index, old_data);
}
ChunkDiff::Added => {
break;
}
ChunkDiff::Pruned(_) => {
panic!("pruned chunk found at unexpected index {chunk_index}")
}
}
}
assert_eq!(newer_state.pruned_chunks(), target_pruned);
assert_eq!(newer_state.len(), target_len);
}
}
impl<const N: usize> Default for CleanBitMap<N> {
fn default() -> Self {
Self::new()
}
}
impl<const N: usize> DirtyBitMap<N> {
#[inline]
pub const fn len(&self) -> u64 {
self.state.projected_len
}
#[inline]
pub const fn is_empty(&self) -> bool {
self.len() == 0
}
#[inline]
pub const fn pruned_chunks(&self) -> usize {
self.state.projected_pruned_chunks
}
pub fn get_bit(&self, bit: u64) -> bool {
assert!(
bit < self.state.projected_len,
"bit offset {bit} out of bounds (len: {})",
self.state.projected_len
);
let chunk_idx = Prunable::<N>::to_chunk_index(bit);
assert!(
chunk_idx >= self.state.projected_pruned_chunks,
"cannot get bit {bit}: chunk {chunk_idx} is pruned (pruned up to chunk {})",
self.state.projected_pruned_chunks
);
let appended_start = self.state.projected_len - self.state.appended_bits.len() as u64;
if bit >= appended_start {
let append_offset = (bit - appended_start) as usize;
return self.state.appended_bits[append_offset];
}
if let Some(&value) = self.state.modified_bits.get(&bit) {
return value;
}
self.current.get_bit(bit)
}
pub fn get_chunk(&self, bit: u64) -> [u8; N] {
assert!(
bit < self.state.projected_len,
"bit offset {bit} out of bounds (len: {})",
self.state.projected_len
);
let chunk_idx = Prunable::<N>::to_chunk_index(bit);
assert!(
chunk_idx >= self.state.projected_pruned_chunks,
"cannot get chunk at bit offset {bit}: chunk {chunk_idx} is pruned (pruned up to chunk {})",
self.state.projected_pruned_chunks
);
let chunk_start_bit = chunk_idx as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
let chunk_end_bit = chunk_start_bit + Prunable::<N>::CHUNK_SIZE_BITS;
let appended_start = self.state.projected_len - self.state.appended_bits.len() as u64;
let chunk_entirely_past_end = chunk_start_bit >= self.state.projected_len;
let chunk_entirely_before_changes =
chunk_end_bit <= appended_start && chunk_end_bit <= self.state.projected_len;
let chunk_needs_reconstruction =
!(chunk_entirely_past_end || chunk_entirely_before_changes)
|| (chunk_start_bit..chunk_end_bit.min(self.state.base_len))
.any(|bit| self.state.modified_bits.contains_key(&bit));
if chunk_needs_reconstruction {
self.reconstruct_modified_chunk(chunk_start_bit)
} else {
*self.current.get_chunk_containing(bit)
}
}
fn reconstruct_modified_chunk(&self, chunk_start: u64) -> [u8; N] {
let mut chunk = if chunk_start < self.current.len() {
*self.current.get_chunk_containing(chunk_start)
} else {
[0u8; N]
};
let appended_start = self.state.projected_len - self.state.appended_bits.len() as u64;
for bit_in_chunk in 0..Prunable::<N>::CHUNK_SIZE_BITS {
let bit = chunk_start + bit_in_chunk;
let byte_idx = (bit_in_chunk / 8) as usize;
let bit_idx = bit_in_chunk % 8;
let mask = 1u8 << bit_idx;
if bit >= self.state.projected_len {
chunk[byte_idx] &= !mask;
} else if let Some(&value) = self.state.modified_bits.get(&bit) {
if value {
chunk[byte_idx] |= mask;
} else {
chunk[byte_idx] &= !mask;
}
} else if bit >= appended_start {
let append_offset = (bit - appended_start) as usize;
if append_offset < self.state.appended_bits.len() {
let value = self.state.appended_bits[append_offset];
if value {
chunk[byte_idx] |= mask;
} else {
chunk[byte_idx] &= !mask;
}
}
}
}
chunk
}
pub fn set_bit(&mut self, bit: u64, value: bool) -> &mut Self {
assert!(
bit < self.state.projected_len,
"cannot set bit {bit}: out of bounds (len: {})",
self.state.projected_len
);
let chunk_idx = Prunable::<N>::to_chunk_index(bit);
assert!(
chunk_idx >= self.state.projected_pruned_chunks,
"cannot set bit {bit}: chunk {chunk_idx} is pruned (pruned up to chunk {})",
self.state.projected_pruned_chunks
);
let appended_start = self.state.projected_len - self.state.appended_bits.len() as u64;
if bit >= appended_start {
let append_offset = (bit - appended_start) as usize;
self.state.appended_bits[append_offset] = value;
} else {
self.state.modified_bits.insert(bit, value);
}
self
}
pub fn push(&mut self, bit: bool) -> &mut Self {
self.state.appended_bits.push(bit);
self.state.projected_len += 1;
self
}
pub fn push_byte(&mut self, byte: u8) -> &mut Self {
for i in 0..8 {
let bit = (byte >> i) & 1 == 1;
self.push(bit);
}
self
}
pub fn push_chunk(&mut self, chunk: &[u8; N]) -> &mut Self {
for byte in chunk {
self.push_byte(*byte);
}
self
}
pub fn pop(&mut self) -> bool {
assert!(self.state.projected_len > 0, "cannot pop from empty bitmap");
let old_projected_len = self.state.projected_len;
self.state.projected_len -= 1;
let bit = self.state.projected_len;
let appended_start = old_projected_len - self.state.appended_bits.len() as u64;
if bit >= appended_start {
self.state.appended_bits.pop().unwrap()
} else {
if let Some(&modified_value) = self.state.modified_bits.get(&bit) {
self.state.modified_bits.remove(&bit);
modified_value
} else {
self.current.get_bit(bit)
}
}
}
pub fn prune_to_bit(&mut self, bit: u64) -> &mut Self {
assert!(
bit <= self.state.projected_len,
"cannot prune to bit {bit}: beyond projected length ({})",
self.state.projected_len
);
let chunk_num = Prunable::<N>::to_chunk_index(bit);
if chunk_num <= self.state.projected_pruned_chunks {
return self; }
let current_pruned = self.current.pruned_chunks();
for chunk_idx in self.state.projected_pruned_chunks..chunk_num {
if self.state.chunks_to_prune.contains_key(&chunk_idx) {
continue; }
assert!(
chunk_idx >= current_pruned,
"attempting to prune chunk {chunk_idx} which is already pruned (current pruned_chunks={current_pruned})",
);
let chunk_data = if chunk_idx < self.current.chunks_len() {
*self.current.get_chunk(chunk_idx)
} else {
let chunk_start_bit = chunk_idx as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
let appended_start =
self.state.projected_len - self.state.appended_bits.len() as u64;
let mut chunk = [0u8; N];
for bit_in_chunk in 0..Prunable::<N>::CHUNK_SIZE_BITS {
let bit = chunk_start_bit + bit_in_chunk;
if bit >= self.state.projected_len {
break;
}
if bit >= appended_start {
let append_idx = (bit - appended_start) as usize;
if append_idx < self.state.appended_bits.len()
&& self.state.appended_bits[append_idx]
{
let byte_idx = (bit_in_chunk / 8) as usize;
let bit_idx = bit_in_chunk % 8;
chunk[byte_idx] |= 1u8 << bit_idx;
}
}
}
chunk
};
self.state.chunks_to_prune.insert(chunk_idx, chunk_data);
}
self.state.projected_pruned_chunks = chunk_num;
self
}
pub fn commit(mut self, commit_number: u64) -> Result<CleanBitMap<N>, Error> {
if commit_number == u64::MAX {
return Err(Error::ReservedCommitNumber);
}
if let Some(&max_commit) = self.commits.keys().next_back() {
if commit_number <= max_commit {
return Err(Error::NonMonotonicCommit {
previous: max_commit,
attempted: commit_number,
});
}
}
let reverse_diff = self.build_reverse_diff();
let target_len_before_appends =
self.state.projected_len - self.state.appended_bits.len() as u64;
while self.current.len() > target_len_before_appends {
self.current.pop();
}
for &bit in &self.state.appended_bits {
self.current.push(bit);
}
assert_eq!(self.current.len(), self.state.projected_len);
for (&bit, &value) in &self.state.modified_bits {
self.current.set_bit(bit, value);
}
if self.state.projected_pruned_chunks > self.state.base_pruned_chunks {
let prune_to_bit =
self.state.projected_pruned_chunks as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
self.current.prune_to_bit(prune_to_bit);
}
self.commits.insert(commit_number, reverse_diff);
Ok(CleanBitMap {
current: self.current,
commits: self.commits,
state: Clean,
})
}
pub fn abort(self) -> CleanBitMap<N> {
CleanBitMap {
current: self.current,
commits: self.commits,
state: Clean,
}
}
fn build_reverse_diff(&self) -> CommitDiff<N> {
let mut changes = BTreeMap::new();
self.capture_modified_chunks(&mut changes);
self.capture_appended_chunks(&mut changes);
self.capture_popped_chunks(&mut changes);
self.capture_pruned_chunks(&mut changes);
CommitDiff {
len: self.state.base_len,
pruned_chunks: self.state.base_pruned_chunks,
chunk_diffs: changes,
}
}
fn capture_modified_chunks(&self, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
for &bit in self.state.modified_bits.keys() {
let chunk_idx = Prunable::<N>::to_chunk_index(bit);
changes.entry(chunk_idx).or_insert_with(|| {
let old_chunk = self
.get_chunk_from_current(chunk_idx)
.expect("chunk must exist for modified bit");
ChunkDiff::Modified(old_chunk)
});
}
}
fn capture_appended_chunks(&self, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
if self.state.appended_bits.is_empty() {
return;
}
let append_start_bit = self.state.projected_len - self.state.appended_bits.len() as u64;
let start_chunk = Prunable::<N>::to_chunk_index(append_start_bit);
let end_chunk = Prunable::<N>::to_chunk_index(self.state.projected_len.saturating_sub(1));
for chunk_idx in start_chunk..=end_chunk {
changes.entry(chunk_idx).or_insert_with(|| {
self.get_chunk_from_current(chunk_idx).map_or(
ChunkDiff::Added,
ChunkDiff::Modified,
)
});
}
}
fn capture_popped_chunks(&self, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
if self.state.projected_len >= self.state.base_len || self.state.base_len == 0 {
return; }
let old_last_chunk = Prunable::<N>::to_chunk_index(self.state.base_len - 1);
let new_last_chunk = if self.state.projected_len > 0 {
Prunable::<N>::to_chunk_index(self.state.projected_len - 1)
} else {
0
};
let start_chunk = self.state.base_pruned_chunks.max(new_last_chunk);
for chunk_idx in start_chunk..=old_last_chunk {
changes.entry(chunk_idx).or_insert_with(|| {
let old_chunk = self
.get_chunk_from_current(chunk_idx)
.expect("chunk must exist in base bitmap for popped bits");
let chunk_start_bit = chunk_idx as u64 * Prunable::<N>::CHUNK_SIZE_BITS;
if self.state.projected_len > chunk_start_bit {
ChunkDiff::Modified(old_chunk)
} else {
ChunkDiff::Removed(old_chunk)
}
});
}
}
fn capture_pruned_chunks(&self, changes: &mut BTreeMap<usize, ChunkDiff<N>>) {
for (&chunk_idx, &chunk_data) in &self.state.chunks_to_prune {
changes.insert(chunk_idx, ChunkDiff::Pruned(chunk_data));
}
}
fn get_chunk_from_current(&self, chunk_idx: usize) -> Option<[u8; N]> {
let current_pruned = self.current.pruned_chunks();
if chunk_idx >= current_pruned && chunk_idx < self.current.chunks_len() {
return Some(*self.current.get_chunk(chunk_idx));
}
None
}
}