1use std::cmp::Ordering;
11use std::collections::BinaryHeap;
12
13const BRUTE_FORCE_THRESHOLD: usize = 1000;
14const M: usize = 16; const EF_CONSTRUCTION: usize = 200; const EF_SEARCH: usize = 64; const ML: f64 = 0.360_674_0;
19
20#[derive(Clone, PartialEq)]
22struct Candidate {
23 idx: usize,
24 sim: f32,
25}
26
27impl Eq for Candidate {}
28
29impl PartialOrd for Candidate {
30 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
31 Some(self.cmp(other))
32 }
33}
34
35impl Ord for Candidate {
36 fn cmp(&self, other: &Self) -> Ordering {
37 other.sim.partial_cmp(&self.sim).unwrap_or(Ordering::Equal)
39 }
40}
41
42#[derive(Clone, PartialEq)]
44struct MaxCandidate {
45 idx: usize,
46 sim: f32,
47}
48
49impl Eq for MaxCandidate {}
50
51impl PartialOrd for MaxCandidate {
52 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
53 Some(self.cmp(other))
54 }
55}
56
57impl Ord for MaxCandidate {
58 fn cmp(&self, other: &Self) -> Ordering {
59 self.sim.partial_cmp(&other.sim).unwrap_or(Ordering::Equal)
60 }
61}
62
63struct Node {
65 connections: Vec<Vec<usize>>, }
67
68pub struct AnnIndex {
71 vectors: Vec<Vec<f32>>,
72 nodes: Vec<Node>,
73 entry_point: usize,
74 max_level: usize,
75}
76
77impl AnnIndex {
78 pub fn build(vectors: Vec<Vec<f32>>) -> Self {
80 let n = vectors.len();
81 if n == 0 {
82 return Self {
83 vectors,
84 nodes: Vec::new(),
85 entry_point: 0,
86 max_level: 0,
87 };
88 }
89
90 if n < BRUTE_FORCE_THRESHOLD {
91 return Self {
92 vectors,
93 nodes: Vec::new(),
94 entry_point: 0,
95 max_level: 0,
96 };
97 }
98
99 let mut index = Self {
100 vectors: Vec::with_capacity(n),
101 nodes: Vec::with_capacity(n),
102 entry_point: 0,
103 max_level: 0,
104 };
105
106 for vec in vectors {
107 index.insert(vec);
108 }
109
110 index
111 }
112
113 fn insert(&mut self, vec: Vec<f32>) {
114 let level = Self::random_level();
115 let new_id = self.vectors.len();
116
117 self.vectors.push(vec);
118 self.nodes.push(Node {
119 connections: vec![Vec::new(); level + 1],
120 });
121
122 if self.nodes.len() == 1 {
123 self.entry_point = 0;
124 self.max_level = level;
125 return;
126 }
127
128 let mut ep = self.entry_point;
129
130 for lc in (level + 1..=self.max_level).rev() {
132 ep = self.search_layer_single(&self.vectors[new_id], ep, lc);
133 }
134
135 let insert_levels = level.min(self.max_level);
137 for lc in (0..=insert_levels).rev() {
138 let neighbors = self.search_layer(&self.vectors[new_id], ep, EF_CONSTRUCTION, lc);
139 let selected = Self::select_neighbors(&neighbors, M);
140
141 if lc < self.nodes[new_id].connections.len() {
142 self.nodes[new_id].connections[lc].clone_from(&selected);
143 }
144
145 for &neighbor in &selected {
146 if lc < self.nodes[neighbor].connections.len() {
147 self.nodes[neighbor].connections[lc].push(new_id);
148 if self.nodes[neighbor].connections[lc].len() > M * 2 {
149 let nv = &self.vectors[neighbor];
150 let mut scored: Vec<(usize, f32)> = self.nodes[neighbor].connections[lc]
151 .iter()
152 .map(|&n| (n, cosine_sim(nv, &self.vectors[n])))
153 .collect();
154 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
155 scored.truncate(M);
156 self.nodes[neighbor].connections[lc] =
157 scored.into_iter().map(|(id, _)| id).collect();
158 }
159 }
160 }
161
162 if !neighbors.is_empty() {
163 ep = neighbors[0].0;
164 }
165 }
166
167 if level > self.max_level {
168 self.max_level = level;
169 self.entry_point = new_id;
170 }
171 }
172
173 fn search_layer_single(&self, query: &[f32], ep: usize, _layer: usize) -> usize {
174 let mut current = ep;
175 let mut best_sim = cosine_sim(query, &self.vectors[ep]);
176
177 loop {
178 let mut improved = false;
179 let conns = &self.nodes[current].connections;
180 let layer_conns = if _layer < conns.len() {
181 &conns[_layer]
182 } else {
183 break;
184 };
185
186 for &neighbor in layer_conns {
187 let sim = cosine_sim(query, &self.vectors[neighbor]);
188 if sim > best_sim {
189 best_sim = sim;
190 current = neighbor;
191 improved = true;
192 }
193 }
194 if !improved {
195 break;
196 }
197 }
198 current
199 }
200
201 fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<(usize, f32)> {
202 let mut visited = vec![false; self.vectors.len()];
203 let mut candidates = BinaryHeap::<MaxCandidate>::new();
204 let mut results = BinaryHeap::<Candidate>::new();
205
206 let sim = cosine_sim(query, &self.vectors[ep]);
207 visited[ep] = true;
208 candidates.push(MaxCandidate { idx: ep, sim });
209 results.push(Candidate { idx: ep, sim });
210
211 while let Some(MaxCandidate { idx: c, sim: _ }) = candidates.pop() {
212 let worst_result = results.peek().map_or(f32::MIN, |r| r.sim);
213 if cosine_sim(query, &self.vectors[c]) < worst_result && results.len() >= ef {
214 break;
215 }
216
217 let conns = &self.nodes[c].connections;
218 let layer_conns = if layer < conns.len() {
219 &conns[layer]
220 } else {
221 continue;
222 };
223
224 for &neighbor in layer_conns {
225 if visited[neighbor] {
226 continue;
227 }
228 visited[neighbor] = true;
229
230 let n_sim = cosine_sim(query, &self.vectors[neighbor]);
231 let worst = results.peek().map_or(f32::MIN, |r| r.sim);
232
233 if results.len() < ef || n_sim > worst {
234 candidates.push(MaxCandidate {
235 idx: neighbor,
236 sim: n_sim,
237 });
238 results.push(Candidate {
239 idx: neighbor,
240 sim: n_sim,
241 });
242 if results.len() > ef {
243 results.pop();
244 }
245 }
246 }
247 }
248
249 let mut out: Vec<(usize, f32)> = results.into_iter().map(|c| (c.idx, c.sim)).collect();
250 out.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
251 out
252 }
253
254 fn select_neighbors(candidates: &[(usize, f32)], max_count: usize) -> Vec<usize> {
255 candidates
256 .iter()
257 .take(max_count)
258 .map(|&(idx, _)| idx)
259 .collect()
260 }
261
262 fn random_level() -> usize {
263 let mut buf = [0u8; 4];
264 let _ = getrandom::fill(&mut buf);
265 let r = f64::from(u32::from_le_bytes(buf)) / f64::from(u32::MAX);
266 (-r.ln() * ML).floor() as usize
267 }
268
269 pub fn search(&self, query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
272 if self.vectors.is_empty() {
273 return Vec::new();
274 }
275
276 if self.nodes.is_empty() || self.vectors.len() < BRUTE_FORCE_THRESHOLD {
278 return brute_force_topk(&self.vectors, query, top_k);
279 }
280
281 let mut ep = self.entry_point;
283 for lc in (1..=self.max_level).rev() {
284 ep = self.search_layer_single(query, ep, lc);
285 }
286
287 let mut results = self.search_layer(query, ep, EF_SEARCH.max(top_k), 0);
288 results.truncate(top_k);
289 results
290 }
291}
292
293pub fn brute_force_topk(vectors: &[Vec<f32>], query: &[f32], top_k: usize) -> Vec<(usize, f32)> {
295 let mut heap = BinaryHeap::<Candidate>::with_capacity(top_k + 1);
296
297 for (i, vec) in vectors.iter().enumerate() {
298 let sim = cosine_sim(query, vec);
299 if heap.len() < top_k {
300 heap.push(Candidate { idx: i, sim });
301 } else if let Some(worst) = heap.peek() {
302 if sim > worst.sim {
303 heap.pop();
304 heap.push(Candidate { idx: i, sim });
305 }
306 }
307 }
308
309 let mut results: Vec<(usize, f32)> = heap.into_iter().map(|c| (c.idx, c.sim)).collect();
310 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
311 results
312}
313
314#[inline]
315fn cosine_sim(a: &[f32], b: &[f32]) -> f32 {
316 if a.len() != b.len() {
317 return 0.0;
318 }
319 let mut dot = 0.0f32;
320 let mut norm_a = 0.0f32;
321 let mut norm_b = 0.0f32;
322 for i in 0..a.len() {
323 dot += a[i] * b[i];
324 norm_a += a[i] * a[i];
325 norm_b += b[i] * b[i];
326 }
327 let denom = (norm_a * norm_b).sqrt();
328 if denom < 1e-10 {
329 0.0
330 } else {
331 dot / denom
332 }
333}
334
335#[cfg(test)]
336mod tests {
337 use super::*;
338
339 fn random_vec(dim: usize, seed: u64) -> Vec<f32> {
340 let mut v = Vec::with_capacity(dim);
341 let mut s = seed;
342 for _ in 0..dim {
343 s = s.wrapping_mul(6364136223846793005).wrapping_add(1);
344 v.push((s as f32 / u64::MAX as f32) * 2.0 - 1.0);
345 }
346 v
347 }
348
349 #[test]
350 fn brute_force_topk_correctness() {
351 let vectors: Vec<Vec<f32>> = (0..100).map(|i| random_vec(16, i)).collect();
352 let query = random_vec(16, 999);
353
354 let results = brute_force_topk(&vectors, &query, 5);
355 assert_eq!(results.len(), 5);
356
357 for w in results.windows(2) {
359 assert!(w[0].1 >= w[1].1);
360 }
361 }
362
363 #[test]
364 fn brute_force_topk_matches_exhaustive() {
365 let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(8, i + 42)).collect();
366 let query = random_vec(8, 123);
367
368 let top5 = brute_force_topk(&vectors, &query, 5);
369
370 let mut all: Vec<(usize, f32)> = vectors
372 .iter()
373 .enumerate()
374 .map(|(i, v)| (i, cosine_sim(&query, v)))
375 .collect();
376 all.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
377 all.truncate(5);
378
379 for (heap_r, exact_r) in top5.iter().zip(all.iter()) {
380 assert_eq!(heap_r.0, exact_r.0);
381 assert!((heap_r.1 - exact_r.1).abs() < 1e-6);
382 }
383 }
384
385 #[test]
386 fn empty_index_returns_empty() {
387 let index = AnnIndex::build(Vec::new());
388 assert!(index.search(&[1.0, 0.0], 5).is_empty());
389 }
390
391 #[test]
392 fn small_index_uses_brute_force() {
393 let vectors: Vec<Vec<f32>> = (0..50).map(|i| random_vec(4, i)).collect();
394 let index = AnnIndex::build(vectors);
395 assert!(index.nodes.is_empty()); let results = index.search(&random_vec(4, 999), 3);
397 assert_eq!(results.len(), 3);
398 }
399}