1use std::cmp::Reverse;
10use std::collections::BinaryHeap;
11
12use crate::distance::DistanceMetric;
13
14#[derive(Clone, Copy)]
16struct Neighbor {
17 dist: f32,
18 id: usize,
19}
20
21impl PartialEq for Neighbor {
22 fn eq(&self, other: &Self) -> bool {
23 self.dist.to_bits() == other.dist.to_bits() && self.id == other.id
24 }
25}
26
27impl Eq for Neighbor {}
28
29impl PartialOrd for Neighbor {
30 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
31 Some(self.cmp(other))
32 }
33}
34
35impl Ord for Neighbor {
36 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
37 self.dist
38 .partial_cmp(&other.dist)
39 .unwrap_or(std::cmp::Ordering::Equal)
40 .then(self.id.cmp(&other.id))
41 }
42}
43
44pub struct HnswConfig {
46 pub m: usize,
48 pub m_max0: usize,
50 pub ef_construction: usize,
52 pub metric: DistanceMetric,
54}
55
56impl Default for HnswConfig {
57 fn default() -> Self {
58 Self {
59 m: 16,
60 m_max0: 32,
61 ef_construction: 200,
62 metric: DistanceMetric::L2,
63 }
64 }
65}
66
67pub struct HnswIndex {
69 vectors: Vec<Vec<f32>>,
70 neighbors: Vec<Vec<Vec<(usize, f32)>>>,
72 entry_point: Option<usize>,
73 max_layer: usize,
74 m: usize,
75 m_max0: usize,
76 ef_construction: usize,
77 dim: usize,
78 metric: DistanceMetric,
79 ml: f64, rng_state: u64,
81}
82
83impl HnswIndex {
84 pub fn new(dim: usize, config: HnswConfig) -> Self {
86 let ml = 1.0 / (config.m as f64).ln();
87 Self {
88 vectors: Vec::new(),
89 neighbors: Vec::new(),
90 entry_point: None,
91 max_layer: 0,
92 m: config.m,
93 m_max0: config.m_max0,
94 ef_construction: config.ef_construction,
95 dim,
96 metric: config.metric,
97 ml,
98 rng_state: 0x5DEECE66D, }
100 }
101
102 pub fn len(&self) -> usize {
104 self.vectors.len()
105 }
106
107 pub fn is_empty(&self) -> bool {
109 self.vectors.is_empty()
110 }
111
112 pub fn insert(&mut self, vector: &[f32]) -> usize {
114 assert_eq!(vector.len(), self.dim, "dimension mismatch");
115
116 let id = self.vectors.len();
117 self.vectors.push(vector.to_vec());
118
119 let level = self.random_level();
120
121 let mut layers = Vec::with_capacity(level + 1);
123 for _ in 0..=level {
124 layers.push(Vec::new());
125 }
126 self.neighbors.push(layers);
127
128 if self.entry_point.is_none() {
129 self.entry_point = Some(id);
131 self.max_layer = level;
132 return id;
133 }
134
135 let ep = self.entry_point.unwrap();
136 let mut current_ep = ep;
137
138 for layer in (level + 1..=self.max_layer).rev() {
140 current_ep = self.greedy_closest(vector, current_ep, layer);
141 }
142
143 let start_layer = level.min(self.max_layer);
146 for layer in (0..=start_layer).rev() {
147 let m_for_layer = if layer == 0 { self.m_max0 } else { self.m };
148
149 let candidates = self.search_layer(vector, current_ep, self.ef_construction, layer);
151
152 let selected: Vec<(usize, f32)> = candidates
154 .into_iter()
155 .take(m_for_layer)
156 .map(|n| (n.id, n.dist))
157 .collect();
158
159 self.neighbors[id][layer] = selected.clone();
161
162 for &(neighbor_id, dist) in &selected {
164 if neighbor_id < self.neighbors.len() && layer < self.neighbors[neighbor_id].len() {
165 self.neighbors[neighbor_id][layer].push((id, dist));
166 if self.neighbors[neighbor_id][layer].len() > m_for_layer {
168 self.neighbors[neighbor_id][layer].sort_by(|a, b| {
169 a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal)
170 });
171 self.neighbors[neighbor_id][layer].truncate(m_for_layer);
172 }
173 }
174 }
175
176 if !selected.is_empty() {
178 current_ep = selected[0].0;
179 }
180 }
181
182 if level > self.max_layer {
184 self.entry_point = Some(id);
185 self.max_layer = level;
186 }
187
188 id
189 }
190
191 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<(usize, f32)> {
193 if self.entry_point.is_none() {
194 return Vec::new();
195 }
196
197 let mut ep = self.entry_point.unwrap();
198
199 for layer in (1..=self.max_layer).rev() {
201 ep = self.greedy_closest(query, ep, layer);
202 }
203
204 let ef = ef.max(k);
206 let candidates = self.search_layer(query, ep, ef, 0);
207
208 candidates
209 .into_iter()
210 .take(k)
211 .map(|n| (n.id, n.dist))
212 .collect()
213 }
214
215 fn greedy_closest(&self, query: &[f32], start: usize, layer: usize) -> usize {
217 let mut best = start;
218 let mut best_dist = self.distance(query, best);
219
220 loop {
221 let mut changed = false;
222 if layer < self.neighbors[best].len() {
223 for &(neighbor, _) in &self.neighbors[best][layer] {
224 let d = self.distance(query, neighbor);
225 if d < best_dist {
226 best_dist = d;
227 best = neighbor;
228 changed = true;
229 }
230 }
231 }
232 if !changed {
233 break;
234 }
235 }
236
237 best
238 }
239
240 fn search_layer(&self, query: &[f32], ep: usize, ef: usize, layer: usize) -> Vec<Neighbor> {
242 let ep_dist = self.distance(query, ep);
243
244 let mut candidates: BinaryHeap<Reverse<Neighbor>> = BinaryHeap::new();
246 let mut result: BinaryHeap<Neighbor> = BinaryHeap::new();
248 let mut visited = vec![false; self.vectors.len()];
249
250 let ep_neighbor = Neighbor {
251 dist: ep_dist,
252 id: ep,
253 };
254 candidates.push(Reverse(ep_neighbor));
255 result.push(ep_neighbor);
256 visited[ep] = true;
257
258 while let Some(Reverse(current)) = candidates.pop() {
259 if result.peek().is_some_and(|f| current.dist > f.dist) {
261 break;
262 }
263
264 if layer < self.neighbors[current.id].len() {
266 for &(neighbor_id, _) in &self.neighbors[current.id][layer] {
267 if visited[neighbor_id] {
268 continue;
269 }
270 visited[neighbor_id] = true;
271
272 let d = self.distance(query, neighbor_id);
273 let n = Neighbor {
274 dist: d,
275 id: neighbor_id,
276 };
277
278 let should_add = result.len() < ef || result.peek().is_some_and(|f| d < f.dist);
279
280 if should_add {
281 candidates.push(Reverse(n));
282 result.push(n);
283 if result.len() > ef {
284 result.pop(); }
286 }
287 }
288 }
289 }
290
291 let mut sorted: Vec<Neighbor> = result.into_vec();
293 sorted.sort();
294 sorted
295 }
296
297 #[inline]
299 fn distance(&self, query: &[f32], id: usize) -> f32 {
300 self.metric.distance(query, &self.vectors[id])
301 }
302
303 fn random_level(&mut self) -> usize {
305 let mut x = self.rng_state;
307 x ^= x << 13;
308 x ^= x >> 7;
309 x ^= x << 17;
310 self.rng_state = x;
311
312 let r = (x as f64) / (u64::MAX as f64);
313 let level = (-r.ln() * self.ml) as usize;
314 level.min(16) }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 fn make_index(dim: usize) -> HnswIndex {
323 HnswIndex::new(dim, HnswConfig::default())
324 }
325
326 #[test]
327 fn empty_search() {
328 let idx = make_index(3);
329 let results = idx.search(&[1.0, 0.0, 0.0], 5, 50);
330 assert!(results.is_empty());
331 }
332
333 #[test]
334 fn single_vector() {
335 let mut idx = make_index(3);
336 let id = idx.insert(&[1.0, 2.0, 3.0]);
337 assert_eq!(id, 0);
338
339 let results = idx.search(&[1.0, 2.0, 3.0], 1, 50);
340 assert_eq!(results.len(), 1);
341 assert_eq!(results[0].0, 0);
342 assert!(results[0].1 < 1e-6); }
344
345 #[test]
346 fn exact_knn_small() {
347 let mut idx = make_index(2);
348
349 let points: Vec<[f32; 2]> = (0..10).map(|i| [i as f32, 0.0]).collect();
351
352 for p in &points {
353 idx.insert(p);
354 }
355
356 let results = idx.search(&[5.0, 0.0], 3, 50);
358 assert!(!results.is_empty());
359 assert_eq!(results[0].0, 5); }
361
362 #[test]
363 fn recall_100_vectors() {
364 let dim = 16;
365 let n = 100;
366 let mut idx = HnswIndex::new(
367 dim,
368 HnswConfig {
369 m: 16,
370 m_max0: 32,
371 ef_construction: 100,
372 metric: DistanceMetric::L2,
373 },
374 );
375
376 let vectors: Vec<Vec<f32>> = (0..n)
378 .map(|i| {
379 (0..dim)
380 .map(|d| ((i * 7 + d * 13) % 100) as f32 / 100.0)
381 .collect()
382 })
383 .collect();
384
385 for v in &vectors {
386 idx.insert(v);
387 }
388
389 let results = idx.search(&vectors[0], 10, 100);
391 assert!(!results.is_empty());
392 assert_eq!(results[0].0, 0); let mut brute: Vec<(usize, f32)> = vectors
396 .iter()
397 .enumerate()
398 .map(|(i, v)| (i, DistanceMetric::L2.distance(&vectors[0], v)))
399 .collect();
400 brute.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
401 let brute_top10: Vec<usize> = brute.iter().take(10).map(|r| r.0).collect();
402
403 let hnsw_top10: Vec<usize> = results.iter().take(10).map(|r| r.0).collect();
405 let hits: usize = hnsw_top10
406 .iter()
407 .filter(|id| brute_top10.contains(id))
408 .count();
409 let recall = hits as f64 / 10.0;
410 assert!(recall >= 0.7, "recall@10 = {recall}, expected >= 0.7");
411 }
412
413 #[test]
414 fn cosine_metric() {
415 let mut idx = HnswIndex::new(
416 3,
417 HnswConfig {
418 metric: DistanceMetric::Cosine,
419 ..HnswConfig::default()
420 },
421 );
422
423 idx.insert(&[1.0, 0.0, 0.0]);
424 idx.insert(&[0.0, 1.0, 0.0]);
425 idx.insert(&[0.9, 0.1, 0.0]); let results = idx.search(&[1.0, 0.0, 0.0], 3, 50);
428 assert_eq!(results.len(), 3);
429 assert_eq!(results[0].0, 0);
431 assert!(results[0].1 < 1e-5);
432 }
433
434 #[test]
435 fn insert_respects_dimension() {
436 let mut idx = make_index(4);
437 idx.insert(&[1.0, 2.0, 3.0, 4.0]);
438 assert_eq!(idx.len(), 1);
439 }
440
441 #[test]
442 #[should_panic(expected = "dimension mismatch")]
443 fn dimension_mismatch_panics() {
444 let mut idx = make_index(4);
445 idx.insert(&[1.0, 2.0]); }
447
448 #[test]
449 fn level_distribution() {
450 let mut idx = make_index(2);
451 let mut max_level = 0;
452 for i in 0..1000 {
453 idx.insert(&[i as f32, 0.0]);
454 max_level = max_level.max(idx.max_layer);
455 }
456 assert!(max_level <= 8, "max_level = {max_level}, unexpectedly high");
458 }
459}