oxirs_vec/
graph_indices.rs

1//! Graph-based indices for efficient nearest neighbor search
2//!
3//! This module implements various graph-based data structures optimized for
4//! nearest neighbor search:
5//! - NSW: Navigable Small World
6//! - ONNG: Optimized Nearest Neighbor Graph
7//! - PANNG: Pruned Approximate Nearest Neighbor Graph
8//! - Delaunay Graph: Approximation for high-dimensional space
9//! - RNG: Relative Neighborhood Graph
10
11use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16#[allow(unused_imports)]
17use scirs2_core::random::{Random, Rng};
18use std::cmp::Ordering;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20
21/// Configuration for graph-based indices
22#[derive(Debug, Clone)]
23pub struct GraphIndexConfig {
24    /// Type of graph to use
25    pub graph_type: GraphType,
26    /// Number of neighbors per node
27    pub num_neighbors: usize,
28    /// Random seed for reproducibility
29    pub random_seed: Option<u64>,
30    /// Enable parallel construction
31    pub parallel_construction: bool,
32    /// Distance metric
33    pub distance_metric: DistanceMetric,
34    /// Enable pruning for better quality
35    pub enable_pruning: bool,
36    /// Search depth multiplier
37    pub search_expansion: f32,
38}
39
40impl Default for GraphIndexConfig {
41    fn default() -> Self {
42        Self {
43            graph_type: GraphType::NSW,
44            num_neighbors: 32,
45            random_seed: None,
46            parallel_construction: true,
47            distance_metric: DistanceMetric::Euclidean,
48            enable_pruning: true,
49            search_expansion: 1.5,
50        }
51    }
52}
53
54/// Available graph types
55#[derive(Debug, Clone, Copy)]
56pub enum GraphType {
57    NSW,      // Navigable Small World
58    ONNG,     // Optimized Nearest Neighbor Graph
59    PANNG,    // Pruned Approximate Nearest Neighbor Graph
60    Delaunay, // Delaunay Graph approximation
61    RNG,      // Relative Neighborhood Graph
62}
63
64/// Distance metrics
65#[derive(Debug, Clone, Copy)]
66pub enum DistanceMetric {
67    Euclidean,
68    Manhattan,
69    Cosine,
70    Angular,
71}
72
73impl DistanceMetric {
74    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
75        match self {
76            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
77            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
78            DistanceMetric::Cosine => f32::cosine_distance(a, b),
79            DistanceMetric::Angular => {
80                // Angular distance = arccos(cosine_similarity) / pi
81                let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
82                cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
83            }
84        }
85    }
86}
87
88/// Search result with distance
89#[derive(Debug, Clone)]
90struct SearchResult {
91    index: usize,
92    distance: f32,
93}
94
95impl PartialEq for SearchResult {
96    fn eq(&self, other: &Self) -> bool {
97        self.distance == other.distance
98    }
99}
100
101impl Eq for SearchResult {}
102
103impl PartialOrd for SearchResult {
104    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105        Some(self.cmp(other))
106    }
107}
108
109impl Ord for SearchResult {
110    fn cmp(&self, other: &Self) -> Ordering {
111        self.distance
112            .partial_cmp(&other.distance)
113            .unwrap_or(Ordering::Equal)
114    }
115}
116
117/// Navigable Small World (NSW) implementation
118pub struct NSWGraph {
119    /// Graph structure
120    graph: Graph<usize, f32>,
121    /// Node index mapping
122    node_map: HashMap<usize, NodeIndex>,
123    /// Data storage
124    data: Vec<(String, Vector)>,
125    /// Configuration
126    config: GraphIndexConfig,
127    /// Entry points for search
128    entry_points: Vec<NodeIndex>,
129}
130
131impl NSWGraph {
132    pub fn new(config: GraphIndexConfig) -> Self {
133        Self {
134            graph: Graph::new(),
135            node_map: HashMap::new(),
136            data: Vec::new(),
137            config,
138            entry_points: Vec::new(),
139        }
140    }
141
142    /// Build the graph from data
143    pub fn build(&mut self) -> Result<()> {
144        if self.data.is_empty() {
145            return Ok(());
146        }
147
148        // Create nodes
149        for (idx, _) in self.data.iter().enumerate() {
150            let node = self.graph.add_node(idx);
151            self.node_map.insert(idx, node);
152        }
153
154        // Select random entry points
155        let num_entry_points = (self.data.len() as f32).sqrt() as usize;
156        let mut rng = if let Some(seed) = self.config.random_seed {
157            Random::seed(seed)
158        } else {
159            Random::seed(42)
160        };
161
162        // Note: Using manual random selection instead of SliceRandom
163        let mut indices: Vec<usize> = (0..self.data.len()).collect();
164        // Manually shuffle using Fisher-Yates algorithm
165        for i in (1..indices.len()).rev() {
166            let j = rng.random_range(0..i + 1);
167            indices.swap(i, j);
168        }
169
170        self.entry_points = indices[..num_entry_points.min(self.data.len())]
171            .iter()
172            .map(|&idx| self.node_map[&idx])
173            .collect();
174
175        // Build graph structure
176        if self.config.parallel_construction && self.data.len() > 1000 {
177            self.build_parallel()?;
178        } else {
179            self.build_sequential()?;
180        }
181
182        Ok(())
183    }
184
185    fn build_sequential(&mut self) -> Result<()> {
186        for idx in 0..self.data.len() {
187            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
188            let node = self.node_map[&idx];
189
190            for (neighbor_idx, distance) in neighbors {
191                let neighbor_node = self.node_map[&neighbor_idx];
192                if !self.graph.contains_edge(node, neighbor_node) {
193                    self.graph.add_edge(node, neighbor_node, distance);
194                }
195            }
196        }
197
198        Ok(())
199    }
200
201    fn build_parallel(&mut self) -> Result<()> {
202        let _chunk_size = (self.data.len() / num_threads()).max(100);
203
204        // Pre-compute all edges that need to be added
205        let mut all_edges = Vec::new();
206        for idx in 0..self.data.len() {
207            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
208            let node = self.node_map[&idx];
209
210            for (neighbor_idx, distance) in neighbors {
211                let neighbor_node = self.node_map[&neighbor_idx];
212                all_edges.push((node, neighbor_node, distance));
213            }
214        }
215
216        // Now add all edges to the graph
217        for (from, to, weight) in all_edges {
218            if !self.graph.contains_edge(from, to) {
219                self.graph.add_edge(from, to, weight);
220            }
221        }
222
223        Ok(())
224    }
225
226    fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
227        let query = &self.data[idx].1.as_f32();
228        let mut heap = BinaryHeap::new();
229
230        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
231            if other_idx == idx {
232                continue;
233            }
234
235            let other = vector.as_f32();
236            let distance = self.config.distance_metric.distance(query, &other);
237
238            if heap.len() < k {
239                heap.push(SearchResult {
240                    index: other_idx,
241                    distance,
242                });
243            } else if distance < heap.peek().expect("heap should have k elements").distance {
244                heap.pop();
245                heap.push(SearchResult {
246                    index: other_idx,
247                    distance,
248                });
249            }
250        }
251
252        Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
253    }
254
255    /// Search for k nearest neighbors
256    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
257        if self.entry_points.is_empty() {
258            return Vec::new();
259        }
260
261        let mut visited = HashSet::new();
262        let mut candidates = BinaryHeap::new();
263        let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
264
265        // Initialize with entry points
266        for &entry in &self.entry_points {
267            let idx = self.graph[entry];
268            let distance = self
269                .config
270                .distance_metric
271                .distance(query, &self.data[idx].1.as_f32());
272            candidates.push(std::cmp::Reverse(SearchResult {
273                index: idx,
274                distance,
275            }));
276            visited.insert(idx);
277        }
278
279        // Search expansion
280        let max_candidates = (k as f32 * self.config.search_expansion) as usize;
281
282        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
283            // Only apply early termination if we have k results
284            if results.len() >= k
285                && current.distance
286                    > results
287                        .peek()
288                        .expect("results should have k elements")
289                        .distance
290            {
291                break;
292            }
293
294            // Update results
295            if results.len() < k {
296                results.push(current.clone());
297            } else if current.distance
298                < results
299                    .peek()
300                    .expect("results should have k elements")
301                    .distance
302            {
303                results.pop();
304                results.push(current.clone());
305            }
306
307            // Explore neighbors
308            let node = self.node_map[&current.index];
309            for neighbor in self.graph.neighbors(node) {
310                let neighbor_idx = self.graph[neighbor];
311
312                if visited.contains(&neighbor_idx) {
313                    continue;
314                }
315
316                visited.insert(neighbor_idx);
317                let distance = self
318                    .config
319                    .distance_metric
320                    .distance(query, &self.data[neighbor_idx].1.as_f32());
321
322                if candidates.len() < max_candidates
323                    || distance
324                        < candidates
325                            .peek()
326                            .expect("candidates should have elements")
327                            .0
328                            .distance
329                {
330                    candidates.push(std::cmp::Reverse(SearchResult {
331                        index: neighbor_idx,
332                        distance,
333                    }));
334                }
335            }
336        }
337
338        let mut results: Vec<(usize, f32)> =
339            results.into_iter().map(|r| (r.index, r.distance)).collect();
340
341        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
342        results
343    }
344}
345
346/// Optimized Nearest Neighbor Graph (ONNG) implementation
347pub struct ONNGGraph {
348    /// Adjacency list representation
349    adjacency: Vec<Vec<(usize, f32)>>,
350    /// Data storage
351    data: Vec<(String, Vector)>,
352    /// Configuration
353    config: GraphIndexConfig,
354}
355
356impl ONNGGraph {
357    pub fn new(config: GraphIndexConfig) -> Self {
358        Self {
359            adjacency: Vec::new(),
360            data: Vec::new(),
361            config,
362        }
363    }
364
365    pub fn build(&mut self) -> Result<()> {
366        if self.data.is_empty() {
367            return Ok(());
368        }
369
370        // Initialize adjacency lists
371        self.adjacency = vec![Vec::new(); self.data.len()];
372
373        // Build initial k-NN graph
374        self.build_knn_graph()?;
375
376        // Optimize graph structure
377        self.optimize_graph()?;
378
379        Ok(())
380    }
381
382    fn build_knn_graph(&mut self) -> Result<()> {
383        for idx in 0..self.data.len() {
384            let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
385            self.adjacency[idx] = neighbors;
386        }
387
388        Ok(())
389    }
390
391    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
392        let query = &self.data[idx].1.as_f32();
393        let mut neighbors = Vec::new();
394
395        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
396            if other_idx == idx {
397                continue;
398            }
399
400            let distance = self
401                .config
402                .distance_metric
403                .distance(query, &vector.as_f32());
404            neighbors.push((other_idx, distance));
405        }
406
407        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
408        neighbors.truncate(k);
409
410        Ok(neighbors)
411    }
412
413    fn optimize_graph(&mut self) -> Result<()> {
414        // Add reverse edges for better connectivity
415        let mut reverse_edges = vec![Vec::new(); self.data.len()];
416
417        for (idx, neighbors) in self.adjacency.iter().enumerate() {
418            for &(neighbor_idx, distance) in neighbors {
419                reverse_edges[neighbor_idx].push((idx, distance));
420            }
421        }
422
423        // Merge and optimize
424        for (idx, reverse) in reverse_edges.into_iter().enumerate() {
425            let mut all_neighbors = self.adjacency[idx].clone();
426            all_neighbors.extend(reverse);
427
428            // Remove duplicates and sort
429            all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
430            all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
431            all_neighbors.truncate(self.config.num_neighbors);
432
433            self.adjacency[idx] = all_neighbors;
434        }
435
436        Ok(())
437    }
438
439    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
440        if self.data.is_empty() {
441            return Vec::new();
442        }
443
444        // Start from multiple random points
445        let start_points = self.select_start_points();
446        let mut visited = HashSet::new();
447        let mut heap = BinaryHeap::new();
448
449        // Initialize with start points
450        for start in start_points {
451            let distance = self
452                .config
453                .distance_metric
454                .distance(query, &self.data[start].1.as_f32());
455            heap.push(std::cmp::Reverse(SearchResult {
456                index: start,
457                distance,
458            }));
459            visited.insert(start);
460        }
461
462        let mut results = Vec::new();
463
464        while let Some(std::cmp::Reverse(current)) = heap.pop() {
465            results.push((current.index, current.distance));
466
467            if results.len() >= k {
468                break;
469            }
470
471            // Explore neighbors
472            for &(neighbor_idx, _) in &self.adjacency[current.index] {
473                if visited.contains(&neighbor_idx) {
474                    continue;
475                }
476
477                visited.insert(neighbor_idx);
478                let distance = self
479                    .config
480                    .distance_metric
481                    .distance(query, &self.data[neighbor_idx].1.as_f32());
482                heap.push(std::cmp::Reverse(SearchResult {
483                    index: neighbor_idx,
484                    distance,
485                }));
486            }
487        }
488
489        results.truncate(k);
490        results
491    }
492
493    fn select_start_points(&self) -> Vec<usize> {
494        // Simple strategy: select sqrt(n) random points
495        let num_points = (self.data.len() as f32).sqrt() as usize;
496        let mut indices: Vec<usize> = (0..self.data.len()).collect();
497
498        let mut rng = if let Some(seed) = self.config.random_seed {
499            Random::seed(seed)
500        } else {
501            Random::seed(42)
502        };
503
504        // Note: Using manual random selection instead of SliceRandom
505        // Manually shuffle using Fisher-Yates algorithm
506        for i in (1..indices.len()).rev() {
507            let j = rng.random_range(0..i + 1);
508            indices.swap(i, j);
509        }
510        indices.truncate(num_points.max(1));
511
512        indices
513    }
514}
515
516/// Pruned Approximate Nearest Neighbor Graph (PANNG) implementation
517pub struct PANNGGraph {
518    /// Pruned adjacency list
519    adjacency: Vec<Vec<(usize, f32)>>,
520    /// Data storage
521    data: Vec<(String, Vector)>,
522    /// Configuration
523    config: GraphIndexConfig,
524    /// Pruning threshold
525    pruning_threshold: f32,
526}
527
528impl PANNGGraph {
529    pub fn new(config: GraphIndexConfig) -> Self {
530        Self {
531            adjacency: Vec::new(),
532            data: Vec::new(),
533            config,
534            pruning_threshold: 0.9, // Angle-based pruning threshold
535        }
536    }
537
538    pub fn build(&mut self) -> Result<()> {
539        if self.data.is_empty() {
540            return Ok(());
541        }
542
543        // Build initial k-NN graph
544        self.adjacency = vec![Vec::new(); self.data.len()];
545        self.build_initial_graph()?;
546
547        // Apply pruning
548        if self.config.enable_pruning {
549            self.prune_graph()?;
550        }
551
552        Ok(())
553    }
554
555    fn build_initial_graph(&mut self) -> Result<()> {
556        // Build with more neighbors initially for pruning
557        let initial_neighbors = self.config.num_neighbors * 2;
558
559        for idx in 0..self.data.len() {
560            let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
561            self.adjacency[idx] = neighbors;
562        }
563
564        Ok(())
565    }
566
567    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
568        let query = &self.data[idx].1.as_f32();
569        let mut heap = BinaryHeap::new();
570
571        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
572            if other_idx == idx {
573                continue;
574            }
575
576            let distance = self
577                .config
578                .distance_metric
579                .distance(query, &vector.as_f32());
580
581            if heap.len() < k {
582                heap.push(SearchResult {
583                    index: other_idx,
584                    distance,
585                });
586            } else if distance < heap.peek().expect("heap should have k elements").distance {
587                heap.pop();
588                heap.push(SearchResult {
589                    index: other_idx,
590                    distance,
591                });
592            }
593        }
594
595        Ok(heap
596            .into_sorted_vec()
597            .into_iter()
598            .map(|r| (r.index, r.distance))
599            .collect())
600    }
601
602    fn prune_graph(&mut self) -> Result<()> {
603        for idx in 0..self.data.len() {
604            let pruned = self.prune_neighbors(idx)?;
605            self.adjacency[idx] = pruned;
606        }
607
608        Ok(())
609    }
610
611    fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
612        let neighbors = &self.adjacency[idx];
613        if neighbors.len() <= self.config.num_neighbors {
614            return Ok(neighbors.clone());
615        }
616
617        let mut pruned = Vec::new();
618        let (_, vector) = &self.data[idx];
619        let query = vector.as_f32();
620
621        for &(neighbor_idx, distance) in neighbors {
622            let (_, vector) = &self.data[neighbor_idx];
623            let neighbor = vector.as_f32();
624            let mut keep = true;
625
626            // Check angle with already selected neighbors
627            for &(selected_idx, _) in &pruned {
628                let (_id, vector): &(String, Vector) = &self.data[selected_idx];
629                let selected = vector.as_f32();
630
631                // Calculate angle between neighbor and selected
632                let angle = self.calculate_angle(&query, &neighbor, &selected);
633
634                if angle < self.pruning_threshold {
635                    keep = false;
636                    break;
637                }
638            }
639
640            if keep {
641                pruned.push((neighbor_idx, distance));
642
643                if pruned.len() >= self.config.num_neighbors {
644                    break;
645                }
646            }
647        }
648
649        Ok(pruned)
650    }
651
652    fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
653        // Calculate vectors from origin
654        let va: Vec<f32> = a
655            .iter()
656            .zip(origin.iter())
657            .map(|(ai, oi)| ai - oi)
658            .collect();
659        let vb: Vec<f32> = b
660            .iter()
661            .zip(origin.iter())
662            .map(|(bi, oi)| bi - oi)
663            .collect();
664
665        // Calculate cosine of angle
666        let dot = f32::dot(&va, &vb);
667        let norm_a = f32::norm(&va);
668        let norm_b = f32::norm(&vb);
669
670        if norm_a == 0.0 || norm_b == 0.0 {
671            return 0.0;
672        }
673
674        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
675    }
676
677    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
678        if self.data.is_empty() {
679            return Vec::new();
680        }
681
682        let mut visited = HashSet::new();
683        let mut candidates = VecDeque::new();
684        let mut results = Vec::new();
685
686        // Start from closest point
687        let start = self.find_closest_point(query);
688        candidates.push_back(start);
689        visited.insert(start);
690
691        while let Some(current) = candidates.pop_front() {
692            let distance = self
693                .config
694                .distance_metric
695                .distance(query, &self.data[current].1.as_f32());
696            results.push((current, distance));
697
698            // Explore neighbors
699            for &(neighbor_idx, _) in &self.adjacency[current] {
700                if !visited.contains(&neighbor_idx) {
701                    visited.insert(neighbor_idx);
702                    candidates.push_back(neighbor_idx);
703                }
704            }
705
706            if results.len() >= k * 2 {
707                break;
708            }
709        }
710
711        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
712        results.truncate(k);
713        results
714    }
715
716    fn find_closest_point(&self, query: &[f32]) -> usize {
717        let mut min_dist = f32::INFINITY;
718        let mut closest = 0;
719
720        // Sample a few random points
721        let sample_size = (self.data.len() as f32).sqrt() as usize;
722        let step = self.data.len() / sample_size.max(1);
723
724        for idx in (0..self.data.len()).step_by(step.max(1)) {
725            let distance = self
726                .config
727                .distance_metric
728                .distance(query, &self.data[idx].1.as_f32());
729            if distance < min_dist {
730                min_dist = distance;
731                closest = idx;
732            }
733        }
734
735        closest
736    }
737}
738
739/// Delaunay Graph approximation for high dimensions
740pub struct DelaunayGraph {
741    /// Approximate Delaunay edges
742    edges: Vec<Vec<(usize, f32)>>,
743    /// Data storage
744    data: Vec<(String, Vector)>,
745    /// Configuration
746    config: GraphIndexConfig,
747}
748
749impl DelaunayGraph {
750    pub fn new(config: GraphIndexConfig) -> Self {
751        Self {
752            edges: Vec::new(),
753            data: Vec::new(),
754            config,
755        }
756    }
757
758    pub fn build(&mut self) -> Result<()> {
759        if self.data.is_empty() {
760            return Ok(());
761        }
762
763        self.edges = vec![Vec::new(); self.data.len()];
764
765        // For high dimensions, we approximate Delaunay by local criteria
766        for idx in 0..self.data.len() {
767            let neighbors = self.find_delaunay_neighbors(idx)?;
768            self.edges[idx] = neighbors;
769        }
770
771        // Make edges bidirectional
772        self.symmetrize_edges();
773
774        Ok(())
775    }
776
777    fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
778        let point = &self.data[idx].1.as_f32();
779        let mut candidates = Vec::new();
780
781        // Find potential neighbors
782        for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
783            if other_idx == idx {
784                continue;
785            }
786
787            let other = other_vec.as_f32();
788            let distance = self.config.distance_metric.distance(point, &other);
789            candidates.push((other_idx, distance));
790        }
791
792        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
793
794        // Apply Delaunay criterion approximation
795        let mut neighbors = Vec::new();
796
797        for &(candidate_idx, distance) in &candidates {
798            if neighbors.len() >= self.config.num_neighbors {
799                break;
800            }
801
802            let candidate = &self.data[candidate_idx].1.as_f32();
803            let mut is_neighbor = true;
804
805            // Check if any existing neighbor violates the empty circumsphere property
806            for &(neighbor_idx, _) in &neighbors {
807                let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
808                let neighbor = vector.as_f32();
809
810                // Approximate check: if candidate is closer to neighbor than to point
811                let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
812                if dist_to_neighbor < distance * 0.9 {
813                    is_neighbor = false;
814                    break;
815                }
816            }
817
818            if is_neighbor {
819                neighbors.push((candidate_idx, distance));
820            }
821        }
822
823        Ok(neighbors)
824    }
825
826    fn symmetrize_edges(&mut self) {
827        let mut symmetric_edges = vec![Vec::new(); self.data.len()];
828
829        // Collect all edges
830        for (idx, neighbors) in self.edges.iter().enumerate() {
831            for &(neighbor_idx, distance) in neighbors {
832                symmetric_edges[idx].push((neighbor_idx, distance));
833                symmetric_edges[neighbor_idx].push((idx, distance));
834            }
835        }
836
837        // Remove duplicates and sort
838        for edges in &mut symmetric_edges {
839            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
840            edges.dedup_by_key(|&mut (idx, _)| idx);
841            edges.truncate(self.config.num_neighbors);
842        }
843
844        self.edges = symmetric_edges;
845    }
846
847    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
848        if self.data.is_empty() {
849            return Vec::new();
850        }
851
852        let mut visited = HashSet::new();
853        let mut heap = BinaryHeap::new();
854        let mut results = Vec::new();
855
856        // Start from a random point
857        let start = 0;
858        let distance = self
859            .config
860            .distance_metric
861            .distance(query, &self.data[start].1.as_f32());
862        heap.push(std::cmp::Reverse(SearchResult {
863            index: start,
864            distance,
865        }));
866        visited.insert(start);
867
868        while let Some(std::cmp::Reverse(current)) = heap.pop() {
869            results.push((current.index, current.distance));
870
871            if results.len() >= k {
872                break;
873            }
874
875            // Explore neighbors
876            for &(neighbor_idx, _) in &self.edges[current.index] {
877                if !visited.contains(&neighbor_idx) {
878                    visited.insert(neighbor_idx);
879                    let distance = self
880                        .config
881                        .distance_metric
882                        .distance(query, &self.data[neighbor_idx].1.as_f32());
883                    heap.push(std::cmp::Reverse(SearchResult {
884                        index: neighbor_idx,
885                        distance,
886                    }));
887                }
888            }
889        }
890
891        results
892    }
893}
894
895/// Relative Neighborhood Graph (RNG) implementation
896pub struct RNGGraph {
897    /// RNG edges
898    edges: Vec<Vec<(usize, f32)>>,
899    /// Data storage
900    data: Vec<(String, Vector)>,
901    /// Configuration
902    config: GraphIndexConfig,
903}
904
905impl RNGGraph {
906    pub fn new(config: GraphIndexConfig) -> Self {
907        Self {
908            edges: Vec::new(),
909            data: Vec::new(),
910            config,
911        }
912    }
913
914    pub fn build(&mut self) -> Result<()> {
915        if self.data.is_empty() {
916            return Ok(());
917        }
918
919        self.edges = vec![Vec::new(); self.data.len()];
920
921        // Build RNG by checking the RNG criterion for each pair
922        for i in 0..self.data.len() {
923            for j in i + 1..self.data.len() {
924                if self.is_rng_edge(i, j)? {
925                    let distance = self
926                        .config
927                        .distance_metric
928                        .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
929
930                    self.edges[i].push((j, distance));
931                    self.edges[j].push((i, distance));
932                }
933            }
934        }
935
936        // Sort edges by distance
937        for edges in &mut self.edges {
938            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
939        }
940
941        Ok(())
942    }
943
944    fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
945        let pi = &self.data[i].1.as_f32();
946        let pj = &self.data[j].1.as_f32();
947        let dist_ij = self.config.distance_metric.distance(pi, pj);
948
949        // Check RNG criterion: no other point k exists such that
950        // max(dist(i,k), dist(j,k)) < dist(i,j)
951        for k in 0..self.data.len() {
952            if k == i || k == j {
953                continue;
954            }
955
956            let pk = &self.data[k].1.as_f32();
957            let dist_ik = self.config.distance_metric.distance(pi, pk);
958            let dist_jk = self.config.distance_metric.distance(pj, pk);
959
960            if dist_ik.max(dist_jk) < dist_ij {
961                return Ok(false);
962            }
963        }
964
965        Ok(true)
966    }
967
968    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
969        if self.data.is_empty() {
970            return Vec::new();
971        }
972
973        let mut visited = HashSet::new();
974        let mut candidates = BinaryHeap::new();
975        let mut results = Vec::new();
976
977        // Start from the closest sampled point
978        let start = self.find_start_point(query);
979        let distance = self
980            .config
981            .distance_metric
982            .distance(query, &self.data[start].1.as_f32());
983        candidates.push(std::cmp::Reverse(SearchResult {
984            index: start,
985            distance,
986        }));
987        visited.insert(start);
988
989        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
990            results.push((current.index, current.distance));
991
992            if results.len() >= k {
993                break;
994            }
995
996            // Explore neighbors
997            for &(neighbor_idx, _) in &self.edges[current.index] {
998                if !visited.contains(&neighbor_idx) {
999                    visited.insert(neighbor_idx);
1000                    let distance = self
1001                        .config
1002                        .distance_metric
1003                        .distance(query, &self.data[neighbor_idx].1.as_f32());
1004                    candidates.push(std::cmp::Reverse(SearchResult {
1005                        index: neighbor_idx,
1006                        distance,
1007                    }));
1008                }
1009            }
1010        }
1011
1012        results
1013    }
1014
1015    fn find_start_point(&self, query: &[f32]) -> usize {
1016        // Sample a subset of points
1017        let sample_size = (self.data.len() as f32).sqrt() as usize;
1018        let mut min_dist = f32::INFINITY;
1019        let mut best = 0;
1020
1021        for i in 0..sample_size.min(self.data.len()) {
1022            let idx = (i * self.data.len()) / sample_size;
1023            let distance = self
1024                .config
1025                .distance_metric
1026                .distance(query, &self.data[idx].1.as_f32());
1027
1028            if distance < min_dist {
1029                min_dist = distance;
1030                best = idx;
1031            }
1032        }
1033
1034        best
1035    }
1036}
1037
1038/// Unified graph index interface
1039pub struct GraphIndex {
1040    graph_type: GraphType,
1041    nsw: Option<NSWGraph>,
1042    onng: Option<ONNGGraph>,
1043    panng: Option<PANNGGraph>,
1044    delaunay: Option<DelaunayGraph>,
1045    rng: Option<RNGGraph>,
1046}
1047
1048impl GraphIndex {
1049    pub fn new(config: GraphIndexConfig) -> Self {
1050        let graph_type = config.graph_type;
1051
1052        let (nsw, onng, panng, delaunay, rng) = match graph_type {
1053            GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1054            GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1055            GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1056            GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1057            GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1058        };
1059
1060        Self {
1061            graph_type,
1062            nsw,
1063            onng,
1064            panng,
1065            delaunay,
1066            rng,
1067        }
1068    }
1069
1070    fn build(&mut self) -> Result<()> {
1071        match self.graph_type {
1072            GraphType::NSW => self
1073                .nsw
1074                .as_mut()
1075                .expect("nsw should be initialized for NSW type")
1076                .build(),
1077            GraphType::ONNG => self
1078                .onng
1079                .as_mut()
1080                .expect("onng should be initialized for ONNG type")
1081                .build(),
1082            GraphType::PANNG => self
1083                .panng
1084                .as_mut()
1085                .expect("panng should be initialized for PANNG type")
1086                .build(),
1087            GraphType::Delaunay => self
1088                .delaunay
1089                .as_mut()
1090                .expect("delaunay should be initialized for Delaunay type")
1091                .build(),
1092            GraphType::RNG => self
1093                .rng
1094                .as_mut()
1095                .expect("rng should be initialized for RNG type")
1096                .build(),
1097        }
1098    }
1099
1100    fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1101        match self.graph_type {
1102            GraphType::NSW => self
1103                .nsw
1104                .as_ref()
1105                .expect("nsw should be initialized for NSW type")
1106                .search(query, k),
1107            GraphType::ONNG => self
1108                .onng
1109                .as_ref()
1110                .expect("onng should be initialized for ONNG type")
1111                .search(query, k),
1112            GraphType::PANNG => self
1113                .panng
1114                .as_ref()
1115                .expect("panng should be initialized for PANNG type")
1116                .search(query, k),
1117            GraphType::Delaunay => self
1118                .delaunay
1119                .as_ref()
1120                .expect("delaunay should be initialized for Delaunay type")
1121                .search(query, k),
1122            GraphType::RNG => self
1123                .rng
1124                .as_ref()
1125                .expect("rng should be initialized for RNG type")
1126                .search(query, k),
1127        }
1128    }
1129}
1130
1131impl VectorIndex for GraphIndex {
1132    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1133        let data = match self.graph_type {
1134            GraphType::NSW => {
1135                &mut self
1136                    .nsw
1137                    .as_mut()
1138                    .expect("nsw should be initialized for NSW type")
1139                    .data
1140            }
1141            GraphType::ONNG => {
1142                &mut self
1143                    .onng
1144                    .as_mut()
1145                    .expect("onng should be initialized for ONNG type")
1146                    .data
1147            }
1148            GraphType::PANNG => {
1149                &mut self
1150                    .panng
1151                    .as_mut()
1152                    .expect("panng should be initialized for PANNG type")
1153                    .data
1154            }
1155            GraphType::Delaunay => {
1156                &mut self
1157                    .delaunay
1158                    .as_mut()
1159                    .expect("delaunay should be initialized for Delaunay type")
1160                    .data
1161            }
1162            GraphType::RNG => {
1163                &mut self
1164                    .rng
1165                    .as_mut()
1166                    .expect("rng should be initialized for RNG type")
1167                    .data
1168            }
1169        };
1170
1171        data.push((uri, vector));
1172        Ok(())
1173    }
1174
1175    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1176        let query_f32 = query.as_f32();
1177        let results = self.search_internal(&query_f32, k);
1178
1179        let data = match self.graph_type {
1180            GraphType::NSW => {
1181                &self
1182                    .nsw
1183                    .as_ref()
1184                    .expect("nsw should be initialized for NSW type")
1185                    .data
1186            }
1187            GraphType::ONNG => {
1188                &self
1189                    .onng
1190                    .as_ref()
1191                    .expect("onng should be initialized for ONNG type")
1192                    .data
1193            }
1194            GraphType::PANNG => {
1195                &self
1196                    .panng
1197                    .as_ref()
1198                    .expect("panng should be initialized for PANNG type")
1199                    .data
1200            }
1201            GraphType::Delaunay => {
1202                &self
1203                    .delaunay
1204                    .as_ref()
1205                    .expect("delaunay should be initialized for Delaunay type")
1206                    .data
1207            }
1208            GraphType::RNG => {
1209                &self
1210                    .rng
1211                    .as_ref()
1212                    .expect("rng should be initialized for RNG type")
1213                    .data
1214            }
1215        };
1216
1217        Ok(results
1218            .into_iter()
1219            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1220            .collect())
1221    }
1222
1223    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1224        let query_f32 = query.as_f32();
1225        let all_results = self.search_internal(&query_f32, 1000);
1226
1227        let data = match self.graph_type {
1228            GraphType::NSW => {
1229                &self
1230                    .nsw
1231                    .as_ref()
1232                    .expect("nsw should be initialized for NSW type")
1233                    .data
1234            }
1235            GraphType::ONNG => {
1236                &self
1237                    .onng
1238                    .as_ref()
1239                    .expect("onng should be initialized for ONNG type")
1240                    .data
1241            }
1242            GraphType::PANNG => {
1243                &self
1244                    .panng
1245                    .as_ref()
1246                    .expect("panng should be initialized for PANNG type")
1247                    .data
1248            }
1249            GraphType::Delaunay => {
1250                &self
1251                    .delaunay
1252                    .as_ref()
1253                    .expect("delaunay should be initialized for Delaunay type")
1254                    .data
1255            }
1256            GraphType::RNG => {
1257                &self
1258                    .rng
1259                    .as_ref()
1260                    .expect("rng should be initialized for RNG type")
1261                    .data
1262            }
1263        };
1264
1265        Ok(all_results
1266            .into_iter()
1267            .filter(|(_, dist)| *dist <= threshold)
1268            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1269            .collect())
1270    }
1271
1272    fn get_vector(&self, uri: &str) -> Option<&Vector> {
1273        let data = match self.graph_type {
1274            GraphType::NSW => {
1275                &self
1276                    .nsw
1277                    .as_ref()
1278                    .expect("nsw should be initialized for NSW type")
1279                    .data
1280            }
1281            GraphType::ONNG => {
1282                &self
1283                    .onng
1284                    .as_ref()
1285                    .expect("onng should be initialized for ONNG type")
1286                    .data
1287            }
1288            GraphType::PANNG => {
1289                &self
1290                    .panng
1291                    .as_ref()
1292                    .expect("panng should be initialized for PANNG type")
1293                    .data
1294            }
1295            GraphType::Delaunay => {
1296                &self
1297                    .delaunay
1298                    .as_ref()
1299                    .expect("delaunay should be initialized for Delaunay type")
1300                    .data
1301            }
1302            GraphType::RNG => {
1303                &self
1304                    .rng
1305                    .as_ref()
1306                    .expect("rng should be initialized for RNG type")
1307                    .data
1308            }
1309        };
1310
1311        data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1312    }
1313}
1314
1315// Add dependencies
1316use petgraph;
1317// Note: Replaced with scirs2_core::random
1318
1319#[cfg(test)]
1320mod tests {
1321    use super::*;
1322
1323    #[test]
1324    fn test_nsw_graph() {
1325        let config = GraphIndexConfig {
1326            graph_type: GraphType::NSW,
1327            num_neighbors: 10,
1328            ..Default::default()
1329        };
1330
1331        let mut index = GraphIndex::new(config);
1332
1333        // Insert test vectors
1334        for i in 0..50 {
1335            let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1336            index.insert(format!("vec_{i}"), vector).unwrap();
1337        }
1338
1339        index.build().unwrap();
1340
1341        // Search for nearest neighbors
1342        let query = Vector::new(vec![25.0, 50.0, 75.0]);
1343        let results = index.search_knn(&query, 5).unwrap();
1344
1345        assert_eq!(results.len(), 5);
1346        assert_eq!(results[0].0, "vec_25"); // Exact match
1347    }
1348
1349    #[test]
1350    fn test_onng_graph() {
1351        let config = GraphIndexConfig {
1352            graph_type: GraphType::ONNG,
1353            num_neighbors: 8,
1354            ..Default::default()
1355        };
1356
1357        let mut index = GraphIndex::new(config);
1358
1359        // Insert test vectors in a circle
1360        for i in 0..20 {
1361            let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1362            let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1363            index.insert(format!("vec_{i}"), vector).unwrap();
1364        }
1365
1366        index.build().unwrap();
1367
1368        // Search for nearest neighbors
1369        let query = Vector::new(vec![1.0, 0.0]);
1370        let results = index.search_knn(&query, 3).unwrap();
1371
1372        assert_eq!(results.len(), 3);
1373    }
1374
1375    #[test]
1376    fn test_panng_graph() {
1377        let config = GraphIndexConfig {
1378            graph_type: GraphType::PANNG,
1379            num_neighbors: 5,
1380            enable_pruning: true,
1381            ..Default::default()
1382        };
1383
1384        let mut index = GraphIndex::new(config);
1385
1386        // Insert test vectors
1387        for i in 0..30 {
1388            let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1389            index.insert(format!("vec_{i}"), vector).unwrap();
1390        }
1391
1392        index.build().unwrap();
1393
1394        // Search for nearest neighbors
1395        let query = Vector::new(vec![0.0, 1.0, 0.0]);
1396        let results = index.search_knn(&query, 5).unwrap();
1397
1398        assert_eq!(results.len(), 5);
1399    }
1400}