use anyhow::{anyhow, Result};
use super::gapped_node::GappedNode;
use super::linear_model::LinearModel;
pub struct MultiLevelAlexTree {
root: Option<Box<InnerNode>>,
leaves: Vec<GappedNode>,
height: usize,
num_keys: usize,
}
pub struct InnerNode {
model: LinearModel,
children: InnerNodeChildren,
split_keys: Vec<i64>,
_num_keys: usize,
_level: usize,
}
pub enum InnerNodeChildren {
Inner(Vec<Box<InnerNode>>),
Leaves(Vec<usize>),
}
#[allow(dead_code)]
const MIN_FANOUT: usize = 16; #[allow(dead_code)]
const MAX_FANOUT: usize = 256; #[allow(dead_code)]
const BULK_BUILD_FANOUT: usize = 64;
impl Default for MultiLevelAlexTree {
fn default() -> Self {
Self::new()
}
}
impl MultiLevelAlexTree {
#[must_use]
pub const fn new() -> Self {
Self {
root: None,
leaves: Vec::new(),
height: 0,
num_keys: 0,
}
}
pub fn bulk_build(mut data: Vec<(i64, Vec<u8>)>) -> Result<Self> {
if data.is_empty() {
return Ok(Self::new());
}
data.sort_unstable_by_key(|(k, _)| *k);
let num_keys = data.len();
let height = Self::calculate_height(num_keys);
let leaves = Self::build_leaves(&data)?;
let total_leaf_keys: usize = leaves
.iter()
.map(super::gapped_node::GappedNode::num_keys)
.sum();
if total_leaf_keys != num_keys {
tracing::warn!(
total_leaf_keys,
num_keys,
"Not all keys inserted into leaves during bulk build"
);
}
if leaves.len() == 1 {
return Ok(Self {
root: None,
leaves,
height: 1,
num_keys: total_leaf_keys,
});
}
let root = Self::build_inner_tree(&leaves, height - 1)?;
Ok(Self {
root: Some(root),
leaves,
height,
num_keys: total_leaf_keys,
})
}
const fn calculate_height(num_keys: usize) -> usize {
if num_keys <= 10_000 {
1 } else if num_keys <= 10_000_000 {
2 } else {
3 }
}
fn build_leaves(data: &[(i64, Vec<u8>)]) -> Result<Vec<GappedNode>> {
let mut leaves = Vec::new();
let keys_per_leaf = 64;
for chunk in data.chunks(keys_per_leaf) {
let mut node = GappedNode::new(chunk.len() * 2, 1.5);
let batch: Vec<(i64, Vec<u8>)> = chunk.to_vec();
if !node.insert_batch(&batch)? {
return Err(anyhow!("Failed to insert batch into leaf"));
}
node.retrain()?;
leaves.push(node);
}
Ok(leaves)
}
fn build_inner_tree(leaves: &[GappedNode], _target_height: usize) -> Result<Box<InnerNode>> {
let leaf_keys: Vec<(i64, usize)> = leaves
.iter()
.enumerate()
.filter_map(|(idx, leaf)| leaf.min_key().map(|key| (key, idx)))
.collect();
if leaf_keys.is_empty() {
return Err(anyhow!("No keys in leaves"));
}
let node = InnerNode::build_simple_root(&leaf_keys);
Ok(Box::new(node))
}
pub fn get(&self, key: i64) -> Result<Option<Vec<u8>>> {
if self.leaves.is_empty() {
return Ok(None);
}
let leaf_idx = self.route_to_leaf(key)?;
self.leaves[leaf_idx].get(key)
}
pub fn insert(&mut self, key: i64, value: Vec<u8>) -> Result<()> {
if self.leaves.is_empty() {
let mut leaf = GappedNode::new(64, 1.5);
leaf.insert(key, value)?;
self.leaves.push(leaf);
self.num_keys = 1;
self.height = 1;
return Ok(());
}
let leaf_idx = self.route_to_leaf(key)?;
if self.leaves[leaf_idx].insert(key, value.clone())? {
self.num_keys += 1;
Ok(())
} else {
self.split_leaf(leaf_idx, key, value)
}
}
fn route_to_leaf(&self, key: i64) -> Result<usize> {
if let Some(root) = &self.root {
return root.route_to_leaf(key);
}
if self.leaves.len() == 1 {
return Ok(0);
}
for (i, leaf) in self.leaves.iter().enumerate() {
if let Some(max_key) = leaf.max_key() {
if key <= max_key {
return Ok(i);
}
}
}
Ok(self.leaves.len() - 1)
}
fn split_leaf(&mut self, leaf_idx: usize, key: i64, value: Vec<u8>) -> Result<()> {
let (split_key, mut new_leaf) = self.leaves[leaf_idx].split()?;
if key < split_key {
self.leaves[leaf_idx].insert(key, value)?;
} else {
new_leaf.insert(key, value)?;
}
self.leaves.push(new_leaf);
self.num_keys += 1;
if let Some(root) = &mut self.root {
root.handle_leaf_split(leaf_idx, split_key, self.leaves.len() - 1);
}
Ok(())
}
#[must_use]
pub const fn len(&self) -> usize {
self.num_keys
}
#[must_use]
pub const fn is_empty(&self) -> bool {
self.num_keys == 0
}
#[must_use]
pub const fn num_leaves(&self) -> usize {
self.leaves.len()
}
#[must_use]
pub const fn height(&self) -> usize {
self.height
}
}
impl InnerNode {
fn build_simple_root(leaf_keys: &[(i64, usize)]) -> Self {
let mut model = LinearModel::new();
model.train(leaf_keys);
let leaf_indices: Vec<usize> = leaf_keys.iter().map(|(_, idx)| *idx).collect();
let mut split_keys = Vec::new();
for leaf_key in leaf_keys.iter().skip(1) {
split_keys.push(leaf_key.0);
}
Self {
model,
children: InnerNodeChildren::Leaves(leaf_indices),
split_keys,
_num_keys: leaf_keys.len(),
_level: 0,
}
}
#[allow(dead_code)]
fn build_from_leaves(
leaf_keys: &[(i64, usize)],
level: usize,
target_height: usize,
) -> Result<Self> {
let mut model = LinearModel::new();
model.train(leaf_keys);
let fanout = Self::calculate_fanout(leaf_keys.len());
let groups = Self::partition_leaves(leaf_keys, fanout);
let split_keys = Self::extract_split_keys(&groups);
let children = if level < target_height - 1 {
let inner_children: Result<Vec<Box<Self>>> = groups
.into_iter()
.map(|group| Self::build_from_leaves(&group, level + 1, target_height))
.map(|r| r.map(Box::new))
.collect();
InnerNodeChildren::Inner(inner_children?)
} else {
let leaf_indices: Vec<usize> = groups
.into_iter()
.flat_map(|group| group.into_iter().map(|(_, idx)| idx))
.collect();
InnerNodeChildren::Leaves(leaf_indices)
};
Ok(Self {
model,
children,
split_keys,
_num_keys: leaf_keys.len(),
_level: level,
})
}
#[allow(dead_code)]
fn calculate_fanout(num_children: usize) -> usize {
if num_children <= MIN_FANOUT {
MIN_FANOUT
} else if num_children <= BULK_BUILD_FANOUT {
num_children
} else {
BULK_BUILD_FANOUT.min(MAX_FANOUT)
}
}
#[allow(dead_code)]
fn partition_leaves(leaves: &[(i64, usize)], fanout: usize) -> Vec<Vec<(i64, usize)>> {
let mut groups = Vec::new();
let chunk_size = leaves.len().div_ceil(fanout);
for chunk in leaves.chunks(chunk_size) {
groups.push(chunk.to_vec());
}
groups
}
#[allow(dead_code)]
fn extract_split_keys(groups: &[Vec<(i64, usize)>]) -> Vec<i64> {
let mut split_keys = Vec::new();
for group in groups.iter().skip(1) {
if let Some((key, _)) = group.first() {
split_keys.push(*key);
}
}
split_keys
}
fn route_to_leaf(&self, key: i64) -> Result<usize> {
let predicted = self.model.predict(key);
let child_idx = self.find_child(key, predicted);
match &self.children {
InnerNodeChildren::Inner(children) => {
children[child_idx].route_to_leaf(key)
}
InnerNodeChildren::Leaves(indices) => {
Ok(indices[child_idx])
}
}
}
fn find_child(&self, key: i64, _predicted: usize) -> usize {
match self.split_keys.binary_search(&key) {
Ok(idx) | Err(idx) => idx.min(self.num_children() - 1),
}
}
const fn num_children(&self) -> usize {
match &self.children {
InnerNodeChildren::Inner(children) => children.len(),
InnerNodeChildren::Leaves(indices) => indices.len(),
}
}
fn handle_leaf_split(&mut self, old_leaf: usize, split_key: i64, new_leaf: usize) {
match &mut self.children {
InnerNodeChildren::Leaves(indices) => {
let insert_pos = indices
.iter()
.position(|&idx| idx == old_leaf)
.unwrap_or(indices.len());
indices.insert(insert_pos + 1, new_leaf);
if insert_pos < self.split_keys.len() {
self.split_keys.insert(insert_pos, split_key);
} else {
self.split_keys.push(split_key);
}
}
InnerNodeChildren::Inner(_) => {
}
}
}
}
impl GappedNode {
#[must_use]
pub fn min_key(&self) -> Option<i64> {
let pairs = self.pairs();
pairs.iter().map(|(k, _)| *k).min()
}
#[must_use]
pub fn max_key(&self) -> Option<i64> {
let pairs = self.pairs();
pairs.iter().map(|(k, _)| *k).max()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_tree() {
let tree = MultiLevelAlexTree::new();
assert_eq!(tree.len(), 0);
assert!(tree.is_empty());
assert_eq!(tree.height(), 0);
}
#[test]
fn test_single_insert() {
let mut tree = MultiLevelAlexTree::new();
tree.insert(42, vec![1, 2, 3]).unwrap();
assert_eq!(tree.len(), 1);
assert!(!tree.is_empty());
assert_eq!(tree.height(), 1);
let result = tree.get(42).unwrap();
assert_eq!(result, Some(vec![1, 2, 3]));
}
#[test]
fn test_bulk_build() {
let data = vec![
(1, vec![1]),
(10, vec![10]),
(20, vec![20]),
(30, vec![30]),
(40, vec![40]),
];
let tree = MultiLevelAlexTree::bulk_build(data).unwrap();
assert_eq!(tree.len(), 5);
assert_eq!(tree.get(1).unwrap(), Some(vec![1]));
assert_eq!(tree.get(20).unwrap(), Some(vec![20]));
assert_eq!(tree.get(40).unwrap(), Some(vec![40]));
assert_eq!(tree.get(100).unwrap(), None);
}
#[test]
fn test_multiple_inserts() {
let mut tree = MultiLevelAlexTree::new();
for i in 0..100 {
tree.insert(i, vec![i as u8]).unwrap();
}
assert_eq!(tree.len(), 100);
for i in 0..100 {
assert_eq!(tree.get(i).unwrap(), Some(vec![i as u8]));
}
}
}