use crate::{Vector, VectorIndex};
use anyhow::Result;
use oxirs_core::simd::SimdOps;
use scirs2_core::random::{Random, Rng, RngExt};
use std::cmp::Ordering;
use std::collections::BinaryHeap;
#[derive(Debug, Clone)]
pub struct TreeIndexConfig {
pub tree_type: TreeType,
pub max_leaf_size: usize,
pub random_seed: Option<u64>,
pub parallel_construction: bool,
pub distance_metric: DistanceMetric,
}
impl Default for TreeIndexConfig {
fn default() -> Self {
Self {
tree_type: TreeType::BallTree,
max_leaf_size: 16, random_seed: None,
parallel_construction: true,
distance_metric: DistanceMetric::Euclidean,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum TreeType {
BallTree,
KdTree,
VpTree,
CoverTree,
RandomProjectionTree,
}
#[derive(Debug, Clone, Copy)]
pub enum DistanceMetric {
Euclidean,
Manhattan,
Cosine,
Minkowski(f32),
}
impl DistanceMetric {
fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
match self {
DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
DistanceMetric::Cosine => f32::cosine_distance(a, b),
DistanceMetric::Minkowski(p) => a
.iter()
.zip(b.iter())
.map(|(x, y)| (x - y).abs().powf(*p))
.sum::<f32>()
.powf(1.0 / p),
}
}
}
#[derive(Debug, Clone)]
struct SearchResult {
index: usize,
distance: f32,
}
impl PartialEq for SearchResult {
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl Eq for SearchResult {}
impl PartialOrd for SearchResult {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for SearchResult {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
pub struct BallTree {
root: Option<Box<BallNode>>,
data: Vec<(String, Vector)>,
config: TreeIndexConfig,
}
#[derive(Clone)]
struct BallNode {
center: Vec<f32>,
radius: f32,
left: Option<Box<BallNode>>,
right: Option<Box<BallNode>>,
indices: Vec<usize>,
}
impl BallTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
}
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
let indices: Vec<usize> = (0..self.data.len()).collect();
let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
Ok(())
}
fn build_node_safe(
&self,
points: &[Vec<f32>],
indices: Vec<usize>,
depth: usize,
) -> Result<BallNode> {
const MAX_DEPTH: usize = 20;
if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
let center = self.compute_centroid(points, &indices);
let radius = self.compute_radius(points, &indices, ¢er);
return Ok(BallNode {
center,
radius,
left: None,
right: None,
indices,
});
}
let split_dim = self.find_split_dimension(points, &indices);
let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
if left_indices.is_empty() || right_indices.is_empty() {
let center = self.compute_centroid(points, &indices);
let radius = self.compute_radius(points, &indices, ¢er);
return Ok(BallNode {
center,
radius,
left: None,
right: None,
indices,
});
}
let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
let center = self.compute_centroid_of_centers(&all_centers);
let radius = left_node.radius.max(right_node.radius)
+ self
.config
.distance_metric
.distance(¢er, &left_node.center);
Ok(BallNode {
center,
radius,
left: Some(Box::new(left_node)),
right: Some(Box::new(right_node)),
indices: Vec::new(),
})
}
fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
let dim = points[0].len();
let mut centroid = vec![0.0; dim];
for &idx in indices {
for (i, &val) in points[idx].iter().enumerate() {
centroid[i] += val;
}
}
let n = indices.len() as f32;
for val in &mut centroid {
*val /= n;
}
centroid
}
fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
indices
.iter()
.map(|&idx| self.config.distance_metric.distance(&points[idx], center))
.fold(0.0f32, f32::max)
}
fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
let dim = points[0].len();
let mut max_spread = 0.0;
let mut split_dim = 0;
#[allow(clippy::needless_range_loop)]
for d in 0..dim {
let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let spread = max_val - min_val;
if spread > max_spread {
max_spread = spread;
split_dim = d;
}
}
split_dim
}
fn partition_indices(
&self,
points: &[Vec<f32>],
indices: &[usize],
dim: usize,
) -> (Vec<usize>, Vec<usize>) {
let mut values: Vec<(f32, usize)> =
indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let mid = values.len() / 2;
let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
(left_indices, right_indices)
}
fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
let dim = centers[0].len();
let mut centroid = vec![0.0; dim];
for center in centers {
for (i, &val) in center.iter().enumerate() {
centroid[i] += val;
}
}
let n = centers.len() as f32;
for val in &mut centroid {
*val /= n;
}
centroid
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
let mut stack: Vec<&BallNode> = vec![self
.root
.as_ref()
.expect("tree should have root after build")];
while let Some(node) = stack.pop() {
let dist_to_center = self.config.distance_metric.distance(query, &node.center);
if heap.len() >= k {
let worst_dist = heap.peek().expect("heap should have k elements").distance;
if dist_to_center - node.radius > worst_dist {
continue; }
}
if node.indices.is_empty() {
if let (Some(left), Some(right)) = (&node.left, &node.right) {
let left_dist = self.config.distance_metric.distance(query, &left.center);
let right_dist = self.config.distance_metric.distance(query, &right.center);
if left_dist < right_dist {
stack.push(right);
stack.push(left);
} else {
stack.push(left);
stack.push(right);
}
}
} else {
for &idx in &node.indices {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(query, point);
if heap.len() < k {
heap.push(SearchResult {
index: idx,
distance: dist,
});
} else if dist < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: idx,
distance: dist,
});
}
}
}
}
let mut results: Vec<(usize, f32)> =
heap.into_iter().map(|r| (r.index, r.distance)).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
}
pub struct KdTree {
root: Option<Box<KdNode>>,
data: Vec<(String, Vector)>,
config: TreeIndexConfig,
}
struct KdNode {
split_dim: usize,
split_value: f32,
left: Option<Box<KdNode>>,
right: Option<Box<KdNode>>,
indices: Vec<usize>,
}
impl KdTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
}
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
let indices: Vec<usize> = (0..self.data.len()).collect();
let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
Ok(())
}
fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
let max_depth = if !self.data.is_empty() {
((self.data.len() as f32).log2() * 2.0) as usize + 10
} else {
50
};
if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
return Ok(KdNode {
split_dim: 0,
split_value: 0.0,
left: None,
right: None,
indices,
});
}
let dimensions = points[0].len();
let split_dim = depth % dimensions;
let mut values: Vec<(f32, usize)> = indices
.iter()
.map(|&idx| (points[idx][split_dim], idx))
.collect();
values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let median_idx = values.len() / 2;
let split_value = values[median_idx].0;
let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(KdNode {
split_dim: 0,
split_value: 0.0,
left: None,
right: None,
indices,
});
}
let left = Some(Box::new(self.build_node(
points,
left_indices,
depth + 1,
)?));
let right = Some(Box::new(self.build_node(
points,
right_indices,
depth + 1,
)?));
Ok(KdNode {
split_dim,
split_value,
left,
right,
indices: Vec::new(),
})
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.search_node(
self.root
.as_ref()
.expect("tree should have root after build"),
query,
k,
&mut heap,
);
let mut results: Vec<(usize, f32)> =
heap.into_iter().map(|r| (r.index, r.distance)).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
fn search_node(
&self,
node: &KdNode,
query: &[f32],
k: usize,
heap: &mut BinaryHeap<SearchResult>,
) {
if !node.indices.is_empty() {
for &idx in &node.indices {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(query, point);
if heap.len() < k {
heap.push(SearchResult {
index: idx,
distance: dist,
});
} else if dist < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: idx,
distance: dist,
});
}
}
return;
}
let go_left = query[node.split_dim] <= node.split_value;
let (first, second) = if go_left {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(child) = first {
self.search_node(child, query, k, heap);
}
if heap.len() < k || {
let split_dist = (query[node.split_dim] - node.split_value).abs();
split_dist < heap.peek().expect("heap should have k elements").distance
} {
if let Some(child) = second {
self.search_node(child, query, k, heap);
}
}
}
}
pub struct VpTree {
root: Option<Box<VpNode>>,
data: Vec<(String, Vector)>,
config: TreeIndexConfig,
}
struct VpNode {
vantage_point: usize,
median_distance: f32,
inside: Option<Box<VpNode>>,
outside: Option<Box<VpNode>>,
indices: Vec<usize>,
}
impl VpTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
}
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
let indices: Vec<usize> = (0..self.data.len()).collect();
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
Ok(())
}
fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
self.build_node_safe(indices, rng, 0)
}
#[allow(deprecated)]
fn build_node_safe<R: Rng>(
&self,
mut indices: Vec<usize>,
rng: &mut R,
depth: usize,
) -> Result<VpNode> {
let max_depth = 30;
if indices.len() <= self.config.max_leaf_size
|| indices.len() <= 2 || depth >= max_depth
{
return Ok(VpNode {
vantage_point: if indices.is_empty() { 0 } else { indices[0] },
median_distance: 0.0,
inside: None,
outside: None,
indices,
});
}
let vp_idx = if indices.len() > 1 {
rng.random_range(0..indices.len())
} else {
0
};
let vantage_point = indices[vp_idx];
indices.remove(vp_idx);
let vp_data = &self.data[vantage_point].1.as_f32();
let mut distances: Vec<(f32, usize)> = indices
.iter()
.map(|&idx| {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(vp_data, point);
(dist, idx)
})
.collect();
distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let median_idx = distances.len() / 2;
let median_distance = distances[median_idx].0;
let inside_indices: Vec<usize> = distances[..median_idx]
.iter()
.map(|(_, idx)| *idx)
.collect();
let outside_indices: Vec<usize> = distances[median_idx..]
.iter()
.map(|(_, idx)| *idx)
.collect();
if inside_indices.is_empty() || outside_indices.is_empty() {
return Ok(VpNode {
vantage_point: if indices.is_empty() { 0 } else { indices[0] },
median_distance: 0.0,
inside: None,
outside: None,
indices,
});
}
let inside = Some(Box::new(self.build_node_safe(
inside_indices,
rng,
depth + 1,
)?));
let outside = Some(Box::new(self.build_node_safe(
outside_indices,
rng,
depth + 1,
)?));
Ok(VpNode {
vantage_point,
median_distance,
inside,
outside,
indices: Vec::new(),
})
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.search_node(
self.root
.as_ref()
.expect("tree should have root after build"),
query,
k,
&mut heap,
f32::INFINITY,
);
let mut results: Vec<(usize, f32)> =
heap.into_iter().map(|r| (r.index, r.distance)).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
fn search_node(
&self,
node: &VpNode,
query: &[f32],
k: usize,
heap: &mut BinaryHeap<SearchResult>,
tau: f32,
) -> f32 {
let mut tau = tau;
if !node.indices.is_empty() {
for &idx in &node.indices {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(query, point);
if dist < tau {
if heap.len() < k {
heap.push(SearchResult {
index: idx,
distance: dist,
});
} else if dist < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: idx,
distance: dist,
});
}
if heap.len() >= k {
tau = heap.peek().expect("heap should have k elements").distance;
}
}
}
return tau;
}
let vp_data = &self.data[node.vantage_point].1.as_f32();
let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
if dist_to_vp < tau {
if heap.len() < k {
heap.push(SearchResult {
index: node.vantage_point,
distance: dist_to_vp,
});
} else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: node.vantage_point,
distance: dist_to_vp,
});
}
if heap.len() >= k {
tau = heap.peek().expect("heap should have k elements").distance;
}
}
if dist_to_vp < node.median_distance {
if let Some(inside) = &node.inside {
tau = self.search_node(inside, query, k, heap, tau);
}
if dist_to_vp + tau >= node.median_distance {
if let Some(outside) = &node.outside {
tau = self.search_node(outside, query, k, heap, tau);
}
}
} else {
if let Some(outside) = &node.outside {
tau = self.search_node(outside, query, k, heap, tau);
}
if dist_to_vp - tau <= node.median_distance {
if let Some(inside) = &node.inside {
tau = self.search_node(inside, query, k, heap, tau);
}
}
}
tau
}
}
pub struct CoverTree {
root: Option<Box<CoverNode>>,
data: Vec<(String, Vector)>,
config: TreeIndexConfig,
base: f32,
}
struct CoverNode {
point: usize,
level: i32,
#[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
}
impl CoverTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
base: 2.0, }
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
self.root = Some(Box::new(CoverNode {
point: 0,
level: self.get_level(0),
children: Vec::new(),
}));
for idx in 1..self.data.len() {
self.insert(idx)?;
}
Ok(())
}
fn get_level(&self, _point_idx: usize) -> i32 {
((self.data.len() as f32).log2() as i32).max(0)
}
fn insert(&mut self, point_idx: usize) -> Result<()> {
let level = self.get_level(point_idx);
if let Some(root) = &mut self.root {
root.children.push(Box::new(CoverNode {
point: point_idx,
level,
children: Vec::new(),
}));
}
Ok(())
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut results = Vec::new();
self.search_node(
self.root
.as_ref()
.expect("tree should have root after build"),
query,
k,
&mut results,
);
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results.truncate(k);
results
}
#[allow(clippy::only_used_in_recursion)]
fn search_node(
&self,
node: &CoverNode,
query: &[f32],
k: usize,
results: &mut Vec<(usize, f32)>,
) {
if results.len() >= k * 10 {
return;
}
let point_data = &self.data[node.point].1.as_f32();
let dist = self.config.distance_metric.distance(query, point_data);
results.push((node.point, dist));
for child in &node.children {
self.search_node(child, query, k, results);
}
}
}
pub struct RandomProjectionTree {
root: Option<Box<RpNode>>,
data: Vec<(String, Vector)>,
config: TreeIndexConfig,
}
struct RpNode {
projection: Vec<f32>,
threshold: f32,
left: Option<Box<RpNode>>,
right: Option<Box<RpNode>>,
indices: Vec<usize>,
}
impl RandomProjectionTree {
pub fn new(config: TreeIndexConfig) -> Self {
Self {
root: None,
data: Vec::new(),
config,
}
}
pub fn build(&mut self) -> Result<()> {
if self.data.is_empty() {
return Ok(());
}
let indices: Vec<usize> = (0..self.data.len()).collect();
let dimensions = self.data[0].1.dimensions;
let mut rng = if let Some(seed) = self.config.random_seed {
Random::seed(seed)
} else {
Random::seed(42)
};
self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
Ok(())
}
fn build_node<R: Rng>(
&self,
indices: Vec<usize>,
dimensions: usize,
rng: &mut R,
) -> Result<RpNode> {
self.build_node_safe(indices, dimensions, rng, 0)
}
#[allow(deprecated)]
fn build_node_safe<R: Rng>(
&self,
indices: Vec<usize>,
dimensions: usize,
rng: &mut R,
depth: usize,
) -> Result<RpNode> {
if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
return Ok(RpNode {
projection: Vec::new(),
threshold: 0.0,
left: None,
right: None,
indices,
});
}
let projection: Vec<f32> = (0..dimensions)
.map(|_| rng.random_range(-1.0..1.0))
.collect();
let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
let projection: Vec<f32> = if norm > 0.0 {
projection.iter().map(|&x| x / norm).collect()
} else {
projection
};
let mut projections: Vec<(f32, usize)> = indices
.iter()
.map(|&idx| {
let point = &self.data[idx].1.as_f32();
let proj_val = f32::dot(point, &projection);
(proj_val, idx)
})
.collect();
projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
let median_idx = projections.len() / 2;
let threshold = projections[median_idx].0;
let left_indices: Vec<usize> = projections[..median_idx]
.iter()
.map(|(_, idx)| *idx)
.collect();
let right_indices: Vec<usize> = projections[median_idx..]
.iter()
.map(|(_, idx)| *idx)
.collect();
if left_indices.is_empty() || right_indices.is_empty() {
return Ok(RpNode {
projection: Vec::new(),
threshold: 0.0,
left: None,
right: None,
indices,
});
}
let left = Some(Box::new(self.build_node_safe(
left_indices,
dimensions,
rng,
depth + 1,
)?));
let right = Some(Box::new(self.build_node_safe(
right_indices,
dimensions,
rng,
depth + 1,
)?));
Ok(RpNode {
projection,
threshold,
left,
right,
indices: Vec::new(),
})
}
pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
if self.root.is_none() {
return Vec::new();
}
let mut heap = BinaryHeap::new();
self.search_node(
self.root
.as_ref()
.expect("tree should have root after build"),
query,
k,
&mut heap,
);
let mut results: Vec<(usize, f32)> =
heap.into_iter().map(|r| (r.index, r.distance)).collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
results
}
fn search_node(
&self,
node: &RpNode,
query: &[f32],
k: usize,
heap: &mut BinaryHeap<SearchResult>,
) {
if !node.indices.is_empty() {
for &idx in &node.indices {
let point = &self.data[idx].1.as_f32();
let dist = self.config.distance_metric.distance(query, point);
if heap.len() < k {
heap.push(SearchResult {
index: idx,
distance: dist,
});
} else if dist < heap.peek().expect("heap should have k elements").distance {
heap.pop();
heap.push(SearchResult {
index: idx,
distance: dist,
});
}
}
return;
}
let query_projection = f32::dot(query, &node.projection);
let go_left = query_projection <= node.threshold;
let (first, second) = if go_left {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(child) = first {
self.search_node(child, query, k, heap);
}
if let Some(child) = second {
self.search_node(child, query, k, heap);
}
}
}
pub struct TreeIndex {
tree_type: TreeType,
ball_tree: Option<BallTree>,
kd_tree: Option<KdTree>,
vp_tree: Option<VpTree>,
cover_tree: Option<CoverTree>,
rp_tree: Option<RandomProjectionTree>,
}
impl TreeIndex {
pub fn new(config: TreeIndexConfig) -> Self {
let tree_type = config.tree_type;
let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
TreeType::RandomProjectionTree => (
None,
None,
None,
None,
Some(RandomProjectionTree::new(config)),
),
};
Self {
tree_type,
ball_tree,
kd_tree,
vp_tree,
cover_tree,
rp_tree,
}
}
pub fn build(&mut self) -> Result<()> {
match self.tree_type {
TreeType::BallTree => self
.ball_tree
.as_mut()
.expect("ball_tree should be initialized for BallTree type")
.build(),
TreeType::KdTree => self
.kd_tree
.as_mut()
.expect("kd_tree should be initialized for KdTree type")
.build(),
TreeType::VpTree => self
.vp_tree
.as_mut()
.expect("vp_tree should be initialized for VpTree type")
.build(),
TreeType::CoverTree => self
.cover_tree
.as_mut()
.expect("cover_tree should be initialized for CoverTree type")
.build(),
TreeType::RandomProjectionTree => self
.rp_tree
.as_mut()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.build(),
}
}
fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
match self.tree_type {
TreeType::BallTree => self
.ball_tree
.as_ref()
.expect("ball_tree should be initialized for BallTree type")
.search(query, k),
TreeType::KdTree => self
.kd_tree
.as_ref()
.expect("kd_tree should be initialized for KdTree type")
.search(query, k),
TreeType::VpTree => self
.vp_tree
.as_ref()
.expect("vp_tree should be initialized for VpTree type")
.search(query, k),
TreeType::CoverTree => self
.cover_tree
.as_ref()
.expect("cover_tree should be initialized for CoverTree type")
.search(query, k),
TreeType::RandomProjectionTree => self
.rp_tree
.as_ref()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.search(query, k),
}
}
}
impl VectorIndex for TreeIndex {
fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
let data = match self.tree_type {
TreeType::BallTree => {
&mut self
.ball_tree
.as_mut()
.expect("ball_tree should be initialized for BallTree type")
.data
}
TreeType::KdTree => {
&mut self
.kd_tree
.as_mut()
.expect("kd_tree should be initialized for KdTree type")
.data
}
TreeType::VpTree => {
&mut self
.vp_tree
.as_mut()
.expect("vp_tree should be initialized for VpTree type")
.data
}
TreeType::CoverTree => {
&mut self
.cover_tree
.as_mut()
.expect("cover_tree should be initialized for CoverTree type")
.data
}
TreeType::RandomProjectionTree => {
&mut self
.rp_tree
.as_mut()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.data
}
};
data.push((uri, vector));
Ok(())
}
fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
let query_f32 = query.as_f32();
let results = self.search_internal(&query_f32, k);
let data = match self.tree_type {
TreeType::BallTree => {
&self
.ball_tree
.as_ref()
.expect("ball_tree should be initialized for BallTree type")
.data
}
TreeType::KdTree => {
&self
.kd_tree
.as_ref()
.expect("kd_tree should be initialized for KdTree type")
.data
}
TreeType::VpTree => {
&self
.vp_tree
.as_ref()
.expect("vp_tree should be initialized for VpTree type")
.data
}
TreeType::CoverTree => {
&self
.cover_tree
.as_ref()
.expect("cover_tree should be initialized for CoverTree type")
.data
}
TreeType::RandomProjectionTree => {
&self
.rp_tree
.as_ref()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.data
}
};
Ok(results
.into_iter()
.map(|(idx, dist)| (data[idx].0.clone(), dist))
.collect())
}
fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
let query_f32 = query.as_f32();
let all_results = self.search_internal(&query_f32, 1000);
let data = match self.tree_type {
TreeType::BallTree => {
&self
.ball_tree
.as_ref()
.expect("ball_tree should be initialized for BallTree type")
.data
}
TreeType::KdTree => {
&self
.kd_tree
.as_ref()
.expect("kd_tree should be initialized for KdTree type")
.data
}
TreeType::VpTree => {
&self
.vp_tree
.as_ref()
.expect("vp_tree should be initialized for VpTree type")
.data
}
TreeType::CoverTree => {
&self
.cover_tree
.as_ref()
.expect("cover_tree should be initialized for CoverTree type")
.data
}
TreeType::RandomProjectionTree => {
&self
.rp_tree
.as_ref()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.data
}
};
Ok(all_results
.into_iter()
.filter(|(_, dist)| *dist <= threshold)
.map(|(idx, dist)| (data[idx].0.clone(), dist))
.collect())
}
fn get_vector(&self, uri: &str) -> Option<&Vector> {
let data = match self.tree_type {
TreeType::BallTree => {
&self
.ball_tree
.as_ref()
.expect("ball_tree should be initialized for BallTree type")
.data
}
TreeType::KdTree => {
&self
.kd_tree
.as_ref()
.expect("kd_tree should be initialized for KdTree type")
.data
}
TreeType::VpTree => {
&self
.vp_tree
.as_ref()
.expect("vp_tree should be initialized for VpTree type")
.data
}
TreeType::CoverTree => {
&self
.cover_tree
.as_ref()
.expect("cover_tree should be initialized for CoverTree type")
.data
}
TreeType::RandomProjectionTree => {
&self
.rp_tree
.as_ref()
.expect("rp_tree should be initialized for RandomProjectionTree type")
.data
}
};
data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
}
}
async fn spawn_task<F, T>(f: F) -> T
where
F: FnOnce() -> T + Send + 'static,
T: Send + 'static,
{
f()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore = "Tree indices are experimental - see module documentation for alternatives"]
fn test_ball_tree() -> Result<()> {
let config = TreeIndexConfig {
tree_type: TreeType::BallTree,
max_leaf_size: 10,
..Default::default()
};
let mut ball_tree = BallTree::new(config);
for i in 0..100 {
let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
ball_tree.data.push((format!("vec_{i}"), vector));
}
ball_tree.build()?;
assert!(ball_tree.root.is_some());
let query = vec![50.0, 100.0];
let results = ball_tree.search(&query, 5);
assert!(results.len() <= 5);
assert!(!results.is_empty());
Ok(())
}
#[test]
#[ignore = "Investigating stack overflow with recursive tree construction"]
fn test_kd_tree() -> Result<()> {
let config = TreeIndexConfig {
tree_type: TreeType::KdTree,
max_leaf_size: 50, ..Default::default()
};
let mut index = TreeIndex::new(config);
for i in 0..3 {
let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
index.insert(format!("vec_{i}"), vector)?;
}
index.build()?;
let query = Vector::new(vec![1.0, 2.0]);
let results = index.search_knn(&query, 2)?;
assert_eq!(results.len(), 2);
Ok(())
}
#[test]
#[ignore = "Investigating stack overflow with recursive tree construction"]
fn test_vp_tree() -> Result<()> {
let config = TreeIndexConfig {
tree_type: TreeType::VpTree,
random_seed: Some(42),
max_leaf_size: 50, ..Default::default()
};
let mut index = TreeIndex::new(config);
for i in 0..3 {
let angle = (i as f32) * std::f32::consts::PI / 4.0;
let vector = Vector::new(vec![angle.cos(), angle.sin()]);
index.insert(format!("vec_{i}"), vector)?;
}
index.build()?;
let query = Vector::new(vec![1.0, 0.0]);
let results = index.search_knn(&query, 2)?;
assert_eq!(results.len(), 2);
Ok(())
}
}