1use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::distance::DistanceMetric;
13
14#[derive(Clone, Copy)]
16struct Neighbor {
17 dist: f32,
18 id: usize,
19}
20
21impl PartialEq for Neighbor {
22 fn eq(&self, other: &Self) -> bool {
23 self.dist.to_bits() == other.dist.to_bits() && self.id == other.id
24 }
25}
26
27impl Eq for Neighbor {}
28
29impl PartialOrd for Neighbor {
30 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
31 Some(self.cmp(other))
32 }
33}
34
35impl Ord for Neighbor {
36 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
37 self.dist
38 .partial_cmp(&other.dist)
39 .unwrap_or(std::cmp::Ordering::Equal)
40 .then(self.id.cmp(&other.id))
41 }
42}
43
44pub struct HnswConfig {
46 pub m: usize,
48 pub m_max0: usize,
50 pub ef_construction: usize,
52 pub metric: DistanceMetric,
54}
55
56impl Default for HnswConfig {
57 fn default() -> Self {
58 Self {
59 m: 16,
60 m_max0: 32,
61 ef_construction: 200,
62 metric: DistanceMetric::L2,
63 }
64 }
65}
66
67pub struct HnswIndex {
69 vectors: Vec<Vec<f32>>,
70 neighbors: Vec<Vec<Vec<(usize, f32)>>>,
72 entry_point: Option<usize>,
73 max_layer: usize,
74 m: usize,
75 m_max0: usize,
76 ef_construction: usize,
77 dim: usize,
78 metric: DistanceMetric,
79 ml: f64, rng_state: u64,
81}
82
83impl HnswIndex {
84 pub fn new(dim: usize, config: HnswConfig) -> Self {
86 let ml = 1.0 / (config.m as f64).ln();
87 Self {
88 vectors: Vec::new(),
89 neighbors: Vec::new(),
90 entry_point: None,
91 max_layer: 0,
92 m: config.m,
93 m_max0: config.m_max0,
94 ef_construction: config.ef_construction,
95 dim,
96 metric: config.metric,
97 ml,
98 rng_state: 0x5DEECE66D, }
100 }
101
102 pub fn len(&self) -> usize {
104 self.vectors.len()
105 }
106
107 pub fn is_empty(&self) -> bool {
109 self.vectors.is_empty()
110 }
111
112 pub fn insert(&mut self, vector: &[f32]) -> usize {
114 assert_eq!(vector.len(), self.dim, "dimension mismatch");
115
116 let id = self.vectors.len();
117 self.vectors.push(vector.to_vec());
118
119 let level = self.random_level();
120
121 let mut layers = Vec::with_capacity(level + 1);
123 for _ in 0..=level {
124 layers.push(Vec::new());
125 }
126 self.neighbors.push(layers);
127
128 if self.entry_point.is_none() {
129 self.entry_point = Some(id);
131 self.max_layer = level;
132 return id;
133 }
134
135 let ep = self.entry_point.unwrap();
136 let mut current_ep = ep;
137
138 for layer in (level + 1..=self.max_layer).rev() {
140 current_ep = self.greedy_closest(vector, current_ep, layer);
141 }
142
143 let start_layer = level.min(self.max_layer);
146 for layer in (0..=start_layer).rev() {
147 let m_for_layer = if layer == 0 { self.m_max0 } else { self.m };
148
149 let candidates = self.search_layer(vector, current_ep, self.ef_construction, layer);
151
152 let selected: Vec<(usize, f32)> = candidates
154 .into_iter()
155 .take(m_for_layer)
156 .map(|n| (n.id, n.dist))
157 .collect();
158
159 self.neighbors[id][layer] = selected.clone();
161
162 for &(neighbor_id, dist) in &selected {
164 if neighbor_id < self.neighbors.len()
165 && layer < self.neighbors[neighbor_id].len()
166 {
167 self.neighbors[neighbor_id][layer].push((id, dist));
168 if self.neighbors[neighbor_id][layer].len() > m_for_layer {
170 self.neighbors[neighbor_id][layer]
171 .sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
172 self.neighbors[neighbor_id][layer].truncate(m_for_layer);
173 }
174 }
175 }
176
177 if !selected.is_empty() {
179 current_ep = selected[0].0;
180 }
181 }
182
183 if level > self.max_layer {
185 self.entry_point = Some(id);
186 self.max_layer = level;
187 }
188
189 id
190 }
191
192 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
194 if self.entry_point.is_none() {
195 return Vec::new();
196 }
197
198 let mut ep = self.entry_point.unwrap();
199
200 for layer in (1..=self.max_layer).rev() {
202 ep = self.greedy_closest(query, ep, layer);
203 }
204
205 let ef = ef.max(k);
207 let candidates = self.search_layer(query, ep, ef, 0);
208
209 candidates
210 .into_iter()
211 .take(k)
212 .map(|n| (n.id, n.dist))
213 .collect()
214 }
215
216 fn greedy_closest(&self, query: &[f32], start: usize, layer: usize) -> usize {
218 let mut best = start;
219 let mut best_dist = self.distance(query, best);
220
221 loop {
222 let mut changed = false;
223 if layer < self.neighbors[best].len() {
224 for &(neighbor, _) in &self.neighbors[best][layer] {
225 let d = self.distance(query, neighbor);
226 if d < best_dist {
227 best_dist = d;
228 best = neighbor;
229 changed = true;
230 }
231 }
232 }
233 if !changed {
234 break;
235 }
236 }
237
238 best
239 }
240
241 fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<Neighbor> {
243 let ep_dist = self.distance(query, ep);
244
245 let mut candidates: BinaryHeap<Reverse<Neighbor>> = BinaryHeap::new();
247 let mut result: BinaryHeap<Neighbor> = BinaryHeap::new();
249 let mut visited = vec![false; self.vectors.len()];
250
251 let ep_neighbor = Neighbor { dist: ep_dist, id: ep };
252 candidates.push(Reverse(ep_neighbor));
253 result.push(ep_neighbor);
254 visited[ep] = true;
255
256 while let Some(Reverse(current)) = candidates.pop() {
257 if result.peek().is_some_and(|f| current.dist > f.dist) {
259 break;
260 }
261
262 if layer < self.neighbors[current.id].len() {
264 for &(neighbor_id, _) in &self.neighbors[current.id][layer] {
265 if visited[neighbor_id] {
266 continue;
267 }
268 visited[neighbor_id] = true;
269
270 let d = self.distance(query, neighbor_id);
271 let n = Neighbor { dist: d, id: neighbor_id };
272
273 let should_add = result.len() < ef
274 || result.peek().is_some_and(|f| d < f.dist);
275
276 if should_add {
277 candidates.push(Reverse(n));
278 result.push(n);
279 if result.len() > ef {
280 result.pop(); }
282 }
283 }
284 }
285 }
286
287 let mut sorted: Vec<Neighbor> = result.into_vec();
289 sorted.sort();
290 sorted
291 }
292
293 #[inline]
295 fn distance(&self, query: &[f32], id: usize) -> f32 {
296 self.metric.distance(query, &self.vectors[id])
297 }
298
299 fn random_level(&mut self) -> usize {
301 let mut x = self.rng_state;
303 x ^= x << 13;
304 x ^= x >> 7;
305 x ^= x << 17;
306 self.rng_state = x;
307
308 let r = (x as f64) / (u64::MAX as f64);
309 let level = (-r.ln() * self.ml) as usize;
310 level.min(16) }
312}
313
314#[cfg(test)]
315mod tests {
316 use super::*;
317
318 fn make_index(dim: usize) -> HnswIndex {
319 HnswIndex::new(dim, HnswConfig::default())
320 }
321
322 #[test]
323 fn empty_search() {
324 let idx = make_index(3);
325 let results = idx.search(&[1.0, 0.0, 0.0], 5, 50);
326 assert!(results.is_empty());
327 }
328
329 #[test]
330 fn single_vector() {
331 let mut idx = make_index(3);
332 let id = idx.insert(&[1.0, 2.0, 3.0]);
333 assert_eq!(id, 0);
334
335 let results = idx.search(&[1.0, 2.0, 3.0], 1, 50);
336 assert_eq!(results.len(), 1);
337 assert_eq!(results[0].0, 0);
338 assert!(results[0].1 < 1e-6); }
340
341 #[test]
342 fn exact_knn_small() {
343 let mut idx = make_index(2);
344
345 let points: Vec<[f32; 2]> = (0..10)
347 .map(|i| [i as f32, 0.0])
348 .collect();
349
350 for p in &points {
351 idx.insert(p);
352 }
353
354 let results = idx.search(&[5.0, 0.0], 3, 50);
356 assert!(!results.is_empty());
357 assert_eq!(results[0].0, 5); }
359
360 #[test]
361 fn recall_100_vectors() {
362 let dim = 16;
363 let n = 100;
364 let mut idx = HnswIndex::new(dim, HnswConfig {
365 m: 16,
366 m_max0: 32,
367 ef_construction: 100,
368 metric: DistanceMetric::L2,
369 });
370
371 let vectors: Vec<Vec<f32>> = (0..n)
373 .map(|i| (0..dim).map(|d| ((i * 7 + d * 13) % 100) as f32 / 100.0).collect())
374 .collect();
375
376 for v in &vectors {
377 idx.insert(v);
378 }
379
380 let results = idx.search(&vectors[0], 10, 100);
382 assert!(!results.is_empty());
383 assert_eq!(results[0].0, 0); let mut brute: Vec<(usize, f32)> = vectors
387 .iter()
388 .enumerate()
389 .map(|(i, v)| (i, DistanceMetric::L2.distance(&vectors[0], v)))
390 .collect();
391 brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
392 let brute_top10: Vec<usize> = brute.iter().take(10).map(|r| r.0).collect();
393
394 let hnsw_top10: Vec<usize> = results.iter().take(10).map(|r| r.0).collect();
396 let hits: usize = hnsw_top10.iter().filter(|id| brute_top10.contains(id)).count();
397 let recall = hits as f64 / 10.0;
398 assert!(recall >= 0.7, "recall@10 = {recall}, expected >= 0.7");
399 }
400
401 #[test]
402 fn cosine_metric() {
403 let mut idx = HnswIndex::new(3, HnswConfig {
404 metric: DistanceMetric::Cosine,
405 ..HnswConfig::default()
406 });
407
408 idx.insert(&[1.0, 0.0, 0.0]);
409 idx.insert(&[0.0, 1.0, 0.0]);
410 idx.insert(&[0.9, 0.1, 0.0]); let results = idx.search(&[1.0, 0.0, 0.0], 3, 50);
413 assert_eq!(results.len(), 3);
414 assert_eq!(results[0].0, 0);
416 assert!(results[0].1 < 1e-5);
417 }
418
419 #[test]
420 fn insert_respects_dimension() {
421 let mut idx = make_index(4);
422 idx.insert(&[1.0, 2.0, 3.0, 4.0]);
423 assert_eq!(idx.len(), 1);
424 }
425
426 #[test]
427 #[should_panic(expected = "dimension mismatch")]
428 fn dimension_mismatch_panics() {
429 let mut idx = make_index(4);
430 idx.insert(&[1.0, 2.0]); }
432
433 #[test]
434 fn level_distribution() {
435 let mut idx = make_index(2);
436 let mut max_level = 0;
437 for i in 0..1000 {
438 idx.insert(&[i as f32, 0.0]);
439 max_level = max_level.max(idx.max_layer);
440 }
441 assert!(max_level <= 8, "max_level = {max_level}, unexpectedly high");
443 }
444}