oxirs_vec/
nsg.rs

1//! NSG (Navigable Small World Graph) index
2//!
3//! NSG is a graph-based approximate nearest neighbor search algorithm that builds
4//! a monotonic navigable graph structure. It provides:
5//!
6//! - **Monotonic Search Path**: Guarantees that search always moves closer to query
7//! - **Memory Efficiency**: Controlled out-degree for compact graph structure
8//! - **High Accuracy**: Better recall than NSW with similar or better performance
9//! - **Fast Search**: O(log n) expected search complexity
10//!
11//! # Algorithm Overview
12//!
13//! NSG construction has two stages:
14//! 1. Build initial kNN graph using any ANN algorithm
15//! 2. Refine graph to ensure navigability and monotonicity
16//!
17//! The key innovation is the monotonic search property: each hop in the search
18//! path gets closer to the query point, preventing cycles and dead-ends.
19//!
20//! # References
21//!
22//! - Fu, Cong, et al. "Fast approximate nearest neighbor search with the navigable
23//!   small world graph." arXiv preprint arXiv:1707.00143 (2017).
24//!
25//! # Example
26//!
27//! ```rust
28//! use oxirs_vec::{Vector, VectorIndex};
29//! use oxirs_vec::nsg::{NsgConfig, NsgIndex};
30//!
31//! let config = NsgConfig {
32//!     out_degree: 32,
33//!     candidate_pool_size: 100,
34//!     search_length: 50,
35//!     ..Default::default()
36//! };
37//!
38//! let mut index = NsgIndex::new(config).unwrap();
39//!
40//! // Add vectors
41//! for i in 0..1000 {
42//!     let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
43//!     index.insert(format!("vec_{}", i), vector).unwrap();
44//! }
45//!
46//! // Build the NSG structure
47//! index.build().unwrap();
48//!
49//! // Search
50//! let query = Vector::new(vec![100.0, 200.0, 300.0]);
51//! let results = index.search_knn(&query, 10).unwrap();
52//! ```
53
54use crate::{Vector, VectorIndex};
55use anyhow::Result;
56use oxirs_core::simd::SimdOps;
57use parking_lot::RwLock as ParkingLotRwLock;
58use scirs2_core::random::Random;
59use std::cmp::Ordering;
60use std::collections::{BinaryHeap, HashMap, HashSet};
61use std::sync::{Arc, RwLock};
62
63/// Configuration for NSG index
64#[derive(Debug, Clone)]
65pub struct NsgConfig {
66    /// Maximum out-degree (number of outgoing edges per node)
67    pub out_degree: usize,
68    /// Candidate pool size during graph construction
69    pub candidate_pool_size: usize,
70    /// Search length during graph refinement
71    pub search_length: usize,
72    /// Distance metric to use
73    pub distance_metric: DistanceMetric,
74    /// Random seed for reproducibility
75    pub random_seed: Option<u64>,
76    /// Enable parallel construction
77    pub parallel_construction: bool,
78    /// Number of threads for parallel construction
79    pub num_threads: usize,
80    /// Initial kNN graph degree
81    pub initial_knn_degree: usize,
82    /// Pruning threshold for edge quality
83    pub pruning_threshold: f32,
84}
85
86impl Default for NsgConfig {
87    fn default() -> Self {
88        Self {
89            out_degree: 32,
90            candidate_pool_size: 100,
91            search_length: 50,
92            distance_metric: DistanceMetric::Euclidean,
93            random_seed: None,
94            parallel_construction: true,
95            num_threads: num_cpus::get(),
96            initial_knn_degree: 64,
97            pruning_threshold: 1.0,
98        }
99    }
100}
101
102/// Distance metrics supported by NSG
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum DistanceMetric {
105    Euclidean,
106    Manhattan,
107    Cosine,
108    Angular,
109    InnerProduct,
110}
111
112impl DistanceMetric {
113    /// Calculate distance between two vectors
114    pub fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
115        match self {
116            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
117            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
118            DistanceMetric::Cosine => f32::cosine_distance(a, b),
119            DistanceMetric::Angular => {
120                let cos_sim = 1.0 - f32::cosine_distance(a, b);
121                cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
122            }
123            DistanceMetric::InnerProduct => {
124                // Negative inner product (to use as distance)
125                -a.iter().zip(b.iter()).map(|(x, y)| x * y).sum::<f32>()
126            }
127        }
128    }
129}
130
131/// Search candidate with distance
132#[derive(Debug, Clone)]
133struct Candidate {
134    id: usize,
135    distance: f32,
136}
137
138impl PartialEq for Candidate {
139    fn eq(&self, other: &Self) -> bool {
140        self.distance == other.distance && self.id == other.id
141    }
142}
143
144impl Eq for Candidate {}
145
146impl PartialOrd for Candidate {
147    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
148        Some(self.cmp(other))
149    }
150}
151
152impl Ord for Candidate {
153    fn cmp(&self, other: &Self) -> Ordering {
154        // Reverse ordering for max-heap (we want min distance at top)
155        other
156            .distance
157            .partial_cmp(&self.distance)
158            .unwrap_or(Ordering::Equal)
159            .then_with(|| self.id.cmp(&other.id))
160    }
161}
162
163/// NSG index structure
164pub struct NsgIndex {
165    /// Configuration
166    config: NsgConfig,
167    /// Stored vectors with URIs
168    data: Vec<(String, Vector)>,
169    /// Forward adjacency list (outgoing edges)
170    graph: Vec<Vec<usize>>,
171    /// Entry point for search
172    entry_point: Option<usize>,
173    /// Whether the index is built
174    is_built: bool,
175    /// URI to index mapping
176    uri_to_idx: HashMap<String, usize>,
177    /// Statistics
178    stats: Arc<RwLock<NsgStats>>,
179}
180
181/// NSG index statistics
182#[derive(Debug, Clone, Default)]
183pub struct NsgStats {
184    /// Number of vectors indexed
185    pub num_vectors: usize,
186    /// Number of edges in the graph
187    pub num_edges: usize,
188    /// Average out-degree
189    pub avg_out_degree: f64,
190    /// Maximum out-degree
191    pub max_out_degree: usize,
192    /// Number of searches performed
193    pub num_searches: usize,
194    /// Average search path length
195    pub avg_search_path_length: f64,
196    /// Total distance computations
197    pub total_distance_computations: usize,
198}
199
200impl NsgIndex {
201    /// Create a new NSG index with given configuration
202    pub fn new(config: NsgConfig) -> Result<Self> {
203        Ok(Self {
204            config,
205            data: Vec::new(),
206            graph: Vec::new(),
207            entry_point: None,
208            is_built: false,
209            uri_to_idx: HashMap::new(),
210            stats: Arc::new(RwLock::new(NsgStats::default())),
211        })
212    }
213
214    /// Add a vector to the index (must call build() after adding all vectors)
215    pub fn add(&mut self, uri: String, vector: Vector) -> Result<()> {
216        if self.is_built {
217            return Err(anyhow::anyhow!(
218                "Cannot add vectors after index is built. Call rebuild() or create a new index."
219            ));
220        }
221
222        let idx = self.data.len();
223        self.uri_to_idx.insert(uri.clone(), idx);
224        self.data.push((uri, vector));
225
226        Ok(())
227    }
228
229    /// Build the NSG structure
230    ///
231    /// This is a two-stage process:
232    /// 1. Build initial kNN graph
233    /// 2. Refine to create navigable monotonic graph
234    pub fn build(&mut self) -> Result<()> {
235        if self.data.is_empty() {
236            return Err(anyhow::anyhow!("Cannot build index with no vectors"));
237        }
238
239        tracing::info!("Building NSG index with {} vectors", self.data.len());
240
241        // Stage 1: Build initial kNN graph
242        tracing::debug!("Stage 1: Building initial kNN graph");
243        self.build_knn_graph()?;
244
245        // Stage 2: Refine to NSG
246        tracing::debug!("Stage 2: Refining to navigable monotonic graph");
247        self.refine_to_nsg()?;
248
249        // Select entry point
250        self.select_entry_point()?;
251
252        self.is_built = true;
253
254        // Update statistics
255        self.update_stats();
256
257        tracing::info!(
258            "NSG index built successfully. {} vectors, {} edges, avg out-degree: {:.2}",
259            self.data.len(),
260            self.count_edges(),
261            self.avg_out_degree()
262        );
263
264        Ok(())
265    }
266
267    /// Build initial kNN graph using brute-force search
268    fn build_knn_graph(&mut self) -> Result<()> {
269        let n = self.data.len();
270        self.graph = vec![Vec::new(); n];
271
272        if self.config.parallel_construction && n > 1000 {
273            self.build_knn_graph_parallel()?;
274        } else {
275            self.build_knn_graph_sequential()?;
276        }
277
278        Ok(())
279    }
280
281    /// Sequential kNN graph construction
282    fn build_knn_graph_sequential(&mut self) -> Result<()> {
283        let n = self.data.len();
284        let k = self.config.initial_knn_degree.min(n - 1);
285
286        for i in 0..n {
287            let mut neighbors = Vec::new();
288
289            // Find k nearest neighbors
290            for j in 0..n {
291                if i == j {
292                    continue;
293                }
294
295                let dist = self.calculate_distance(i, j);
296                neighbors.push(Candidate {
297                    id: j,
298                    distance: dist,
299                });
300            }
301
302            // Sort and keep top-k
303            neighbors.sort_by(|a, b| {
304                a.distance
305                    .partial_cmp(&b.distance)
306                    .unwrap_or(Ordering::Equal)
307            });
308            neighbors.truncate(k);
309
310            // Add bidirectional edges
311            self.graph[i] = neighbors.iter().map(|c| c.id).collect();
312        }
313
314        Ok(())
315    }
316
317    /// Parallel kNN graph construction
318    fn build_knn_graph_parallel(&mut self) -> Result<()> {
319        let n = self.data.len();
320        let k = self.config.initial_knn_degree.min(n - 1);
321
322        // Create thread-safe graph structure
323        let graph = Arc::new(ParkingLotRwLock::new(vec![Vec::new(); n]));
324        let data = Arc::new(self.data.clone());
325        let config = self.config.clone();
326
327        // Process in parallel chunks
328        let chunk_size = (n + self.config.num_threads - 1) / self.config.num_threads;
329        let mut handles = Vec::new();
330
331        for chunk_start in (0..n).step_by(chunk_size) {
332            let chunk_end = (chunk_start + chunk_size).min(n);
333            let graph_clone = Arc::clone(&graph);
334            let data_clone = Arc::clone(&data);
335            let config_clone = config.clone();
336
337            let handle = std::thread::spawn(move || {
338                for i in chunk_start..chunk_end {
339                    let mut neighbors = Vec::new();
340
341                    for j in 0..n {
342                        if i == j {
343                            continue;
344                        }
345
346                        let vec_i = &data_clone[i].1.as_f32();
347                        let vec_j = &data_clone[j].1.as_f32();
348                        let dist = config_clone.distance_metric.distance(vec_i, vec_j);
349
350                        neighbors.push(Candidate {
351                            id: j,
352                            distance: dist,
353                        });
354                    }
355
356                    neighbors.sort_by(|a, b| {
357                        a.distance
358                            .partial_cmp(&b.distance)
359                            .unwrap_or(Ordering::Equal)
360                    });
361                    neighbors.truncate(k);
362
363                    let mut graph_lock = graph_clone.write();
364                    graph_lock[i] = neighbors.iter().map(|c| c.id).collect();
365                }
366            });
367
368            handles.push(handle);
369        }
370
371        // Wait for all threads
372        for handle in handles {
373            handle
374                .join()
375                .map_err(|_| anyhow::anyhow!("Thread panicked"))?;
376        }
377
378        // Copy results back
379        self.graph = Arc::try_unwrap(graph)
380            .map_err(|_| anyhow::anyhow!("Failed to unwrap graph"))?
381            .into_inner();
382
383        Ok(())
384    }
385
386    /// Refine kNN graph to NSG with monotonic navigability
387    fn refine_to_nsg(&mut self) -> Result<()> {
388        let n = self.data.len();
389        let mut new_graph = vec![Vec::new(); n];
390
391        // Select a temporary entry point for refinement
392        let temp_entry = self.select_temp_entry_point();
393
394        #[allow(clippy::needless_range_loop)]
395        for i in 0..n {
396            // Find candidate neighbors through navigation
397            let candidates = self.search_for_neighbors(i, temp_entry)?;
398
399            // Prune to maintain out-degree constraint
400            let neighbors = self.prune_neighbors(i, candidates)?;
401
402            new_graph[i] = neighbors;
403        }
404
405        // Ensure connectivity by adding reverse edges where needed
406        self.ensure_connectivity(&mut new_graph)?;
407
408        self.graph = new_graph;
409
410        Ok(())
411    }
412
413    /// Search for candidate neighbors during graph refinement
414    fn search_for_neighbors(&self, query_id: usize, entry_id: usize) -> Result<Vec<Candidate>> {
415        let mut visited = HashSet::new();
416        let mut candidates = BinaryHeap::new();
417        let mut result = Vec::new();
418
419        // Start from entry point
420        let entry_dist = self.calculate_distance(query_id, entry_id);
421        candidates.push(Candidate {
422            id: entry_id,
423            distance: entry_dist,
424        });
425        visited.insert(entry_id);
426
427        while let Some(current) = candidates.pop() {
428            if result.len() >= self.config.candidate_pool_size {
429                break;
430            }
431
432            result.push(current.clone());
433
434            // Explore neighbors
435            for &neighbor_id in &self.graph[current.id] {
436                if visited.contains(&neighbor_id) {
437                    continue;
438                }
439
440                visited.insert(neighbor_id);
441
442                let dist = self.calculate_distance(query_id, neighbor_id);
443                candidates.push(Candidate {
444                    id: neighbor_id,
445                    distance: dist,
446                });
447
448                if visited.len() >= self.config.search_length {
449                    break;
450                }
451            }
452        }
453
454        // Sort by distance
455        result.sort_by(|a, b| {
456            a.distance
457                .partial_cmp(&b.distance)
458                .unwrap_or(Ordering::Equal)
459        });
460
461        Ok(result)
462    }
463
464    /// Prune neighbors to maintain graph quality and out-degree constraint
465    fn prune_neighbors(
466        &self,
467        _query_id: usize,
468        mut candidates: Vec<Candidate>,
469    ) -> Result<Vec<usize>> {
470        if candidates.is_empty() {
471            return Ok(Vec::new());
472        }
473
474        let mut result = Vec::new();
475        let mut pruned = HashSet::new();
476
477        while !candidates.is_empty() && result.len() < self.config.out_degree {
478            // Find best candidate (minimum distance)
479            let best_idx = candidates
480                .iter()
481                .position_min_by(|a, b| {
482                    a.distance
483                        .partial_cmp(&b.distance)
484                        .unwrap_or(Ordering::Equal)
485                })
486                .unwrap();
487
488            let best = candidates.swap_remove(best_idx);
489
490            if pruned.contains(&best.id) {
491                continue;
492            }
493
494            result.push(best.id);
495            pruned.insert(best.id);
496
497            // Prune candidates that are too close to the selected neighbor
498            candidates.retain(|c| {
499                let dist_to_best = self.calculate_distance(c.id, best.id);
500                dist_to_best > best.distance * self.config.pruning_threshold
501            });
502        }
503
504        Ok(result)
505    }
506
507    /// Ensure graph connectivity by adding reverse edges
508    fn ensure_connectivity(&self, graph: &mut [Vec<usize>]) -> Result<()> {
509        let n = graph.len();
510
511        // Build reverse index
512        let mut in_edges: Vec<HashSet<usize>> = vec![HashSet::new(); n];
513        for (i, neighbors) in graph.iter().enumerate() {
514            for &j in neighbors {
515                in_edges[j].insert(i);
516            }
517        }
518
519        // For each node, ensure it has at least one incoming edge
520        for (i, edges) in in_edges.iter().enumerate() {
521            if edges.is_empty() && i != 0 {
522                // Find closest node that has outgoing edges
523                let mut min_dist = f32::INFINITY;
524                let mut closest = 0;
525
526                for (j, neighbors) in graph.iter().enumerate() {
527                    if i == j || neighbors.len() >= self.config.out_degree {
528                        continue;
529                    }
530
531                    let dist = self.calculate_distance(i, j);
532                    if dist < min_dist {
533                        min_dist = dist;
534                        closest = j;
535                    }
536                }
537
538                // Add edge from closest to i
539                if !graph[closest].contains(&i) {
540                    graph[closest].push(i);
541                }
542            }
543        }
544
545        Ok(())
546    }
547
548    /// Select entry point for search (node with highest out-degree)
549    fn select_entry_point(&mut self) -> Result<()> {
550        if self.data.is_empty() {
551            return Ok(());
552        }
553
554        let mut max_degree = 0;
555        let mut entry = 0;
556
557        for i in 0..self.graph.len() {
558            if self.graph[i].len() > max_degree {
559                max_degree = self.graph[i].len();
560                entry = i;
561            }
562        }
563
564        self.entry_point = Some(entry);
565
566        Ok(())
567    }
568
569    /// Select temporary entry point for graph refinement
570    fn select_temp_entry_point(&self) -> usize {
571        if let Some(seed) = self.config.random_seed {
572            let mut rng = Random::seed(seed);
573            rng.random_range(0..self.data.len())
574        } else {
575            // Use centroid as entry point
576            self.find_centroid()
577        }
578    }
579
580    /// Find centroid of all vectors
581    fn find_centroid(&self) -> usize {
582        if self.data.is_empty() {
583            return 0;
584        }
585
586        let dim = self.data[0].1.dimensions;
587        let mut centroid = vec![0.0f32; dim];
588
589        // Calculate mean
590        for (_, vec) in &self.data {
591            let vals = vec.as_f32();
592            for i in 0..dim {
593                centroid[i] += vals[i];
594            }
595        }
596
597        let n = self.data.len() as f32;
598        for val in &mut centroid {
599            *val /= n;
600        }
601
602        // Find closest vector to centroid
603        let mut min_dist = f32::INFINITY;
604        let mut closest = 0;
605
606        for i in 0..self.data.len() {
607            let dist = self
608                .config
609                .distance_metric
610                .distance(&centroid, &self.data[i].1.as_f32());
611            if dist < min_dist {
612                min_dist = dist;
613                closest = i;
614            }
615        }
616
617        closest
618    }
619
620    /// Calculate distance between two vectors by index
621    fn calculate_distance(&self, i: usize, j: usize) -> f32 {
622        let vec_i = self.data[i].1.as_f32();
623        let vec_j = self.data[j].1.as_f32();
624        self.config.distance_metric.distance(&vec_i, &vec_j)
625    }
626
627    /// Perform greedy search on the graph
628    fn greedy_search(&self, query: &[f32], k: usize, ef: usize) -> Result<Vec<Candidate>> {
629        if !self.is_built {
630            return Err(anyhow::anyhow!("Index not built. Call build() first."));
631        }
632
633        let entry = self
634            .entry_point
635            .ok_or_else(|| anyhow::anyhow!("No entry point set"))?;
636
637        let mut visited = HashSet::new();
638        let mut candidates = BinaryHeap::new();
639        let mut result_set = BinaryHeap::new();
640
641        // Initialize with entry point
642        let entry_dist = self
643            .config
644            .distance_metric
645            .distance(query, &self.data[entry].1.as_f32());
646        candidates.push(Candidate {
647            id: entry,
648            distance: entry_dist,
649        });
650        result_set.push(Candidate {
651            id: entry,
652            distance: entry_dist,
653        });
654        visited.insert(entry);
655
656        while let Some(current) = candidates.pop() {
657            // Check if we've explored enough
658            if result_set.len() >= ef && current.distance > result_set.peek().unwrap().distance {
659                break;
660            }
661
662            // Explore neighbors
663            for &neighbor_id in &self.graph[current.id] {
664                if visited.contains(&neighbor_id) {
665                    continue;
666                }
667
668                visited.insert(neighbor_id);
669
670                let dist = self
671                    .config
672                    .distance_metric
673                    .distance(query, &self.data[neighbor_id].1.as_f32());
674                let candidate = Candidate {
675                    id: neighbor_id,
676                    distance: dist,
677                };
678
679                if result_set.len() < ef || dist < result_set.peek().unwrap().distance {
680                    candidates.push(candidate.clone());
681                    result_set.push(candidate);
682
683                    if result_set.len() > ef {
684                        result_set.pop();
685                    }
686                }
687            }
688        }
689
690        // Convert to sorted vector
691        let mut results: Vec<_> = result_set.into_sorted_vec();
692        results.truncate(k);
693
694        Ok(results)
695    }
696
697    /// Update index statistics
698    fn update_stats(&self) {
699        let mut stats = self.stats.write().unwrap();
700        stats.num_vectors = self.data.len();
701        stats.num_edges = self.count_edges();
702        stats.avg_out_degree = self.avg_out_degree();
703        stats.max_out_degree = self.max_out_degree();
704    }
705
706    /// Count total edges in graph
707    fn count_edges(&self) -> usize {
708        self.graph.iter().map(|neighbors| neighbors.len()).sum()
709    }
710
711    /// Calculate average out-degree
712    fn avg_out_degree(&self) -> f64 {
713        if self.graph.is_empty() {
714            return 0.0;
715        }
716        self.count_edges() as f64 / self.graph.len() as f64
717    }
718
719    /// Get maximum out-degree
720    fn max_out_degree(&self) -> usize {
721        self.graph
722            .iter()
723            .map(|neighbors| neighbors.len())
724            .max()
725            .unwrap_or(0)
726    }
727
728    /// Get index statistics
729    pub fn stats(&self) -> NsgStats {
730        self.stats.read().unwrap().clone()
731    }
732
733    /// Get number of vectors in the index
734    pub fn len(&self) -> usize {
735        self.data.len()
736    }
737
738    /// Check if index is empty
739    pub fn is_empty(&self) -> bool {
740        self.data.is_empty()
741    }
742
743    /// Check if index is built
744    pub fn is_built(&self) -> bool {
745        self.is_built
746    }
747}
748
749impl VectorIndex for NsgIndex {
750    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
751        self.add(uri, vector)
752    }
753
754    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
755        let query_vals = query.as_f32();
756        let ef = k.max(self.config.search_length);
757        let candidates = self.greedy_search(&query_vals, k, ef)?;
758
759        // Convert to (URI, similarity) format
760        // Note: NSG uses distance, so we convert to similarity (1 / (1 + distance))
761        // Candidates are sorted by distance (ascending), so we reverse to get descending similarity
762        let mut results: Vec<_> = candidates
763            .into_iter()
764            .map(|c| {
765                let uri = self.data[c.id].0.clone();
766                let similarity = 1.0 / (1.0 + c.distance);
767                (uri, similarity)
768            })
769            .collect();
770
771        // Reverse to get descending order of similarity
772        results.reverse();
773
774        Ok(results)
775    }
776
777    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
778        // Search with a large k and filter by threshold
779        let k = self.data.len().min(1000);
780        let all_results = self.search_knn(query, k)?;
781
782        let filtered: Vec<_> = all_results
783            .into_iter()
784            .filter(|(_, similarity)| *similarity >= threshold)
785            .collect();
786
787        Ok(filtered)
788    }
789
790    fn get_vector(&self, uri: &str) -> Option<&Vector> {
791        self.uri_to_idx
792            .get(uri)
793            .and_then(|&idx| self.data.get(idx))
794            .map(|(_, vec)| vec)
795    }
796
797    fn remove_vector(&mut self, id: String) -> Result<()> {
798        if self.is_built {
799            return Err(anyhow::anyhow!(
800                "Cannot remove vectors from built index. Rebuild index instead."
801            ));
802        }
803
804        if let Some(&idx) = self.uri_to_idx.get(&id) {
805            self.data.remove(idx);
806            self.uri_to_idx.remove(&id);
807
808            // Rebuild index mapping
809            self.uri_to_idx.clear();
810            for (i, (uri, _)) in self.data.iter().enumerate() {
811                self.uri_to_idx.insert(uri.clone(), i);
812            }
813
814            Ok(())
815        } else {
816            Err(anyhow::anyhow!("Vector with id '{}' not found", id))
817        }
818    }
819}
820
821// Helper trait for finding min by
822trait IteratorExt: Iterator {
823    fn position_min_by<F>(self, compare: F) -> Option<usize>
824    where
825        F: FnMut(&Self::Item, &Self::Item) -> Ordering;
826}
827
828impl<I: Iterator> IteratorExt for I {
829    fn position_min_by<F>(mut self, mut compare: F) -> Option<usize>
830    where
831        F: FnMut(&Self::Item, &Self::Item) -> Ordering,
832    {
833        let first = self.next()?;
834        let mut min_item = first;
835        let mut min_pos = 0;
836
837        for (pos, item) in self.enumerate() {
838            if compare(&item, &min_item) == Ordering::Less {
839                min_item = item;
840                min_pos = pos + 1;
841            }
842        }
843
844        Some(min_pos)
845    }
846}
847
848#[cfg(test)]
849mod tests {
850    use super::*;
851
852    #[test]
853    fn test_nsg_creation() {
854        let config = NsgConfig::default();
855        let index = NsgIndex::new(config).unwrap();
856        assert_eq!(index.len(), 0);
857        assert!(!index.is_built());
858    }
859
860    #[test]
861    fn test_nsg_add_vectors() {
862        let config = NsgConfig::default();
863        let mut index = NsgIndex::new(config).unwrap();
864
865        for i in 0..10 {
866            let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
867            index.add(format!("vec_{}", i), vec).unwrap();
868        }
869
870        assert_eq!(index.len(), 10);
871    }
872
873    #[test]
874    fn test_nsg_build_and_search() {
875        let config = NsgConfig {
876            out_degree: 32,
877            candidate_pool_size: 100,
878            search_length: 50,
879            initial_knn_degree: 64,
880            ..Default::default()
881        };
882        let mut index = NsgIndex::new(config).unwrap();
883
884        // Add vectors in a more structured way to ensure connectivity
885        for i in 0..100 {
886            let vec = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
887            index.add(format!("vec_{}", i), vec).unwrap();
888        }
889
890        // Build index
891        index.build().unwrap();
892        assert!(index.is_built());
893
894        // Search with a query close to vec_10 (easier to verify)
895        let query = Vector::new(vec![10.1, 20.1, 30.1]);
896        let results = index.search_knn(&query, 10).unwrap();
897
898        assert!(!results.is_empty());
899        assert_eq!(results.len(), 10);
900
901        // Results should be sorted by similarity (descending)
902        for i in 1..results.len() {
903            assert!(
904                results[i - 1].1 >= results[i].1,
905                "Results not sorted: {}@{} < {}@{}",
906                results[i - 1].1,
907                i - 1,
908                results[i].1,
909                i
910            );
911        }
912
913        // The closest vectors should be vec_10, vec_11, vec_9, etc.
914        // At least one of these should be in top 10
915        let nearby_found = results.iter().take(10).any(|(uri, _)| {
916            uri.contains("10")
917                || uri.contains("11")
918                || uri.contains("9")
919                || uri.contains("12")
920                || uri.contains("8")
921        });
922        assert!(
923            nearby_found,
924            "Expected nearby vectors (8-12) in top 10 results"
925        );
926    }
927
928    #[test]
929    fn test_nsg_distance_metrics() {
930        for metric in [
931            DistanceMetric::Euclidean,
932            DistanceMetric::Manhattan,
933            DistanceMetric::Cosine,
934            DistanceMetric::Angular,
935        ] {
936            let config = NsgConfig {
937                distance_metric: metric,
938                out_degree: 8,
939                ..Default::default()
940            };
941            let mut index = NsgIndex::new(config).unwrap();
942
943            for i in 0..20 {
944                let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
945                index.add(format!("vec_{}", i), vec).unwrap();
946            }
947
948            index.build().unwrap();
949
950            let query = Vector::new(vec![10.0, 20.0]);
951            let results = index.search_knn(&query, 3).unwrap();
952
953            assert!(!results.is_empty());
954        }
955    }
956
957    #[test]
958    fn test_nsg_stats() {
959        let config = NsgConfig::default();
960        let mut index = NsgIndex::new(config).unwrap();
961
962        for i in 0..50 {
963            let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
964            index.add(format!("vec_{}", i), vec).unwrap();
965        }
966
967        index.build().unwrap();
968
969        let stats = index.stats();
970        assert_eq!(stats.num_vectors, 50);
971        assert!(stats.num_edges > 0);
972        assert!(stats.avg_out_degree > 0.0);
973    }
974
975    #[test]
976    fn test_nsg_threshold_search() {
977        let config = NsgConfig::default();
978        let mut index = NsgIndex::new(config).unwrap();
979
980        for i in 0..30 {
981            let vec = Vector::new(vec![i as f32, (i * 2) as f32]);
982            index.add(format!("vec_{}", i), vec).unwrap();
983        }
984
985        index.build().unwrap();
986
987        let query = Vector::new(vec![15.0, 30.0]);
988        let results = index.search_threshold(&query, 0.5).unwrap();
989
990        assert!(!results.is_empty());
991        // All results should have similarity >= 0.5
992        for (_, similarity) in results {
993            assert!(similarity >= 0.5);
994        }
995    }
996}