use std::collections::HashSet;
use std::fmt::Debug;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::{AtomicPtr, AtomicUsize, Ordering};
use crate::api::{KeyEncodingId, TreeId};
use crate::bplustree::NodeView;
use crate::bplustree::node_view::NodeViewError;
use crate::database::metadata::Metadata;
use crate::keyfmt::KeyFormat;
use crate::page::PageError;
use crate::storage::epoch::COMMIT_COUNT;
use crate::storage::epoch::EpochManager;
use crate::storage::metadata_manager::{MetadataError, MetadataManager};
use crate::storage::{HasEpoch, NodeStorage, PageStorage, StorageError};
use std::result::Result;
use thiserror::Error;
use zerocopy::AsBytes;
pub type NodeId = u64;
pub type PathNode = (NodeId, usize);
#[derive(Debug, Error)]
pub enum DeleteResult<N> {
Deleted(N),
NotFound,
}
pub enum SplitResult<N> {
SplitNodes {
left_node: N,
right_node: N,
split_key: Vec<u8>,
},
}
#[derive(Debug, Error)]
pub enum TreeError {
#[error("bad input: {0}")]
BadInput(String),
#[error("tree invariant violated: {0}")]
Invariant(&'static str),
#[error("node not found: {0}")]
NodeNotFound(String),
#[error(transparent)]
Storage(#[from] StorageError),
#[error(transparent)]
Metadata(#[from] MetadataError),
#[error(transparent)]
NodeView(#[from] NodeViewError),
#[error(
"entry too large: key ({key_len} bytes) + value ({val_len} bytes) exceeds max {max_len} bytes"
)]
EntryTooLarge {
key_len: usize,
val_len: usize,
max_len: usize,
},
}
const LEAF_PER_ENTRY_OVERHEAD: usize = std::mem::size_of::<u16>() + crate::page::leaf::SLOT_SIZE;
pub const MAX_ENTRY_PAYLOAD: usize = crate::page::leaf::BUFFER_SIZE / 2 - LEAF_PER_ENTRY_OVERHEAD;
#[derive(Debug, Error)]
pub enum CommitError {
#[error("metadata write failed: {0}")]
Metadata(#[from] MetadataError),
#[error("storage error during commit: {0}")]
Storage(#[from] StorageError),
#[error("stale base — rebase and retry")]
RebaseRequired,
#[error("commit aborted (test injection)")]
Injected,
}
pub trait TxnTracker {
fn reclaim(&mut self, node_id: NodeId);
fn add_new(&mut self, node_id: NodeId);
fn record_staged_height(&mut self, height: u64);
fn record_staged_size(&mut self, size: u64);
fn staged_size(&self) -> Option<u64>;
fn staged_height(&self) -> Option<u64>;
fn is_dirty(&self, node_id: NodeId) -> bool;
fn mark_dirty(&mut self, node_id: NodeId);
}
#[derive(Debug, Clone)]
pub struct MetadataSnapshot {
pub root_id: NodeId,
pub height: u64,
pub size: u64,
}
#[derive(Debug, Clone)]
pub struct StagedMetadata {
pub root_id: NodeId,
pub height: u64,
pub size: u64,
}
pub struct BaseVersion {
pub committed_ptr: *const Metadata,
}
struct RetiredPtr(*mut Metadata);
unsafe impl Send for RetiredPtr {}
pub struct BPlusTree<S, P>
where
S: NodeStorage + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
id: TreeId,
storage: Arc<S>,
page_storage: Arc<P>,
epoch_mgr: Arc<EpochManager>,
#[allow(dead_code)]
key_encoding: KeyEncodingId,
#[allow(dead_code)]
encoding_version: u16,
key_format: KeyFormat,
meta_a: u64,
meta_b: u64,
max_keys: usize,
min_internal_keys: usize,
min_leaf_keys: usize,
commit_count: AtomicUsize,
committed: AtomicPtr<Metadata>,
retired_meta: Mutex<Vec<RetiredPtr>>,
}
impl<S, P> Drop for BPlusTree<S, P>
where
S: NodeStorage + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
fn drop(&mut self) {
let ptr = self.committed.load(Ordering::Acquire);
if !ptr.is_null() {
unsafe {
drop(Box::from_raw(ptr));
}
}
for rp in self.retired_meta.lock().unwrap().drain(..) {
unsafe {
drop(Box::from_raw(rp.0));
}
}
}
}
#[derive(Default)]
pub struct TransactionTracker {
pub reclaimed: Vec<NodeId>,
pub added: Vec<NodeId>,
pub staged_height: Option<u64>,
pub staged_size: Option<u64>,
pub dirty_pages: HashSet<NodeId>,
}
impl TransactionTracker {
pub fn new() -> Self {
Self {
reclaimed: Vec::new(),
added: Vec::new(),
staged_height: None,
staged_size: None,
dirty_pages: HashSet::new(),
}
}
}
impl TxnTracker for TransactionTracker {
fn reclaim(&mut self, node_id: NodeId) {
self.reclaimed.push(node_id);
}
fn add_new(&mut self, node_id: NodeId) {
self.added.push(node_id);
}
fn record_staged_height(&mut self, height: u64) {
self.staged_height = Some(height);
}
fn record_staged_size(&mut self, size: u64) {
self.staged_size = Some(size);
}
fn staged_size(&self) -> Option<u64> {
self.staged_size
}
fn staged_height(&self) -> Option<u64> {
self.staged_height
}
fn is_dirty(&self, node_id: NodeId) -> bool {
self.dirty_pages.contains(&node_id)
}
fn mark_dirty(&mut self, node_id: NodeId) {
self.dirty_pages.insert(node_id);
}
}
#[derive(Debug)]
pub struct WriteResult {
pub new_root_id: NodeId,
pub reclaimed_nodes: Vec<NodeId>,
pub staged_nodes: Vec<NodeId>,
pub new_height: u64,
pub new_size: u64,
}
pub struct SharedBPlusTree<S, P>
where
S: NodeStorage + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
inner: Arc<BPlusTree<S, P>>,
}
impl<S, P> Clone for SharedBPlusTree<S, P>
where
S: NodeStorage + HasEpoch + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
}
impl<S, P> SharedBPlusTree<S, P>
where
S: NodeStorage + HasEpoch + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
pub fn new(tree: BPlusTree<S, P>) -> Self {
Self {
inner: Arc::new(tree),
}
}
pub fn put_with_root_tracked<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
root_id: NodeId,
collector: &mut TransactionTracker,
) -> Result<WriteResult, TreeError> {
let new_root_id = self.inner.put_inner(key, value, root_id, collector)?;
let write_res = WriteResult {
new_root_id,
reclaimed_nodes: std::mem::take(&mut collector.reclaimed),
staged_nodes: std::mem::take(&mut collector.added),
new_height: collector.staged_height.unwrap_or(self.inner.get_height()),
new_size: collector.staged_size.unwrap_or(self.inner.get_size()),
};
Ok(write_res)
}
pub fn delete_with_root_tracked<K: AsRef<[u8]>>(
&self,
key: &K,
root_id: NodeId,
collector: &mut TransactionTracker,
) -> Result<WriteResult, TreeError> {
let delete_res = self.inner.delete_inner(key, root_id, collector)?;
let DeleteResult::Deleted(new_root_id) = delete_res else {
return Err(TreeError::NodeNotFound(
"key not found for deletion".to_string(),
));
};
let write_res = WriteResult {
new_root_id,
reclaimed_nodes: std::mem::take(&mut collector.reclaimed),
staged_nodes: std::mem::take(&mut collector.added),
new_height: collector.staged_height.unwrap_or(self.inner.get_height()),
new_size: collector.staged_size.unwrap_or(self.inner.get_size()),
};
Ok(write_res)
}
pub fn search<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, TreeError> {
self.inner.get(key)
}
pub fn contains_key<K: AsRef<[u8]>>(&self, key: K) -> Result<bool, TreeError> {
self.inner.contains_key(key)
}
pub fn get_root_id(&self) -> NodeId {
self.inner.get_root_id()
}
pub fn get_height(&self) -> u64 {
self.inner.get_height()
}
pub fn get_size(&self) -> u64 {
self.inner.get_size()
}
pub fn get_snapshot(&self) -> MetadataSnapshot {
self.inner.get_snapshot()
}
pub fn try_commit(
&self,
version: &BaseVersion,
new_metadata: StagedMetadata,
) -> Result<(), CommitError> {
self.inner.try_commit(version, new_metadata)
}
pub fn get_metadata_ptr(&self) -> *const Metadata {
self.inner.committed.load(Ordering::SeqCst)
}
#[allow(clippy::should_implement_trait)]
pub fn clone(&self) -> Self {
Self {
inner: Arc::clone(&self.inner),
}
}
#[cfg(test)]
pub fn search_with_root<K: AsRef<[u8]>>(
&self,
key: &K,
root_id: NodeId,
) -> Result<Option<Vec<u8>>, TreeError> {
self.inner.get_inner(key, root_id)
}
#[cfg(test)]
pub fn put_with_root<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
root_id: NodeId,
) -> Result<WriteResult, TreeError> {
let mut collector = TransactionTracker::new();
self.put_with_root_tracked(key, value, root_id, &mut collector)
}
#[cfg(test)]
pub fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
) -> Result<WriteResult, TreeError> {
let _guard = self.inner.epoch_mgr.pin();
let root_id = self.inner.get_root_id();
self.put_with_root(key, value, root_id)
}
#[cfg(test)]
pub fn delete_with_root<K: AsRef<[u8]>>(
&self,
key: &K,
root_id: NodeId,
) -> Result<WriteResult, TreeError> {
let mut collector = TransactionTracker::new();
self.delete_with_root_tracked(key, root_id, &mut collector)
}
#[cfg(test)]
pub fn insert_with_root<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
root_id: NodeId,
) -> Result<WriteResult, TreeError> {
self.put_with_root(key, value, root_id)
}
#[cfg(test)]
pub fn insert<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
) -> Result<WriteResult, TreeError> {
self.put(key, value)
}
#[cfg(test)]
pub fn get_metadata(&self) -> &Metadata {
unsafe { &*self.inner.committed.load(Ordering::Acquire) }
}
pub fn search_range(
&self,
start: &[u8],
end: Option<&[u8]>,
) -> Result<super::iterator::BPlusTreeIter<'_, S>, TreeError> {
let root_id = self.inner.get_root_id();
super::iterator::BPlusTreeIter::new(
&self.inner.storage,
root_id,
&self.inner.epoch_mgr,
start,
end,
)
}
pub fn epoch_mgr(&self) -> Arc<EpochManager> {
Arc::clone(&self.inner.epoch_mgr)
}
pub fn reclaim_deferred(&self) -> Result<(), TreeError> {
let safe_epoch = self.inner.epoch_mgr.oldest_active();
let reclaimed = self.inner.epoch_mgr.reclaim(safe_epoch);
for pid in reclaimed {
self.inner.storage.free_node(pid)?;
}
Ok(())
}
}
impl<S, P> BPlusTree<S, P>
where
S: NodeStorage + Send + Sync + 'static,
P: PageStorage + Send + Sync + 'static,
{
#[allow(clippy::too_many_arguments)]
pub fn open(
storage: Arc<S>,
page_storage: Arc<P>,
meta: Metadata,
meta_a: u64,
meta_b: u64,
key_format: KeyFormat,
key_encoding: KeyEncodingId,
epoch_mgr: Arc<EpochManager>,
) -> BPlusTree<S, P> {
let id = meta.id;
let order = meta.order;
let md_ptr = Box::into_raw(Box::new(meta));
Self {
id,
storage,
page_storage,
epoch_mgr,
key_encoding,
key_format,
encoding_version: 1,
meta_a,
meta_b,
max_keys: order as usize - 1,
min_internal_keys: (order as usize).div_ceil(2) - 1,
min_leaf_keys: (order as usize - 1).div_ceil(2),
commit_count: AtomicUsize::new(0),
committed: AtomicPtr::new(md_ptr),
retired_meta: Mutex::new(Vec::new()),
}
}
#[allow(dead_code)]
fn read_node_view(&self, id: NodeId) -> Result<Option<NodeView>, TreeError> {
Ok(self.storage.read_node_view(id)?)
}
fn write_node_view(
&self,
node: &NodeView,
tracker: &mut impl TxnTracker,
) -> Result<u64, TreeError> {
let new_id = self.storage.write_node_view(node)?;
tracker.add_new(new_id);
tracker.mark_dirty(new_id);
Ok(new_id)
}
pub fn get_insertion_path<K: AsRef<[u8]>>(
&self,
key: K,
root_id: NodeId,
) -> Result<(Vec<PathNode>, bool), TreeError> {
let mut path = vec![];
let mut current_id = root_id;
loop {
match self.storage.read_node_view(current_id)? {
Some(node) => match &node {
NodeView::Leaf { .. } => {
let mut found = false;
let i = match node.lower_bound(key.as_ref()) {
Ok(i) => {
found = true;
i
}
Err(i) => i,
};
path.push((current_id, i));
return Ok((path, found));
}
NodeView::Internal { .. } => {
let i = match node.lower_bound(key.as_ref()) {
Ok(i) => i + 1,
Err(i) => i,
};
path.push((current_id, i));
let child_id = node.child_ptr_at(i)?;
current_id = child_id;
}
},
None => {
return Err(TreeError::Invariant("node not found while traversing path"));
}
}
}
}
#[cfg(test)]
pub fn put<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
let _guard = self.epoch_mgr.pin();
let root_id = self.get_root_id();
self.put_inner(key, value, root_id, track)
}
pub fn put_inner<K: AsRef<[u8]>, V: AsRef<[u8]>>(
&self,
key: K,
value: V,
root_id: NodeId,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
let key_bytes = key.as_ref();
let val_bytes = value.as_ref();
let payload = key_bytes.len() + val_bytes.len();
if payload > MAX_ENTRY_PAYLOAD {
return Err(TreeError::EntryTooLarge {
key_len: key_bytes.len(),
val_len: val_bytes.len(),
max_len: MAX_ENTRY_PAYLOAD,
});
}
let mut current_root = root_id;
loop {
let (mut path, found) = self.get_insertion_path(key_bytes, current_root)?;
let (leaf_node_id, idx) = path
.pop()
.ok_or(TreeError::Invariant("insertion path is empty"))?;
let mut leaf_node = self.storage.read_node_view(leaf_node_id)?.ok_or_else(|| {
TreeError::NodeNotFound(format!("Leaf node with ID {} not found", leaf_node_id))
})?;
let NodeView::Leaf { .. } = &mut leaf_node else {
return Err(TreeError::Invariant(
"expected leaf node at insertion point",
));
};
if found {
match leaf_node.replace_at(idx, val_bytes) {
Ok(()) => {}
Err(NodeViewError::Page(PageError::PageFull {})) => {
current_root = self.handle_leaf_split(path, leaf_node, track)?;
continue;
}
Err(e) => return Err(e.into()),
}
} else {
match leaf_node.insert_at(idx, key_bytes, val_bytes) {
Ok(()) => {}
Err(NodeViewError::Page(PageError::PageFull {})) => {
current_root = self.handle_leaf_split(path, leaf_node, track)?;
continue;
}
Err(e) => return Err(e.into()),
}
}
let base_size = track.staged_size().unwrap_or_else(|| self.get_size());
if !found {
track.record_staged_size(base_size + 1);
}
let base_height = track.staged_height().unwrap_or_else(|| self.get_height());
track.record_staged_height(base_height);
if leaf_node.keys_len() > self.max_keys {
return self.handle_leaf_split(path, leaf_node, track);
} else {
return self.write_and_propagate(path, &leaf_node, track);
}
}
}
fn handle_leaf_split(
&self,
path: Vec<(NodeId, usize)>,
leaf_node: NodeView,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
let SplitResult::SplitNodes {
left_node,
right_node,
split_key,
} = self.split_leaf_node(leaf_node)?;
let right_id = self.write_node_view(&right_node, track)?;
let left_id = self.write_node_view(&left_node, track)?;
self.propagate_split(path, left_id, right_id, split_key, track)
}
fn split_leaf_node(&self, mut leaf_node: NodeView) -> Result<SplitResult<NodeView>, TreeError> {
if let NodeView::Leaf { .. } = &mut leaf_node {
let mid = leaf_node.keys_len() / 2;
let right_node = leaf_node.split_off(mid)?;
let split_key = right_node.first_key()?;
Ok(SplitResult::SplitNodes {
left_node: leaf_node,
right_node,
split_key,
})
} else {
Err(TreeError::Invariant("expected leaf node for splitting"))
}
}
fn split_internal_node(
&self,
mut internal_node: NodeView,
) -> Result<SplitResult<NodeView>, TreeError> {
if let NodeView::Internal { .. } = &mut internal_node {
let mid = internal_node.keys_len() / 2;
let split_idx = mid + 1;
let right_node = internal_node.split_off(split_idx)?;
let split_key = internal_node.pop_key()?.ok_or(TreeError::Invariant(
"internal node has no mid key for split",
))?;
Ok(SplitResult::SplitNodes {
left_node: internal_node,
right_node,
split_key,
})
} else {
Err(TreeError::Invariant("expected internal node for splitting"))
}
}
fn write_and_propagate(
&self,
path: Vec<(u64, usize)>,
node: &NodeView,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
if let Some(pid) = node.page_id() {
if track.is_dirty(pid) {
self.storage.write_node_view_at_offset(node, pid)?;
let root_id = path.first().map(|(id, _)| *id).unwrap_or(pid);
return Ok(root_id);
}
}
let new_node_id = self.write_node_view(node, track)?;
if path.is_empty() {
Ok(new_node_id)
} else {
let new_root = self.propagate_node_update(path, new_node_id, track)?;
Ok(new_root)
}
}
fn propagate_node_update(
&self,
mut path: Vec<(NodeId, usize)>,
mut updated_child_id: NodeId,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
let root_id = path.first().map(|(id, _)| *id).unwrap_or(updated_child_id);
while let Some((parent_id, insert_pos)) = path.pop() {
let mut parent_node = self.storage.read_node_view(parent_id)?.ok_or_else(|| {
TreeError::NodeNotFound(format!("Parent node {} not found", parent_id))
})?;
let NodeView::Internal { .. } = parent_node else {
return Err(TreeError::Invariant(
"expected internal node while updating parents",
));
};
if insert_pos > parent_node.keys_len() + 1 {
return Err(TreeError::Invariant(
"insert position out of bounds for parent node",
));
}
track.reclaim(parent_node.child_ptr_at(insert_pos)?);
parent_node.replace_child_at(insert_pos, updated_child_id)?;
if track.is_dirty(parent_id) {
self.storage
.write_node_view_at_offset(&parent_node, parent_id)?;
return Ok(root_id);
}
updated_child_id = self.write_node_view(&parent_node, track)?;
}
Ok(updated_child_id)
}
fn propagate_split(
&self,
mut path: Vec<(NodeId, usize)>,
mut left: NodeId,
mut right: NodeId,
mut key: Vec<u8>,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
while let Some((parent_id, insert_pos)) = path.pop() {
let Some(mut node) = self.storage.read_node_view(parent_id)? else {
return Err(TreeError::NodeNotFound(format!(
"Parent node {} not found",
parent_id
)));
};
let NodeView::Internal { .. } = &mut node else {
return Err(TreeError::Invariant(
"expected internal node in propagation path",
));
};
let left_child_prev = node.child_ptr_at(insert_pos)?;
track.reclaim(left_child_prev);
node.replace_child_at(insert_pos, left)?;
match node.insert_separator_at(insert_pos, &key, right) {
Ok(()) => {
if node.keys_len() <= self.max_keys {
return self.write_and_propagate(path, &node, track);
}
let SplitResult::SplitNodes {
left_node,
right_node,
split_key,
} = self.split_internal_node(node)?;
left = self.write_node_view(&left_node, track)?;
right = self.write_node_view(&right_node, track)?;
key = split_key;
}
Err(NodeViewError::Page(PageError::PageFull {})) => {
let SplitResult::SplitNodes {
mut left_node,
mut right_node,
split_key,
} = self.split_internal_node(node)?;
if key.as_slice() < split_key.as_slice() {
let idx = match left_node.lower_bound(&key) {
Ok(i) => i + 1,
Err(i) => i,
};
left_node.insert_separator_at(idx, &key, right)?;
} else {
let idx = match right_node.lower_bound(&key) {
Ok(i) => i + 1,
Err(i) => i,
};
right_node.insert_separator_at(idx, &key, right)?;
}
left = self.write_node_view(&left_node, track)?;
right = self.write_node_view(&right_node, track)?;
key = split_key;
}
Err(e) => return Err(e.into()),
}
}
let mut new_root = NodeView::new_internal(self.key_format);
new_root.write_leftmost_child(left)?;
new_root.insert_separator_at(0, &key, right)?;
let new_root_id = self.write_node_view(&new_root, track)?;
let base_height = track.staged_height().unwrap_or_else(|| self.get_height());
track.record_staged_height(base_height + 1);
Ok(new_root_id)
}
pub fn get<K: AsRef<[u8]>>(&self, key: K) -> Result<Option<Vec<u8>>, TreeError> {
let _guard = self.epoch_mgr.pin();
let root_id = self.get_root_id();
self.get_inner(key, root_id)
}
pub fn contains_key<K: AsRef<[u8]>>(&self, key: K) -> Result<bool, TreeError> {
let _guard = self.epoch_mgr.pin();
let root_id = self.get_root_id();
self.contains_key_inner(key, root_id)
}
pub fn contains_key_inner<K: AsRef<[u8]>>(
&self,
key: K,
root_id: NodeId,
) -> Result<bool, TreeError> {
let mut current_id = root_id;
loop {
match self.storage.read_node_view(current_id)? {
Some(node) => match &node {
NodeView::Leaf { .. } => {
return Ok(node.lower_bound(key.as_ref()).is_ok());
}
NodeView::Internal { .. } => {
let i = match node.lower_bound(key.as_ref()) {
Ok(i) => i + 1,
Err(i) => i,
};
current_id = node.child_ptr_at(i)?;
}
},
None => {
return Err(TreeError::Invariant("node not found during search"));
}
}
}
}
pub fn get_inner<K: AsRef<[u8]>>(
&self,
key: K,
root_id: NodeId,
) -> Result<Option<Vec<u8>>, TreeError> {
let mut current_id = root_id;
loop {
match self.storage.read_node_view(current_id)? {
Some(node) => match &node {
NodeView::Leaf { .. } => {
match node.lower_bound(key.as_ref()) {
Ok(i) => {
let vb = node.value_bytes_at(i)?;
return Ok(Some(vb.to_vec()));
}
Err(_) => {
return Ok(None);
}
};
}
NodeView::Internal { .. } => {
let i = match node.lower_bound(key.as_ref()) {
Ok(i) => i + 1,
Err(i) => i,
};
current_id = node.child_ptr_at(i)?;
}
},
None => {
return Err(TreeError::Invariant("node not found during search"));
}
}
}
}
#[cfg(test)]
pub fn delete<K: AsRef<[u8]>>(
&mut self,
key: K,
root_id: NodeId,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
let _guard = self.epoch_mgr.pin();
let res = self.delete_inner(key, root_id, track)?;
match res {
DeleteResult::NotFound => Err(TreeError::NodeNotFound("key not found".to_string())),
DeleteResult::Deleted(new_root_id) => Ok(new_root_id),
}
}
pub fn delete_inner<K: AsRef<[u8]>>(
&self,
key: K,
root_id: NodeId,
track: &mut impl TxnTracker,
) -> Result<DeleteResult<NodeId>, TreeError> {
let (mut path, found) = self.get_insertion_path(key, root_id)?;
let (leaf_node_id, idx) = path
.pop()
.ok_or(TreeError::Invariant("insertion path is empty"))?;
if !found {
return Ok(DeleteResult::NotFound);
}
let mut leaf_node = self.storage.read_node_view(leaf_node_id)?.ok_or_else(|| {
TreeError::NodeNotFound(format!("Leaf node with ID {} not found", leaf_node_id))
})?;
let NodeView::Leaf { .. } = &mut leaf_node else {
return Err(TreeError::Invariant("expected leaf node at deletion point"));
};
leaf_node.delete_at(idx)?;
let base_size = track.staged_size().unwrap_or_else(|| self.get_size());
track.record_staged_size(base_size.saturating_sub(1));
let base_height = track.staged_height().unwrap_or_else(|| self.get_height());
track.record_staged_height(base_height);
if leaf_node.entry_count() >= self.min_leaf_keys || path.is_empty() {
let new_root_id = self.write_and_propagate(path, &leaf_node, track)?;
return Ok(DeleteResult::Deleted(new_root_id));
}
let new_root_id = self.handle_underflow(path, leaf_node, track)?;
Ok(DeleteResult::Deleted(new_root_id))
}
fn handle_underflow(
&self,
mut path: Vec<(NodeId, usize)>,
mut node: NodeView,
track: &mut impl TxnTracker,
) -> Result<NodeId, TreeError> {
while let Some((parent_id, idx)) = path.pop() {
let Some(mut parent_node) = self.storage.read_node_view(parent_id)? else {
return Err(TreeError::NodeNotFound(format!(
"Parent node {} not found",
parent_id
)));
};
{
let NodeView::Internal { .. } = &mut parent_node else {
return Err(TreeError::Invariant("expected internal node as parent"));
};
if path.is_empty() && parent_node.children_len()? == 1 {
parent_node.child_ptr_at(0)?;
}
if idx > 0 && self.try_borrow_from_left(&mut node, &mut parent_node, idx, track)? {
return self.write_and_propagate(path, &parent_node, track);
}
if (idx < parent_node.keys_len())
&& self.try_borrow_from_right(&mut node, &mut parent_node, idx, track)?
{
return self.write_and_propagate(path, &parent_node, track);
}
let mut merged = None;
if let Some(id) =
self.try_merge_with_left(&mut node, &mut parent_node, idx, track)?
{
merged = Some(id);
} else if let Some(id) =
self.try_merge_with_right(&mut node, &mut parent_node, idx, track)?
{
merged = Some(id);
}
if merged.is_some() {
if parent_node.keys_len() < self.min_internal_keys {
if path.is_empty() {
if parent_node.children_len()? == 1 {
track.reclaim(parent_id);
let bh = track.staged_height().unwrap_or_else(|| self.get_height());
track.record_staged_height(bh.saturating_sub(1));
return Ok(parent_node.child_ptr_at(0)?);
} else {
return self.write_and_propagate(path, &parent_node, track);
}
}
node = parent_node;
continue;
} else {
return self.write_and_propagate(path, &parent_node, track);
}
}
let new_node_id = self.write_node_view(&node, track)?;
let current_child_id = parent_node.child_ptr_at(idx)?;
track.reclaim(current_child_id);
parent_node.replace_child_at(idx, new_node_id)?;
return self.write_and_propagate(path, &parent_node, track);
}
}
Err(TreeError::Invariant("node underflow could not be resolved"))
}
fn try_borrow_from_left(
&self,
node: &mut NodeView,
parent_node: &mut NodeView,
idx: usize,
track: &mut impl TxnTracker,
) -> Result<bool, TreeError> {
if idx == 0 {
return Ok(false);
}
let parent_key_idx = idx - 1;
let left_child_idx = idx - 1;
let left_sibling_id = parent_node.child_ptr_at(left_child_idx)?;
let Some(mut left_sibling) = self.storage.read_node_view(left_sibling_id)? else {
return Err(TreeError::NodeNotFound(
format!("Left sibling id: {} not found", left_sibling_id).to_string(),
));
};
match (&mut left_sibling, &mut *node) {
(NodeView::Leaf { .. }, NodeView::Leaf { .. }) => {
if left_sibling.keys_len() > self.min_leaf_keys {
let borrowed_key = left_sibling.key_bytes_at(left_sibling.keys_len() - 1)?;
let borrowed_value =
left_sibling.value_bytes_at(left_sibling.keys_len() - 1)?;
match node.insert_at(0, borrowed_key, borrowed_value) {
Ok(()) => {}
Err(NodeViewError::Page(PageError::PageFull {})) => return Ok(false),
Err(e) => return Err(e.into()),
}
parent_node.replace_key_at(parent_key_idx, borrowed_key)?;
left_sibling.delete_at(left_sibling.keys_len() - 1)?;
} else {
return Ok(false);
}
}
(NodeView::Internal { .. }, NodeView::Internal { .. }) => {
if left_sibling.keys_len() > self.min_internal_keys {
let borrowed_key = left_sibling.key_bytes_at(left_sibling.keys_len() - 1)?;
let borrowed_child =
left_sibling.child_ptr_at(left_sibling.children_len()? - 1)?;
let separator_key = parent_node.key_bytes_at(parent_key_idx)?;
match node.push_front(separator_key, borrowed_child) {
Ok(()) => {}
Err(NodeViewError::Page(PageError::PageFull {})) => return Ok(false),
Err(e) => return Err(e.into()),
}
parent_node.replace_key_at(parent_key_idx, borrowed_key)?;
left_sibling.delete_at(left_sibling.keys_len() - 1)?;
} else {
return Ok(false);
}
}
_ => {
return Err(TreeError::Invariant("mismatched node types for borrow"));
}
};
let new_node_id = self.write_node_view(node, track)?;
let new_left_node_id = self.write_node_view(&left_sibling, track)?;
track.reclaim(left_sibling_id);
parent_node.replace_child_at(left_child_idx, new_left_node_id)?;
let current_child_id = parent_node.child_ptr_at(idx)?;
track.reclaim(current_child_id);
parent_node.replace_child_at(idx, new_node_id)?;
Ok(true)
}
fn try_borrow_from_right(
&self,
node: &mut NodeView,
parent_node: &mut NodeView,
idx: usize,
track: &mut impl TxnTracker,
) -> Result<bool, TreeError> {
if idx >= parent_node.keys_len() {
return Ok(false);
}
let parent_key_idx = idx;
let right_sibling_id = parent_node.child_ptr_at(idx + 1)?;
let Some(mut right_sibling) = self.storage.read_node_view(right_sibling_id)? else {
return Err(TreeError::NodeNotFound(
format!("Right sibling id: {} not found", right_sibling_id).to_string(),
));
};
match (&mut *node, &mut right_sibling) {
(NodeView::Leaf { .. }, NodeView::Leaf { .. }) => {
if right_sibling.keys_len() > self.min_leaf_keys {
let borrowed_key = right_sibling.key_bytes_at(0)?;
let borrowed_value = right_sibling.value_bytes_at(0)?;
match node.insert_at(
node.keys_len(),
borrowed_key.as_bytes(),
borrowed_value.as_bytes(),
) {
Ok(()) => {}
Err(NodeViewError::Page(PageError::PageFull {})) => return Ok(false),
Err(e) => return Err(e.into()),
}
right_sibling.delete_at(0)?;
let separator_key = right_sibling.key_bytes_at(0)?;
parent_node.replace_key_at(parent_key_idx, separator_key.as_bytes())?;
} else {
return Ok(false);
}
}
(NodeView::Internal { .. }, NodeView::Internal { .. }) => {
if right_sibling.keys_len() > self.min_internal_keys {
let separator_key = parent_node.key_at(parent_key_idx)?;
let right_first_key = right_sibling.delete_key_at(0)?;
parent_node.replace_key_at(parent_key_idx, right_first_key.as_bytes())?;
let borrowed_child = right_sibling.child_ptr_at(0)?;
node.insert_separator_at(node.keys_len(), &separator_key, borrowed_child)?;
right_sibling.delete_child_at(0)?;
} else {
return Ok(false);
}
}
_ => {
return Err(TreeError::Invariant("mismatched node types for borrow"));
}
}
let new_node_id = self.write_node_view(node, track)?;
let new_right_node_id = self.write_node_view(&right_sibling, track)?;
track.reclaim(right_sibling_id);
parent_node.replace_child_at(idx + 1, new_right_node_id)?;
let current_child_id = parent_node.child_ptr_at(idx)?;
track.reclaim(current_child_id);
parent_node.replace_child_at(idx, new_node_id)?;
Ok(true)
}
fn try_merge_with_left(
&self,
node: &mut NodeView,
parent_node: &mut NodeView,
idx: usize,
track: &mut impl TxnTracker,
) -> Result<Option<NodeId>, TreeError> {
if idx == 0 {
return Ok(None);
}
let left_child_idx = idx - 1;
let left_sibling_id = parent_node.child_ptr_at(left_child_idx)?;
let parent_key_idx = idx - 1;
let Some(mut left_sibling) = self.storage.read_node_view(left_sibling_id)? else {
return Err(TreeError::NodeNotFound(
format!("Left sibling id: {} not found", left_sibling_id).to_string(),
));
};
match (&mut left_sibling, &mut *node) {
(NodeView::Leaf { .. }, NodeView::Leaf { .. }) => {
if left_sibling.keys_len() + node.keys_len() > self.max_keys
|| !left_sibling.can_merge_physically(node)
{
return Ok(None);
}
let merged_node = self.merge_nodes_view(&mut left_sibling, node)?;
let merged_node_id = self.write_node_view(merged_node, track)?;
track.reclaim(parent_node.child_ptr_at(idx)?);
parent_node.delete_child_at(idx)?;
track.reclaim(left_sibling_id);
parent_node.replace_child_at(left_child_idx, merged_node_id)?;
if parent_node.keys_len() > 0 {
parent_node.delete_key_at(parent_key_idx)?;
}
Ok(Some(merged_node_id))
}
(NodeView::Internal { .. }, NodeView::Internal { .. }) => {
if left_sibling.keys_len() + node.keys_len() > self.max_keys
|| !left_sibling.can_merge_physically(node)
{
return Ok(None);
}
let separator_key = parent_node.delete_key_at(parent_key_idx)?;
left_sibling.insert_separator_at(
node.keys_len() + 1,
separator_key.as_bytes(),
node.child_ptr_at(0)?,
)?;
let merged_node = self.merge_nodes_view(&mut left_sibling, node)?;
let merged_node_id = self.write_node_view(merged_node, track)?;
track.reclaim(parent_node.child_ptr_at(idx)?);
parent_node.delete_child_at(idx)?;
track.reclaim(left_sibling_id);
parent_node.replace_child_at(left_child_idx, merged_node_id)?;
Ok(Some(merged_node_id))
}
_ => Err(TreeError::Invariant("mismatched node types for merge")),
}
}
fn try_merge_with_right(
&self,
node: &mut NodeView,
parent_node: &mut NodeView,
idx: usize,
track: &mut impl TxnTracker,
) -> Result<Option<NodeId>, TreeError> {
let right_idx = idx + 1;
if right_idx >= parent_node.children_len()? {
return Ok(None);
}
let right_sibling_id = parent_node.child_ptr_at(right_idx)?;
let parent_key_idx = idx;
let Some(mut right_sibling) = self.storage.read_node_view(right_sibling_id)? else {
return Err(TreeError::NodeNotFound(
format!("Right sibling id: {} not found", right_sibling_id).to_string(),
));
};
match (&mut *node, &mut right_sibling) {
(NodeView::Leaf { .. }, NodeView::Leaf { .. }) => {
if node.keys_len() + right_sibling.keys_len() > self.max_keys
|| !node.can_merge_physically(&right_sibling)
{
return Ok(None);
}
let merged_node = self.merge_nodes_view(node, &mut right_sibling)?;
let merged_node_id = self.write_node_view(merged_node, track)?;
track.reclaim(parent_node.child_ptr_at(right_idx)?);
parent_node.delete_child_at(right_idx)?;
track.reclaim(parent_node.child_ptr_at(idx)?);
parent_node.replace_child_at(idx, merged_node_id)?;
if parent_node.keys_len() > 0 {
parent_node.delete_key_at(parent_key_idx)?;
}
Ok(Some(merged_node_id))
}
(NodeView::Internal { .. }, NodeView::Internal { .. }) => {
if node.keys_len() + right_sibling.keys_len() > self.max_keys
|| !node.can_merge_physically(&right_sibling)
{
return Ok(None);
}
let separator_key = parent_node.delete_key_at(parent_key_idx)?;
node.insert_separator_at(
node.keys_len(),
separator_key.as_bytes(),
right_sibling.child_ptr_at(0)?,
)?;
let merged_node = self.merge_nodes_view(node, &mut right_sibling)?;
let merged_node_id = self.write_node_view(merged_node, track)?;
track.reclaim(parent_node.child_ptr_at(right_idx)?);
parent_node.delete_child_at(right_idx)?;
track.reclaim(parent_node.child_ptr_at(idx)?);
parent_node.replace_child_at(idx, merged_node_id)?;
Ok(Some(merged_node_id))
}
_ => Err(TreeError::Invariant("mismatched node types for merge")),
}
}
pub fn merge_nodes_view<'a>(
&'a self,
left_node: &'a mut NodeView,
right_node: &'a mut NodeView,
) -> Result<&'a NodeView, TreeError> {
match (&mut *left_node, &mut *right_node) {
(NodeView::Leaf { .. }, NodeView::Leaf { .. }) => {
if left_node.keys_len() + right_node.keys_len() > self.max_keys {
return Err(TreeError::Invariant("merge would exceed max keys"));
}
left_node.merge_into(right_node)?;
Ok(left_node)
}
(NodeView::Internal { .. }, NodeView::Internal { .. }) => {
if left_node.keys_len() + right_node.keys_len() > self.max_keys {
return Err(TreeError::Invariant("merge would exceed max keys"));
}
left_node.merge_into(right_node)?;
Ok(left_node)
}
_ => Err(TreeError::Invariant("mismatched node types for merge")),
}
}
#[cfg(test)]
pub fn commit(&self, new_root_id: NodeId, _height: u64, _size: u64) -> Result<(), TreeError> {
let current_meta = unsafe { &*self.committed.load(Ordering::Acquire) };
let new_txn_id = current_meta.txn_id + 1;
let target_slot = if new_txn_id % 2 == 0 {
self.meta_a
} else {
self.meta_b
};
MetadataManager::commit_metadata(
&*self.page_storage,
target_slot,
new_txn_id,
self.id,
new_root_id,
self.get_height(),
self.get_order(),
self.get_size(),
)?;
let current_ptr = self.committed.load(Ordering::Acquire);
let current = unsafe { &mut *current_ptr };
self.storage.flush()?;
current.root_node_id = new_root_id;
current.txn_id = new_txn_id;
self.commit_count.fetch_add(1, Ordering::Relaxed);
let _new_epoch = self.epoch_mgr.advance();
let safe_epoch = self.epoch_mgr.oldest_active();
let reclaimed = self.epoch_mgr.reclaim(safe_epoch);
for pid in reclaimed {
self.storage.free_node(pid)?;
}
if (self.commit_count.load(Ordering::Relaxed) as u64) % COMMIT_COUNT == 0 {
self.epoch_mgr.advance(); }
Ok(())
}
pub fn try_commit(
&self,
base_version: &BaseVersion,
new_meta: StagedMetadata,
) -> Result<(), CommitError> {
#[cfg(any(test, feature = "testing"))]
{
let injected: Result<(), CommitError> = Ok(());
fail::fail_point!("tree::commit::try_commit_failure", |_| {
injected = Err(CommitError::Injected);
println!("Injected failure in try_commit");
});
injected?; }
let expected = base_version.committed_ptr;
let current_ptr = self.committed.load(Ordering::Acquire);
let current = unsafe { &*current_ptr };
let new_txn_id = current.txn_id + 1;
let metadata = Metadata {
root_node_id: new_meta.root_id,
id: self.id,
height: new_meta.height,
size: new_meta.size,
txn_id: new_txn_id,
checksum: 0,
order: current.order,
};
let boxed = Box::new(metadata);
let new_ptr = Box::into_raw(boxed);
let result = self.committed.compare_exchange(
expected as *mut Metadata,
new_ptr,
Ordering::SeqCst,
Ordering::Relaxed,
);
match result {
Ok(old_ptr) => {
let slot = if new_txn_id % 2 == 0 {
self.meta_a
} else {
self.meta_b
};
let res = MetadataManager::commit_metadata_with_object(
&*self.page_storage,
slot,
&metadata,
);
if let Err(e) = res {
unsafe {
drop(Box::from_raw(new_ptr));
}
self.committed.store(current_ptr, Ordering::Release);
return Err(CommitError::Metadata(e));
}
self.storage.flush()?;
self.epoch_mgr.advance();
let safe_epoch = self.epoch_mgr.oldest_active();
let reclaimed = self.epoch_mgr.reclaim(safe_epoch);
for nid in reclaimed {
self.storage.free_node(nid)?;
}
if (self.commit_count.load(Ordering::Relaxed) as u64) % COMMIT_COUNT == 0 {
self.epoch_mgr.advance(); }
self.retired_meta.lock().unwrap().push(RetiredPtr(old_ptr));
Ok(())
}
Err(_) => {
unsafe {
drop(Box::from_raw(new_ptr));
}
Err(CommitError::RebaseRequired)
}
}
}
pub fn metadata(&self) -> &Metadata {
unsafe { &*self.committed.load(Ordering::Acquire) }
}
pub fn metadata_ptr(&self) -> *const Metadata {
unsafe { &*self.committed.load(Ordering::Acquire) }
}
pub fn snapshot(&self) -> Metadata {
let ptr = self.committed.load(Ordering::Acquire);
unsafe { *ptr }
}
pub fn get_snapshot(&self) -> MetadataSnapshot {
let ptr = self.committed.load(Ordering::Acquire);
let meta = unsafe { &*ptr };
MetadataSnapshot {
root_id: meta.root_node_id,
height: meta.height,
size: meta.size,
}
}
pub fn get_root_id(&self) -> NodeId {
self.snapshot().root_node_id
}
pub fn get_height(&self) -> u64 {
self.snapshot().height
}
pub fn get_size(&self) -> u64 {
self.snapshot().size
}
pub fn get_order(&self) -> u64 {
self.snapshot().order
}
pub fn reclaim_node(&self, node_id: NodeId) -> Result<(), TreeError> {
let epoch = self
.epoch_mgr
.get_current_thread_epoch()
.ok_or(TreeError::Invariant(
"failed to get epoch for current thread",
))?;
self.epoch_mgr.add_reclaim_candidate(epoch, node_id);
Ok(())
}
#[cfg(any(test, feature = "testing"))]
pub fn test_force_publish(&self, metadata: &Metadata) {
let old_ptr = self
.committed
.swap(Box::into_raw(Box::new(*metadata)), Ordering::SeqCst);
if !old_ptr.is_null() {
unsafe {
drop(Box::from_raw(old_ptr));
}
}
}
#[cfg(any(test, feature = "testing"))]
pub fn get_epoch_mgr(&self) -> Arc<EpochManager> {
Arc::clone(&self.epoch_mgr)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tests::common::{test_storage::TestStorage, test_tree};
#[test]
fn commit_happy_path() {
let storage = TestStorage::new();
let test_harness = test_tree::<TestStorage>(storage, 128);
let tree = test_harness.tree;
let base = BaseVersion {
committed_ptr: tree.metadata_ptr(),
};
let staged = StagedMetadata {
root_id: 42,
height: 3,
size: 10,
};
let res = tree.try_commit(&base, staged);
assert!(res.is_ok(), "Commit should succeed");
let m = tree.metadata();
assert_eq!(m.root_node_id, 42);
assert_eq!(m.txn_id, 2); }
#[test]
fn commit_happy_path_2() {
let storage = TestStorage::new(); let h = test_tree::<TestStorage>(storage, 128);
let base = BaseVersion {
committed_ptr: h.tree.metadata_ptr(),
};
let staged = StagedMetadata {
root_id: 42,
height: 3,
size: 10,
};
h.tree.try_commit(&base, staged).expect("commit ok");
let m = h.tree.metadata();
assert_eq!(m.root_node_id, 42);
assert_eq!(m.height, 3);
assert_eq!(m.size, 10);
assert_eq!(m.txn_id, 2);
let (slot, txn, rid, hgt, _ord, sz) = h.storage.last_commit().unwrap();
assert_eq!(slot, (txn % 2) as u8);
assert_eq!(txn, 2);
assert_eq!(rid, 42);
assert_eq!(hgt, 3);
assert_eq!(sz, 10);
assert_eq!(h.storage.flush_count(), 1);
}
#[test]
fn commit_aborts_on_conflict() {
let storage = TestStorage::new(); storage.inject_commit_failure(true);
let test_harness = test_tree::<TestStorage>(storage, 128);
let tree = test_harness.tree;
let _mocks = test_harness.storage;
let base = BaseVersion {
committed_ptr: tree.metadata_ptr(),
};
let staged = StagedMetadata {
root_id: 42,
height: 3,
size: 10,
};
let result = tree.try_commit(&base, staged);
println!("Commit result: {:?}", result);
assert!(result.is_err());
}
#[test]
fn txn_id_is_strictly_monotonic() {
let storage = TestStorage::new(); let h = test_tree::<TestStorage>(storage, 128);
let mut prev = h.tree.metadata().txn_id;
for i in 0..5 {
loop {
let base = BaseVersion {
committed_ptr: h.tree.metadata_ptr(),
};
if h.tree
.try_commit(
&base,
StagedMetadata {
root_id: 100 + i,
height: 3,
size: i,
},
)
.is_ok()
{
break;
}
}
let now = h.tree.metadata().txn_id;
assert_eq!(now, prev + 1);
prev = now;
}
}
#[test]
fn slot_follows_txn_mod2() {
let storage = TestStorage::new(); let h = test_tree::<TestStorage>(storage, 128);
for i in 0..6 {
let base = BaseVersion {
committed_ptr: h.tree.metadata_ptr(),
};
h.tree
.try_commit(
&base,
StagedMetadata {
root_id: 200 + i,
height: 3,
size: i,
},
)
.unwrap();
let (slot, txn, ..) = h.storage.last_commit().unwrap();
assert_eq!(slot, (txn % 2) as u8);
}
}
#[test]
fn commit_metadata_write_failure_is_abort() {
let storage = TestStorage::new(); storage.inject_commit_failure(true);
let test_harness = test_tree::<TestStorage>(storage, 128);
let tree = test_harness.tree;
let _mocks = test_harness.storage;
let base = BaseVersion {
committed_ptr: tree.metadata_ptr(),
};
let staged = StagedMetadata {
root_id: 42,
height: 3,
size: 10,
};
let md_before = tree.metadata(); let result = tree.try_commit(&base, staged);
assert!(result.is_err(), "Commit should fail due to storage failure");
let md_after = tree.metadata(); assert_eq!(
md_before.root_node_id, md_after.root_node_id,
"Root node ID should not change on commit failure"
);
}
#[test]
fn flush_failure_after_cas_keeps_published_state() {
let storage = TestStorage::new(); storage.inject_flush_failure(true);
let test_harness = test_tree::<TestStorage>(storage, 128);
let tree = test_harness.tree;
let _mocks = test_harness.storage;
let base = BaseVersion {
committed_ptr: tree.metadata_ptr(),
};
let staged = StagedMetadata {
root_id: 42,
height: 3,
size: 10,
};
let md_before = tree.metadata(); let result = tree.try_commit(&base, staged);
assert!(result.is_err(), "Commit should fail due to flush failure");
let md_after = tree.metadata(); assert_ne!(
md_before.root_node_id, md_after.root_node_id,
"Metadata should be published regardless of flush failure"
);
}
}