use std::cmp::Ordering;
use std::collections::BinaryHeap;
const BRUTE_FORCE_THRESHOLD: usize = 1000;
const M: usize = 16; const EF_CONSTRUCTION: usize = 200; const EF_SEARCH: usize = 64; const ML: f64 = 0.360_674_0;
#[derive(Clone, PartialEq)]
struct Candidate {
idx: usize,
sim: f32,
}
impl Eq for Candidate {}
impl PartialOrd for Candidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Candidate {
fn cmp(&self, other: &Self) -> Ordering {
other.sim.partial_cmp(&self.sim).unwrap_or(Ordering::Equal)
}
}
#[derive(Clone, PartialEq)]
struct MaxCandidate {
idx: usize,
sim: f32,
}
impl Eq for MaxCandidate {}
impl PartialOrd for MaxCandidate {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for MaxCandidate {
fn cmp(&self, other: &Self) -> Ordering {
self.sim.partial_cmp(&other.sim).unwrap_or(Ordering::Equal)
}
}
struct Node {
connections: Vec<Vec<usize>>, }
pub struct AnnIndex {
vectors: Vec<Vec<f32>>,
nodes: Vec<Node>,
entry_point: usize,
max_level: usize,
}
impl AnnIndex {
pub fn build(vectors: Vec<Vec<f32>>) -> Self {
let n = vectors.len();
if n == 0 {
return Self {
vectors,
nodes: Vec::new(),
entry_point: 0,
max_level: 0,
};
}
if n < BRUTE_FORCE_THRESHOLD {
return Self {
vectors,
nodes: Vec::new(),
entry_point: 0,
max_level: 0,
};
}
let mut index = Self {
vectors: Vec::with_capacity(n),
nodes: Vec::with_capacity(n),
entry_point: 0,
max_level: 0,
};
for vec in vectors {
index.insert(vec);
}
index
}
fn insert(&mut self, vec: Vec<f32>) {
let level = Self::random_level();
let new_id = self.vectors.len();
self.vectors.push(vec);
self.nodes.push(Node {
connections: vec![Vec::new(); level + 1],
});
if self.nodes.len() == 1 {
self.entry_point = 0;
self.max_level = level;
return;
}
let mut ep = self.entry_point;
for lc in (level + 1..=self.max_level).rev() {
ep = self.search_layer_single(&self.vectors[new_id], ep, lc);
}
let insert_levels = level.min(self.max_level);
for lc in (0..=insert_levels).rev() {
let neighbors = self.search_layer(&self.vectors[new_id], ep, EF_CONSTRUCTION, lc);
let selected = Self::select_neighbors(&neighbors, M);
if lc < self.nodes[new_id].connections.len() {
self.nodes[new_id].connections[lc].clone_from(&selected);
}
for &neighbor in &selected {
if lc < self.nodes[neighbor].connections.len() {
self.nodes[neighbor].connections[lc].push(new_id);
if self.nodes[neighbor].connections[lc].len() > M * 2 {
let nv = &self.vectors[neighbor];
let mut scored: Vec<(usize, f32)> = self.nodes[neighbor].connections[lc]
.iter()
.map(|&n| (n, cosine_sim(nv, &self.vectors[n])))
.collect();
scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
scored.truncate(M);
self.nodes[neighbor].connections[lc] =
scored.into_iter().map(|(id, _)| id).collect();
}
}
}
if !neighbors.is_empty() {
ep = neighbors[0].0;
}
}
if level > self.max_level {
self.max_level = level;
self.entry_point = new_id;
}
}
fn search_layer_single(&self, query: &[f32], ep: usize, _layer: usize) -> usize {
let mut current = ep;
let mut best_sim = cosine_sim(query, &self.vectors[ep]);
loop {
let mut improved = false;
let conns = &self.nodes[current].connections;
let layer_conns = if _layer < conns.len() {
&conns[_layer]
} else {
break;
};
for &neighbor in layer_conns {
let sim = cosine_sim(query, &self.vectors[neighbor]);
if sim > best_sim {
best_sim = sim;
current = neighbor;
improved = true;
}
}
if !improved {
break;
}
}
current
}
fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<(usize, f32)> {
let mut visited = vec![false; self.vectors.len()];
let mut candidates = BinaryHeap::<MaxCandidate>::new();
let mut results = BinaryHeap::<Candidate>::new();
let sim = cosine_sim(query, &self.vectors[ep]);
visited[ep] = true;
candidates.push(MaxCandidate { idx: ep, sim });
results.push(Candidate { idx: ep, sim });
while let Some(MaxCandidate { idx: c, sim: _ }) = candidates.pop() {
let worst_result = results.peek().map_or(f32::MIN, |r| r.sim);
if cosine_sim(query, &self.vectors[c]) < worst_result && results.len() >= ef {
break;
}
let conns = &self.nodes[c].connections;
let layer_conns = if layer < conns.len() {
&conns[layer]
} else {
continue;
};
for &neighbor in layer_conns {
if visited[neighbor] {
continue;
}
visited[neighbor] = true;
let n_sim = cosine_sim(query, &self.vectors[neighbor]);
let worst = results.peek().map_or(f32::MIN, |r| r.sim);
if results.len() < ef || n_sim > worst {
candidates.push(MaxCandidate {
idx: neighbor,
sim: n_sim,
});
results.push(Candidate {
idx: neighbor,
sim: n_sim,
});
if results.len() > ef {
results.pop();
}
}
}
}
let mut out: Vec<(usize, f32)> = results.into_iter().map(|c| (c.idx, c.sim)).collect();
out.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
out
}
fn select_neighbors(candidates: &[(usize, f32)], max_count: usize) -> Vec<usize> {
candidates
.iter()
.take(max_count)
.map(|&(idx, _)| idx)
.collect()
}
fn random_level() -> usize {
let mut buf = [0u8; 4];
let _ = getrandom::fill(&mut buf);
let r = f64::from(u32::from_le_bytes(buf)) / f64::from(u32::MAX);
(-r.ln() * ML).floor() as usize
}
pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
if self.vectors.is_empty() {
return Vec::new();
}
if self.nodes.is_empty() || self.vectors.len() < BRUTE_FORCE_THRESHOLD {
return brute_force_topk(&self.vectors, query, top_k);
}
let mut ep = self.entry_point;
for lc in (1..=self.max_level).rev() {
ep = self.search_layer_single(query, ep, lc);
}
let mut results = self.search_layer(query, ep, EF_SEARCH.max(top_k), 0);
results.truncate(top_k);
results
}
}
pub fn brute_force_topk(vectors: &[Vec<f32>], query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
let mut heap = BinaryHeap::<Candidate>::with_capacity(top_k + 1);
for (i, vec) in vectors.iter().enumerate() {
let sim = cosine_sim(query, vec);
if heap.len() < top_k {
heap.push(Candidate { idx: i, sim });
} else if let Some(worst) = heap.peek() {
if sim > worst.sim {
heap.pop();
heap.push(Candidate { idx: i, sim });
}
}
}
let mut results: Vec<(usize, f32)> = heap.into_iter().map(|c| (c.idx, c.sim)).collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
results
}
#[inline]
fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
let mut dot = 0.0f32;
let mut norm_a = 0.0f32;
let mut norm_b = 0.0f32;
for i in 0..a.len() {
dot += a[i] * b[i];
norm_a += a[i] * a[i];
norm_b += b[i] * b[i];
}
let denom = (norm_a * norm_b).sqrt();
if denom < 1e-10 {
0.0
} else {
dot / denom
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
let mut v = Vec::with_capacity(dim);
let mut s = seed;
for _ in 0..dim {
s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
v.push((s as f32 / u64::MAX as f32) * 2.0 - 1.0);
}
v
}
#[test]
fn brute_force_topk_correctness() {
let vectors: Vec<Vec<f32>> = (0..100).map(|i| random_vec(16, i)).collect();
let query = random_vec(16, 999);
let results = brute_force_topk(&vectors, &query, 5);
assert_eq!(results.len(), 5);
for w in results.windows(2) {
assert!(w[0].1 >= w[1].1);
}
}
#[test]
fn brute_force_topk_matches_exhaustive() {
let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(8, i + 42)).collect();
let query = random_vec(8, 123);
let top5 = brute_force_topk(&vectors, &query, 5);
let mut all: Vec<(usize, f32)> = vectors
.iter()
.enumerate()
.map(|(i, v)| (i, cosine_sim(&query, v)))
.collect();
all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
all.truncate(5);
for (heap_r, exact_r) in top5.iter().zip(all.iter()) {
assert_eq!(heap_r.0, exact_r.0);
assert!((heap_r.1 - exact_r.1).abs() < 1e-6);
}
}
#[test]
fn empty_index_returns_empty() {
let index = AnnIndex::build(Vec::new());
assert!(index.search(&[1.0, 0.0], 5).is_empty());
}
#[test]
fn small_index_uses_brute_force() {
let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(4, i)).collect();
let index = AnnIndex::build(vectors);
assert!(index.nodes.is_empty()); let results = index.search(&random_vec(4, 999), 3);
assert_eq!(results.len(), 3);
}
}