use std::sync::{Arc, Mutex};
pub const B: usize = 8;
const MAX_KEYS: usize = 2 * B - 1;
type NodeArc<K, V> = Arc<Mutex<Node<K, V>>>;
enum Node<K, V> {
Leaf {
keys: Vec<K>,
values: Vec<V>,
next: Option<NodeArc<K, V>>,
},
Internal {
keys: Vec<K>,
children: Vec<NodeArc<K, V>>,
},
}
impl<K: Clone, V: Clone> Node<K, V> {
fn new_leaf() -> Self {
Node::Leaf {
keys: Vec::with_capacity(MAX_KEYS),
values: Vec::with_capacity(MAX_KEYS),
next: None,
}
}
fn key_count(&self) -> usize {
match self {
Node::Leaf { keys, .. } => keys.len(),
Node::Internal { keys, .. } => keys.len(),
}
}
fn is_full(&self) -> bool {
self.key_count() >= MAX_KEYS
}
}
struct SplitResult<K, V> {
median: K,
right: NodeArc<K, V>,
}
fn split_leaf<K: Clone, V: Clone>(node: &mut Node<K, V>) -> SplitResult<K, V> {
if let Node::Leaf { keys, values, next } = node {
let mid = keys.len() / 2;
let r_keys: Vec<K> = keys.drain(mid..).collect();
let r_values: Vec<V> = values.drain(mid..).collect();
let median = r_keys[0].clone();
let old_next = next.take();
let right_node = Node::Leaf {
keys: r_keys,
values: r_values,
next: old_next,
};
let right_arc = Arc::new(Mutex::new(right_node));
*next = Some(Arc::clone(&right_arc));
SplitResult {
median,
right: right_arc,
}
} else {
unreachable!("split_leaf on non-leaf")
}
}
fn split_internal<K: Clone, V: Clone>(node: &mut Node<K, V>) -> SplitResult<K, V> {
if let Node::Internal { keys, children } = node {
let mid = keys.len() / 2;
let median = keys[mid].clone();
let r_keys: Vec<K> = keys.drain(mid + 1..).collect();
keys.truncate(mid); let r_children: Vec<NodeArc<K, V>> = children.drain(mid + 1..).collect();
let right_node = Node::Internal {
keys: r_keys,
children: r_children,
};
SplitResult {
median,
right: Arc::new(Mutex::new(right_node)),
}
} else {
unreachable!("split_internal on non-internal")
}
}
pub struct ConcurrentBTree<K, V> {
root: NodeArc<K, V>,
}
impl<K: Ord + Clone, V: Clone> ConcurrentBTree<K, V> {
pub fn new() -> Self {
ConcurrentBTree {
root: Arc::new(Mutex::new(Node::new_leaf())),
}
}
pub fn lookup(&self, key: &K) -> Option<V> {
let mut cur = Arc::clone(&self.root);
loop {
let next: Result<Option<NodeArc<K, V>>, V> = {
let g = cur.lock().ok()?;
match &*g {
Node::Leaf { keys, values, .. } => {
return match keys.binary_search(key) {
Ok(i) => Some(values[i].clone()),
Err(_) => None,
};
}
Node::Internal { keys, children } => {
let idx = upper_bound(keys, key);
Ok(Some(Arc::clone(&children[idx])))
}
}
};
cur = next.ok()??;
}
}
pub fn insert(&self, key: K, value: V) {
{
let root_full = self.root.lock().map(|g| g.is_full()).unwrap_or(false);
if root_full {
let mut root_g = match self.root.lock() {
Ok(g) => g,
Err(_) => return,
};
if root_g.is_full() {
let is_leaf = matches!(*root_g, Node::Leaf { .. });
let SplitResult { median, right } = if is_leaf {
split_leaf(&mut root_g)
} else {
split_internal(&mut root_g)
};
let left_data = std::mem::replace(&mut *root_g, Node::new_leaf());
let left_arc: NodeArc<K, V> = Arc::new(Mutex::new(left_data));
*root_g = Node::Internal {
keys: vec![median],
children: vec![left_arc, right],
};
}
}
}
insert_non_full(&self.root, key, value);
}
pub fn delete(&self, key: &K) -> Option<V> {
delete_rec(&self.root, key)
}
pub fn range_scan(&self, lo: &K, hi: &K) -> Vec<(K, V)> {
let mut result = Vec::new();
let leaf = match find_leftmost_leaf(&self.root, lo) {
Some(a) => a,
None => return result,
};
let mut cur = leaf;
loop {
let nxt: Option<NodeArc<K, V>> = {
let g = match cur.lock() {
Ok(g) => g,
Err(_) => break,
};
match &*g {
Node::Leaf { keys, values, next } => {
for (k, v) in keys.iter().zip(values.iter()) {
if k > hi {
return result;
}
if k >= lo {
result.push((k.clone(), v.clone()));
}
}
next.clone()
}
_ => break,
}
};
match nxt {
Some(n) => cur = n,
None => break,
}
}
result
}
}
fn upper_bound<K: Ord>(keys: &[K], key: &K) -> usize {
keys.partition_point(|k| k <= key)
}
fn insert_non_full<K: Ord + Clone, V: Clone>(node_arc: &NodeArc<K, V>, key: K, value: V) {
let is_leaf = node_arc
.lock()
.map(|g| matches!(*g, Node::Leaf { .. }))
.unwrap_or(true);
if is_leaf {
if let Ok(mut g) = node_arc.lock() {
if let Node::Leaf { keys, values, .. } = &mut *g {
let pos = keys.partition_point(|k| k < &key);
if pos < keys.len() && keys[pos] == key {
values[pos] = value;
} else {
keys.insert(pos, key);
values.insert(pos, value);
}
}
}
return;
}
let (child_idx, child_arc) = {
let g = match node_arc.lock() {
Ok(g) => g,
Err(_) => return,
};
if let Node::Internal { keys, children } = &*g {
let idx = upper_bound(keys, &key);
let child_idx = idx.min(children.len() - 1);
(child_idx, Arc::clone(&children[child_idx]))
} else {
return;
}
};
let child_full = child_arc.lock().map(|g| g.is_full()).unwrap_or(false);
if child_full {
let SplitResult { median, right } = {
let mut cg = match child_arc.lock() {
Ok(g) => g,
Err(_) => return,
};
let is_leaf_child = matches!(*cg, Node::Leaf { .. });
if is_leaf_child {
split_leaf(&mut cg)
} else {
split_internal(&mut cg)
}
};
if let Ok(mut pg) = node_arc.lock() {
if let Node::Internal { keys, children } = &mut *pg {
keys.insert(child_idx, median.clone());
children.insert(child_idx + 1, right);
}
}
let target = {
let g = match node_arc.lock() {
Ok(g) => g,
Err(_) => return,
};
if let Node::Internal { keys, children } = &*g {
let idx = upper_bound(keys, &key);
let idx = idx.min(children.len() - 1);
Arc::clone(&children[idx])
} else {
return;
}
};
insert_non_full(&target, key, value);
} else {
insert_non_full(&child_arc, key, value);
}
}
fn delete_rec<K: Ord + Clone, V: Clone>(node_arc: &NodeArc<K, V>, key: &K) -> Option<V> {
let is_leaf = node_arc
.lock()
.map(|g| matches!(*g, Node::Leaf { .. }))
.unwrap_or(true);
if is_leaf {
let mut g = node_arc.lock().ok()?;
if let Node::Leaf { keys, values, .. } = &mut *g {
let pos = keys.binary_search(key).ok()?;
keys.remove(pos);
return Some(values.remove(pos));
}
return None;
}
let child = {
let g = node_arc.lock().ok()?;
if let Node::Internal { keys, children } = &*g {
let idx = upper_bound(keys, key);
let idx = idx.min(children.len() - 1);
Arc::clone(&children[idx])
} else {
return None;
}
};
delete_rec(&child, key)
}
fn find_leftmost_leaf<K: Ord + Clone, V: Clone>(
node_arc: &NodeArc<K, V>,
key: &K,
) -> Option<NodeArc<K, V>> {
let is_leaf = node_arc
.lock()
.map(|g| matches!(*g, Node::Leaf { .. }))
.unwrap_or(true);
if is_leaf {
return Some(Arc::clone(node_arc));
}
let child = {
let g = node_arc.lock().ok()?;
if let Node::Internal { keys, children } = &*g {
let idx = upper_bound(keys, key);
let idx = idx.saturating_sub(1).min(children.len() - 1);
Arc::clone(&children[idx])
} else {
return None;
}
};
find_leftmost_leaf(&child, key)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_insert_lookup() {
let tree = ConcurrentBTree::<i32, i32>::new();
for i in 0..50 {
tree.insert(i, i * 2);
}
for i in 0..50 {
assert_eq!(tree.lookup(&i), Some(i * 2), "key {i}");
}
assert_eq!(tree.lookup(&100), None);
}
#[test]
fn test_range_scan() {
let tree = ConcurrentBTree::<i32, i32>::new();
for i in 0..30 {
tree.insert(i, i);
}
let result = tree.range_scan(&10, &20);
let keys: Vec<i32> = result.iter().map(|(k, _)| *k).collect();
assert_eq!(keys, (10..=20).collect::<Vec<_>>());
}
#[test]
fn test_split_correct() {
let tree = ConcurrentBTree::<i32, i32>::new();
for i in (0..100).rev() {
tree.insert(i, i);
}
for i in 0..100 {
assert_eq!(tree.lookup(&i), Some(i), "key {i} missing");
}
}
#[test]
fn test_delete() {
let tree = ConcurrentBTree::<i32, i32>::new();
for i in 0..20 {
tree.insert(i, i * 10);
}
let v = tree.delete(&5);
assert_eq!(v, Some(50));
assert_eq!(tree.lookup(&5), None);
assert_eq!(tree.lookup(&6), Some(60));
}
#[test]
fn test_update_existing() {
let tree = ConcurrentBTree::<i32, i32>::new();
tree.insert(1, 100);
tree.insert(1, 200);
assert_eq!(tree.lookup(&1), Some(200));
}
}