1use common::types::{DistanceMetric, VectorId};
15use parking_lot::RwLock;
16use rand::Rng;
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, HashMap, HashSet};
19
20use crate::distance::calculate_distance;
21
22#[inline]
26fn similarity_to_distance(similarity: f32, metric: DistanceMetric) -> f32 {
27 match metric {
28 DistanceMetric::Cosine => 1.0 - similarity,
30 DistanceMetric::Euclidean => -similarity,
32 DistanceMetric::DotProduct => -similarity,
34 }
35}
36
37#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
39pub struct HnswConfig {
40 pub m: usize,
42 pub m_max0: usize,
44 pub ef_construction: usize,
46 pub ef_search: usize,
48 pub level_multiplier: f64,
50 pub distance_metric: DistanceMetric,
52}
53
54impl Default for HnswConfig {
55 fn default() -> Self {
56 let m = 16;
57 Self {
58 m,
59 m_max0: m * 2,
60 ef_construction: 200,
61 ef_search: 50,
62 level_multiplier: 1.0 / (m as f64).ln(),
63 distance_metric: DistanceMetric::Cosine,
64 }
65 }
66}
67
68impl HnswConfig {
69 pub fn new(m: usize, ef_construction: usize, ef_search: usize) -> Self {
70 Self {
71 m,
72 m_max0: m * 2,
73 ef_construction,
74 ef_search,
75 level_multiplier: 1.0 / (m as f64).ln(),
76 distance_metric: DistanceMetric::Cosine,
77 }
78 }
79
80 pub fn with_distance_metric(mut self, metric: DistanceMetric) -> Self {
81 self.distance_metric = metric;
82 self
83 }
84}
85
86#[derive(Debug)]
88struct HnswNode {
89 id: VectorId,
91 vector: Vec<f32>,
93 connections: Vec<Vec<usize>>,
95 max_layer: usize,
97}
98
99#[derive(Debug, Clone)]
101struct Candidate {
102 node_idx: usize,
103 distance: f32,
104}
105
106impl PartialEq for Candidate {
107 fn eq(&self, other: &Self) -> bool {
108 self.distance == other.distance
109 }
110}
111
112impl Eq for Candidate {}
113
114impl PartialOrd for Candidate {
115 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
116 Some(self.cmp(other))
117 }
118}
119
120impl Ord for Candidate {
121 fn cmp(&self, other: &Self) -> Ordering {
122 other
124 .distance
125 .partial_cmp(&self.distance)
126 .unwrap_or(Ordering::Equal)
127 }
128}
129
130#[derive(Debug, Clone)]
132struct FurthestCandidate {
133 node_idx: usize,
134 distance: f32,
135}
136
137impl PartialEq for FurthestCandidate {
138 fn eq(&self, other: &Self) -> bool {
139 self.distance == other.distance
140 }
141}
142
143impl Eq for FurthestCandidate {}
144
145impl PartialOrd for FurthestCandidate {
146 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
147 Some(self.cmp(other))
148 }
149}
150
151impl Ord for FurthestCandidate {
152 fn cmp(&self, other: &Self) -> Ordering {
153 self.distance
155 .partial_cmp(&other.distance)
156 .unwrap_or(Ordering::Equal)
157 }
158}
159
160pub struct HnswIndex {
162 config: HnswConfig,
163 nodes: RwLock<Vec<HnswNode>>,
165 entry_point: RwLock<Option<usize>>,
167 max_level: RwLock<usize>,
169 id_to_idx: RwLock<HashMap<VectorId, usize>>,
171 dimension: RwLock<Option<usize>>,
173}
174
175impl HnswIndex {
176 pub fn new() -> Self {
178 Self::with_config(HnswConfig::default())
179 }
180
181 pub fn with_config(config: HnswConfig) -> Self {
183 Self {
184 config,
185 nodes: RwLock::new(Vec::new()),
186 entry_point: RwLock::new(None),
187 max_level: RwLock::new(0),
188 id_to_idx: RwLock::new(HashMap::new()),
189 dimension: RwLock::new(None),
190 }
191 }
192
193 fn random_level(&self) -> usize {
195 let mut rng = rand::thread_rng();
196 let uniform: f64 = rng.gen();
197
198 (-uniform.ln() * self.config.level_multiplier).floor() as usize
199 }
200
201 fn distance(&self, query: &[f32], node_idx: usize, nodes: &[HnswNode]) -> f32 {
204 similarity_to_distance(
205 calculate_distance(query, &nodes[node_idx].vector, self.config.distance_metric),
206 self.config.distance_metric,
207 )
208 }
209
210 fn search_layer(
212 &self,
213 query: &[f32],
214 entry_points: Vec<usize>,
215 ef: usize,
216 layer: usize,
217 nodes: &[HnswNode],
218 ) -> Vec<Candidate> {
219 let mut visited: HashSet<usize> = HashSet::new();
220 let mut candidates: BinaryHeap<Candidate> = BinaryHeap::new();
221 let mut results: BinaryHeap<FurthestCandidate> = BinaryHeap::new();
222
223 for &ep in &entry_points {
225 visited.insert(ep);
226 let dist = self.distance(query, ep, nodes);
227 candidates.push(Candidate {
228 node_idx: ep,
229 distance: dist,
230 });
231 results.push(FurthestCandidate {
232 node_idx: ep,
233 distance: dist,
234 });
235 }
236
237 while let Some(candidate) = candidates.pop() {
238 let furthest_dist = results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
240
241 if candidate.distance > furthest_dist && results.len() >= ef {
243 break;
244 }
245
246 let node = &nodes[candidate.node_idx];
248 if layer < node.connections.len() {
249 for &neighbor_idx in &node.connections[layer] {
250 if visited.insert(neighbor_idx) {
251 let dist = self.distance(query, neighbor_idx, nodes);
252
253 let should_add = results.len() < ef
254 || dist < results.peek().map(|r| r.distance).unwrap_or(f32::MAX);
255
256 if should_add {
257 candidates.push(Candidate {
258 node_idx: neighbor_idx,
259 distance: dist,
260 });
261 results.push(FurthestCandidate {
262 node_idx: neighbor_idx,
263 distance: dist,
264 });
265
266 while results.len() > ef {
268 results.pop();
269 }
270 }
271 }
272 }
273 }
274 }
275
276 let mut final_results: Vec<Candidate> = results
278 .into_iter()
279 .map(|fc| Candidate {
280 node_idx: fc.node_idx,
281 distance: fc.distance,
282 })
283 .collect();
284 final_results.sort_by(|a, b| {
285 a.distance
286 .partial_cmp(&b.distance)
287 .unwrap_or(Ordering::Equal)
288 });
289 final_results
290 }
291
292 fn select_neighbors_simple(&self, candidates: &[Candidate], m: usize) -> Vec<usize> {
294 candidates.iter().take(m).map(|c| c.node_idx).collect()
295 }
296
297 fn select_neighbors_heuristic(
299 &self,
300 query: &[f32],
301 candidates: &[Candidate],
302 m: usize,
303 nodes: &[HnswNode],
304 extend_candidates: bool,
305 ) -> Vec<usize> {
306 let mut working_candidates = candidates.to_vec();
307
308 if extend_candidates {
310 let mut extended: HashSet<usize> =
311 working_candidates.iter().map(|c| c.node_idx).collect();
312 for candidate in candidates.iter().take(m) {
313 let node = &nodes[candidate.node_idx];
314 for layer_connections in &node.connections {
315 for &neighbor in layer_connections {
316 if extended.insert(neighbor) {
317 let dist = self.distance(query, neighbor, nodes);
318 working_candidates.push(Candidate {
319 node_idx: neighbor,
320 distance: dist,
321 });
322 }
323 }
324 }
325 }
326 working_candidates.sort_by(|a, b| {
327 a.distance
328 .partial_cmp(&b.distance)
329 .unwrap_or(Ordering::Equal)
330 });
331 }
332
333 let mut selected: Vec<usize> = Vec::with_capacity(m);
335
336 for candidate in &working_candidates {
337 if selected.len() >= m {
338 break;
339 }
340
341 let mut is_good = true;
343 for &sel_idx in &selected {
344 let dist_to_selected = calculate_distance(
345 &nodes[candidate.node_idx].vector,
346 &nodes[sel_idx].vector,
347 self.config.distance_metric,
348 );
349 if dist_to_selected < candidate.distance {
350 is_good = false;
351 break;
352 }
353 }
354
355 if is_good {
356 selected.push(candidate.node_idx);
357 }
358 }
359
360 if selected.len() < m {
362 for candidate in &working_candidates {
363 if selected.len() >= m {
364 break;
365 }
366 if !selected.contains(&candidate.node_idx) {
367 selected.push(candidate.node_idx);
368 }
369 }
370 }
371
372 selected
373 }
374
375 fn add_connection(&self, from_idx: usize, to_idx: usize, layer: usize, nodes: &mut [HnswNode]) {
377 let m_max = if layer == 0 {
378 self.config.m_max0
379 } else {
380 self.config.m
381 };
382
383 if layer < nodes[from_idx].connections.len()
385 && !nodes[from_idx].connections[layer].contains(&to_idx)
386 {
387 nodes[from_idx].connections[layer].push(to_idx);
388
389 if nodes[from_idx].connections[layer].len() > m_max {
391 let conn_indices: Vec<usize> = nodes[from_idx].connections[layer].clone();
392 let mut sorted_candidates: Vec<Candidate> = conn_indices
393 .iter()
394 .map(|&idx| Candidate {
395 node_idx: idx,
396 distance: self.distance(&nodes[from_idx].vector, idx, nodes),
397 })
398 .collect();
399 sorted_candidates.sort_by(|a, b| {
400 a.distance
401 .partial_cmp(&b.distance)
402 .unwrap_or(Ordering::Equal)
403 });
404 nodes[from_idx].connections[layer] =
405 self.select_neighbors_simple(&sorted_candidates, m_max);
406 }
407 }
408
409 if layer < nodes[to_idx].connections.len()
411 && !nodes[to_idx].connections[layer].contains(&from_idx)
412 {
413 nodes[to_idx].connections[layer].push(from_idx);
414
415 if nodes[to_idx].connections[layer].len() > m_max {
417 let conn_indices: Vec<usize> = nodes[to_idx].connections[layer].clone();
418 let mut sorted_candidates: Vec<Candidate> = conn_indices
419 .iter()
420 .map(|&idx| Candidate {
421 node_idx: idx,
422 distance: self.distance(&nodes[to_idx].vector, idx, nodes),
423 })
424 .collect();
425 sorted_candidates.sort_by(|a, b| {
426 a.distance
427 .partial_cmp(&b.distance)
428 .unwrap_or(Ordering::Equal)
429 });
430 nodes[to_idx].connections[layer] =
431 self.select_neighbors_simple(&sorted_candidates, m_max);
432 }
433 }
434 }
435
436 pub fn insert(&self, id: VectorId, vector: Vec<f32>) {
438 let vector_dim = vector.len();
439
440 {
442 let mut dim = self.dimension.write();
443 if let Some(d) = *dim {
444 if d != vector_dim {
445 tracing::error!("Dimension mismatch: expected {}, got {}", d, vector_dim);
446 return;
447 }
448 } else {
449 *dim = Some(vector_dim);
450 }
451 }
452
453 let new_level = self.random_level();
454
455 let new_node = HnswNode {
457 id: id.clone(),
458 vector: vector.clone(),
459 connections: (0..=new_level).map(|_| Vec::new()).collect(),
460 max_layer: new_level,
461 };
462
463 let mut nodes = self.nodes.write();
464 let new_idx = nodes.len();
465 nodes.push(new_node);
466
467 self.id_to_idx.write().insert(id, new_idx);
469
470 let entry = *self.entry_point.read();
472 let entry_idx = match entry {
473 None => {
474 *self.entry_point.write() = Some(new_idx);
475 *self.max_level.write() = new_level;
476 return;
477 }
478 Some(idx) => idx,
479 };
480 let current_max_level = *self.max_level.read();
481
482 let mut current_entry = vec![entry_idx];
484
485 for layer in (new_level + 1..=current_max_level).rev() {
487 let nearest = self.search_layer(&vector, current_entry.clone(), 1, layer, &nodes);
488 if !nearest.is_empty() {
489 current_entry = vec![nearest[0].node_idx];
490 }
491 }
492
493 for layer in (0..=new_level.min(current_max_level)).rev() {
495 let candidates = self.search_layer(
496 &vector,
497 current_entry.clone(),
498 self.config.ef_construction,
499 layer,
500 &nodes,
501 );
502
503 let m = if layer == 0 {
504 self.config.m_max0
505 } else {
506 self.config.m
507 };
508
509 let neighbors = self.select_neighbors_heuristic(&vector, &candidates, m, &nodes, false);
510
511 for &neighbor_idx in &neighbors {
513 self.add_connection(new_idx, neighbor_idx, layer, &mut nodes);
514 }
515
516 if !candidates.is_empty() {
518 current_entry = candidates.iter().take(1).map(|c| c.node_idx).collect();
519 }
520 }
521
522 if new_level > current_max_level {
524 *self.entry_point.write() = Some(new_idx);
525 *self.max_level.write() = new_level;
526 }
527 }
528
529 pub fn search(&self, query: &[f32], k: usize) -> Vec<(VectorId, f32)> {
531 self.search_with_ef(query, k, self.config.ef_search)
532 }
533
534 pub fn search_with_ef(&self, query: &[f32], k: usize, ef: usize) -> Vec<(VectorId, f32)> {
536 let nodes = self.nodes.read();
537
538 if nodes.is_empty() {
539 return Vec::new();
540 }
541
542 let entry = *self.entry_point.read();
543 let entry_idx = match entry {
544 None => return Vec::new(),
545 Some(idx) => idx,
546 };
547 let max_level = *self.max_level.read();
548
549 let mut current_entry = vec![entry_idx];
551
552 for layer in (1..=max_level).rev() {
554 let nearest = self.search_layer(query, current_entry.clone(), 1, layer, &nodes);
555 if !nearest.is_empty() {
556 current_entry = vec![nearest[0].node_idx];
557 }
558 }
559
560 let candidates = self.search_layer(query, current_entry, ef.max(k), 0, &nodes);
562
563 candidates
565 .into_iter()
566 .take(k)
567 .map(|c| (nodes[c.node_idx].id.clone(), c.distance))
568 .collect()
569 }
570
571 pub fn delete(&self, id: &VectorId) -> bool {
573 let idx = {
574 let id_map = self.id_to_idx.read();
575 match id_map.get(id) {
576 Some(&idx) => idx,
577 None => return false,
578 }
579 };
580
581 let mut nodes = self.nodes.write();
582 let mut id_map = self.id_to_idx.write();
583
584 for layer in 0..nodes[idx].connections.len() {
586 let neighbors: Vec<usize> = nodes[idx].connections[layer].clone();
587 for neighbor_idx in neighbors {
588 if neighbor_idx < nodes.len() && layer < nodes[neighbor_idx].connections.len() {
589 nodes[neighbor_idx].connections[layer].retain(|&n| n != idx);
590 }
591 }
592 }
593
594 nodes[idx].connections.clear();
596 nodes[idx].vector.clear();
597 id_map.remove(id);
598
599 let entry = *self.entry_point.read();
601 if entry == Some(idx) {
602 let new_entry = nodes
604 .iter()
605 .enumerate()
606 .filter(|(_, n)| !n.vector.is_empty())
607 .max_by_key(|(_, n)| n.max_layer)
608 .map(|(i, _)| i);
609 *self.entry_point.write() = new_entry;
610 }
611
612 true
613 }
614
615 pub fn len(&self) -> usize {
617 self.id_to_idx.read().len()
618 }
619
620 pub fn is_empty(&self) -> bool {
622 self.len() == 0
623 }
624
625 pub fn stats(&self) -> HnswStats {
627 let nodes = self.nodes.read();
628 let max_level = *self.max_level.read();
629
630 let mut level_counts = vec![0usize; max_level + 1];
631 let mut total_connections = 0usize;
632
633 for node in nodes.iter() {
634 if !node.vector.is_empty() {
635 for (layer, connections) in node.connections.iter().enumerate() {
636 if layer <= max_level {
637 level_counts[layer] += 1;
638 total_connections += connections.len();
639 }
640 }
641 }
642 }
643
644 HnswStats {
645 num_vectors: self.len(),
646 max_level,
647 level_counts,
648 total_connections,
649 avg_connections: if !self.is_empty() {
650 total_connections as f64 / self.len() as f64
651 } else {
652 0.0
653 },
654 }
655 }
656
657 pub fn config(&self) -> &HnswConfig {
659 &self.config
660 }
661
662 pub fn dimension(&self) -> Option<usize> {
664 *self.dimension.read()
665 }
666
667 pub fn entry_point(&self) -> Option<usize> {
669 *self.entry_point.read()
670 }
671
672 pub fn max_level(&self) -> usize {
674 *self.max_level.read()
675 }
676
677 pub(crate) fn nodes_read(&self) -> Vec<NodeSnapshot> {
680 self.nodes
681 .read()
682 .iter()
683 .map(|node| NodeSnapshot {
684 id: node.id.clone(),
685 vector: node.vector.clone(),
686 connections: node.connections.clone(),
687 max_layer: node.max_layer,
688 })
689 .collect()
690 }
691
692 pub fn from_snapshot(snapshot: crate::persistence::HnswFullSnapshot) -> Result<Self, String> {
694 use std::collections::HashMap;
695
696 let mut nodes = Vec::with_capacity(snapshot.nodes.len());
697 let mut id_to_idx = HashMap::with_capacity(snapshot.nodes.len());
698
699 for (idx, snode) in snapshot.nodes.into_iter().enumerate() {
700 id_to_idx.insert(snode.id.clone(), idx);
701 nodes.push(HnswNode {
702 id: snode.id,
703 vector: snode.vector,
704 connections: snode.connections,
705 max_layer: snode.max_layer,
706 });
707 }
708
709 let dimension = if nodes.is_empty() {
710 None
711 } else {
712 Some(snapshot.dimension)
713 };
714
715 Ok(Self {
716 config: snapshot.config,
717 nodes: RwLock::new(nodes),
718 entry_point: RwLock::new(snapshot.entry_point),
719 max_level: RwLock::new(snapshot.max_level),
720 id_to_idx: RwLock::new(id_to_idx),
721 dimension: RwLock::new(dimension),
722 })
723 }
724}
725
726#[derive(Debug, Clone)]
728pub(crate) struct NodeSnapshot {
729 pub id: String,
730 pub vector: Vec<f32>,
731 pub connections: Vec<Vec<usize>>,
732 pub max_layer: usize,
733}
734
735impl Default for HnswIndex {
736 fn default() -> Self {
737 Self::new()
738 }
739}
740
741#[derive(Debug, Clone)]
743pub struct HnswStats {
744 pub num_vectors: usize,
745 pub max_level: usize,
746 pub level_counts: Vec<usize>,
747 pub total_connections: usize,
748 pub avg_connections: f64,
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 fn random_vector(dim: usize) -> Vec<f32> {
756 let mut rng = rand::thread_rng();
757 (0..dim).map(|_| rng.gen::<f32>()).collect()
758 }
759
760 fn normalize(v: &mut Vec<f32>) {
761 let norm: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
762 if norm > 0.0 {
763 for x in v.iter_mut() {
764 *x /= norm;
765 }
766 }
767 }
768
769 #[test]
770 fn test_hnsw_basic_operations() {
771 let index = HnswIndex::new();
772
773 for i in 0..100 {
775 let mut vec = random_vector(128);
776 normalize(&mut vec);
777 index.insert(format!("vec_{}", i), vec);
778 }
779
780 assert_eq!(index.len(), 100);
781 assert!(!index.is_empty());
782
783 let mut query = random_vector(128);
785 normalize(&mut query);
786 let results = index.search(&query, 10);
787
788 assert_eq!(results.len(), 10);
789
790 for i in 1..results.len() {
792 assert!(results[i - 1].1 <= results[i].1);
793 }
794 }
795
796 #[test]
797 fn test_hnsw_delete() {
798 let index = HnswIndex::new();
799
800 for i in 0..10 {
801 let mut vec = random_vector(64);
802 normalize(&mut vec);
803 index.insert(format!("vec_{}", i), vec);
804 }
805
806 assert_eq!(index.len(), 10);
807
808 assert!(index.delete(&"vec_5".to_string()));
810 assert_eq!(index.len(), 9);
811
812 assert!(!index.delete(&"vec_999".to_string()));
814 }
815
816 #[test]
817 fn test_hnsw_recall() {
818 let dim = 128;
819 let n_vectors = 1000;
820 let index = HnswIndex::with_config(HnswConfig::new(16, 200, 100));
821
822 let mut vectors: Vec<(VectorId, Vec<f32>)> = Vec::new();
824 for i in 0..n_vectors {
825 let mut vec = random_vector(dim);
826 normalize(&mut vec);
827 let id: VectorId = format!("vec_{}", i);
828 vectors.push((id.clone(), vec.clone()));
829 index.insert(id, vec);
830 }
831
832 let n_queries = 10;
834 let k = 10;
835 let mut total_recall = 0.0;
836
837 for _ in 0..n_queries {
838 let mut query = random_vector(dim);
839 normalize(&mut query);
840
841 let hnsw_results: HashSet<String> = index
843 .search(&query, k)
844 .into_iter()
845 .map(|(id, _)| id)
846 .collect();
847
848 let mut exact: Vec<(String, f32)> = vectors
850 .iter()
851 .map(|(id, vec)| {
852 let sim = calculate_distance(&query, vec, DistanceMetric::Cosine);
853 (
854 id.clone(),
855 similarity_to_distance(sim, DistanceMetric::Cosine),
856 )
857 })
858 .collect();
859 exact.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
860 let exact_results: HashSet<String> =
861 exact.into_iter().take(k).map(|(id, _)| id).collect();
862
863 let overlap = hnsw_results.intersection(&exact_results).count();
865 total_recall += overlap as f64 / k as f64;
866 }
867
868 let avg_recall = total_recall / n_queries as f64;
869 println!("Average recall@{}: {:.2}%", k, avg_recall * 100.0);
870
871 assert!(
873 avg_recall >= 0.80,
874 "Recall too low: {:.2}%",
875 avg_recall * 100.0
876 );
877 }
878
879 #[test]
880 fn test_hnsw_stats() {
881 let index = HnswIndex::new();
882
883 for i in 0..50 {
884 let mut vec = random_vector(64);
885 normalize(&mut vec);
886 index.insert(format!("vec_{}", i), vec);
887 }
888
889 let stats = index.stats();
890 assert_eq!(stats.num_vectors, 50);
891 let _ = stats.max_level;
893 assert!(stats.avg_connections > 0.0);
894
895 println!("HNSW Stats: {:?}", stats);
896 }
897
898 #[test]
899 fn test_hnsw_custom_ef() {
900 let index = HnswIndex::new();
901
902 for i in 0..100 {
903 let mut vec = random_vector(64);
904 normalize(&mut vec);
905 index.insert(format!("vec_{}", i), vec);
906 }
907
908 let mut query = random_vector(64);
909 normalize(&mut query);
910
911 let results_low_ef = index.search_with_ef(&query, 10, 10);
913 let results_high_ef = index.search_with_ef(&query, 10, 200);
914
915 assert_eq!(results_low_ef.len(), 10);
916 assert_eq!(results_high_ef.len(), 10);
917
918 }
921
922 #[test]
923 fn test_hnsw_empty_search() {
924 let index = HnswIndex::new();
925 let query = random_vector(64);
926 let results = index.search(&query, 10);
927 assert!(results.is_empty());
928 }
929
930 #[test]
931 fn test_hnsw_single_vector() {
932 let index = HnswIndex::new();
933
934 let mut vec = random_vector(64);
935 normalize(&mut vec);
936 index.insert("single".to_string(), vec.clone());
937
938 let results = index.search(&vec, 5);
939 assert_eq!(results.len(), 1);
940 assert_eq!(results[0].0, "single".to_string());
941 assert!(
943 results[0].1.abs() < 0.1,
944 "Distance to self was {}",
945 results[0].1
946 );
947 }
948}