#![allow(
clippy::doc_markdown,
clippy::many_single_char_names,
clippy::type_complexity
)]
use alloc::sync::Arc;
use alloc::vec::Vec;
const ORDER: usize = 8;
const MAX_ENTRIES: usize = ORDER - 1; const MAX_CHILDREN: usize = ORDER;
#[derive(Debug)]
enum BNode<K, V> {
Leaf {
entries: Vec<(K, V)>,
},
Internal {
entries: Vec<(K, V)>,
children: Vec<Arc<BNode<K, V>>>,
},
}
impl<K: Clone, V: Clone> Clone for BNode<K, V> {
fn clone(&self) -> Self {
match self {
Self::Leaf { entries } => Self::Leaf {
entries: entries.clone(),
},
Self::Internal { entries, children } => Self::Internal {
entries: entries.clone(),
children: children.clone(),
},
}
}
}
#[derive(Debug)]
pub struct PersistentBTreeMap<K, V> {
root: Arc<BNode<K, V>>,
len: usize,
}
impl<K, V> Default for PersistentBTreeMap<K, V> {
fn default() -> Self {
Self::new()
}
}
impl<K, V> Clone for PersistentBTreeMap<K, V> {
fn clone(&self) -> Self {
Self {
root: self.root.clone(),
len: self.len,
}
}
}
impl<K: PartialEq, V: PartialEq> PartialEq for PersistentBTreeMap<K, V>
where
K: Ord,
{
fn eq(&self, other: &Self) -> bool {
self.len == other.len && self.iter().eq(other.iter())
}
}
impl<K: Eq + Ord, V: Eq> Eq for PersistentBTreeMap<K, V> {}
impl<K, V> PersistentBTreeMap<K, V> {
#[must_use]
pub fn new() -> Self {
Self {
root: Arc::new(BNode::Leaf {
entries: Vec::new(),
}),
len: 0,
}
}
#[must_use]
pub const fn len(&self) -> usize {
self.len
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.len == 0
}
}
impl<K: Ord, V> PersistentBTreeMap<K, V> {
pub fn get(&self, key: &K) -> Option<&V> {
let mut node: &Arc<BNode<K, V>> = &self.root;
loop {
match &**node {
BNode::Leaf { entries } => {
return entries
.binary_search_by(|(k, _)| k.cmp(key))
.ok()
.map(|i| &entries[i].1);
}
BNode::Internal { entries, children } => {
match entries.binary_search_by(|(k, _)| k.cmp(key)) {
Ok(i) => return Some(&entries[i].1),
Err(i) => {
node = &children[i];
}
}
}
}
}
}
pub fn iter(&self) -> Iter<'_, K, V> {
let mut stack: Vec<(&Arc<BNode<K, V>>, usize)> = Vec::with_capacity(8);
stack.push((&self.root, 0));
Iter { stack }
}
}
impl<K: Ord + Clone, V: Clone> PersistentBTreeMap<K, V> {
#[must_use]
pub fn insert(&self, key: K, value: V) -> (Self, Option<V>) {
let (new_left, split, prev_v) = insert_helper(&self.root, key, value);
let new_root = if let Some((right, median)) = split {
Arc::new(BNode::Internal {
entries: alloc::vec![median],
children: alloc::vec![new_left, right],
})
} else {
new_left
};
let new_len = if prev_v.is_none() {
self.len + 1
} else {
self.len
};
(
Self {
root: new_root,
len: new_len,
},
prev_v,
)
}
pub fn insert_mut(&mut self, key: K, value: V) -> Option<V> {
let (split, prev_v) = insert_transient_helper(&mut self.root, key, value);
if let Some((right, median)) = split {
let old_root = core::mem::replace(
&mut self.root,
Arc::new(BNode::Leaf {
entries: Vec::new(),
}),
);
self.root = Arc::new(BNode::Internal {
entries: alloc::vec![median],
children: alloc::vec![old_root, right],
});
}
if prev_v.is_none() {
self.len += 1;
}
prev_v
}
}
fn insert_transient_helper<K: Ord + Clone, V: Clone>(
node: &mut Arc<BNode<K, V>>,
k: K,
v: V,
) -> (Option<(Arc<BNode<K, V>>, (K, V))>, Option<V>) {
let inner = Arc::make_mut(node);
match inner {
BNode::Leaf { entries } => {
let pos = entries.binary_search_by(|(ek, _)| ek.cmp(&k));
let prev_v = match pos {
Ok(idx) => Some(core::mem::replace(&mut entries[idx].1, v)),
Err(idx) => {
entries.insert(idx, (k, v));
None
}
};
if entries.len() <= MAX_ENTRIES {
return (None, prev_v);
}
let mid = entries.len() / 2;
let right_entries = entries.split_off(mid + 1);
let median = entries.pop().expect("mid was in-bounds");
let right = Arc::new(BNode::Leaf {
entries: right_entries,
});
(Some((right, median)), prev_v)
}
BNode::Internal { entries, children } => {
let pos = entries.binary_search_by(|(ek, _)| ek.cmp(&k));
match pos {
Ok(idx) => {
let prev_v = core::mem::replace(&mut entries[idx].1, v);
(None, Some(prev_v))
}
Err(idx) => {
let (split, prev_v) = insert_transient_helper(&mut children[idx], k, v);
if let Some((right_sibling, median)) = split {
entries.insert(idx, median);
children.insert(idx + 1, right_sibling);
}
if children.len() <= MAX_CHILDREN {
return (None, prev_v);
}
let mid = entries.len() / 2;
let right_entries = entries.split_off(mid + 1);
let median = entries.pop().expect("mid was in-bounds");
let right_children = children.split_off(mid + 1);
let right = Arc::new(BNode::Internal {
entries: right_entries,
children: right_children,
});
(Some((right, median)), prev_v)
}
}
}
}
}
fn insert_helper<K: Ord + Clone, V: Clone>(
node: &Arc<BNode<K, V>>,
k: K,
v: V,
) -> (
Arc<BNode<K, V>>,
Option<(Arc<BNode<K, V>>, (K, V))>,
Option<V>,
) {
match &**node {
BNode::Leaf { entries } => {
let pos = entries.binary_search_by(|(ek, _)| ek.cmp(&k));
let mut new_entries = entries.clone();
let prev_v = match pos {
Ok(idx) => Some(core::mem::replace(&mut new_entries[idx].1, v)),
Err(idx) => {
new_entries.insert(idx, (k, v));
None
}
};
if new_entries.len() <= MAX_ENTRIES {
return (
Arc::new(BNode::Leaf {
entries: new_entries,
}),
None,
prev_v,
);
}
let mid = new_entries.len() / 2; let right_entries = new_entries.split_off(mid + 1);
let median = new_entries.pop().expect("mid was in-bounds");
let left = Arc::new(BNode::Leaf {
entries: new_entries,
});
let right = Arc::new(BNode::Leaf {
entries: right_entries,
});
(left, Some((right, median)), prev_v)
}
BNode::Internal { entries, children } => {
let pos = entries.binary_search_by(|(ek, _)| ek.cmp(&k));
match pos {
Ok(idx) => {
let mut new_entries = entries.clone();
let prev_v = core::mem::replace(&mut new_entries[idx].1, v);
(
Arc::new(BNode::Internal {
entries: new_entries,
children: children.clone(),
}),
None,
Some(prev_v),
)
}
Err(idx) => {
let (new_child, split, prev_v) = insert_helper(&children[idx], k, v);
let mut new_entries = entries.clone();
let mut new_children = children.clone();
new_children[idx] = new_child;
if let Some((right_sibling, median)) = split {
new_entries.insert(idx, median);
new_children.insert(idx + 1, right_sibling);
}
if new_children.len() <= MAX_CHILDREN {
return (
Arc::new(BNode::Internal {
entries: new_entries,
children: new_children,
}),
None,
prev_v,
);
}
let mid = new_entries.len() / 2; let right_entries = new_entries.split_off(mid + 1);
let median = new_entries.pop().expect("mid was in-bounds");
let right_children = new_children.split_off(mid + 1);
let left = Arc::new(BNode::Internal {
entries: new_entries,
children: new_children,
});
let right = Arc::new(BNode::Internal {
entries: right_entries,
children: right_children,
});
(left, Some((right, median)), prev_v)
}
}
}
}
}
#[derive(Debug)]
pub struct Iter<'a, K, V> {
stack: Vec<(&'a Arc<BNode<K, V>>, usize)>,
}
impl<'a, K, V> Iterator for Iter<'a, K, V> {
type Item = (&'a K, &'a V);
fn next(&mut self) -> Option<(&'a K, &'a V)> {
loop {
let (node, idx) = *self.stack.last()?;
match &**node {
BNode::Leaf { entries } => {
if idx < entries.len() {
let (k, v) = &entries[idx];
self.stack.last_mut().unwrap().1 = idx + 1;
return Some((k, v));
}
self.stack.pop();
}
BNode::Internal { entries, children } => {
let phase = idx & 1;
let slot = idx >> 1;
if phase == 0 {
if slot < children.len() {
self.stack.last_mut().unwrap().1 = idx + 1;
self.stack.push((&children[slot], 0));
continue;
}
self.stack.pop();
} else {
if slot < entries.len() {
let (k, v) = &entries[slot];
self.stack.last_mut().unwrap().1 = idx + 1;
return Some((k, v));
}
self.stack.pop();
}
}
}
}
}
}
impl<'a, K: Ord, V> IntoIterator for &'a PersistentBTreeMap<K, V> {
type Item = (&'a K, &'a V);
type IntoIter = Iter<'a, K, V>;
fn into_iter(self) -> Self::IntoIter {
self.iter()
}
}
#[cfg(test)]
#[allow(
clippy::cast_possible_truncation,
clippy::cast_possible_wrap,
clippy::cast_sign_loss,
clippy::cast_lossless,
clippy::needless_range_loop,
clippy::items_after_statements,
clippy::manual_range_patterns,
clippy::unreadable_literal,
clippy::similar_names
)]
mod tests {
use super::*;
use alloc::collections::BTreeMap;
#[test]
fn empty_map_is_empty() {
let pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
assert_eq!(pb.len(), 0);
assert!(pb.is_empty());
assert!(pb.get(&42).is_none());
}
#[test]
fn insert_single_into_empty_works() {
let (pb, prev) = PersistentBTreeMap::<i64, i64>::new().insert(1, 100);
assert_eq!(prev, None);
assert_eq!(pb.len(), 1);
assert_eq!(pb.get(&1), Some(&100));
assert_eq!(pb.get(&2), None);
}
#[test]
fn insert_replace_returns_prev_keeps_len() {
let (pb, p1) = PersistentBTreeMap::<i64, i64>::new().insert(7, 10);
assert_eq!(p1, None);
let (pb, p2) = pb.insert(7, 99);
assert_eq!(p2, Some(10));
assert_eq!(pb.len(), 1);
assert_eq!(pb.get(&7), Some(&99));
}
#[test]
fn insert_crosses_leaf_split_boundary() {
let mut pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for i in 0..20_i64 {
pb = pb.insert(i, i * 7).0;
}
for i in 0..20_i64 {
assert_eq!(pb.get(&i), Some(&(i * 7)));
}
assert!(pb.get(&20).is_none());
assert_eq!(pb.len(), 20);
}
#[test]
fn insert_grows_through_multiple_internal_splits() {
let mut pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for i in 0..200_i64 {
pb = pb.insert(i, i * 11).0;
}
for i in 0..200_i64 {
assert_eq!(pb.get(&i), Some(&(i * 11)));
}
assert_eq!(pb.len(), 200);
}
#[test]
fn clone_then_insert_preserves_original() {
let mut a: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for i in 0..100_i64 {
a = a.insert(i, i).0;
}
let b = a.clone();
let (b, _) = b.insert(999, 999);
assert_eq!(a.len(), 100);
assert!(a.get(&999).is_none());
assert_eq!(b.len(), 101);
assert_eq!(b.get(&999), Some(&999));
for i in 0..100_i64 {
assert_eq!(a.get(&i), Some(&i), "A drift at {i}");
assert_eq!(b.get(&i), Some(&i), "B drift at {i}");
}
}
#[test]
fn iter_yields_sorted_order() {
let mut pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for &k in &[7_i64, 3, 11, 1, 9, 5, 14, 2, 8, 12, 4, 6, 10, 13] {
pb = pb.insert(k, k * 2).0;
}
let collected: Vec<(i64, i64)> = pb.iter().map(|(k, v)| (*k, *v)).collect();
let expected: Vec<(i64, i64)> = (1..=14).map(|k| (k, k * 2)).collect();
assert_eq!(collected, expected);
}
#[test]
fn iter_handles_taller_tree() {
let mut pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for i in 0..500_i64 {
pb = pb.insert(i, i).0;
}
let collected: Vec<i64> = pb.iter().map(|(k, _)| *k).collect();
let expected: Vec<i64> = (0..500).collect();
assert_eq!(collected, expected);
}
struct Splitmix(u64);
impl Splitmix {
fn new(seed: u64) -> Self {
Self(seed)
}
fn next(&mut self) -> u64 {
self.0 = self.0.wrapping_add(0x9E37_79B9_7F4A_7C15);
let mut x = self.0;
x = (x ^ (x >> 30)).wrapping_mul(0xBF58_476D_1CE4_E5B9);
x = (x ^ (x >> 27)).wrapping_mul(0x94D0_49BB_1331_11EB);
x ^ (x >> 31)
}
}
#[test]
fn fuzz_oracle_against_std_btreemap() {
let mut pb: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
let mut oracle: BTreeMap<i64, i64> = BTreeMap::new();
let mut rng = Splitmix::new(0xC0FFEE_u64);
const STEPS: usize = 100_000;
const KEY_RANGE: i64 = 4096;
for step in 0..STEPS {
let op = rng.next() % 3; let key = (rng.next() as i64) % KEY_RANGE;
match op {
0 | 1 => {
let val = rng.next() as i64;
let (new_pb, prev_pb) = pb.insert(key, val);
let prev_oracle = oracle.insert(key, val);
assert_eq!(prev_pb, prev_oracle, "prev drift @ step {step}, key {key}");
pb = new_pb;
assert_eq!(pb.len(), oracle.len(), "len drift @ step {step}");
}
2 => {
let pb_v = pb.get(&key).copied();
let oracle_v = oracle.get(&key).copied();
assert_eq!(pb_v, oracle_v, "get drift @ step {step}, key {key}");
}
_ => unreachable!(),
}
}
for (k, v) in &oracle {
assert_eq!(pb.get(k), Some(v), "final drift at key {k}");
}
let pb_collected: Vec<(i64, i64)> = pb.iter().map(|(k, v)| (*k, *v)).collect();
let oracle_collected: Vec<(i64, i64)> = oracle.iter().map(|(k, v)| (*k, *v)).collect();
assert_eq!(pb_collected, oracle_collected);
}
#[test]
fn fuzz_oracle_clone_isolation() {
let mut a: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
let mut oracle_a: BTreeMap<i64, i64> = BTreeMap::new();
let mut rng = Splitmix::new(0xDECAFBAD_u64);
for _ in 0..1_000 {
let k = (rng.next() as i64) % 1000;
let v = rng.next() as i64;
a = a.insert(k, v).0;
oracle_a.insert(k, v);
}
let mut b = a.clone();
let mut oracle_b = oracle_a.clone();
let mut c = a.clone();
let mut oracle_c = oracle_a.clone();
for _ in 0..500 {
let k = (rng.next() as i64) % 2000;
let v = rng.next() as i64;
b = b.insert(k, v).0;
oracle_b.insert(k, v);
}
for _ in 0..300 {
let k = (rng.next() as i64) % 500;
let v = rng.next() as i64;
c = c.insert(k, v).0;
oracle_c.insert(k, v);
}
for (k, v) in &oracle_a {
assert_eq!(a.get(k), Some(v), "A drift at {k}");
}
for (k, v) in &oracle_b {
assert_eq!(b.get(k), Some(v), "B drift at {k}");
}
for (k, v) in &oracle_c {
assert_eq!(c.get(k), Some(v), "C drift at {k}");
}
}
#[test]
fn partial_eq_compares_by_elements() {
let mut a: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
let mut b: PersistentBTreeMap<i64, i64> = PersistentBTreeMap::new();
for &k in &[5_i64, 2, 8, 1, 7, 3, 6, 4] {
a = a.insert(k, k * 10).0;
}
for &k in &[1_i64, 2, 3, 4, 5, 6, 7, 8] {
b = b.insert(k, k * 10).0;
}
assert_eq!(a, b);
let (a, _) = a.insert(9, 90);
assert_ne!(a, b);
}
}