#![forbid(unsafe_code)]
use heapless::Vec as HeaplessVec;
use crate::btree::node::{
decode_node, encode_node, max_inline_value, max_key_len, DecodedNode, InternalEntry, LeafEntry,
NodeKind, INTERNAL_LEFTMOST_CHILD_BYTES, INTERNAL_SLOT_BYTES, LEAF_SLOT_BYTES,
};
use crate::btree::{BTree, MAX_BTREE_DEPTH};
use crate::error::{Error, Result};
use crate::pager::page::{Page, PageId, PAGE_SIZE, PAGE_TRAILER_SIZE};
use crate::pager::Pager;
use crate::platform::FileBackend;
struct PathFrame {
page_id: PageId,
node: DecodedNode,
child_index: usize,
}
enum ReplaceOutcome {
Fits { new_id: PageId },
Split {
left_id: PageId,
right_id: PageId,
promoted_key: Vec<u8>,
},
}
const PAYLOAD_BYTES: usize = PAGE_SIZE - PAGE_TRAILER_SIZE - crate::btree::node::NODE_HEADER_SIZE;
impl<F: FileBackend> BTree<F> {
pub fn insert(&mut self, pager: &mut Pager<F>, key: &[u8], value: &[u8]) -> Result<()> {
check_key_value_size(key, value)?;
let path = self.descend_with_path(pager, key)?;
self.apply_insert(pager, path, key, value)
}
fn descend_with_path(
&self,
pager: &mut Pager<F>,
key: &[u8],
) -> Result<HeaplessVec<PathFrame, MAX_BTREE_DEPTH>> {
let mut path: HeaplessVec<PathFrame, MAX_BTREE_DEPTH> = HeaplessVec::new();
let mut current = self.root;
loop {
let decoded = {
let page_ref = pager.read_page(current)?;
decode_node(page_ref.as_bytes())?
};
match decoded.kind {
NodeKind::Leaf => {
let frame = PathFrame {
page_id: current,
node: decoded,
child_index: 0,
};
if path.push(frame).is_err() {
return Err(Error::BTreeDepthExceeded {
limit: MAX_BTREE_DEPTH,
});
}
return Ok(path);
}
NodeKind::Internal => {
let child_index = pivot_index(&decoded, key);
let raw = decoded.children[child_index];
let next = PageId::new(raw).ok_or(Error::BTreeInvariantViolated {
reason: "internal node had zero child page-id",
})?;
let frame = PathFrame {
page_id: current,
node: decoded,
child_index,
};
if path.push(frame).is_err() {
return Err(Error::BTreeDepthExceeded {
limit: MAX_BTREE_DEPTH,
});
}
current = next;
}
}
}
}
fn apply_insert(
&mut self,
pager: &mut Pager<F>,
mut path: HeaplessVec<PathFrame, MAX_BTREE_DEPTH>,
key: &[u8],
value: &[u8],
) -> Result<()> {
let mut freed: HeaplessVec<PageId, { MAX_BTREE_DEPTH * 2 }> = HeaplessVec::new();
let Some(leaf_frame) = path.pop() else {
return Err(Error::BTreeInvariantViolated {
reason: "insert: descend returned empty path",
});
};
let mut outcome = replace_leaf(pager, leaf_frame, key, value, &mut freed)?;
while let Some(parent_frame) = path.pop() {
outcome = replace_internal(pager, parent_frame, outcome, &mut freed)?;
}
let new_root = build_new_root(pager, outcome)?;
self.root = new_root;
for old_id in freed.iter().copied() {
pager.free_page(old_id)?;
}
Ok(())
}
}
fn replace_leaf<F: FileBackend>(
pager: &mut Pager<F>,
frame: PathFrame,
key: &[u8],
value: &[u8],
freed: &mut HeaplessVec<PageId, { MAX_BTREE_DEPTH * 2 }>,
) -> Result<ReplaceOutcome> {
let mut leaf = frame.node;
if leaf.leaves.iter().any(|e| e.key.as_slice() == key) {
return Err(Error::BTreeKeyExists);
}
let insert_at = leaf
.leaves
.iter()
.position(|e| e.key.as_slice() > key)
.unwrap_or(leaf.leaves.len());
leaf.leaves.insert(
insert_at,
LeafEntry {
key: key.to_vec(),
value: value.to_vec(),
},
);
push_freed(freed, frame.page_id)?;
if leaf.occupied_bytes() <= PAYLOAD_BYTES {
let new_id = write_new_node(pager, &leaf)?;
return Ok(ReplaceOutcome::Fits { new_id });
}
split_leaf(pager, leaf)
}
fn replace_internal<F: FileBackend>(
pager: &mut Pager<F>,
frame: PathFrame,
child_outcome: ReplaceOutcome,
freed: &mut HeaplessVec<PageId, { MAX_BTREE_DEPTH * 2 }>,
) -> Result<ReplaceOutcome> {
let mut internal = frame.node;
let idx = frame.child_index;
match child_outcome {
ReplaceOutcome::Fits { new_id } => {
internal.children[idx] = new_id.get();
}
ReplaceOutcome::Split {
left_id,
right_id,
promoted_key,
} => {
internal.children[idx] = left_id.get();
internal
.internals
.insert(idx, InternalEntry { key: promoted_key });
internal.children.insert(idx + 1, right_id.get());
}
}
push_freed(freed, frame.page_id)?;
if internal.occupied_bytes() <= PAYLOAD_BYTES {
let new_id = write_new_node(pager, &internal)?;
return Ok(ReplaceOutcome::Fits { new_id });
}
split_internal(pager, internal)
}
fn build_new_root<F: FileBackend>(pager: &mut Pager<F>, outcome: ReplaceOutcome) -> Result<PageId> {
let (left_id, right_id, promoted_key) = match outcome {
ReplaceOutcome::Fits { new_id } => return Ok(new_id),
ReplaceOutcome::Split {
left_id,
right_id,
promoted_key,
} => (left_id, right_id, promoted_key),
};
let level = node_level_after_split(pager, left_id)?;
let next_level = level.checked_add(1).ok_or(Error::BTreeDepthExceeded {
limit: MAX_BTREE_DEPTH,
})?;
let root_node = DecodedNode {
kind: NodeKind::Internal,
level: next_level,
next_sibling: 0,
children: vec![left_id.get(), right_id.get()],
leaves: Vec::new(),
internals: vec![InternalEntry { key: promoted_key }],
};
write_new_node(pager, &root_node)
}
fn push_freed(freed: &mut HeaplessVec<PageId, { MAX_BTREE_DEPTH * 2 }>, id: PageId) -> Result<()> {
freed.push(id).map_err(|_| Error::BTreeInvariantViolated {
reason: "insert: too many displaced pages to track",
})
}
fn pivot_index(node: &DecodedNode, key: &[u8]) -> usize {
let mut idx = node.internals.len();
for (i, pivot) in node.internals.iter().enumerate() {
if pivot.key.as_slice() > key {
idx = i;
break;
}
}
idx
}
fn check_key_value_size(key: &[u8], value: &[u8]) -> Result<()> {
if key.len() > max_key_len() {
return Err(Error::BTreeKeyTooLarge {
key_len: key.len(),
max: max_key_len(),
});
}
let v_max = max_inline_value(key.len());
if value.len() > v_max {
return Err(Error::BTreeValueTooLarge {
value_len: value.len(),
max: v_max,
});
}
Ok(())
}
pub(crate) fn write_new_node<F: FileBackend>(
pager: &mut Pager<F>,
node: &DecodedNode,
) -> Result<PageId> {
let new_id = pager.alloc_page()?;
let mut page = Page::zeroed();
encode_node(node, &mut page)?;
pager.write_page(new_id, &page)?;
Ok(new_id)
}
fn split_leaf<F: FileBackend>(
pager: &mut Pager<F>,
mut leaf: DecodedNode,
) -> Result<ReplaceOutcome> {
let mid = leaf.leaves.len() / 2;
let original_sibling = leaf.next_sibling;
let right_entries: Vec<LeafEntry> = leaf.leaves.split_off(mid);
let promoted_key = right_entries[0].key.clone();
let right_node = DecodedNode {
kind: NodeKind::Leaf,
level: 0,
next_sibling: original_sibling,
children: Vec::new(),
leaves: right_entries,
internals: Vec::new(),
};
let right_id = write_new_node(pager, &right_node)?;
let left_node = DecodedNode {
kind: NodeKind::Leaf,
level: 0,
next_sibling: right_id.get(),
children: Vec::new(),
leaves: leaf.leaves,
internals: Vec::new(),
};
let left_id = write_new_node(pager, &left_node)?;
Ok(ReplaceOutcome::Split {
left_id,
right_id,
promoted_key,
})
}
fn split_internal<F: FileBackend>(
pager: &mut Pager<F>,
mut internal: DecodedNode,
) -> Result<ReplaceOutcome> {
let k = internal.internals.len();
debug_assert!(k >= 2, "internal split needs ≥ 2 pivots");
let mid = k / 2;
let right_pivots: Vec<InternalEntry> = internal.internals.split_off(mid + 1);
let right_children: Vec<u64> = internal.children.split_off(mid + 1);
let promoted_pivot = internal
.internals
.pop()
.ok_or(Error::BTreeInvariantViolated {
reason: "internal split: missing promoted pivot",
})?;
let level = internal.level;
let right_node = DecodedNode {
kind: NodeKind::Internal,
level,
next_sibling: 0,
children: right_children,
leaves: Vec::new(),
internals: right_pivots,
};
let right_id = write_new_node(pager, &right_node)?;
let left_node = DecodedNode {
kind: NodeKind::Internal,
level,
next_sibling: 0,
children: internal.children,
leaves: Vec::new(),
internals: internal.internals,
};
let left_id = write_new_node(pager, &left_node)?;
Ok(ReplaceOutcome::Split {
left_id,
right_id,
promoted_key: promoted_pivot.key,
})
}
fn node_level_after_split<F: FileBackend>(pager: &mut Pager<F>, id: PageId) -> Result<u8> {
let page_ref = pager.read_page(id)?;
let decoded = decode_node(page_ref.as_bytes())?;
Ok(decoded.level)
}
const _: usize = PAYLOAD_BYTES;
const _UNUSED_CHECKS: () = {
let _ = INTERNAL_LEFTMOST_CHILD_BYTES;
let _ = INTERNAL_SLOT_BYTES;
let _ = LEAF_SLOT_BYTES;
};
#[cfg(test)]
mod tests {
use super::*;
use crate::pager::{Config, Pager};
use crate::platform::FileHandle;
use proptest::prelude::*;
use rand::prelude::IndexedRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;
use std::collections::BTreeMap;
fn config() -> Config {
Config::default()
}
#[test]
fn insert_single_key_round_trip() {
let mut pager = Pager::<FileHandle>::memory(config()).expect("pager");
let mut tree = BTree::<FileHandle>::empty(&mut pager).expect("empty");
tree.insert(&mut pager, b"hello", b"world").expect("ins");
assert_eq!(
tree.get(&mut pager, b"hello").expect("get"),
Some(b"world".to_vec())
);
}
#[test]
fn duplicate_key_errors() {
let mut pager = Pager::<FileHandle>::memory(config()).expect("pager");
let mut tree = BTree::<FileHandle>::empty(&mut pager).expect("empty");
tree.insert(&mut pager, b"k", b"v1").expect("ins");
let err = tree
.insert(&mut pager, b"k", b"v2")
.expect_err("dup must fail");
assert!(matches!(err, Error::BTreeKeyExists));
}
#[test]
fn insert_growth_splits_root() {
let mut pager = Pager::<FileHandle>::memory(config()).expect("pager");
let mut tree = BTree::<FileHandle>::empty(&mut pager).expect("empty");
let value = vec![0xABu8; 256];
for i in 0..200u32 {
let key = format!("key-{i:08}");
tree.insert(&mut pager, key.as_bytes(), &value)
.expect("ins");
}
for i in 0..200u32 {
let key = format!("key-{i:08}");
assert_eq!(
tree.get(&mut pager, key.as_bytes()).expect("get"),
Some(value.clone()),
"key {key}"
);
}
let root = tree.root();
let page_ref = pager.read_page(root).expect("read root");
let decoded = decode_node(page_ref.as_bytes()).expect("decode root");
assert!(
decoded.level >= 1,
"expected internal root, got {decoded:?}"
);
}
proptest! {
#![proptest_config(ProptestConfig {
cases: 16,
max_shrink_iters: 32,
.. ProptestConfig::default()
})]
#[test]
fn insert_oracle_property(seed in any::<u64>()) {
run_insert_oracle(seed, 200);
}
}
#[test]
fn insert_oracle_10k() {
for seed in 0..3u64 {
run_insert_oracle(seed, 10_000);
}
}
fn run_insert_oracle(seed: u64, ops: usize) {
let mut rng = ChaCha8Rng::seed_from_u64(seed);
let mut pager = Pager::<FileHandle>::memory(config()).expect("pager");
let mut tree = BTree::<FileHandle>::empty(&mut pager).expect("empty");
let mut oracle: BTreeMap<Vec<u8>, Vec<u8>> = BTreeMap::new();
for op in 0..ops {
let key = random_key(&mut rng);
let value = random_value(&mut rng);
let key_already = oracle.contains_key(&key);
let res = tree.insert(&mut pager, &key, &value);
if key_already {
assert!(
matches!(res, Err(Error::BTreeKeyExists)),
"seed {seed} op {op}: expected BTreeKeyExists, got {res:?}"
);
} else {
res.unwrap_or_else(|e| panic!("seed {seed} op {op}: insert err {e:?}"));
oracle.insert(key.clone(), value.clone());
}
if op.is_multiple_of(127) {
let keys: Vec<&Vec<u8>> = oracle.keys().collect();
if !keys.is_empty() {
let sample: Vec<&Vec<u8>> =
keys.choose_multiple(&mut rng, 4).copied().collect();
for k in sample {
assert_eq!(
tree.get(&mut pager, k).expect("get").as_ref(),
oracle.get(k),
"seed {seed} op {op}: key {k:?}"
);
}
}
}
}
for (k, v) in &oracle {
assert_eq!(
tree.get(&mut pager, k).expect("get").as_ref(),
Some(v),
"seed {seed} final: key {k:?}"
);
}
}
fn random_key(rng: &mut ChaCha8Rng) -> Vec<u8> {
use rand::Rng;
let len = rng.random_range(1..16);
(0..len).map(|_| rng.random_range(b'a'..=b'z')).collect()
}
fn random_value(rng: &mut ChaCha8Rng) -> Vec<u8> {
use rand::Rng;
let len = rng.random_range(0..64);
(0..len).map(|_| rng.random()).collect()
}
}