use bytes::Bytes;
use kimberlite_types::Offset;
use crate::Key;
use crate::error::StoreError;
use crate::page::{Page, PageType};
use crate::types::PageId;
use crate::version::{RowVersion, VersionChain};
fn serialize_key(key: &Key) -> Vec<u8> {
let mut buf = Vec::with_capacity(2 + key.len());
buf.extend_from_slice(&(key.len() as u16).to_le_bytes());
buf.extend_from_slice(key.as_bytes());
buf
}
fn deserialize_key(data: &[u8]) -> Result<(Key, &[u8]), StoreError> {
if data.len() < 2 {
return Err(StoreError::BTreeInvariant("key length truncated".into()));
}
let key_len = u16::from_le_bytes(data[0..2].try_into().unwrap()) as usize;
if data.len() < 2 + key_len {
return Err(StoreError::BTreeInvariant("key data truncated".into()));
}
let key = Key::from(&data[2..2 + key_len]);
Ok((key, &data[2 + key_len..]))
}
fn serialize_version(version: &RowVersion) -> Vec<u8> {
let mut buf = Vec::with_capacity(16 + 4 + version.data.len());
buf.extend_from_slice(&version.created_at.as_u64().to_le_bytes());
buf.extend_from_slice(&version.deleted_at.as_u64().to_le_bytes());
buf.extend_from_slice(&(version.data.len() as u32).to_le_bytes());
buf.extend_from_slice(&version.data);
buf
}
fn deserialize_version(data: &[u8]) -> Result<(RowVersion, &[u8]), StoreError> {
if data.len() < 20 {
return Err(StoreError::BTreeInvariant(
"version header truncated".into(),
));
}
let created_at = Offset::new(u64::from_le_bytes(data[0..8].try_into().unwrap()));
let deleted_at = Offset::new(u64::from_le_bytes(data[8..16].try_into().unwrap()));
let data_len = u32::from_le_bytes(data[16..20].try_into().unwrap()) as usize;
if data.len() < 20 + data_len {
return Err(StoreError::BTreeInvariant("version data truncated".into()));
}
let version = RowVersion {
created_at,
deleted_at,
data: Bytes::copy_from_slice(&data[20..20 + data_len]),
};
Ok((version, &data[20 + data_len..]))
}
#[derive(Debug, Clone)]
pub struct LeafEntry {
pub key: Key,
pub versions: VersionChain,
}
impl LeafEntry {
pub fn new(key: Key, version: RowVersion) -> Self {
Self {
key,
versions: VersionChain::single(version),
}
}
pub fn serialize(&self) -> Vec<u8> {
let versions: Vec<&RowVersion> = self.versions.iter().collect();
let version_data: Vec<Vec<u8>> = versions.iter().map(|v| serialize_version(v)).collect();
let version_total: usize = version_data.iter().map(Vec::len).sum();
let mut buf = Vec::with_capacity(2 + self.key.len() + 2 + version_total);
buf.extend_from_slice(&serialize_key(&self.key));
buf.extend_from_slice(&(versions.len() as u16).to_le_bytes());
for v in version_data {
buf.extend_from_slice(&v);
}
buf
}
pub fn deserialize(data: &[u8]) -> Result<Self, StoreError> {
let (key, remaining) = deserialize_key(data)?;
if remaining.len() < 2 {
return Err(StoreError::BTreeInvariant("version count truncated".into()));
}
let version_count = u16::from_le_bytes(remaining[0..2].try_into().unwrap()) as usize;
let mut remaining = &remaining[2..];
let mut version_vec = Vec::with_capacity(version_count);
for _ in 0..version_count {
let (version, rest) = deserialize_version(remaining)?;
version_vec.push(version);
remaining = rest;
}
let versions = VersionChain::from_vec(version_vec);
Ok(Self { key, versions })
}
pub fn serialized_size(&self) -> usize {
2 + self.key.len() + 2 + self.versions.total_size() + (self.versions.len() * 20)
}
}
#[derive(Debug, Clone)]
pub struct InternalEntry {
pub key: Key,
pub child: PageId,
}
impl InternalEntry {
pub fn new(key: Key, child: PageId) -> Self {
Self { key, child }
}
pub fn serialize(&self) -> Vec<u8> {
let mut buf = Vec::with_capacity(2 + self.key.len() + 8);
buf.extend_from_slice(&serialize_key(&self.key));
buf.extend_from_slice(&self.child.as_u64().to_le_bytes());
buf
}
pub fn deserialize(data: &[u8]) -> Result<Self, StoreError> {
let (key, remaining) = deserialize_key(data)?;
if remaining.len() < 8 {
return Err(StoreError::BTreeInvariant("child page id truncated".into()));
}
let child = PageId::new(u64::from_le_bytes(remaining[0..8].try_into().unwrap()));
Ok(Self { key, child })
}
#[allow(dead_code)]
pub fn serialized_size(&self) -> usize {
2 + self.key.len() + 8
}
}
#[derive(Debug)]
pub struct LeafNode {
entries: Vec<LeafEntry>,
pub next_leaf: Option<PageId>,
}
impl LeafNode {
pub fn new() -> Self {
Self {
entries: Vec::new(),
next_leaf: None,
}
}
pub fn from_page(page: &Page) -> Result<Self, StoreError> {
debug_assert_eq!(page.page_type(), PageType::Leaf, "expected leaf page");
let mut entries = Vec::with_capacity(page.item_count());
let next_leaf = if page.item_count() > 0 {
let first_item = page.get_item(0);
if first_item.len() == 8 {
let page_id = u64::from_le_bytes(first_item.try_into().unwrap());
if page_id == u64::MAX {
None
} else {
Some(PageId::new(page_id))
}
} else {
None
}
} else {
None
};
for i in 1..page.item_count() {
let entry = LeafEntry::deserialize(page.get_item(i))?;
entries.push(entry);
}
Ok(Self { entries, next_leaf })
}
pub fn to_page(&self, page: &mut Page) -> Result<(), StoreError> {
debug_assert_eq!(page.page_type(), PageType::Leaf, "expected leaf page");
while page.item_count() > 0 {
page.remove_item(page.item_count() - 1);
}
let next_leaf_bytes = match self.next_leaf {
Some(id) => id.as_u64().to_le_bytes(),
None => u64::MAX.to_le_bytes(),
};
page.insert_item(0, &next_leaf_bytes)?;
for (i, entry) in self.entries.iter().enumerate() {
let data = entry.serialize();
page.insert_item(i + 1, &data)?;
}
Ok(())
}
pub fn len(&self) -> usize {
self.entries.len()
}
#[allow(dead_code)]
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub fn size_on_page(&self) -> usize {
const SLOT_SIZE: usize = 4;
let mut total = SLOT_SIZE + 8; for entry in &self.entries {
total += SLOT_SIZE + entry.serialized_size();
}
total
}
fn find_key_index(&self, key: &Key) -> Result<usize, usize> {
self.entries.binary_search_by(|e| e.key.cmp(key))
}
pub fn get(&self, key: &Key) -> Option<&LeafEntry> {
match self.find_key_index(key) {
Ok(idx) => Some(&self.entries[idx]),
Err(_) => None,
}
}
#[allow(dead_code)]
pub fn get_mut(&mut self, key: &Key) -> Option<&mut LeafEntry> {
match self.find_key_index(key) {
Ok(idx) => Some(&mut self.entries[idx]),
Err(_) => None,
}
}
pub fn insert(&mut self, key: Key, version: RowVersion) -> bool {
match self.find_key_index(&key) {
Ok(idx) => {
self.entries[idx].versions.add(version);
false
}
Err(idx) => {
self.entries.insert(idx, LeafEntry::new(key, version));
true
}
}
}
pub fn delete(&mut self, key: &Key, pos: Offset) -> bool {
match self.find_key_index(key) {
Ok(idx) => self.entries[idx].versions.delete_at(pos),
Err(_) => false,
}
}
#[allow(dead_code)]
pub fn iter(&self) -> impl Iterator<Item = &LeafEntry> {
self.entries.iter()
}
pub fn range(&self, start: &Key, end: &Key) -> impl Iterator<Item = &LeafEntry> {
let start_idx = match self.find_key_index(start) {
Ok(i) | Err(i) => i,
};
let end_idx = match self.find_key_index(end) {
Ok(i) => i + 1, Err(i) => i,
};
let len = self.entries.len();
let lo = start_idx.min(end_idx).min(len);
let hi = end_idx.min(len);
self.entries[lo..hi].iter()
}
pub fn split(&mut self) -> (Key, LeafNode) {
let mid = self.entries.len() / 2;
let right_entries = self.entries.split_off(mid);
let split_key = right_entries[0].key.clone();
let right = LeafNode {
entries: right_entries,
next_leaf: self.next_leaf,
};
(split_key, right)
}
#[allow(dead_code)]
pub fn first_key(&self) -> Option<&Key> {
self.entries.first().map(|e| &e.key)
}
#[allow(dead_code)]
pub fn last_key(&self) -> Option<&Key> {
self.entries.last().map(|e| &e.key)
}
}
impl Default for LeafNode {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug)]
pub struct InternalNode {
keys: Vec<Key>,
children: Vec<PageId>,
}
impl InternalNode {
#[allow(dead_code)]
pub fn new(first_child: PageId) -> Self {
Self {
keys: Vec::new(),
children: vec![first_child],
}
}
pub fn from_split(left: PageId, key: Key, right: PageId) -> Self {
Self {
keys: vec![key],
children: vec![left, right],
}
}
pub fn from_page(page: &Page) -> Result<Self, StoreError> {
debug_assert_eq!(
page.page_type(),
PageType::Internal,
"expected internal page"
);
if page.item_count() == 0 {
return Err(StoreError::BTreeInvariant("empty internal node".into()));
}
let first_item = page.get_item(0);
if first_item.len() != 8 {
return Err(StoreError::BTreeInvariant("invalid first child".into()));
}
let first_child = PageId::new(u64::from_le_bytes(first_item.try_into().unwrap()));
let mut keys = Vec::new();
let mut children = vec![first_child];
for i in 1..page.item_count() {
let entry = InternalEntry::deserialize(page.get_item(i))?;
keys.push(entry.key);
children.push(entry.child);
}
Ok(Self { keys, children })
}
pub fn to_page(&self, page: &mut Page) -> Result<(), StoreError> {
debug_assert_eq!(
page.page_type(),
PageType::Internal,
"expected internal page"
);
while page.item_count() > 0 {
page.remove_item(page.item_count() - 1);
}
page.insert_item(0, &self.children[0].as_u64().to_le_bytes())?;
for (i, (key, child)) in self
.keys
.iter()
.zip(self.children.iter().skip(1))
.enumerate()
{
let entry = InternalEntry::new(key.clone(), *child);
page.insert_item(i + 1, &entry.serialize())?;
}
Ok(())
}
pub fn key_count(&self) -> usize {
self.keys.len()
}
#[allow(dead_code)]
pub fn child_count(&self) -> usize {
self.children.len()
}
pub fn find_child_index(&self, key: &Key) -> usize {
match self.keys.binary_search_by(|k| k.cmp(key)) {
Ok(i) => i + 1, Err(i) => i, }
}
pub fn find_child(&self, key: &Key) -> PageId {
let idx = self.find_child_index(key);
self.children[idx]
}
pub fn insert(&mut self, key: Key, right_child: PageId) {
let idx = match self.keys.binary_search_by(|k| k.cmp(&key)) {
Ok(i) | Err(i) => i,
};
self.keys.insert(idx, key);
self.children.insert(idx + 1, right_child);
}
pub fn split(&mut self) -> (Key, InternalNode) {
let mid = self.keys.len() / 2;
let split_key = self.keys.remove(mid);
let right_keys = self.keys.split_off(mid);
let right_children = self.children.split_off(mid + 1);
let right = InternalNode {
keys: right_keys,
children: right_children,
};
(split_key, right)
}
#[allow(dead_code)]
pub fn first_key(&self) -> Option<&Key> {
self.keys.first()
}
#[allow(dead_code)]
pub fn first_child(&self) -> PageId {
self.children[0]
}
#[allow(dead_code)]
pub fn children(&self) -> &[PageId] {
&self.children
}
}
#[cfg(test)]
mod node_tests {
use super::*;
#[test]
fn test_leaf_entry_serialization() {
let entry = LeafEntry::new(
Key::from("test-key"),
RowVersion::new(Offset::new(5), Bytes::from("test-value")),
);
let serialized = entry.serialize();
let deserialized = LeafEntry::deserialize(&serialized).unwrap();
assert_eq!(deserialized.key, entry.key);
assert_eq!(deserialized.versions.len(), entry.versions.len());
}
#[test]
fn test_internal_entry_serialization() {
let entry = InternalEntry::new(Key::from("separator"), PageId::new(42));
let serialized = entry.serialize();
let deserialized = InternalEntry::deserialize(&serialized).unwrap();
assert_eq!(deserialized.key, entry.key);
assert_eq!(deserialized.child, entry.child);
}
#[test]
fn test_leaf_range_inverted_returns_empty() {
let mut node = LeafNode::new();
for k in ["a", "b", "c"] {
node.insert(
Key::from(k),
RowVersion::new(Offset::new(1), Bytes::from(k)),
);
}
let inverted: Vec<&Key> = node
.range(&Key::from("zzz"), &Key::from("0"))
.map(|e| &e.key)
.collect();
assert!(inverted.is_empty(), "inverted range must return empty");
let past: Vec<&Key> = node
.range(&Key::from("c"), &Key::from("a"))
.map(|e| &e.key)
.collect();
assert!(past.is_empty(), "start > end within key set must be empty");
let normal: Vec<&Key> = node
.range(&Key::from("a"), &Key::from("c"))
.map(|e| &e.key)
.collect();
assert_eq!(
normal,
vec![&Key::from("a"), &Key::from("b"), &Key::from("c")]
);
}
#[test]
fn test_leaf_node_operations() {
let mut node = LeafNode::new();
node.insert(
Key::from("c"),
RowVersion::new(Offset::new(1), Bytes::from("C")),
);
node.insert(
Key::from("a"),
RowVersion::new(Offset::new(2), Bytes::from("A")),
);
node.insert(
Key::from("b"),
RowVersion::new(Offset::new(3), Bytes::from("B")),
);
assert_eq!(node.len(), 3);
let keys: Vec<&Key> = node.iter().map(|e| &e.key).collect();
assert_eq!(keys[0], &Key::from("a"));
assert_eq!(keys[1], &Key::from("b"));
assert_eq!(keys[2], &Key::from("c"));
assert!(node.get(&Key::from("b")).is_some());
assert!(node.get(&Key::from("d")).is_none());
}
#[test]
fn test_internal_node_find_child() {
let mut node = InternalNode::new(PageId::new(0));
node.insert(Key::from("m"), PageId::new(1));
node.insert(Key::from("t"), PageId::new(2));
assert_eq!(node.find_child(&Key::from("a")), PageId::new(0));
assert_eq!(node.find_child(&Key::from("m")), PageId::new(1));
assert_eq!(node.find_child(&Key::from("n")), PageId::new(1));
assert_eq!(node.find_child(&Key::from("t")), PageId::new(2));
assert_eq!(node.find_child(&Key::from("z")), PageId::new(2));
}
}