use crate::simd;
use crate::RetrieveError;
pub struct AnnoyIndex {
pub(crate) vectors: Vec<f32>,
pub(crate) dimension: usize,
pub(crate) num_vectors: usize,
params: AnnoyParams,
built: bool,
pub(crate) trees: Vec<RPTree>,
}
#[derive(Clone, Debug)]
pub struct AnnoyParams {
pub num_trees: usize,
pub tree_params: RPTreeParams,
}
impl Default for AnnoyParams {
fn default() -> Self {
Self {
num_trees: 10,
tree_params: RPTreeParams::default(),
}
}
}
pub(crate) struct RPTree {
root: Option<TreeNode>,
}
enum TreeNode {
Leaf {
indices: Vec<u32>,
},
Internal {
hyperplane: Vec<f32>, threshold: f32,
left: Box<TreeNode>,
right: Box<TreeNode>,
},
}
#[derive(Clone, Debug)]
pub struct RPTreeParams {
pub max_leaf_size: usize,
}
impl Default for RPTreeParams {
fn default() -> Self {
Self { max_leaf_size: 10 }
}
}
impl AnnoyIndex {
pub fn new(dimension: usize, params: AnnoyParams) -> Result<Self, RetrieveError> {
if dimension == 0 {
return Err(RetrieveError::EmptyQuery);
}
Ok(Self {
vectors: Vec::new(),
dimension,
num_vectors: 0,
params,
built: false,
trees: Vec::new(),
})
}
pub fn add(&mut self, _doc_id: u32, vector: Vec<f32>) -> Result<(), RetrieveError> {
if self.built {
return Err(RetrieveError::Other(
"Cannot add vectors after index is built".to_string(),
));
}
if vector.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: vector.len(),
});
}
self.vectors.extend_from_slice(&vector);
self.num_vectors += 1;
Ok(())
}
pub fn build(&mut self) -> Result<(), RetrieveError> {
if self.built {
return Ok(());
}
if self.num_vectors == 0 {
return Err(RetrieveError::EmptyIndex);
}
self.trees = Vec::new();
for _ in 0..self.params.num_trees {
let tree = self.build_tree()?;
self.trees.push(tree);
}
self.built = true;
Ok(())
}
fn build_tree(&self) -> Result<RPTree, RetrieveError> {
use rand::Rng;
let mut rng = rand::rng();
let mut hyperplane = Vec::new();
let mut norm = 0.0;
for _ in 0..self.dimension {
let val = rng.random::<f32>() * 2.0 - 1.0;
norm += val * val;
hyperplane.push(val);
}
let norm = norm.sqrt();
if norm > 0.0 {
for val in &mut hyperplane {
*val /= norm;
}
}
let indices: Vec<u32> = (0..self.num_vectors as u32).collect();
let root = self.build_tree_recursive(&indices, &hyperplane, 0)?;
Ok(RPTree { root })
}
fn build_tree_recursive(
&self,
indices: &[u32],
hyperplane: &[f32],
depth: usize,
) -> Result<Option<TreeNode>, RetrieveError> {
if indices.is_empty() {
return Ok(None);
}
if indices.len() <= self.params.tree_params.max_leaf_size {
return Ok(Some(TreeNode::Leaf {
indices: indices.to_vec(),
}));
}
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
for &idx in indices {
let vec = self.get_vector(idx as usize);
let projection = simd::dot(vec, hyperplane);
if projection < 0.0 {
left_indices.push(idx);
} else {
right_indices.push(idx);
}
}
use rand::Rng;
let mut rng = rand::rng();
let mut new_hyperplane = Vec::new();
let mut norm = 0.0;
for _ in 0..self.dimension {
let val = rng.random::<f32>() * 2.0 - 1.0;
norm += val * val;
new_hyperplane.push(val);
}
let norm = norm.sqrt();
if norm > 0.0 {
for val in &mut new_hyperplane {
*val /= norm;
}
}
let left = self.build_tree_recursive(&left_indices, &new_hyperplane, depth + 1)?;
let right = self.build_tree_recursive(&right_indices, &new_hyperplane, depth + 1)?;
Ok(Some(TreeNode::Internal {
hyperplane: hyperplane.to_vec(),
threshold: 0.0,
left: Box::new(left.unwrap_or(TreeNode::Leaf {
indices: Vec::new(),
})),
right: Box::new(right.unwrap_or(TreeNode::Leaf {
indices: Vec::new(),
})),
}))
}
pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<(u32, f32)>, RetrieveError> {
if !self.built {
return Err(RetrieveError::Other(
"Index must be built before search".to_string(),
));
}
if query.len() != self.dimension {
return Err(RetrieveError::DimensionMismatch {
query_dim: self.dimension,
doc_dim: query.len(),
});
}
let mut candidate_set = std::collections::HashSet::new();
for tree in &self.trees {
if let Some(ref root) = tree.root {
let candidates = self.search_tree(root, query);
for idx in candidates {
candidate_set.insert(idx);
}
}
}
let mut results: Vec<(u32, f32)> = candidate_set
.iter()
.map(|&idx| {
let vec = self.get_vector(idx as usize);
let dist = 1.0 - simd::dot(query, vec);
(idx, dist)
})
.collect();
results.sort_unstable_by(|a, b| a.1.total_cmp(&b.1)); Ok(results.into_iter().take(k).collect())
}
fn search_tree(&self, node: &TreeNode, query: &[f32]) -> Vec<u32> {
match node {
TreeNode::Leaf { indices } => indices.clone(),
TreeNode::Internal {
hyperplane,
threshold: _,
left,
right,
} => {
let projection = simd::dot(query, hyperplane);
if projection < 0.0 {
self.search_tree(left, query)
} else {
self.search_tree(right, query)
}
}
}
}
fn get_vector(&self, idx: usize) -> &[f32] {
let start = idx * self.dimension;
let end = start + self.dimension;
&self.vectors[start..end]
}
}