use std::ops::Range;
use bytes::Bytes;
use kimberlite_types::Offset;
use crate::Key;
use crate::cache::PageCache;
use crate::error::StoreError;
use crate::node::{InternalNode, LeafNode};
use crate::page::PageType;
use crate::types::{BTREE_MIN_KEYS, CRC_SIZE, PAGE_HEADER_SIZE, PAGE_SIZE, PageId};
const LEAF_PAGE_BYTE_BUDGET: usize = PAGE_SIZE - PAGE_HEADER_SIZE - CRC_SIZE;
use crate::version::RowVersion;
const MAX_TREE_DEPTH: usize = 32;
#[derive(Debug, Clone, Default)]
pub struct BTreeMeta {
pub root: Option<PageId>,
pub height: usize,
}
impl BTreeMeta {
#[allow(dead_code)]
pub fn new() -> Self {
Self::default()
}
#[allow(dead_code)]
pub fn with_root(root: PageId, height: usize) -> Self {
Self {
root: Some(root),
height,
}
}
}
pub struct BTree<'a> {
meta: &'a mut BTreeMeta,
cache: &'a mut PageCache,
}
impl<'a> BTree<'a> {
pub fn new(meta: &'a mut BTreeMeta, cache: &'a mut PageCache) -> Self {
Self { meta, cache }
}
#[allow(dead_code)]
pub fn root(&self) -> Option<PageId> {
self.meta.root
}
#[allow(dead_code)]
pub fn height(&self) -> usize {
self.meta.height
}
pub fn get(&mut self, key: &Key) -> Result<Option<Bytes>, StoreError> {
let Some(root) = self.meta.root else {
return Ok(None);
};
let leaf_id = self.find_leaf(root, key, 0)?;
let page = self
.cache
.get(leaf_id)?
.ok_or(StoreError::PageNotFound(leaf_id))?;
let leaf = LeafNode::from_page(page)?;
if let Some(entry) = leaf.get(key) {
if let Some(version) = entry.versions.current() {
return Ok(Some(version.data.clone()));
}
}
Ok(None)
}
pub fn get_at(&mut self, key: &Key, pos: Offset) -> Result<Option<Bytes>, StoreError> {
let Some(root) = self.meta.root else {
return Ok(None);
};
let leaf_id = self.find_leaf(root, key, 0)?;
let page = self
.cache
.get(leaf_id)?
.ok_or(StoreError::PageNotFound(leaf_id))?;
let leaf = LeafNode::from_page(page)?;
if let Some(entry) = leaf.get(key) {
if let Some(version) = entry.versions.at(pos) {
return Ok(Some(version.data.clone()));
}
}
Ok(None)
}
pub fn scan(
&mut self,
range: Range<Key>,
limit: usize,
) -> Result<Vec<(Key, Bytes)>, StoreError> {
debug_assert!(
range.start <= range.end,
"scan called with inverted range: start={:?} > end={:?}",
range.start,
range.end
);
if range.start >= range.end {
return Ok(Vec::new());
}
let Some(root) = self.meta.root else {
return Ok(Vec::new());
};
let mut results = Vec::new();
let start_leaf_id = self.find_leaf(root, &range.start, 0)?;
let mut current_leaf_id = Some(start_leaf_id);
while let Some(leaf_id) = current_leaf_id {
if results.len() >= limit {
break;
}
let page = self
.cache
.get(leaf_id)?
.ok_or(StoreError::PageNotFound(leaf_id))?;
let leaf = LeafNode::from_page(page)?;
for entry in leaf.range(&range.start, &range.end) {
if entry.key >= range.end {
current_leaf_id = None;
break;
}
if let Some(version) = entry.versions.current() {
results.push((entry.key.clone(), version.data.clone()));
if results.len() >= limit {
break;
}
}
}
if current_leaf_id.is_some() {
current_leaf_id = leaf.next_leaf;
}
}
Ok(results)
}
pub fn scan_at(
&mut self,
range: Range<Key>,
limit: usize,
pos: Offset,
) -> Result<Vec<(Key, Bytes)>, StoreError> {
debug_assert!(
range.start <= range.end,
"scan_at called with inverted range: start={:?} > end={:?}",
range.start,
range.end
);
if range.start >= range.end {
return Ok(Vec::new());
}
let Some(root) = self.meta.root else {
return Ok(Vec::new());
};
let mut results = Vec::new();
let start_leaf_id = self.find_leaf(root, &range.start, 0)?;
let mut current_leaf_id = Some(start_leaf_id);
while let Some(leaf_id) = current_leaf_id {
if results.len() >= limit {
break;
}
let page = self
.cache
.get(leaf_id)?
.ok_or(StoreError::PageNotFound(leaf_id))?;
let leaf = LeafNode::from_page(page)?;
for entry in leaf.range(&range.start, &range.end) {
if entry.key >= range.end {
current_leaf_id = None;
break;
}
if let Some(version) = entry.versions.at(pos) {
results.push((entry.key.clone(), version.data.clone()));
if results.len() >= limit {
break;
}
}
}
if current_leaf_id.is_some() {
current_leaf_id = leaf.next_leaf;
}
}
Ok(results)
}
pub fn put(&mut self, key: Key, value: Bytes, pos: Offset) -> Result<(), StoreError> {
let version = RowVersion::new(pos, value);
match self.meta.root {
None => {
let page_id = self.cache.allocate(PageType::Leaf)?;
let page = self.cache.get_mut(page_id)?.unwrap();
let mut leaf = LeafNode::new();
leaf.insert(key, version);
leaf.to_page(page)?;
self.meta.root = Some(page_id);
self.meta.height = 1;
}
Some(root) => {
if let Some((split_key, new_child)) =
self.insert_recursive(root, key, version, 0)?
{
let new_root_id = self.cache.allocate(PageType::Internal)?;
let page = self.cache.get_mut(new_root_id)?.unwrap();
let internal = InternalNode::from_split(root, split_key, new_child);
internal.to_page(page)?;
self.meta.root = Some(new_root_id);
self.meta.height += 1;
}
}
}
Ok(())
}
pub fn delete(&mut self, key: &Key, pos: Offset) -> Result<bool, StoreError> {
let Some(root) = self.meta.root else {
return Ok(false);
};
let leaf_id = self.find_leaf(root, key, 0)?;
let page = self
.cache
.get_mut(leaf_id)?
.ok_or(StoreError::PageNotFound(leaf_id))?;
let mut leaf = LeafNode::from_page(page)?;
let deleted = leaf.delete(key, pos);
if deleted {
leaf.to_page(page)?;
}
Ok(deleted)
}
fn find_leaf(
&mut self,
page_id: PageId,
key: &Key,
depth: usize,
) -> Result<PageId, StoreError> {
if depth >= MAX_TREE_DEPTH {
return Err(StoreError::BTreeInvariant("tree too deep".into()));
}
let page = self
.cache
.get(page_id)?
.ok_or(StoreError::PageNotFound(page_id))?;
match page.page_type() {
PageType::Leaf => Ok(page_id),
PageType::Internal => {
let internal = InternalNode::from_page(page)?;
let child_id = internal.find_child(key);
self.find_leaf(child_id, key, depth + 1)
}
PageType::Free => Err(StoreError::BTreeInvariant(
"hit free page during search".into(),
)),
}
}
fn insert_recursive(
&mut self,
page_id: PageId,
key: Key,
version: RowVersion,
depth: usize,
) -> Result<Option<(Key, PageId)>, StoreError> {
if depth >= MAX_TREE_DEPTH {
return Err(StoreError::BTreeInvariant("tree too deep".into()));
}
let page = self
.cache
.get(page_id)?
.ok_or(StoreError::PageNotFound(page_id))?;
let page_type = page.page_type();
match page_type {
PageType::Leaf => self.insert_into_leaf(page_id, key, version),
PageType::Internal => {
let page = self.cache.get(page_id)?.unwrap();
let internal = InternalNode::from_page(page)?;
let child_id = internal.find_child(&key);
drop(internal);
if let Some((child_split_key, new_child_id)) =
self.insert_recursive(child_id, key, version, depth + 1)?
{
self.insert_into_internal(page_id, child_split_key, new_child_id)
} else {
Ok(None)
}
}
PageType::Free => Err(StoreError::BTreeInvariant(
"hit free page during insert".into(),
)),
}
}
fn insert_into_leaf(
&mut self,
page_id: PageId,
key: Key,
version: RowVersion,
) -> Result<Option<(Key, PageId)>, StoreError> {
let page = self
.cache
.get_mut(page_id)?
.ok_or(StoreError::PageNotFound(page_id))?;
let mut leaf = LeafNode::from_page(page)?;
leaf.insert(key, version);
if leaf.len() > BTREE_MIN_KEYS * 2 || leaf.size_on_page() > LEAF_PAGE_BYTE_BUDGET {
let (split_key, mut right_leaf) = leaf.split();
let right_page_id = self.cache.allocate(PageType::Leaf)?;
right_leaf.next_leaf = leaf.next_leaf;
leaf.next_leaf = Some(right_page_id);
let left_page = self.cache.get_mut(page_id)?.unwrap();
leaf.to_page(left_page)?;
let right_page = self.cache.get_mut(right_page_id)?.unwrap();
right_leaf.to_page(right_page)?;
Ok(Some((split_key, right_page_id)))
} else {
leaf.to_page(page)?;
Ok(None)
}
}
fn insert_into_internal(
&mut self,
page_id: PageId,
key: Key,
child_id: PageId,
) -> Result<Option<(Key, PageId)>, StoreError> {
let page = self
.cache
.get_mut(page_id)?
.ok_or(StoreError::PageNotFound(page_id))?;
let mut internal = InternalNode::from_page(page)?;
internal.insert(key, child_id);
if internal.key_count() > BTREE_MIN_KEYS * 2 {
let (split_key, right_internal) = internal.split();
let right_page_id = self.cache.allocate(PageType::Internal)?;
let left_page = self.cache.get_mut(page_id)?.unwrap();
internal.to_page(left_page)?;
let right_page = self.cache.get_mut(right_page_id)?.unwrap();
right_internal.to_page(right_page)?;
Ok(Some((split_key, right_page_id)))
} else {
internal.to_page(page)?;
Ok(None)
}
}
}
#[cfg(test)]
mod btree_tests {
use super::*;
use tempfile::tempdir;
fn create_cache() -> (tempfile::TempDir, PageCache) {
let dir = tempdir().unwrap();
let path = dir.path().join("btree_test.db");
let cache = PageCache::open(&path, Some(100)).unwrap();
(dir, cache)
}
#[test]
fn test_empty_tree() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
let mut tree = BTree::new(&mut meta, &mut cache);
assert!(tree.root().is_none());
assert_eq!(tree.get(&Key::from("key")).unwrap(), None);
}
#[test]
fn test_single_insert_and_get() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
{
let mut tree = BTree::new(&mut meta, &mut cache);
tree.put(Key::from("hello"), Bytes::from("world"), Offset::new(1))
.unwrap();
}
{
let mut tree = BTree::new(&mut meta, &mut cache);
assert!(tree.root().is_some());
assert_eq!(
tree.get(&Key::from("hello")).unwrap(),
Some(Bytes::from("world"))
);
assert_eq!(tree.get(&Key::from("missing")).unwrap(), None);
}
}
#[test]
fn test_multiple_inserts() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
{
let mut tree = BTree::new(&mut meta, &mut cache);
for i in 0_u64..10 {
let key = Key::from(format!("key{i:02}"));
let value = Bytes::from(format!("value{i}"));
tree.put(key, value, Offset::new(i)).unwrap();
}
}
{
let mut tree = BTree::new(&mut meta, &mut cache);
for i in 0..10 {
let key = Key::from(format!("key{i:02}"));
let expected = Bytes::from(format!("value{i}"));
assert_eq!(tree.get(&key).unwrap(), Some(expected));
}
}
}
#[test]
fn test_mvcc_get_at() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
let key = Key::from("mvcc-key");
{
let mut tree = BTree::new(&mut meta, &mut cache);
tree.put(key.clone(), Bytes::from("v1"), Offset::new(1))
.unwrap();
tree.put(key.clone(), Bytes::from("v2"), Offset::new(5))
.unwrap();
tree.put(key.clone(), Bytes::from("v3"), Offset::new(10))
.unwrap();
assert_eq!(tree.get(&key).unwrap(), Some(Bytes::from("v3")));
assert_eq!(tree.get_at(&key, Offset::new(0)).unwrap(), None);
assert_eq!(
tree.get_at(&key, Offset::new(1)).unwrap(),
Some(Bytes::from("v1"))
);
assert_eq!(
tree.get_at(&key, Offset::new(3)).unwrap(),
Some(Bytes::from("v1"))
);
assert_eq!(
tree.get_at(&key, Offset::new(5)).unwrap(),
Some(Bytes::from("v2"))
);
assert_eq!(
tree.get_at(&key, Offset::new(8)).unwrap(),
Some(Bytes::from("v2"))
);
assert_eq!(
tree.get_at(&key, Offset::new(10)).unwrap(),
Some(Bytes::from("v3"))
);
assert_eq!(
tree.get_at(&key, Offset::new(100)).unwrap(),
Some(Bytes::from("v3"))
);
}
}
#[test]
fn test_delete() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
{
let mut tree = BTree::new(&mut meta, &mut cache);
tree.put(Key::from("key"), Bytes::from("value"), Offset::new(1))
.unwrap();
assert_eq!(
tree.get(&Key::from("key")).unwrap(),
Some(Bytes::from("value"))
);
tree.delete(&Key::from("key"), Offset::new(5)).unwrap();
assert_eq!(tree.get(&Key::from("key")).unwrap(), None);
assert_eq!(
tree.get_at(&Key::from("key"), Offset::new(3)).unwrap(),
Some(Bytes::from("value"))
);
}
}
#[test]
fn test_scan_range() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
{
let mut tree = BTree::new(&mut meta, &mut cache);
for i in 0_u64..20 {
let key = Key::from(format!("key{i:02}"));
let value = Bytes::from(format!("value{i}"));
tree.put(key, value, Offset::new(i)).unwrap();
}
let results = tree
.scan(Key::from("key05")..Key::from("key10"), 100)
.unwrap();
assert_eq!(results.len(), 5);
assert_eq!(results[0].0, Key::from("key05"));
assert_eq!(results[4].0, Key::from("key09"));
}
}
#[test]
fn test_node_splitting() {
let (_dir, mut cache) = create_cache();
let mut meta = BTreeMeta::new();
{
let mut tree = BTree::new(&mut meta, &mut cache);
for i in 0_u64..50 {
let key = Key::from(format!("key{i:03}"));
let value = Bytes::from(format!("value{i}"));
tree.put(key, value, Offset::new(i)).unwrap();
}
assert!(tree.height() >= 1);
for i in 0..50 {
let key = Key::from(format!("key{i:03}"));
let expected = Bytes::from(format!("value{i}"));
assert_eq!(
tree.get(&key).unwrap(),
Some(expected),
"failed for key{i:03}"
);
}
}
}
}