use super::GappedNode;
use anyhow::Result;
#[derive(Debug)]
pub struct AlexTree {
leaves: Vec<GappedNode>,
split_keys: Vec<i64>,
}
impl AlexTree {
#[must_use]
pub fn new() -> Self {
Self {
leaves: vec![GappedNode::new(100, 1.0)],
split_keys: vec![],
}
}
#[must_use]
pub fn with_expansion(expansion_factor: f64) -> Self {
Self {
leaves: vec![GappedNode::new(100, expansion_factor)],
split_keys: vec![],
}
}
pub fn insert(&mut self, key: i64, value: Vec<u8>) -> Result<()> {
let leaf_idx = self.find_leaf_index(key);
let insert_result = self.leaves[leaf_idx].insert(key, value.clone())?;
if !insert_result {
let (split_key, right_leaf) = self.leaves[leaf_idx].split()?;
self.split_keys.insert(leaf_idx, split_key);
self.leaves.insert(leaf_idx + 1, right_leaf);
let new_leaf_idx = self.find_leaf_index(key);
self.leaves[new_leaf_idx].insert(key, value)?;
}
Ok(())
}
pub fn insert_batch(&mut self, mut entries: Vec<(i64, Vec<u8>)>) -> Result<()> {
if entries.is_empty() {
return Ok(());
}
entries.sort_unstable_by_key(|(k, _)| *k);
let mut leaf_groups: Vec<Vec<(i64, Vec<u8>)>> = vec![Vec::new(); self.leaves.len()];
for (key, value) in entries {
let leaf_idx = self.find_leaf_index(key);
leaf_groups[leaf_idx].push((key, value));
}
let mut modified_leaves = Vec::new();
for (leaf_idx, group) in leaf_groups.iter_mut().enumerate() {
if group.is_empty() {
continue;
}
let success = self.leaves[leaf_idx].insert_batch(group)?;
if !success {
for (key, value) in group.drain(..) {
self.insert(key, value)?;
}
}
modified_leaves.push(leaf_idx);
}
for leaf_idx in modified_leaves {
if self.leaves[leaf_idx].needs_retrain() {
self.leaves[leaf_idx].retrain()?;
}
}
Ok(())
}
pub fn get(&self, key: i64) -> Result<Option<Vec<u8>>> {
let leaf_idx = self.find_leaf_index(key);
self.leaves[leaf_idx].get(key)
}
pub fn lower_bound(&self, search_key: i64) -> Result<Option<(i64, Vec<u8>)>> {
let start_leaf_idx = self.find_leaf_index(search_key);
for leaf in &self.leaves[start_leaf_idx..] {
if let Some((key, _pos)) = leaf.lower_bound_position(search_key) {
if let Ok(Some(value)) = leaf.get(key) {
return Ok(Some((key, value)));
}
}
}
Ok(None)
}
fn find_leaf_index(&self, key: i64) -> usize {
match self.split_keys.binary_search(&key) {
Ok(idx) => idx + 1, Err(idx) => idx, }
}
pub fn len(&self) -> usize {
self.leaves
.iter()
.map(super::gapped_node::GappedNode::num_keys)
.sum()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.len() == 0
}
#[must_use]
pub const fn num_leaves(&self) -> usize {
self.leaves.len()
}
pub fn range(&self, start_key: i64, end_key: i64) -> Result<Vec<(i64, Vec<u8>)>> {
if start_key > end_key {
return Ok(Vec::new());
}
let mut results = Vec::new();
let start_leaf_idx = self.find_leaf_index(start_key);
for leaf in &self.leaves[start_leaf_idx..] {
for (key, value) in leaf.pairs() {
if key > end_key {
return Ok(results);
}
if key >= start_key {
results.push((key, value));
}
}
}
Ok(results)
}
}
impl Default for AlexTree {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_basic_insert_get() {
let mut tree = AlexTree::new();
tree.insert(10, vec![1]).unwrap();
tree.insert(20, vec![2]).unwrap();
tree.insert(30, vec![3]).unwrap();
assert_eq!(tree.len(), 3);
assert_eq!(tree.get(10).unwrap(), Some(vec![1]));
assert_eq!(tree.get(20).unwrap(), Some(vec![2]));
assert_eq!(tree.get(30).unwrap(), Some(vec![3]));
assert_eq!(tree.get(40).unwrap(), None);
}
#[test]
fn test_split_creates_new_leaf() {
let mut tree = AlexTree::with_expansion(0.0);
for i in 0..100 {
tree.insert(i, vec![i as u8]).unwrap();
}
assert!(tree.num_leaves() > 1);
assert_eq!(tree.len(), 100);
for i in 0..100 {
assert!(tree.get(i).unwrap().is_some(), "Missing key {}", i);
}
}
#[test]
fn test_out_of_order_inserts() {
let mut tree = AlexTree::new();
tree.insert(50, vec![5]).unwrap();
tree.insert(10, vec![1]).unwrap();
tree.insert(30, vec![3]).unwrap();
tree.insert(20, vec![2]).unwrap();
tree.insert(40, vec![4]).unwrap();
assert_eq!(tree.len(), 5);
for i in [10, 20, 30, 40, 50] {
assert!(tree.get(i).unwrap().is_some());
}
}
#[test]
fn test_large_scale() {
let mut tree = AlexTree::new();
for i in 0..10000 {
tree.insert(i, vec![(i % 256) as u8]).unwrap();
}
assert_eq!(tree.len(), 10000);
for i in (0..10000).step_by(100) {
assert!(tree.get(i).unwrap().is_some());
}
}
#[test]
fn test_range_query_basic() {
let mut tree = AlexTree::new();
for i in 1..=5 {
tree.insert(i * 10, vec![i as u8]).unwrap();
}
let results = tree.range(20, 40).unwrap();
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, 20);
assert_eq!(results[1].0, 30);
assert_eq!(results[2].0, 40);
}
#[test]
fn test_range_query_empty() {
let mut tree = AlexTree::new();
tree.insert(10, vec![1]).unwrap();
tree.insert(20, vec![2]).unwrap();
let results = tree.range(15, 18).unwrap();
assert_eq!(results.len(), 0);
let results = tree.range(30, 20).unwrap();
assert_eq!(results.len(), 0);
}
#[test]
fn test_range_query_large() {
let mut tree = AlexTree::new();
for i in 0..1000 {
tree.insert(i, vec![(i % 256) as u8]).unwrap();
}
let results = tree.range(100, 200).unwrap();
assert_eq!(results.len(), 101);
for i in 0..results.len() - 1 {
assert!(results[i].0 < results[i + 1].0);
}
}
#[test]
fn test_range_query_across_splits() {
let mut tree = AlexTree::with_expansion(0.0);
for i in 0..500 {
tree.insert(i, vec![(i % 256) as u8]).unwrap();
}
assert!(tree.num_leaves() > 1);
let results = tree.range(100, 400).unwrap();
assert_eq!(results.len(), 301);
for i in 100..=400 {
assert!(results.iter().any(|(k, _)| *k == i));
}
}
}