use faer::{MatRef, RowRef};
use fixedbitset::FixedBitSet;
use rand::{rngs::StdRng, Rng, SeedableRng};
use rayon::prelude::*;
use std::collections::BinaryHeap;
use std::sync::{
atomic::{AtomicUsize, Ordering},
Arc,
};
use thousands::*;
use crate::prelude::*;
use crate::utils::tree_utils::*;
use crate::utils::*;
#[derive(Clone, Copy, Debug)]
#[repr(C)]
struct FlatNode {
n_descendants: u32,
child_a: u32,
child_b: u32,
split_idx: u32,
}
#[derive(Clone)]
enum BuildNode<T> {
Split {
hyperplane: Vec<T>,
offset: T,
left: usize,
right: usize,
},
Leaf {
items: Vec<usize>,
},
}
pub struct AnnoyIndex<T> {
pub vectors_flat: Vec<T>,
pub dim: usize,
pub n: usize,
norms: Vec<T>,
metric: Dist,
nodes: Vec<FlatNode>,
roots: Vec<u32>,
split_data: Vec<T>,
leaf_indices: Vec<usize>,
pub n_trees: usize,
original_ids: Vec<usize>,
}
impl<T> VectorDistance<T> for AnnoyIndex<T>
where
T: AnnSearchFloat,
{
fn vectors_flat(&self) -> &[T] {
&self.vectors_flat
}
fn dim(&self) -> usize {
self.dim
}
fn norms(&self) -> &[T] {
&self.norms
}
}
impl<T> AnnoyIndex<T>
where
T: AnnSearchFloat,
{
pub fn new(data: MatRef<T>, n_trees: usize, metric: Dist, seed: usize) -> Self {
let mut rng = StdRng::seed_from_u64(seed as u64);
let (vectors_flat, n, dim) = matrix_to_flat(data);
let norms = if metric == Dist::Cosine {
(0..n)
.map(|i| {
let start = i * dim;
let end = start + dim;
T::calculate_l2_norm(&vectors_flat[start..end])
})
.collect()
} else {
Vec::new()
};
let seeds: Vec<u64> = (0..n_trees).map(|_| rng.random()).collect();
let forest: Vec<Vec<BuildNode<T>>> = seeds
.into_par_iter()
.map(|tree_seed| {
let mut tree_rng = StdRng::seed_from_u64(tree_seed);
Self::build_tree_recursive(
&vectors_flat,
dim,
(0..n).collect(),
&mut tree_rng,
metric,
)
})
.collect();
let total_nodes: usize = forest.iter().map(|t| t.len()).sum();
let mut nodes = Vec::with_capacity(total_nodes);
let mut roots = Vec::with_capacity(n_trees);
let mut split_data = Vec::new();
let mut leaf_indices = Vec::new();
for tree in forest {
let root_offset = nodes.len() as u32;
roots.push(root_offset);
for node in tree {
match node {
BuildNode::Split {
hyperplane,
offset,
left,
right,
} => {
let split_idx = (split_data.len() / (dim + 1)) as u32;
split_data.extend(hyperplane);
split_data.push(offset);
nodes.push(FlatNode {
n_descendants: 2,
child_a: root_offset + left as u32,
child_b: root_offset + right as u32,
split_idx,
});
}
BuildNode::Leaf { items } => {
let start = leaf_indices.len() as u32;
let len = items.len() as u32;
leaf_indices.extend(items);
nodes.push(FlatNode {
n_descendants: 1,
child_a: start,
child_b: len,
split_idx: 0,
});
}
}
}
}
let mut idx = AnnoyIndex {
nodes,
roots,
split_data,
leaf_indices,
vectors_flat,
dim,
n_trees,
n,
norms,
metric,
original_ids: (0..n).collect(),
};
let new_to_old = idx.optimise_memory_layout();
idx.original_ids = new_to_old;
idx
}
fn build_tree_recursive(
vectors_flat: &[T],
dim: usize,
items: Vec<usize>,
rng: &mut StdRng,
metric: Dist,
) -> Vec<BuildNode<T>> {
let mut nodes = Vec::with_capacity(items.len());
Self::build_node(vectors_flat, dim, items, &mut nodes, rng, metric);
nodes
}
fn build_node(
vectors_flat: &[T],
dim: usize,
items: Vec<usize>,
nodes: &mut Vec<BuildNode<T>>,
rng: &mut StdRng,
metric: Dist,
) -> usize {
if items.len() <= LEAF_MIN_MEMBERS {
let node_idx = nodes.len();
nodes.push(BuildNode::Leaf { items });
return node_idx;
}
for _ in 0..10 {
let idx1 = items[rng.random_range(0..items.len())];
let idx2 = items[rng.random_range(0..items.len())];
if idx1 == idx2 {
continue;
}
let v1_start = idx1 * dim;
let v2_start = idx2 * dim;
let v1 = &vectors_flat[v1_start..v1_start + dim];
let v2 = &vectors_flat[v2_start..v2_start + dim];
let (hyperplane, threshold) = match metric {
Dist::Cosine => {
let norm1 = v1
.iter()
.map(|&x| x * x)
.fold(T::zero(), |a, b| a + b)
.sqrt();
let norm2 = v2
.iter()
.map(|&x| x * x)
.fold(T::zero(), |a, b| a + b)
.sqrt();
if norm1 == T::zero() || norm2 == T::zero() {
continue;
}
let hp: Vec<T> = (0..dim).map(|k| v1[k] / norm1 - v2[k] / norm2).collect();
(hp, T::zero())
}
Dist::Euclidean => {
let mut hp = Vec::with_capacity(dim);
let mut dot_v1 = T::zero();
let mut dot_v2 = T::zero();
for k in 0..dim {
let val1 = v1[k];
let val2 = v2[k];
hp.push(val1 - val2);
dot_v1 = dot_v1 + val1 * val1;
dot_v2 = dot_v2 + val2 * val2;
}
let thresh = (dot_v1 - dot_v2) / T::from_f64(2.0).unwrap();
(hp, thresh)
}
};
let mut left_items = Vec::new();
let mut right_items = Vec::new();
for &item in &items {
let vec_start = item * dim;
let vec = &vectors_flat[vec_start..vec_start + dim];
let mut dot = T::zero();
for k in 0..dim {
dot = dot + vec[k] * hyperplane[k];
}
if dot > threshold {
left_items.push(item);
} else {
right_items.push(item);
}
}
if left_items.is_empty() || right_items.is_empty() {
continue;
}
let ratio = left_items.len() as f64 / items.len() as f64;
if (0.05..=0.95).contains(&ratio) {
let node_idx = nodes.len();
nodes.push(BuildNode::Split {
hyperplane,
offset: threshold,
left: 0,
right: 0,
});
let left_idx = Self::build_node(vectors_flat, dim, left_items, nodes, rng, metric);
let right_idx =
Self::build_node(vectors_flat, dim, right_items, nodes, rng, metric);
if let BuildNode::Split {
ref mut left,
ref mut right,
..
} = nodes[node_idx]
{
*left = left_idx;
*right = right_idx;
}
return node_idx;
}
}
let node_idx = nodes.len();
nodes.push(BuildNode::Leaf { items });
node_idx
}
#[inline(always)]
fn get_margin(v1: &[T], v2: &[T], dim: usize) -> T {
T::dot_simd(v1, &v2[..dim]) - v2[dim]
}
fn optimise_memory_layout(&mut self) -> Vec<usize> {
if self.roots.is_empty() || self.n == 0 {
return Vec::new();
}
let mut new_to_old = Vec::with_capacity(self.n);
let mut old_to_new = vec![usize::MAX; self.n];
let mut visited = vec![false; self.n];
let mut stack = vec![self.roots[0]];
while let Some(node_idx) = stack.pop() {
let node = unsafe { self.nodes.get_unchecked(node_idx as usize) };
if node.n_descendants == 1 {
let start = node.child_a as usize;
let len = node.child_b as usize;
let leaf_items = unsafe { self.leaf_indices.get_unchecked(start..start + len) };
for &old_id in leaf_items {
if !visited[old_id] {
visited[old_id] = true;
old_to_new[old_id] = new_to_old.len();
new_to_old.push(old_id);
}
}
} else {
stack.push(node.child_b);
stack.push(node.child_a);
}
}
for old_id in 0..self.n {
if !visited[old_id] {
old_to_new[old_id] = new_to_old.len();
new_to_old.push(old_id);
}
}
let mut new_vectors_flat = Vec::with_capacity(self.vectors_flat.len());
let mut new_norms = if self.norms.is_empty() {
Vec::new()
} else {
Vec::with_capacity(self.n)
};
for &old_id in &new_to_old {
let start = old_id * self.dim;
let end = start + self.dim;
new_vectors_flat.extend_from_slice(&self.vectors_flat[start..end]);
if !self.norms.is_empty() {
new_norms.push(self.norms[old_id]);
}
}
for id_ref in self.leaf_indices.iter_mut() {
*id_ref = old_to_new[*id_ref];
}
self.vectors_flat = new_vectors_flat;
self.norms = new_norms;
new_to_old
}
#[inline]
pub fn query(
&self,
query_vec: &[T],
k: usize,
search_k: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
let limit = search_k.unwrap_or(k * self.n_trees * 20);
let mut visited_count = 0;
let mut visited = FixedBitSet::with_capacity(self.n);
let query_norm = if self.metric == Dist::Cosine {
query_vec
.iter()
.map(|&x| x * x)
.fold(T::zero(), |acc, x| acc + x)
.sqrt()
} else {
T::one()
};
let mut candidates: Vec<(T, usize)> = Vec::with_capacity(limit);
let mut pq = BinaryHeap::with_capacity(self.n_trees * 2);
for &root in &self.roots {
pq.push(BacktrackEntry {
margin: f64::MAX,
node_idx: root,
});
}
while visited_count < limit {
let Some(entry) = pq.pop() else { break };
let mut current_idx = entry.node_idx;
loop {
let node = unsafe { self.nodes.get_unchecked(current_idx as usize) };
if node.n_descendants == 1 {
let start = node.child_a as usize;
let len = node.child_b as usize;
visited_count += len;
let leaf_items = unsafe { self.leaf_indices.get_unchecked(start..start + len) };
for &item in leaf_items {
if visited.contains(item) {
continue;
}
visited.insert(item);
let dist = match self.metric {
Dist::Euclidean => self.euclidean_distance_to_query(item, query_vec),
Dist::Cosine => {
self.cosine_distance_to_query(item, query_vec, query_norm)
}
};
candidates.push((dist, item));
}
break;
} else {
let split_offset = node.split_idx as usize * (self.dim + 1);
let plane = unsafe {
self.split_data
.get_unchecked(split_offset..split_offset + self.dim + 1)
};
let margin = Self::get_margin(query_vec, plane, self.dim)
.to_f64()
.unwrap();
let (closer, farther) = if margin > 0.0 {
(node.child_a, node.child_b)
} else {
(node.child_b, node.child_a)
};
pq.push(BacktrackEntry {
margin: -margin.abs(),
node_idx: farther,
});
current_idx = closer;
}
}
}
if candidates.len() > k {
candidates.select_nth_unstable_by(k - 1, |a, b| {
a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal)
});
candidates.truncate(k);
}
candidates
.sort_unstable_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
candidates
.into_iter()
.map(|(dist, idx)| {
let original_idx = self.original_ids[idx];
(original_idx, dist)
})
.unzip()
}
#[inline]
pub fn query_row(
&self,
query_row: RowRef<T>,
k: usize,
search_k: Option<usize>,
) -> (Vec<usize>, Vec<T>) {
if query_row.col_stride() == 1 {
let slice =
unsafe { std::slice::from_raw_parts(query_row.as_ptr(), query_row.ncols()) };
return self.query(slice, k, search_k);
}
let query_vec: Vec<T> = query_row.iter().cloned().collect();
self.query(&query_vec, k, search_k)
}
pub fn generate_knn(
&self,
k: usize,
search_k: Option<usize>,
return_dist: bool,
verbose: bool,
) -> (Vec<Vec<usize>>, Option<Vec<Vec<T>>>) {
let counter = Arc::new(AtomicUsize::new(0));
let unordered_results: Vec<(usize, Vec<usize>, Vec<T>)> = (0..self.n)
.into_par_iter()
.map(|i| {
let start = i * self.dim;
let end = start + self.dim;
let vec = &self.vectors_flat[start..end];
let orig_id = self.original_ids[i];
if verbose {
let count = counter.fetch_add(1, Ordering::Relaxed) + 1;
if count.is_multiple_of(100_000) {
println!(
" Processed {} / {} samples.",
count.separate_with_underscores(),
self.n.separate_with_underscores()
);
}
}
let (indices, dists) = self.query(vec, k, search_k);
(orig_id, indices, dists)
})
.collect();
let mut final_indices = vec![Vec::new(); self.n];
let mut final_dists = if return_dist {
Some(vec![Vec::new(); self.n])
} else {
None
};
for (orig_id, indices, dists) in unordered_results {
final_indices[orig_id] = indices;
if let Some(ref mut fd) = final_dists {
fd[orig_id] = dists;
}
}
(final_indices, final_dists)
}
pub fn memory_usage_bytes(&self) -> usize {
std::mem::size_of_val(self)
+ self.vectors_flat.capacity() * std::mem::size_of::<T>()
+ self.norms.capacity() * std::mem::size_of::<T>()
+ self.nodes.capacity() * std::mem::size_of::<FlatNode>()
+ self.roots.capacity() * std::mem::size_of::<u32>()
+ self.split_data.capacity() * std::mem::size_of::<T>()
+ self.leaf_indices.capacity() * std::mem::size_of::<usize>()
+ self.original_ids.capacity() * std::mem::size_of::<usize>()
}
}
impl<T> KnnValidation<T> for AnnoyIndex<T>
where
T: AnnSearchFloat,
{
fn query_for_validation(&self, query_vec: &[T], k: usize) -> (Vec<usize>, Vec<T>) {
self.query(query_vec, k, None)
}
fn n(&self) -> usize {
self.n
}
fn metric(&self) -> Dist {
self.metric
}
fn original_ids(&self) -> &[usize] {
&self.original_ids
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
use faer::Mat;
fn create_simple_matrix() -> Mat<f32> {
let data = [
1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 0.0, 1.0, 0.0, 1.0, ];
Mat::from_fn(5, 3, |i, j| data[i * 3 + j])
}
#[test]
fn test_annoy_index_creation() {
let mat = create_simple_matrix();
let _ = AnnoyIndex::new(mat.as_ref(), 4, Dist::Euclidean, 42);
}
#[test]
fn test_annoy_query_finds_self() {
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 4, Dist::Euclidean, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 1, None);
assert_eq!(indices.len(), 1);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_annoy_query_euclidean() {
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
for i in 1..distances.len() {
assert!(distances[i] >= distances[i - 1]);
}
}
#[test]
fn test_annoy_query_cosine() {
use crate::prelude::*;
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 8, Dist::Cosine, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_annoy_query_k_larger_than_dataset() {
use crate::prelude::*;
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 4, Dist::Euclidean, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, _) = index.query(&query, 10, None);
assert!(indices.len() <= 5);
}
#[test]
fn test_annoy_query_search_k() {
use crate::prelude::*;
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 4, Dist::Euclidean, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices1, _) = index.query(&query, 3, Some(10));
let (indices2, _) = index.query(&query, 3, Some(50));
assert_eq!(indices1.len(), 3);
assert_eq!(indices2.len(), 3);
}
#[test]
fn test_annoy_multiple_trees() {
let mat = create_simple_matrix();
let index_few = AnnoyIndex::new(mat.as_ref(), 2, Dist::Euclidean, 42);
let index_many = AnnoyIndex::new(mat.as_ref(), 16, Dist::Euclidean, 42);
let query = vec![0.9, 0.1, 0.0];
let (indices1, _) = index_few.query(&query, 3, None);
let (indices2, _) = index_many.query(&query, 3, None);
assert_eq!(indices1.len(), 3);
assert_eq!(indices2.len(), 3);
}
#[test]
fn test_annoy_query_row() {
let mat = create_simple_matrix();
let index = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 42);
let (indices, distances) = index.query_row(mat.row(0), 1, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
}
#[test]
fn test_annoy_reproducibility() {
let mat = create_simple_matrix();
let index1 = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 42);
let index2 = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 42);
let query = vec![0.5, 0.5, 0.0];
let (indices1, _) = index1.query(&query, 3, None);
let (indices2, _) = index2.query(&query, 3, None);
assert_eq!(indices1, indices2);
}
#[test]
fn test_annoy_different_seeds() {
let mat = create_simple_matrix();
let index1 = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 42);
let index2 = AnnoyIndex::new(mat.as_ref(), 8, Dist::Euclidean, 123);
let query = vec![0.5, 0.5, 0.0];
let (indices1, _) = index1.query(&query, 3, None);
let (indices2, _) = index2.query(&query, 3, None);
assert_eq!(indices1.len(), 3);
assert_eq!(indices2.len(), 3);
}
#[test]
fn test_annoy_larger_dataset() {
let n = 100;
let dim = 10;
let mut data = Vec::with_capacity(n * dim);
for i in 0..n {
for j in 0..dim {
data.push((i * j) as f32 / 10.0);
}
}
let mat = Mat::from_fn(n, dim, |i, j| data[i * dim + j]);
let index = AnnoyIndex::new(mat.as_ref(), 16, Dist::Euclidean, 42);
let query: Vec<f32> = (0..dim).map(|_| 0.0).collect();
let (indices, _) = index.query(&query, 5, None);
assert_eq!(indices.len(), 5);
assert_eq!(indices[0], 0); }
#[test]
fn test_annoy_orthogonal_vectors() {
use crate::prelude::*;
let data = [1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
let mat = Mat::from_fn(3, 3, |i, j| data[i * 3 + j]);
let index = AnnoyIndex::new(mat.as_ref(), 4, Dist::Cosine, 42);
let query = vec![1.0, 0.0, 0.0];
let (indices, distances) = index.query(&query, 3, None);
assert_eq!(indices[0], 0);
assert_relative_eq!(distances[0], 0.0, epsilon = 1e-5);
assert_relative_eq!(distances[1], 1.0, epsilon = 1e-5);
assert_relative_eq!(distances[2], 1.0, epsilon = 1e-5);
}
#[test]
fn test_annoy_parallel_build() {
let n = 50;
let dim = 5;
let data: Vec<f32> = (0..n * dim).map(|i| i as f32).collect();
let mat = Mat::from_fn(n, dim, |i, j| data[i * dim + j]);
let index = AnnoyIndex::new(mat.as_ref(), 32, Dist::Euclidean, 42);
let query: Vec<f32> = (0..dim).map(|_| 0.0).collect();
let (indices, _) = index.query(&query, 3, None);
assert_eq!(indices.len(), 3);
}
}