use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
use crate::Vector;
use anyhow::Result;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
pub struct KdTree {
pub(crate) root: Option<Box<KdNode>>,
pub(crate) data: Vec<(String, Vector)>,
pub(crate) config: TreeIndexConfig,
}
pub(crate) struct KdNode {
split_dim: usize,
split_value: f32,
left: Option<Box<KdNode>>,
right: Option<Box<KdNode>>,
indices: Vec<usize>,
}
impl KdTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
}
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
let indices: Vec<usize> = (0..self.data.len()).collect();
let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
Ok(())
}
fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
let max_depth = if !self.data.is_empty() {
((self.data.len() as f32).log2() * 2.0) as usize + 10
} else {
50
};
if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
return Ok(KdNode {
split_dim: 0,
split_value: 0.0,
left: None,
right: None,
indices,
});
}
let dimensions = points[0].len();
let split_dim = depth % dimensions;
let mut values: Vec<(f32, usize)> = indices
.iter()
.map(|&idx| (points[idx][split_dim], idx))
.collect();
values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let median_idx = values.len() / 2;
let split_value = values[median_idx].0;
let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(KdNode {
split_dim: 0,
split_value: 0.0,
left: None,
right: None,
indices,
});
}
let left = Some(Box::new(self.build_node(
points,
left_indices,
depth + 1,
)?));
let right = Some(Box::new(self.build_node(
points,
right_indices,
depth + 1,
)?));
Ok(KdNode {
split_dim,
split_value,
left,
right,
indices: Vec::new(),
})
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.search_node(
self.root
.as_ref()
.expect("tree should have root after build"),
query,
k,
&mut heap,
);
let mut results: Vec<(usize, f32)> =
heap.into_iter().map(|r| (r.index, r.distance)).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
fn search_node(
&self,
node: &KdNode,
query: &[f32],
k: usize,
heap: &mut BinaryHeap<SearchResult>,
) {
if !node.indices.is_empty() {
for &idx in &node.indices {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(query, point);
if heap.len() < k {
heap.push(SearchResult {
index: idx,
distance: dist,
});
} else if dist < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: idx,
distance: dist,
});
}
}
return;
}
let go_left = query[node.split_dim] <= node.split_value;
let (first, second) = if go_left {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(child) = first {
self.search_node(child, query, k, heap);
}
if heap.len() < k || {
let split_dist = (query[node.split_dim] - node.split_value).abs();
split_dist < heap.peek().expect("heap should have k elements").distance
} {
if let Some(child) = second {
self.search_node(child, query, k, heap);
}
}
}
}