1use crate::{Vector, VectorIndex};
55use anyhow::Result;
56use oxirs_core::simd::SimdOps;
57use parking_lot::RwLock as ParkingLotRwLock;
58use scirs2_core::random::Random;
59use std::cmp::Ordering;
60use std::collections::{BinaryHeap, HashMap, HashSet};
61use std::sync::{Arc, RwLock};
62
63#[derive(Debug, Clone)]
65pub struct NsgConfig {
66 pub out_degree: usize,
68 pub candidate_pool_size: usize,
70 pub search_length: usize,
72 pub distance_metric: DistanceMetric,
74 pub random_seed: Option<u64>,
76 pub parallel_construction: bool,
78 pub num_threads: usize,
80 pub initial_knn_degree: usize,
82 pub pruning_threshold: f32,
84}
85
86impl Default for NsgConfig {
87 fn default() -> Self {
88 Self {
89 out_degree: 32,
90 candidate_pool_size: 100,
91 search_length: 50,
92 distance_metric: DistanceMetric::Euclidean,
93 random_seed: None,
94 parallel_construction: true,
95 num_threads: num_cpus::get(),
96 initial_knn_degree: 64,
97 pruning_threshold: 1.0,
98 }
99 }
100}
101
102#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DistanceMetric {
105 Euclidean,
106 Manhattan,
107 Cosine,
108 Angular,
109 InnerProduct,
110}
111
112impl DistanceMetric {
113 pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
115 match self {
116 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
117 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
118 DistanceMetric::Cosine => f32::cosine_distance(a, b),
119 DistanceMetric::Angular => {
120 let cos_sim = 1.0 - f32::cosine_distance(a, b);
121 cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
122 }
123 DistanceMetric::InnerProduct => {
124 -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
126 }
127 }
128 }
129}
130
131#[derive(Debug, Clone)]
133struct Candidate {
134 id: usize,
135 distance: f32,
136}
137
138impl PartialEq for Candidate {
139 fn eq(&self, other: &Self) -> bool {
140 self.distance == other.distance && self.id == other.id
141 }
142}
143
144impl Eq for Candidate {}
145
146impl PartialOrd for Candidate {
147 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148 Some(self.cmp(other))
149 }
150}
151
152impl Ord for Candidate {
153 fn cmp(&self, other: &Self) -> Ordering {
154 other
156 .distance
157 .partial_cmp(&self.distance)
158 .unwrap_or(Ordering::Equal)
159 .then_with(|| self.id.cmp(&other.id))
160 }
161}
162
163pub struct NsgIndex {
165 config: NsgConfig,
167 data: Vec<(String, Vector)>,
169 graph: Vec<Vec<usize>>,
171 entry_point: Option<usize>,
173 is_built: bool,
175 uri_to_idx: HashMap<String, usize>,
177 stats: Arc<RwLock<NsgStats>>,
179}
180
181#[derive(Debug, Clone, Default)]
183pub struct NsgStats {
184 pub num_vectors: usize,
186 pub num_edges: usize,
188 pub avg_out_degree: f64,
190 pub max_out_degree: usize,
192 pub num_searches: usize,
194 pub avg_search_path_length: f64,
196 pub total_distance_computations: usize,
198}
199
200impl NsgIndex {
201 pub fn new(config: NsgConfig) -> Result<Self> {
203 Ok(Self {
204 config,
205 data: Vec::new(),
206 graph: Vec::new(),
207 entry_point: None,
208 is_built: false,
209 uri_to_idx: HashMap::new(),
210 stats: Arc::new(RwLock::new(NsgStats::default())),
211 })
212 }
213
214 pub fn add(&mut self, uri: String, vector: Vector) -> Result<()> {
216 if self.is_built {
217 return Err(anyhow::anyhow!(
218 "Cannot add vectors after index is built. Call rebuild() or create a new index."
219 ));
220 }
221
222 let idx = self.data.len();
223 self.uri_to_idx.insert(uri.clone(), idx);
224 self.data.push((uri, vector));
225
226 Ok(())
227 }
228
229 pub fn build(&mut self) -> Result<()> {
235 if self.data.is_empty() {
236 return Err(anyhow::anyhow!("Cannot build index with no vectors"));
237 }
238
239 tracing::info!("Building NSG index with {} vectors", self.data.len());
240
241 tracing::debug!("Stage 1: Building initial kNN graph");
243 self.build_knn_graph()?;
244
245 tracing::debug!("Stage 2: Refining to navigable monotonic graph");
247 self.refine_to_nsg()?;
248
249 self.select_entry_point()?;
251
252 self.is_built = true;
253
254 self.update_stats();
256
257 tracing::info!(
258 "NSG index built successfully. {} vectors, {} edges, avg out-degree: {:.2}",
259 self.data.len(),
260 self.count_edges(),
261 self.avg_out_degree()
262 );
263
264 Ok(())
265 }
266
267 fn build_knn_graph(&mut self) -> Result<()> {
269 let n = self.data.len();
270 self.graph = vec![Vec::new(); n];
271
272 if self.config.parallel_construction && n > 1000 {
273 self.build_knn_graph_parallel()?;
274 } else {
275 self.build_knn_graph_sequential()?;
276 }
277
278 Ok(())
279 }
280
281 fn build_knn_graph_sequential(&mut self) -> Result<()> {
283 let n = self.data.len();
284 let k = self.config.initial_knn_degree.min(n - 1);
285
286 for i in 0..n {
287 let mut neighbors = Vec::new();
288
289 for j in 0..n {
291 if i == j {
292 continue;
293 }
294
295 let dist = self.calculate_distance(i, j);
296 neighbors.push(Candidate {
297 id: j,
298 distance: dist,
299 });
300 }
301
302 neighbors.sort_by(|a, b| {
304 a.distance
305 .partial_cmp(&b.distance)
306 .unwrap_or(Ordering::Equal)
307 });
308 neighbors.truncate(k);
309
310 self.graph[i] = neighbors.iter().map(|c| c.id).collect();
312 }
313
314 Ok(())
315 }
316
317 fn build_knn_graph_parallel(&mut self) -> Result<()> {
319 let n = self.data.len();
320 let k = self.config.initial_knn_degree.min(n - 1);
321
322 let graph = Arc::new(ParkingLotRwLock::new(vec![Vec::new(); n]));
324 let data = Arc::new(self.data.clone());
325 let config = self.config.clone();
326
327 let chunk_size = (n + self.config.num_threads - 1) / self.config.num_threads;
329 let mut handles = Vec::new();
330
331 for chunk_start in (0..n).step_by(chunk_size) {
332 let chunk_end = (chunk_start + chunk_size).min(n);
333 let graph_clone = Arc::clone(&graph);
334 let data_clone = Arc::clone(&data);
335 let config_clone = config.clone();
336
337 let handle = std::thread::spawn(move || {
338 for i in chunk_start..chunk_end {
339 let mut neighbors = Vec::new();
340
341 for j in 0..n {
342 if i == j {
343 continue;
344 }
345
346 let vec_i = &data_clone[i].1.as_f32();
347 let vec_j = &data_clone[j].1.as_f32();
348 let dist = config_clone.distance_metric.distance(vec_i, vec_j);
349
350 neighbors.push(Candidate {
351 id: j,
352 distance: dist,
353 });
354 }
355
356 neighbors.sort_by(|a, b| {
357 a.distance
358 .partial_cmp(&b.distance)
359 .unwrap_or(Ordering::Equal)
360 });
361 neighbors.truncate(k);
362
363 let mut graph_lock = graph_clone.write();
364 graph_lock[i] = neighbors.iter().map(|c| c.id).collect();
365 }
366 });
367
368 handles.push(handle);
369 }
370
371 for handle in handles {
373 handle
374 .join()
375 .map_err(|_| anyhow::anyhow!("Thread panicked"))?;
376 }
377
378 self.graph = Arc::try_unwrap(graph)
380 .map_err(|_| anyhow::anyhow!("Failed to unwrap graph"))?
381 .into_inner();
382
383 Ok(())
384 }
385
386 fn refine_to_nsg(&mut self) -> Result<()> {
388 let n = self.data.len();
389 let mut new_graph = vec![Vec::new(); n];
390
391 let temp_entry = self.select_temp_entry_point();
393
394 #[allow(clippy::needless_range_loop)]
395 for i in 0..n {
396 let candidates = self.search_for_neighbors(i, temp_entry)?;
398
399 let neighbors = self.prune_neighbors(i, candidates)?;
401
402 new_graph[i] = neighbors;
403 }
404
405 self.ensure_connectivity(&mut new_graph)?;
407
408 self.graph = new_graph;
409
410 Ok(())
411 }
412
413 fn search_for_neighbors(&self, query_id: usize, entry_id: usize) -> Result<Vec<Candidate>> {
415 let mut visited = HashSet::new();
416 let mut candidates = BinaryHeap::new();
417 let mut result = Vec::new();
418
419 let entry_dist = self.calculate_distance(query_id, entry_id);
421 candidates.push(Candidate {
422 id: entry_id,
423 distance: entry_dist,
424 });
425 visited.insert(entry_id);
426
427 while let Some(current) = candidates.pop() {
428 if result.len() >= self.config.candidate_pool_size {
429 break;
430 }
431
432 result.push(current.clone());
433
434 for &neighbor_id in &self.graph[current.id] {
436 if visited.contains(&neighbor_id) {
437 continue;
438 }
439
440 visited.insert(neighbor_id);
441
442 let dist = self.calculate_distance(query_id, neighbor_id);
443 candidates.push(Candidate {
444 id: neighbor_id,
445 distance: dist,
446 });
447
448 if visited.len() >= self.config.search_length {
449 break;
450 }
451 }
452 }
453
454 result.sort_by(|a, b| {
456 a.distance
457 .partial_cmp(&b.distance)
458 .unwrap_or(Ordering::Equal)
459 });
460
461 Ok(result)
462 }
463
464 fn prune_neighbors(
466 &self,
467 _query_id: usize,
468 mut candidates: Vec<Candidate>,
469 ) -> Result<Vec<usize>> {
470 if candidates.is_empty() {
471 return Ok(Vec::new());
472 }
473
474 let mut result = Vec::new();
475 let mut pruned = HashSet::new();
476
477 while !candidates.is_empty() && result.len() < self.config.out_degree {
478 let best_idx = candidates
480 .iter()
481 .position_min_by(|a, b| {
482 a.distance
483 .partial_cmp(&b.distance)
484 .unwrap_or(Ordering::Equal)
485 })
486 .unwrap();
487
488 let best = candidates.swap_remove(best_idx);
489
490 if pruned.contains(&best.id) {
491 continue;
492 }
493
494 result.push(best.id);
495 pruned.insert(best.id);
496
497 candidates.retain(|c| {
499 let dist_to_best = self.calculate_distance(c.id, best.id);
500 dist_to_best > best.distance * self.config.pruning_threshold
501 });
502 }
503
504 Ok(result)
505 }
506
507 fn ensure_connectivity(&self, graph: &mut [Vec<usize>]) -> Result<()> {
509 let n = graph.len();
510
511 let mut in_edges: Vec<HashSet<usize>> = vec![HashSet::new(); n];
513 for (i, neighbors) in graph.iter().enumerate() {
514 for &j in neighbors {
515 in_edges[j].insert(i);
516 }
517 }
518
519 for (i, edges) in in_edges.iter().enumerate() {
521 if edges.is_empty() && i != 0 {
522 let mut min_dist = f32::INFINITY;
524 let mut closest = 0;
525
526 for (j, neighbors) in graph.iter().enumerate() {
527 if i == j || neighbors.len() >= self.config.out_degree {
528 continue;
529 }
530
531 let dist = self.calculate_distance(i, j);
532 if dist < min_dist {
533 min_dist = dist;
534 closest = j;
535 }
536 }
537
538 if !graph[closest].contains(&i) {
540 graph[closest].push(i);
541 }
542 }
543 }
544
545 Ok(())
546 }
547
548 fn select_entry_point(&mut self) -> Result<()> {
550 if self.data.is_empty() {
551 return Ok(());
552 }
553
554 let mut max_degree = 0;
555 let mut entry = 0;
556
557 for i in 0..self.graph.len() {
558 if self.graph[i].len() > max_degree {
559 max_degree = self.graph[i].len();
560 entry = i;
561 }
562 }
563
564 self.entry_point = Some(entry);
565
566 Ok(())
567 }
568
569 fn select_temp_entry_point(&self) -> usize {
571 if let Some(seed) = self.config.random_seed {
572 let mut rng = Random::seed(seed);
573 rng.random_range(0..self.data.len())
574 } else {
575 self.find_centroid()
577 }
578 }
579
580 fn find_centroid(&self) -> usize {
582 if self.data.is_empty() {
583 return 0;
584 }
585
586 let dim = self.data[0].1.dimensions;
587 let mut centroid = vec![0.0f32; dim];
588
589 for (_, vec) in &self.data {
591 let vals = vec.as_f32();
592 for i in 0..dim {
593 centroid[i] += vals[i];
594 }
595 }
596
597 let n = self.data.len() as f32;
598 for val in &mut centroid {
599 *val /= n;
600 }
601
602 let mut min_dist = f32::INFINITY;
604 let mut closest = 0;
605
606 for i in 0..self.data.len() {
607 let dist = self
608 .config
609 .distance_metric
610 .distance(¢roid, &self.data[i].1.as_f32());
611 if dist < min_dist {
612 min_dist = dist;
613 closest = i;
614 }
615 }
616
617 closest
618 }
619
620 fn calculate_distance(&self, i: usize, j: usize) -> f32 {
622 let vec_i = self.data[i].1.as_f32();
623 let vec_j = self.data[j].1.as_f32();
624 self.config.distance_metric.distance(&vec_i, &vec_j)
625 }
626
627 fn greedy_search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<Candidate>> {
629 if !self.is_built {
630 return Err(anyhow::anyhow!("Index not built. Call build() first."));
631 }
632
633 let entry = self
634 .entry_point
635 .ok_or_else(|| anyhow::anyhow!("No entry point set"))?;
636
637 let mut visited = HashSet::new();
638 let mut candidates = BinaryHeap::new();
639 let mut result_set = BinaryHeap::new();
640
641 let entry_dist = self
643 .config
644 .distance_metric
645 .distance(query, &self.data[entry].1.as_f32());
646 candidates.push(Candidate {
647 id: entry,
648 distance: entry_dist,
649 });
650 result_set.push(Candidate {
651 id: entry,
652 distance: entry_dist,
653 });
654 visited.insert(entry);
655
656 while let Some(current) = candidates.pop() {
657 if result_set.len() >= ef && current.distance > result_set.peek().unwrap().distance {
659 break;
660 }
661
662 for &neighbor_id in &self.graph[current.id] {
664 if visited.contains(&neighbor_id) {
665 continue;
666 }
667
668 visited.insert(neighbor_id);
669
670 let dist = self
671 .config
672 .distance_metric
673 .distance(query, &self.data[neighbor_id].1.as_f32());
674 let candidate = Candidate {
675 id: neighbor_id,
676 distance: dist,
677 };
678
679 if result_set.len() < ef || dist < result_set.peek().unwrap().distance {
680 candidates.push(candidate.clone());
681 result_set.push(candidate);
682
683 if result_set.len() > ef {
684 result_set.pop();
685 }
686 }
687 }
688 }
689
690 let mut results: Vec<_> = result_set.into_sorted_vec();
692 results.truncate(k);
693
694 Ok(results)
695 }
696
697 fn update_stats(&self) {
699 let mut stats = self.stats.write().unwrap();
700 stats.num_vectors = self.data.len();
701 stats.num_edges = self.count_edges();
702 stats.avg_out_degree = self.avg_out_degree();
703 stats.max_out_degree = self.max_out_degree();
704 }
705
706 fn count_edges(&self) -> usize {
708 self.graph.iter().map(|neighbors| neighbors.len()).sum()
709 }
710
711 fn avg_out_degree(&self) -> f64 {
713 if self.graph.is_empty() {
714 return 0.0;
715 }
716 self.count_edges() as f64 / self.graph.len() as f64
717 }
718
719 fn max_out_degree(&self) -> usize {
721 self.graph
722 .iter()
723 .map(|neighbors| neighbors.len())
724 .max()
725 .unwrap_or(0)
726 }
727
728 pub fn stats(&self) -> NsgStats {
730 self.stats.read().unwrap().clone()
731 }
732
733 pub fn len(&self) -> usize {
735 self.data.len()
736 }
737
738 pub fn is_empty(&self) -> bool {
740 self.data.is_empty()
741 }
742
743 pub fn is_built(&self) -> bool {
745 self.is_built
746 }
747}
748
749impl VectorIndex for NsgIndex {
750 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
751 self.add(uri, vector)
752 }
753
754 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
755 let query_vals = query.as_f32();
756 let ef = k.max(self.config.search_length);
757 let candidates = self.greedy_search(&query_vals, k, ef)?;
758
759 let mut results: Vec<_> = candidates
763 .into_iter()
764 .map(|c| {
765 let uri = self.data[c.id].0.clone();
766 let similarity = 1.0 / (1.0 + c.distance);
767 (uri, similarity)
768 })
769 .collect();
770
771 results.reverse();
773
774 Ok(results)
775 }
776
777 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
778 let k = self.data.len().min(1000);
780 let all_results = self.search_knn(query, k)?;
781
782 let filtered: Vec<_> = all_results
783 .into_iter()
784 .filter(|(_, similarity)| *similarity >= threshold)
785 .collect();
786
787 Ok(filtered)
788 }
789
790 fn get_vector(&self, uri: &str) -> Option<&Vector> {
791 self.uri_to_idx
792 .get(uri)
793 .and_then(|&idx| self.data.get(idx))
794 .map(|(_, vec)| vec)
795 }
796
797 fn remove_vector(&mut self, id: String) -> Result<()> {
798 if self.is_built {
799 return Err(anyhow::anyhow!(
800 "Cannot remove vectors from built index. Rebuild index instead."
801 ));
802 }
803
804 if let Some(&idx) = self.uri_to_idx.get(&id) {
805 self.data.remove(idx);
806 self.uri_to_idx.remove(&id);
807
808 self.uri_to_idx.clear();
810 for (i, (uri, _)) in self.data.iter().enumerate() {
811 self.uri_to_idx.insert(uri.clone(), i);
812 }
813
814 Ok(())
815 } else {
816 Err(anyhow::anyhow!("Vector with id '{}' not found", id))
817 }
818 }
819}
820
821trait IteratorExt: Iterator {
823 fn position_min_by<F>(self, compare: F) -> Option<usize>
824 where
825 F: FnMut(&Self::Item, &Self::Item) -> Ordering;
826}
827
828impl<I: Iterator> IteratorExt for I {
829 fn position_min_by<F>(mut self, mut compare: F) -> Option<usize>
830 where
831 F: FnMut(&Self::Item, &Self::Item) -> Ordering,
832 {
833 let first = self.next()?;
834 let mut min_item = first;
835 let mut min_pos = 0;
836
837 for (pos, item) in self.enumerate() {
838 if compare(&item, &min_item) == Ordering::Less {
839 min_item = item;
840 min_pos = pos + 1;
841 }
842 }
843
844 Some(min_pos)
845 }
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851
852 #[test]
853 fn test_nsg_creation() {
854 let config = NsgConfig::default();
855 let index = NsgIndex::new(config).unwrap();
856 assert_eq!(index.len(), 0);
857 assert!(!index.is_built());
858 }
859
860 #[test]
861 fn test_nsg_add_vectors() {
862 let config = NsgConfig::default();
863 let mut index = NsgIndex::new(config).unwrap();
864
865 for i in 0..10 {
866 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
867 index.add(format!("vec_{}", i), vec).unwrap();
868 }
869
870 assert_eq!(index.len(), 10);
871 }
872
873 #[test]
874 fn test_nsg_build_and_search() {
875 let config = NsgConfig {
876 out_degree: 32,
877 candidate_pool_size: 100,
878 search_length: 50,
879 initial_knn_degree: 64,
880 ..Default::default()
881 };
882 let mut index = NsgIndex::new(config).unwrap();
883
884 for i in 0..100 {
886 let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
887 index.add(format!("vec_{}", i), vec).unwrap();
888 }
889
890 index.build().unwrap();
892 assert!(index.is_built());
893
894 let query = Vector::new(vec![10.1, 20.1, 30.1]);
896 let results = index.search_knn(&query, 10).unwrap();
897
898 assert!(!results.is_empty());
899 assert_eq!(results.len(), 10);
900
901 for i in 1..results.len() {
903 assert!(
904 results[i - 1].1 >= results[i].1,
905 "Results not sorted: {}@{} < {}@{}",
906 results[i - 1].1,
907 i - 1,
908 results[i].1,
909 i
910 );
911 }
912
913 let nearby_found = results.iter().take(10).any(|(uri, _)| {
916 uri.contains("10")
917 || uri.contains("11")
918 || uri.contains("9")
919 || uri.contains("12")
920 || uri.contains("8")
921 });
922 assert!(
923 nearby_found,
924 "Expected nearby vectors (8-12) in top 10 results"
925 );
926 }
927
928 #[test]
929 fn test_nsg_distance_metrics() {
930 for metric in [
931 DistanceMetric::Euclidean,
932 DistanceMetric::Manhattan,
933 DistanceMetric::Cosine,
934 DistanceMetric::Angular,
935 ] {
936 let config = NsgConfig {
937 distance_metric: metric,
938 out_degree: 8,
939 ..Default::default()
940 };
941 let mut index = NsgIndex::new(config).unwrap();
942
943 for i in 0..20 {
944 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
945 index.add(format!("vec_{}", i), vec).unwrap();
946 }
947
948 index.build().unwrap();
949
950 let query = Vector::new(vec![10.0, 20.0]);
951 let results = index.search_knn(&query, 3).unwrap();
952
953 assert!(!results.is_empty());
954 }
955 }
956
957 #[test]
958 fn test_nsg_stats() {
959 let config = NsgConfig::default();
960 let mut index = NsgIndex::new(config).unwrap();
961
962 for i in 0..50 {
963 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
964 index.add(format!("vec_{}", i), vec).unwrap();
965 }
966
967 index.build().unwrap();
968
969 let stats = index.stats();
970 assert_eq!(stats.num_vectors, 50);
971 assert!(stats.num_edges > 0);
972 assert!(stats.avg_out_degree > 0.0);
973 }
974
975 #[test]
976 fn test_nsg_threshold_search() {
977 let config = NsgConfig::default();
978 let mut index = NsgIndex::new(config).unwrap();
979
980 for i in 0..30 {
981 let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
982 index.add(format!("vec_{}", i), vec).unwrap();
983 }
984
985 index.build().unwrap();
986
987 let query = Vector::new(vec![15.0, 30.0]);
988 let results = index.search_threshold(&query, 0.5).unwrap();
989
990 assert!(!results.is_empty());
991 for (_, similarity) in results {
993 assert!(similarity >= 0.5);
994 }
995 }
996}