1use std::collections::{BinaryHeap, HashMap, HashSet};
19
20use ordered_float::OrderedFloat;
21
22use crate::distance::DistanceMetric;
23
24pub struct HnswIndex {
26 dim: usize,
27 metric: DistanceMetric,
28 m: usize,
29 m_max0: usize,
30 ef_construction: usize,
31 ml: f64,
32 nodes: HashMap<u64, Node>,
33 entry_point: Option<u64>,
34 max_layer: usize,
35 rng_state: u64,
36}
37
38struct Node {
39 id: u64,
40 vector: Vec<f32>,
41 layer: usize,
42 neighbors: Vec<Vec<u64>>,
44 deleted: bool,
45}
46
47#[derive(Debug, Clone, PartialEq)]
49pub struct SearchResult {
50 pub id: u64,
52 pub distance: f32,
54}
55
56#[derive(PartialEq, Eq)]
57struct MinItem(OrderedFloat<f32>, u64);
58
59impl Ord for MinItem {
60 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
61 other.0.cmp(&self.0)
62 }
63}
64impl PartialOrd for MinItem {
65 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
66 Some(self.cmp(other))
67 }
68}
69
70#[derive(PartialEq, Eq)]
71struct MaxItem(OrderedFloat<f32>, u64);
72
73impl Ord for MaxItem {
74 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
75 self.0.cmp(&other.0)
76 }
77}
78impl PartialOrd for MaxItem {
79 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
80 Some(self.cmp(other))
81 }
82}
83
84impl HnswIndex {
85 pub fn new(dim: usize, metric: DistanceMetric, m: usize, ef_construction: usize) -> Self {
92 Self {
93 dim,
94 metric,
95 m,
96 m_max0: m * 2,
97 ef_construction,
98 ml: 1.0 / (m as f64).ln(),
99 nodes: HashMap::new(),
100 entry_point: None,
101 max_layer: 0,
102 rng_state: 42,
103 }
104 }
105
106 pub fn len(&self) -> usize {
108 self.nodes.values().filter(|n| !n.deleted).count()
109 }
110
111 pub fn is_empty(&self) -> bool {
113 self.len() == 0
114 }
115
116 pub fn dim(&self) -> usize {
118 self.dim
119 }
120
121 pub fn metric(&self) -> DistanceMetric {
123 self.metric
124 }
125
126 pub fn insert(&mut self, id: u64, vector: &[f32]) {
130 assert_eq!(
131 vector.len(),
132 self.dim,
133 "vector dimension mismatch: expected {}, got {}",
134 self.dim,
135 vector.len()
136 );
137
138 if self.nodes.contains_key(&id) {
139 self.delete(id);
140 }
141
142 let level = self.random_level();
143 let vector = vector.to_vec();
144
145 let mut neighbors = Vec::with_capacity(level + 1);
146 for _ in 0..=level {
147 neighbors.push(Vec::new());
148 }
149
150 let node = Node {
151 id,
152 vector,
153 layer: level,
154 neighbors,
155 deleted: false,
156 };
157
158 self.nodes.insert(id, node);
159
160 if self.entry_point.is_none() {
161 self.entry_point = Some(id);
162 self.max_layer = level;
163 return;
164 }
165
166 let ep = match self.entry_point {
167 Some(ep) => ep,
168 None => return,
169 };
170
171 let mut current_ep = ep;
172 let query = &self.nodes[&id].vector.clone();
173
174 for lc in (level + 1..=self.max_layer).rev() {
175 current_ep = self.greedy_closest(query, current_ep, lc);
176 }
177
178 let insert_top = level.min(self.max_layer);
179 let mut ep_set = vec![current_ep];
180
181 for lc in (0..=insert_top).rev() {
182 let m_max = if lc == 0 { self.m_max0 } else { self.m };
183
184 let candidates = self.search_layer(query, &ep_set, self.ef_construction, lc);
185
186 let selected: Vec<u64> = candidates.iter().take(m_max).map(|&(_, nid)| nid).collect();
187
188 if let Some(node) = self.nodes.get_mut(&id) {
189 node.neighbors[lc] = selected.clone();
190 }
191
192 for &neighbor_id in &selected {
193 let needs_prune = {
194 let Some(neighbor) = self.nodes.get_mut(&neighbor_id) else {
195 continue;
196 };
197 if lc < neighbor.neighbors.len() {
198 neighbor.neighbors[lc].push(id);
199 neighbor.neighbors[lc].len() > m_max
200 } else {
201 false
202 }
203 };
204
205 if needs_prune {
206 let nv = self.nodes[&neighbor_id].vector.clone();
207 let neighbor_ids: Vec<u64> = self.nodes[&neighbor_id].neighbors[lc].clone();
208 let mut scored: Vec<(f32, u64)> = neighbor_ids
209 .iter()
210 .map(|&nid| {
211 let dist = self.metric.distance(&nv, &self.nodes[&nid].vector);
212 (dist, nid)
213 })
214 .collect();
215 scored
216 .sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
217 scored.truncate(m_max);
218 if let Some(neighbor) = self.nodes.get_mut(&neighbor_id) {
219 neighbor.neighbors[lc] = scored.into_iter().map(|(_, nid)| nid).collect();
220 }
221 }
222 }
223
224 ep_set = candidates.iter().map(|&(_, nid)| nid).collect();
225 }
226
227 if level > self.max_layer {
228 self.entry_point = Some(id);
229 self.max_layer = level;
230 }
231 }
232
233 pub fn delete(&mut self, id: u64) {
235 if let Some(node) = self.nodes.get_mut(&id) {
236 node.deleted = true;
237 }
238
239 if self.entry_point == Some(id) {
240 self.entry_point = self
241 .nodes
242 .values()
243 .filter(|n| !n.deleted)
244 .max_by_key(|n| n.layer)
245 .map(|n| n.id);
246 if let Some(ep) = self.entry_point {
247 self.max_layer = self.nodes[&ep].layer;
248 } else {
249 self.max_layer = 0;
250 }
251 }
252 }
253
254 pub fn search(&self, query: &[f32], k: usize, ef: usize) -> Vec<SearchResult> {
260 assert_eq!(query.len(), self.dim);
261
262 let ep = match self.entry_point {
263 Some(ep) if !self.nodes[&ep].deleted || !self.is_empty() => ep,
264 _ => return vec![],
265 };
266
267 let mut current_ep = ep;
268 for lc in (1..=self.max_layer).rev() {
269 current_ep = self.greedy_closest(query, current_ep, lc);
270 }
271
272 let ef = ef.max(k);
273 let candidates = self.search_layer(query, &[current_ep], ef, 0);
274
275 candidates
276 .into_iter()
277 .filter(|&(_, id)| !self.nodes[&id].deleted)
278 .take(k)
279 .map(|(dist, id)| SearchResult { id, distance: dist })
280 .collect()
281 }
282
283 pub fn contains(&self, id: u64) -> bool {
285 self.nodes.get(&id).is_some_and(|n| !n.deleted)
286 }
287
288 fn random_level(&mut self) -> usize {
289 self.rng_state ^= self.rng_state << 13;
290 self.rng_state ^= self.rng_state >> 7;
291 self.rng_state ^= self.rng_state << 17;
292
293 let r = (self.rng_state as f64) / (u64::MAX as f64);
294 (-r.ln() * self.ml) as usize
295 }
296
297 fn greedy_closest(&self, query: &[f32], mut ep: u64, layer: usize) -> u64 {
298 let mut best_dist = self.metric.distance(query, &self.nodes[&ep].vector);
299
300 loop {
301 let mut changed = false;
302 let node = &self.nodes[&ep];
303 if layer < node.neighbors.len() {
304 for &neighbor_id in &node.neighbors[layer] {
305 if let Some(neighbor) = self.nodes.get(&neighbor_id) {
306 let dist = self.metric.distance(query, &neighbor.vector);
307 if dist < best_dist {
308 best_dist = dist;
309 ep = neighbor_id;
310 changed = true;
311 }
312 }
313 }
314 }
315 if !changed {
316 break;
317 }
318 }
319 ep
320 }
321
322 fn search_layer(
323 &self,
324 query: &[f32],
325 entry_points: &[u64],
326 ef: usize,
327 layer: usize,
328 ) -> Vec<(f32, u64)> {
329 let mut visited = HashSet::new();
330 let mut candidates: BinaryHeap<MinItem> = BinaryHeap::new();
331 let mut results: BinaryHeap<MaxItem> = BinaryHeap::new();
332
333 for &ep in entry_points {
334 if !self.nodes.contains_key(&ep) {
335 continue;
336 }
337 let dist = self.metric.distance(query, &self.nodes[&ep].vector);
338 visited.insert(ep);
339 candidates.push(MinItem(OrderedFloat(dist), ep));
340 results.push(MaxItem(OrderedFloat(dist), ep));
341 }
342
343 while let Some(MinItem(c_dist, c_id)) = candidates.pop() {
344 let f_dist = results
345 .peek()
346 .map(|r| r.0)
347 .unwrap_or(OrderedFloat(f32::MAX));
348 if c_dist > f_dist {
349 break;
350 }
351
352 let node = match self.nodes.get(&c_id) {
353 Some(n) => n,
354 None => continue,
355 };
356
357 if layer < node.neighbors.len() {
358 for &neighbor_id in &node.neighbors[layer] {
359 if !visited.insert(neighbor_id) {
360 continue;
361 }
362 if let Some(neighbor) = self.nodes.get(&neighbor_id) {
363 let dist = self.metric.distance(query, &neighbor.vector);
364 let f_dist = results
365 .peek()
366 .map(|r| r.0)
367 .unwrap_or(OrderedFloat(f32::MAX));
368
369 if dist < f_dist.0 || results.len() < ef {
370 candidates.push(MinItem(OrderedFloat(dist), neighbor_id));
371 results.push(MaxItem(OrderedFloat(dist), neighbor_id));
372 if results.len() > ef {
373 results.pop();
374 }
375 }
376 }
377 }
378 }
379 }
380
381 let mut result: Vec<(f32, u64)> = results
382 .into_iter()
383 .map(|MaxItem(d, id)| (d.0, id))
384 .collect();
385 result.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
386 result
387 }
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 fn make_index(n: usize, dim: usize) -> (HnswIndex, Vec<Vec<f32>>) {
395 let mut index = HnswIndex::new(dim, DistanceMetric::L2, 16, 200);
396 let mut vectors = Vec::new();
397 for i in 0..n {
398 let v: Vec<f32> = (0..dim).map(|d| ((i * dim + d) as f32) * 0.01).collect();
399 vectors.push(v.clone());
400 index.insert(i as u64, &v);
401 }
402 (index, vectors)
403 }
404
405 #[test]
406 fn test_insert_and_search() {
407 let (index, vectors) = make_index(100, 8);
408 assert_eq!(index.len(), 100);
409
410 let results = index.search(&vectors[42], 5, 50);
412 assert!(!results.is_empty());
413 assert_eq!(results[0].id, 42);
414 assert!(results[0].distance < 1e-6);
415 }
416
417 #[test]
418 fn test_search_empty_index() {
419 let index = HnswIndex::new(4, DistanceMetric::L2, 16, 200);
420 let results = index.search(&[1.0, 2.0, 3.0, 4.0], 5, 50);
421 assert!(results.is_empty());
422 }
423
424 #[test]
425 fn test_single_vector() {
426 let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
427 index.insert(1, &[1.0, 2.0, 3.0]);
428
429 let results = index.search(&[1.0, 2.0, 3.0], 5, 50);
430 assert_eq!(results.len(), 1);
431 assert_eq!(results[0].id, 1);
432 }
433
434 #[test]
435 fn test_delete() {
436 let (mut index, vectors) = make_index(50, 4);
437 assert_eq!(index.len(), 50);
438
439 index.delete(25);
440 assert_eq!(index.len(), 49);
441 assert!(!index.contains(25));
442
443 let results = index.search(&vectors[25], 5, 50);
445 assert!(results.iter().all(|r| r.id != 25));
446 }
447
448 #[test]
449 fn test_cosine_metric() {
450 let mut index = HnswIndex::new(3, DistanceMetric::Cosine, 16, 200);
451 index.insert(1, &[1.0, 0.0, 0.0]);
453 index.insert(2, &[2.0, 0.0, 0.0]); index.insert(3, &[0.0, 1.0, 0.0]); let results = index.search(&[3.0, 0.0, 0.0], 3, 50);
457 assert!(results.len() >= 2);
459 let ids: Vec<u64> = results.iter().map(|r| r.id).collect();
460 assert!(ids[0] == 1 || ids[0] == 2);
462 }
463
464 #[test]
465 fn test_inner_product() {
466 let mut index = HnswIndex::new(2, DistanceMetric::InnerProduct, 16, 200);
467 index.insert(1, &[1.0, 0.0]);
468 index.insert(2, &[0.0, 1.0]);
469 index.insert(3, &[10.0, 0.0]); let results = index.search(&[1.0, 0.0], 3, 50);
472 assert_eq!(results[0].id, 3);
474 }
475
476 #[test]
477 fn test_recall_quality() {
478 let n = 500;
480 let dim = 16;
481 let (index, vectors) = make_index(n, dim);
482
483 let query = &vectors[0];
484 let k = 10;
485
486 let mut dists: Vec<(f32, u64)> = vectors
488 .iter()
489 .enumerate()
490 .map(|(i, v)| (DistanceMetric::L2.distance(query, v), i as u64))
491 .collect();
492 dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
493 let ground_truth: HashSet<u64> = dists.iter().take(k).map(|&(_, id)| id).collect();
494
495 let results = index.search(query, k, 100);
496 let found: HashSet<u64> = results.iter().map(|r| r.id).collect();
497
498 let recall = ground_truth.intersection(&found).count() as f32 / k as f32;
499 assert!(
500 recall >= 0.8,
501 "Recall too low: {:.2} (expected >= 0.80)",
502 recall
503 );
504 }
505
506 #[test]
507 fn test_duplicate_insert() {
508 let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
509 index.insert(1, &[1.0, 2.0, 3.0]);
510 index.insert(1, &[4.0, 5.0, 6.0]); assert_eq!(index.len(), 1);
513 let results = index.search(&[4.0, 5.0, 6.0], 1, 50);
514 assert_eq!(results[0].id, 1);
515 assert!(results[0].distance < 1e-6);
516 }
517
518 #[test]
519 fn test_k_larger_than_index() {
520 let (index, _) = make_index(5, 4);
521 let results = index.search(&[0.0; 4], 100, 200);
522 assert_eq!(results.len(), 5); }
524
525 #[test]
526 fn test_contains() {
527 let mut index = HnswIndex::new(3, DistanceMetric::L2, 16, 200);
528 assert!(!index.contains(1));
529 index.insert(1, &[1.0, 2.0, 3.0]);
530 assert!(index.contains(1));
531 index.delete(1);
532 assert!(!index.contains(1));
533 }
534}