use crate::simd;
use crate::RetrieveError;
pub struct BallTreeIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
params: BallTreeParams,
built: bool,
root: Option<BallNode>,
}
#[derive(Clone, Debug)]
pub struct BallTreeParams {
pub max_leaf_size: usize,
pub max_depth: usize,
}
impl Default for BallTreeParams {
fn default() -> Self {
Self {
max_leaf_size: 10,
max_depth: 32,
}
}
}
enum BallNode {
Internal {
center: Vec<f32>,
radius: f32,
left: Box<BallNode>,
right: Box<BallNode>,
},
Leaf {
indices: Vec<u32>,
center: Vec<f32>,
radius: f32,
},
}
impl BallTreeIndex {
pub fn new(dimension: usize, params: BallTreeParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::Other(
"Dimension must be greater than 0".to_string(),
));
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
params,
built: false,
root: None,
})
}
pub fn add(&mut self, _doc_id: u32, embedding: Vec<f32>) -> Result<(), RetrieveError> {
if embedding.len() != self.dimension {
return Err(RetrieveError::Other(format!(
"Embedding dimension {} != {}",
embedding.len(),
self.dimension
)));
}
if self.built {
return Err(RetrieveError::Other(
"Cannot add vectors after build".to_string(),
));
}
self.vectors.extend_from_slice(&embedding);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
let indices: Vec<u32> = (0..self.num_vectors as u32).collect();
self.root = Some(self.build_tree(&indices, 0)?);
self.built = true;
Ok(())
}
fn build_tree(&self, indices: &[u32], depth: usize) -> Result<BallNode, RetrieveError> {
if indices.is_empty() {
return Err(RetrieveError::Other("Empty indices".to_string()));
}
let center = self.compute_center(indices);
let radius = self.compute_radius(indices, ¢er);
if indices.len() <= self.params.max_leaf_size || depth >= self.params.max_depth {
return Ok(BallNode::Leaf {
indices: indices.to_vec(),
center,
radius,
});
}
let (seed1_idx, seed2_idx) = self.find_farthest_pair(indices);
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
for &idx in indices {
let vec = self.get_vector(idx as usize);
let dist1 = self.euclidean_distance(vec, &self.get_vector(seed1_idx as usize));
let dist2 = self.euclidean_distance(vec, &self.get_vector(seed2_idx as usize));
if dist1 < dist2 {
left_indices.push(idx);
} else {
right_indices.push(idx);
}
}
if left_indices.is_empty() {
left_indices.push(right_indices.pop().unwrap());
}
if right_indices.is_empty() {
right_indices.push(left_indices.pop().unwrap());
}
let left = self.build_tree(&left_indices, depth + 1)?;
let right = self.build_tree(&right_indices, depth + 1)?;
Ok(BallNode::Internal {
center,
radius,
left: Box::new(left),
right: Box::new(right),
})
}
fn compute_center(&self, indices: &[u32]) -> Vec<f32> {
let mut center = vec![0.0f32; self.dimension];
for &idx in indices {
let vec = self.get_vector(idx as usize);
for (j, &val) in vec.iter().enumerate() {
center[j] += val;
}
}
let count = indices.len() as f32;
for val in center.iter_mut() {
*val /= count;
}
center
}
fn compute_radius(&self, indices: &[u32], center: &[f32]) -> f32 {
let mut max_radius = 0.0f32;
for &idx in indices {
let vec = self.get_vector(idx as usize);
let dist = self.euclidean_distance(vec, center);
max_radius = max_radius.max(dist);
}
max_radius
}
fn find_farthest_pair(&self, indices: &[u32]) -> (u32, u32) {
let mut max_dist = 0.0f32;
let mut pair = (indices[0], indices[0]);
for i in 0..indices.len() {
for j in (i + 1)..indices.len() {
let vec1 = self.get_vector(indices[i] as usize);
let vec2 = self.get_vector(indices[j] as usize);
let dist = self.euclidean_distance(vec1, vec2);
if dist > max_dist {
max_dist = dist;
pair = (indices[i], indices[j]);
}
}
}
pair
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::Other("Index not built".to_string()));
}
if query.len() != self.dimension {
return Err(RetrieveError::Other(format!(
"Query dimension {} != {}",
query.len(),
self.dimension
)));
}
let root = self
.root
.as_ref()
.ok_or_else(|| RetrieveError::Other("Tree not built".to_string()))?;
let mut best_k: Vec<(f32, u32)> = Vec::with_capacity(k);
let mut best_dist = f32::INFINITY;
self.search_recursive_pruned(root, query, k, &mut best_k, &mut best_dist)?;
let mut results: Vec<(u32, f32)> = best_k.iter().map(|&(d, idx)| (idx, d)).collect();
results.sort_by(|a, b| a.1.total_cmp(&b.1));
Ok(results)
}
fn search_recursive_pruned(
&self,
node: &BallNode,
query: &[f32],
k: usize,
best_k: &mut Vec<(f32, u32)>,
best_dist: &mut f32,
) -> Result<(), RetrieveError> {
match node {
BallNode::Leaf {
indices,
center,
radius,
} => {
let dist_to_center = self.euclidean_distance(query, center);
let min_possible_dist = (dist_to_center - radius).max(0.0);
if min_possible_dist > *best_dist {
return Ok(());
}
for &idx in indices {
let vec = self.get_vector(idx as usize);
let dist = self.cosine_distance(query, vec);
if best_k.len() < k {
best_k.push((dist, idx));
if best_k.len() == k {
*best_dist = best_k
.iter()
.map(|&(d, _)| d)
.fold(f32::NEG_INFINITY, f32::max);
}
} else if dist < *best_dist {
if let Some(worst_idx) = best_k
.iter()
.enumerate()
.max_by(|a, b| a.1 .0.total_cmp(&b.1 .0))
.map(|(i, _)| i)
{
best_k[worst_idx] = (dist, idx);
*best_dist = best_k
.iter()
.map(|&(d, _)| d)
.fold(f32::NEG_INFINITY, f32::max);
}
}
}
}
BallNode::Internal {
center,
radius,
left,
right,
} => {
let dist_to_center = self.euclidean_distance(query, center);
let min_possible_dist = (dist_to_center - radius).max(0.0);
if min_possible_dist > *best_dist {
return Ok(());
}
let (left_center, left_radius) = match left.as_ref() {
BallNode::Internal { center, radius, .. } => (center, *radius),
BallNode::Leaf { center, radius, .. } => (center, *radius),
};
let (right_center, right_radius) = match right.as_ref() {
BallNode::Internal { center, radius, .. } => (center, *radius),
BallNode::Leaf { center, radius, .. } => (center, *radius),
};
let left_dist = self.euclidean_distance(query, left_center);
let right_dist = self.euclidean_distance(query, right_center);
let left_min = (left_dist - left_radius).max(0.0);
let right_min = (right_dist - right_radius).max(0.0);
if left_min < right_min {
self.search_recursive_pruned(left, query, k, best_k, best_dist)?;
self.search_recursive_pruned(right, query, k, best_k, best_dist)?;
} else {
self.search_recursive_pruned(right, query, k, best_k, best_dist)?;
self.search_recursive_pruned(left, query, k, best_k, best_dist)?;
}
}
}
Ok(())
}
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
fn euclidean_distance(&self, a: &[f32], b: &[f32]) -> f32 {
use crate::simd;
let a_squared = simd::dot(a, a);
let b_squared = simd::dot(b, b);
let ab_dot = simd::dot(a, b);
(a_squared + b_squared - 2.0 * ab_dot).sqrt()
}
fn cosine_distance(&self, a: &[f32], b: &[f32]) -> f32 {
crate::distance::cosine_distance_normalized(a, b)
}
}