1use crate::{Vector, VectorIndex};
55use anyhow::Result;
56use oxirs_core::simd::SimdOps;
57use parking_lot::RwLock as ParkingLotRwLock;
58use scirs2_core::random::Random;
59use std::cmp::Ordering;
60use std::collections::{BinaryHeap, HashMap, HashSet};
61use std::sync::{Arc, RwLock};
62
63#[derive(Debug, Clone)]
65pub struct NsgConfig {
66 pub out_degree: usize,
68 pub candidate_pool_size: usize,
70 pub search_length: usize,
72 pub distance_metric: DistanceMetric,
74 pub random_seed: Option<u64>,
76 pub parallel_construction: bool,
78 pub num_threads: usize,
80 pub initial_knn_degree: usize,
82 pub pruning_threshold: f32,
84}
85
86impl Default for NsgConfig {
87 fn default() -> Self {
88 Self {
89 out_degree: 32,
90 candidate_pool_size: 100,
91 search_length: 50,
92 distance_metric: DistanceMetric::Euclidean,
93 random_seed: None,
94 parallel_construction: true,
95 num_threads: num_cpus::get(),
96 initial_knn_degree: 64,
97 pruning_threshold: 1.0,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DistanceMetric {
105 Euclidean,
106 Manhattan,
107 Cosine,
108 Angular,
109 InnerProduct,
110}
111
112impl DistanceMetric {
113 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
115 match self {
116 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
117 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
118 DistanceMetric::Cosine => f32::cosine_distance(a, b),
119 DistanceMetric::Angular => {
120 let cos_sim = 1.0 - f32::cosine_distance(a, b);
121 cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
122 }
123 DistanceMetric::InnerProduct => {
124 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
126 }
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133struct Candidate {
134 id: usize,
135 distance: f32,
136}
137
138impl PartialEq for Candidate {
139 fn eq(&self, other: &Self) -> bool {
140 self.distance == other.distance && self.id == other.id
141 }
142}
143
144impl Eq for Candidate {}
145
146impl PartialOrd for Candidate {
147 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148 Some(self.cmp(other))
149 }
150}
151
152impl Ord for Candidate {
153 fn cmp(&self, other: &Self) -> Ordering {
154 other
156 .distance
157 .partial_cmp(&self.distance)
158 .unwrap_or(Ordering::Equal)
159 .then_with(|| self.id.cmp(&other.id))
160 }
161}
162
163pub struct NsgIndex {
165 config: NsgConfig,
167 data: Vec<(String, Vector)>,
169 graph: Vec<Vec<usize>>,
171 entry_point: Option<usize>,
173 is_built: bool,
175 uri_to_idx: HashMap<String, usize>,
177 stats: Arc<RwLock<NsgStats>>,
179}
180
181#[derive(Debug, Clone, Default)]
183pub struct NsgStats {
184 pub num_vectors: usize,
186 pub num_edges: usize,
188 pub avg_out_degree: f64,
190 pub max_out_degree: usize,
192 pub num_searches: usize,
194 pub avg_search_path_length: f64,
196 pub total_distance_computations: usize,
198}
199
200impl NsgIndex {
201 pub fn new(config: NsgConfig) -> Result<Self> {
203 Ok(Self {
204 config,
205 data: Vec::new(),
206 graph: Vec::new(),
207 entry_point: None,
208 is_built: false,
209 uri_to_idx: HashMap::new(),
210 stats: Arc::new(RwLock::new(NsgStats::default())),
211 })
212 }
213
214 pub fn add(&mut self, uri: String, vector: Vector) -> Result<()> {
216 if self.is_built {
217 return Err(anyhow::anyhow!(
218 "Cannot add vectors after index is built. Call rebuild() or create a new index."
219 ));
220 }
221
222 let idx = self.data.len();
223 self.uri_to_idx.insert(uri.clone(), idx);
224 self.data.push((uri, vector));
225
226 Ok(())
227 }
228
229 pub fn build(&mut self) -> Result<()> {
235 if self.data.is_empty() {
236 return Err(anyhow::anyhow!("Cannot build index with no vectors"));
237 }
238
239 tracing::info!("Building NSG index with {} vectors", self.data.len());
240
241 tracing::debug!("Stage 1: Building initial kNN graph");
243 self.build_knn_graph()?;
244
245 tracing::debug!("Stage 2: Refining to navigable monotonic graph");
247 self.refine_to_nsg()?;
248
249 self.select_entry_point()?;
251
252 self.is_built = true;
253
254 self.update_stats();
256
257 tracing::info!(
258 "NSG index built successfully. {} vectors, {} edges, avg out-degree: {:.2}",
259 self.data.len(),
260 self.count_edges(),
261 self.avg_out_degree()
262 );
263
264 Ok(())
265 }
266
267 fn build_knn_graph(&mut self) -> Result<()> {
269 let n = self.data.len();
270 self.graph = vec![Vec::new(); n];
271
272 if self.config.parallel_construction && n > 1000 {
273 self.build_knn_graph_parallel()?;
274 } else {
275 self.build_knn_graph_sequential()?;
276 }
277
278 Ok(())
279 }
280
281 fn build_knn_graph_sequential(&mut self) -> Result<()> {
283 let n = self.data.len();
284 let k = self.config.initial_knn_degree.min(n - 1);
285
286 for i in 0..n {
287 let mut neighbors = Vec::new();
288
289 for j in 0..n {
291 if i == j {
292 continue;
293 }
294
295 let dist = self.calculate_distance(i, j);
296 neighbors.push(Candidate {
297 id: j,
298 distance: dist,
299 });
300 }
301
302 neighbors.sort_by(|a, b| {
304 a.distance
305 .partial_cmp(&b.distance)
306 .unwrap_or(Ordering::Equal)
307 });
308 neighbors.truncate(k);
309
310 self.graph[i] = neighbors.iter().map(|c| c.id).collect();
312 }
313
314 Ok(())
315 }
316
317 fn build_knn_graph_parallel(&mut self) -> Result<()> {
319 let n = self.data.len();
320 let k = self.config.initial_knn_degree.min(n - 1);
321
322 let graph = Arc::new(ParkingLotRwLock::new(vec![Vec::new(); n]));
324 let data = Arc::new(self.data.clone());
325 let config = self.config.clone();
326
327 let chunk_size = (n + self.config.num_threads - 1) / self.config.num_threads;
329 let mut handles = Vec::new();
330
331 for chunk_start in (0..n).step_by(chunk_size) {
332 let chunk_end = (chunk_start + chunk_size).min(n);
333 let graph_clone = Arc::clone(&graph);
334 let data_clone = Arc::clone(&data);
335 let config_clone = config.clone();
336
337 let handle = std::thread::spawn(move || {
338 for i in chunk_start..chunk_end {
339 let mut neighbors = Vec::new();
340
341 for j in 0..n {
342 if i == j {
343 continue;
344 }
345
346 let vec_i = &data_clone[i].1.as_f32();
347 let vec_j = &data_clone[j].1.as_f32();
348 let dist = config_clone.distance_metric.distance(vec_i, vec_j);
349
350 neighbors.push(Candidate {
351 id: j,
352 distance: dist,
353 });
354 }
355
356 neighbors.sort_by(|a, b| {
357 a.distance
358 .partial_cmp(&b.distance)
359 .unwrap_or(Ordering::Equal)
360 });
361 neighbors.truncate(k);
362
363 let mut graph_lock = graph_clone.write();
364 graph_lock[i] = neighbors.iter().map(|c| c.id).collect();
365 }
366 });
367
368 handles.push(handle);
369 }
370
371 for handle in handles {
373 handle
374 .join()
375 .map_err(|_| anyhow::anyhow!("Thread panicked"))?;
376 }
377
378 self.graph = Arc::try_unwrap(graph)
380 .map_err(|_| anyhow::anyhow!("Failed to unwrap graph"))?
381 .into_inner();
382
383 Ok(())
384 }
385
386 fn refine_to_nsg(&mut self) -> Result<()> {
388 let n = self.data.len();
389 let mut new_graph = vec![Vec::new(); n];
390
391 let temp_entry = self.select_temp_entry_point();
393
394 #[allow(clippy::needless_range_loop)]
395 for i in 0..n {
396 let candidates = self.search_for_neighbors(i, temp_entry)?;
398
399 let neighbors = self.prune_neighbors(i, candidates)?;
401
402 new_graph[i] = neighbors;
403 }
404
405 self.ensure_connectivity(&mut new_graph)?;
407
408 self.graph = new_graph;
409
410 Ok(())
411 }
412
413 fn search_for_neighbors(&self, query_id: usize, entry_id: usize) -> Result<Vec<Candidate>> {
415 let mut visited = HashSet::new();
416 let mut candidates = BinaryHeap::new();
417 let mut result = Vec::new();
418
419 let entry_dist = self.calculate_distance(query_id, entry_id);
421 candidates.push(Candidate {
422 id: entry_id,
423 distance: entry_dist,
424 });
425 visited.insert(entry_id);
426
427 while let Some(current) = candidates.pop() {
428 if result.len() >= self.config.candidate_pool_size {
429 break;
430 }
431
432 result.push(current.clone());
433
434 for &neighbor_id in &self.graph[current.id] {
436 if visited.contains(&neighbor_id) {
437 continue;
438 }
439
440 visited.insert(neighbor_id);
441
442 let dist = self.calculate_distance(query_id, neighbor_id);
443 candidates.push(Candidate {
444 id: neighbor_id,
445 distance: dist,
446 });
447
448 if visited.len() >= self.config.search_length {
449 break;
450 }
451 }
452 }
453
454 result.sort_by(|a, b| {
456 a.distance
457 .partial_cmp(&b.distance)
458 .unwrap_or(Ordering::Equal)
459 });
460
461 Ok(result)
462 }
463
464 fn prune_neighbors(
466 &self,
467 _query_id: usize,
468 mut candidates: Vec<Candidate>,
469 ) -> Result<Vec<usize>> {
470 if candidates.is_empty() {
471 return Ok(Vec::new());
472 }
473
474 let mut result = Vec::new();
475 let mut pruned = HashSet::new();
476
477 while !candidates.is_empty() && result.len() < self.config.out_degree {
478 let best_idx = candidates
480 .iter()
481 .position_min_by(|a, b| {
482 a.distance
483 .partial_cmp(&b.distance)
484 .unwrap_or(Ordering::Equal)
485 })
486 .expect("candidates should not be empty during pruning");
487
488 let best = candidates.swap_remove(best_idx);
489
490 if pruned.contains(&best.id) {
491 continue;
492 }
493
494 result.push(best.id);
495 pruned.insert(best.id);
496
497 candidates.retain(|c| {
499 let dist_to_best = self.calculate_distance(c.id, best.id);
500 dist_to_best > best.distance * self.config.pruning_threshold
501 });
502 }
503
504 Ok(result)
505 }
506
507 fn ensure_connectivity(&self, graph: &mut [Vec<usize>]) -> Result<()> {
509 let n = graph.len();
510
511 let mut in_edges: Vec<HashSet<usize>> = vec![HashSet::new(); n];
513 for (i, neighbors) in graph.iter().enumerate() {
514 for &j in neighbors {
515 in_edges[j].insert(i);
516 }
517 }
518
519 for (i, edges) in in_edges.iter().enumerate() {
521 if edges.is_empty() && i != 0 {
522 let mut min_dist = f32::INFINITY;
524 let mut closest = 0;
525
526 for (j, neighbors) in graph.iter().enumerate() {
527 if i == j || neighbors.len() >= self.config.out_degree {
528 continue;
529 }
530
531 let dist = self.calculate_distance(i, j);
532 if dist < min_dist {
533 min_dist = dist;
534 closest = j;
535 }
536 }
537
538 if !graph[closest].contains(&i) {
540 graph[closest].push(i);
541 }
542 }
543 }
544
545 Ok(())
546 }
547
548 fn select_entry_point(&mut self) -> Result<()> {
550 if self.data.is_empty() {
551 return Ok(());
552 }
553
554 let mut max_degree = 0;
555 let mut entry = 0;
556
557 for i in 0..self.graph.len() {
558 if self.graph[i].len() > max_degree {
559 max_degree = self.graph[i].len();
560 entry = i;
561 }
562 }
563
564 self.entry_point = Some(entry);
565
566 Ok(())
567 }
568
569 fn select_temp_entry_point(&self) -> usize {
571 if let Some(seed) = self.config.random_seed {
572 let mut rng = Random::seed(seed);
573 rng.random_range(0..self.data.len())
574 } else {
575 self.find_centroid()
577 }
578 }
579
580 fn find_centroid(&self) -> usize {
582 if self.data.is_empty() {
583 return 0;
584 }
585
586 let dim = self.data[0].1.dimensions;
587 let mut centroid = vec![0.0f32; dim];
588
589 for (_, vec) in &self.data {
591 let vals = vec.as_f32();
592 for i in 0..dim {
593 centroid[i] += vals[i];
594 }
595 }
596
597 let n = self.data.len() as f32;
598 for val in &mut centroid {
599 *val /= n;
600 }
601
602 let mut min_dist = f32::INFINITY;
604 let mut closest = 0;
605
606 for i in 0..self.data.len() {
607 let dist = self
608 .config
609 .distance_metric
610 .distance(¢roid, &self.data[i].1.as_f32());
611 if dist < min_dist {
612 min_dist = dist;
613 closest = i;
614 }
615 }
616
617 closest
618 }
619
620 fn calculate_distance(&self, i: usize, j: usize) -> f32 {
622 let vec_i = self.data[i].1.as_f32();
623 let vec_j = self.data[j].1.as_f32();
624 self.config.distance_metric.distance(&vec_i, &vec_j)
625 }
626
627 fn greedy_search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<Candidate>> {
629 if !self.is_built {
630 return Err(anyhow::anyhow!("Index not built. Call build() first."));
631 }
632
633 let entry = self
634 .entry_point
635 .ok_or_else(|| anyhow::anyhow!("No entry point set"))?;
636
637 let mut visited = HashSet::new();
638 let mut candidates = BinaryHeap::new();
639 let mut result_set = BinaryHeap::new();
640
641 let entry_dist = self
643 .config
644 .distance_metric
645 .distance(query, &self.data[entry].1.as_f32());
646 candidates.push(Candidate {
647 id: entry,
648 distance: entry_dist,
649 });
650 result_set.push(Candidate {
651 id: entry,
652 distance: entry_dist,
653 });
654 visited.insert(entry);
655
656 while let Some(current) = candidates.pop() {
657 if result_set.len() >= ef
659 && current.distance
660 > result_set
661 .peek()
662 .expect("result_set should not be empty during search")
663 .distance
664 {
665 break;
666 }
667
668 for &neighbor_id in &self.graph[current.id] {
670 if visited.contains(&neighbor_id) {
671 continue;
672 }
673
674 visited.insert(neighbor_id);
675
676 let dist = self
677 .config
678 .distance_metric
679 .distance(query, &self.data[neighbor_id].1.as_f32());
680 let candidate = Candidate {
681 id: neighbor_id,
682 distance: dist,
683 };
684
685 if result_set.len() < ef
686 || dist
687 < result_set
688 .peek()
689 .expect("result_set should not be empty during search")
690 .distance
691 {
692 candidates.push(candidate.clone());
693 result_set.push(candidate);
694
695 if result_set.len() > ef {
696 result_set.pop();
697 }
698 }
699 }
700 }
701
702 let mut results: Vec<_> = result_set.into_sorted_vec();
704 results.truncate(k);
705
706 Ok(results)
707 }
708
709 fn update_stats(&self) {
711 let mut stats = self
712 .stats
713 .write()
714 .expect("stats lock should not be poisoned");
715 stats.num_vectors = self.data.len();
716 stats.num_edges = self.count_edges();
717 stats.avg_out_degree = self.avg_out_degree();
718 stats.max_out_degree = self.max_out_degree();
719 }
720
721 fn count_edges(&self) -> usize {
723 self.graph.iter().map(|neighbors| neighbors.len()).sum()
724 }
725
726 fn avg_out_degree(&self) -> f64 {
728 if self.graph.is_empty() {
729 return 0.0;
730 }
731 self.count_edges() as f64 / self.graph.len() as f64
732 }
733
734 fn max_out_degree(&self) -> usize {
736 self.graph
737 .iter()
738 .map(|neighbors| neighbors.len())
739 .max()
740 .unwrap_or(0)
741 }
742
743 pub fn stats(&self) -> NsgStats {
745 self.stats
746 .read()
747 .expect("stats lock should not be poisoned")
748 .clone()
749 }
750
751 pub fn len(&self) -> usize {
753 self.data.len()
754 }
755
756 pub fn is_empty(&self) -> bool {
758 self.data.is_empty()
759 }
760
761 pub fn is_built(&self) -> bool {
763 self.is_built
764 }
765}
766
767impl VectorIndex for NsgIndex {
768 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
769 self.add(uri, vector)
770 }
771
772 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
773 let query_vals = query.as_f32();
774 let ef = k.max(self.config.search_length);
775 let candidates = self.greedy_search(&query_vals, k, ef)?;
776
777 let mut results: Vec<_> = candidates
781 .into_iter()
782 .map(|c| {
783 let uri = self.data[c.id].0.clone();
784 let similarity = 1.0 / (1.0 + c.distance);
785 (uri, similarity)
786 })
787 .collect();
788
789 results.reverse();
791
792 Ok(results)
793 }
794
795 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
796 let k = self.data.len().min(1000);
798 let all_results = self.search_knn(query, k)?;
799
800 let filtered: Vec<_> = all_results
801 .into_iter()
802 .filter(|(_, similarity)| *similarity >= threshold)
803 .collect();
804
805 Ok(filtered)
806 }
807
808 fn get_vector(&self, uri: &str) -> Option<&Vector> {
809 self.uri_to_idx
810 .get(uri)
811 .and_then(|&idx| self.data.get(idx))
812 .map(|(_, vec)| vec)
813 }
814
815 fn remove_vector(&mut self, id: String) -> Result<()> {
816 if self.is_built {
817 return Err(anyhow::anyhow!(
818 "Cannot remove vectors from built index. Rebuild index instead."
819 ));
820 }
821
822 if let Some(&idx) = self.uri_to_idx.get(&id) {
823 self.data.remove(idx);
824 self.uri_to_idx.remove(&id);
825
826 self.uri_to_idx.clear();
828 for (i, (uri, _)) in self.data.iter().enumerate() {
829 self.uri_to_idx.insert(uri.clone(), i);
830 }
831
832 Ok(())
833 } else {
834 Err(anyhow::anyhow!("Vector with id '{}' not found", id))
835 }
836 }
837}
838
839trait IteratorExt: Iterator {
841 fn position_min_by<F>(self, compare: F) -> Option<usize>
842 where
843 F: FnMut(&Self::Item, &Self::Item) -> Ordering;
844}
845
846impl<I: Iterator> IteratorExt for I {
847 fn position_min_by<F>(mut self, mut compare: F) -> Option<usize>
848 where
849 F: FnMut(&Self::Item, &Self::Item) -> Ordering,
850 {
851 let first = self.next()?;
852 let mut min_item = first;
853 let mut min_pos = 0;
854
855 for (pos, item) in self.enumerate() {
856 if compare(&item, &min_item) == Ordering::Less {
857 min_item = item;
858 min_pos = pos + 1;
859 }
860 }
861
862 Some(min_pos)
863 }
864}
865
866#[cfg(test)]
867mod tests {
868 use super::*;
869
870 #[test]
871 fn test_nsg_creation() {
872 let config = NsgConfig::default();
873 let index = NsgIndex::new(config).unwrap();
874 assert_eq!(index.len(), 0);
875 assert!(!index.is_built());
876 }
877
878 #[test]
879 fn test_nsg_add_vectors() {
880 let config = NsgConfig::default();
881 let mut index = NsgIndex::new(config).unwrap();
882
883 for i in 0..10 {
884 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
885 index.add(format!("vec_{}", i), vec).unwrap();
886 }
887
888 assert_eq!(index.len(), 10);
889 }
890
891 #[test]
892 fn test_nsg_build_and_search() {
893 let config = NsgConfig {
894 out_degree: 32,
895 candidate_pool_size: 100,
896 search_length: 50,
897 initial_knn_degree: 64,
898 ..Default::default()
899 };
900 let mut index = NsgIndex::new(config).unwrap();
901
902 for i in 0..100 {
904 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
905 index.add(format!("vec_{}", i), vec).unwrap();
906 }
907
908 index.build().unwrap();
910 assert!(index.is_built());
911
912 let query = Vector::new(vec![10.1, 20.1, 30.1]);
914 let results = index.search_knn(&query, 10).unwrap();
915
916 assert!(!results.is_empty());
917 assert_eq!(results.len(), 10);
918
919 for i in 1..results.len() {
921 assert!(
922 results[i - 1].1 >= results[i].1,
923 "Results not sorted: {}@{} < {}@{}",
924 results[i - 1].1,
925 i - 1,
926 results[i].1,
927 i
928 );
929 }
930
931 let nearby_found = results.iter().take(10).any(|(uri, _)| {
934 uri.contains("10")
935 || uri.contains("11")
936 || uri.contains("9")
937 || uri.contains("12")
938 || uri.contains("8")
939 });
940 assert!(
941 nearby_found,
942 "Expected nearby vectors (8-12) in top 10 results"
943 );
944 }
945
946 #[test]
947 fn test_nsg_distance_metrics() {
948 for metric in [
949 DistanceMetric::Euclidean,
950 DistanceMetric::Manhattan,
951 DistanceMetric::Cosine,
952 DistanceMetric::Angular,
953 ] {
954 let config = NsgConfig {
955 distance_metric: metric,
956 out_degree: 8,
957 ..Default::default()
958 };
959 let mut index = NsgIndex::new(config).unwrap();
960
961 for i in 0..20 {
962 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
963 index.add(format!("vec_{}", i), vec).unwrap();
964 }
965
966 index.build().unwrap();
967
968 let query = Vector::new(vec![10.0, 20.0]);
969 let results = index.search_knn(&query, 3).unwrap();
970
971 assert!(!results.is_empty());
972 }
973 }
974
975 #[test]
976 fn test_nsg_stats() {
977 let config = NsgConfig::default();
978 let mut index = NsgIndex::new(config).unwrap();
979
980 for i in 0..50 {
981 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
982 index.add(format!("vec_{}", i), vec).unwrap();
983 }
984
985 index.build().unwrap();
986
987 let stats = index.stats();
988 assert_eq!(stats.num_vectors, 50);
989 assert!(stats.num_edges > 0);
990 assert!(stats.avg_out_degree > 0.0);
991 }
992
993 #[test]
994 fn test_nsg_threshold_search() {
995 let config = NsgConfig::default();
996 let mut index = NsgIndex::new(config).unwrap();
997
998 for i in 0..30 {
999 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
1000 index.add(format!("vec_{}", i), vec).unwrap();
1001 }
1002
1003 index.build().unwrap();
1004
1005 let query = Vector::new(vec![15.0, 30.0]);
1006 let results = index.search_threshold(&query, 0.5).unwrap();
1007
1008 assert!(!results.is_empty());
1009 for (_, similarity) in results {
1011 assert!(similarity >= 0.5);
1012 }
1013 }
1014}