oxirs_vec/
graph_indices.rs

1//! Graph-based indices for efficient nearest neighbor search
2//!
3//! This module implements various graph-based data structures optimized for
4//! nearest neighbor search:
5//! - NSW: Navigable Small World
6//! - ONNG: Optimized Nearest Neighbor Graph
7//! - PANNG: Pruned Approximate Nearest Neighbor Graph
8//! - Delaunay Graph: Approximation for high-dimensional space
9//! - RNG: Relative Neighborhood Graph
10
11use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16#[allow(unused_imports)]
17use scirs2_core::random::{Random, Rng};
18use std::cmp::Ordering;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20
21/// Configuration for graph-based indices
22#[derive(Debug, Clone)]
23pub struct GraphIndexConfig {
24    /// Type of graph to use
25    pub graph_type: GraphType,
26    /// Number of neighbors per node
27    pub num_neighbors: usize,
28    /// Random seed for reproducibility
29    pub random_seed: Option<u64>,
30    /// Enable parallel construction
31    pub parallel_construction: bool,
32    /// Distance metric
33    pub distance_metric: DistanceMetric,
34    /// Enable pruning for better quality
35    pub enable_pruning: bool,
36    /// Search depth multiplier
37    pub search_expansion: f32,
38}
39
40impl Default for GraphIndexConfig {
41    fn default() -> Self {
42        Self {
43            graph_type: GraphType::NSW,
44            num_neighbors: 32,
45            random_seed: None,
46            parallel_construction: true,
47            distance_metric: DistanceMetric::Euclidean,
48            enable_pruning: true,
49            search_expansion: 1.5,
50        }
51    }
52}
53
54/// Available graph types
55#[derive(Debug, Clone, Copy)]
56pub enum GraphType {
57    NSW,      // Navigable Small World
58    ONNG,     // Optimized Nearest Neighbor Graph
59    PANNG,    // Pruned Approximate Nearest Neighbor Graph
60    Delaunay, // Delaunay Graph approximation
61    RNG,      // Relative Neighborhood Graph
62}
63
64/// Distance metrics
65#[derive(Debug, Clone, Copy)]
66pub enum DistanceMetric {
67    Euclidean,
68    Manhattan,
69    Cosine,
70    Angular,
71}
72
73impl DistanceMetric {
74    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
75        match self {
76            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
77            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
78            DistanceMetric::Cosine => f32::cosine_distance(a, b),
79            DistanceMetric::Angular => {
80                // Angular distance = arccos(cosine_similarity) / pi
81                let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
82                cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
83            }
84        }
85    }
86}
87
88/// Search result with distance
89#[derive(Debug, Clone)]
90struct SearchResult {
91    index: usize,
92    distance: f32,
93}
94
95impl PartialEq for SearchResult {
96    fn eq(&self, other: &Self) -> bool {
97        self.distance == other.distance
98    }
99}
100
101impl Eq for SearchResult {}
102
103impl PartialOrd for SearchResult {
104    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105        Some(self.cmp(other))
106    }
107}
108
109impl Ord for SearchResult {
110    fn cmp(&self, other: &Self) -> Ordering {
111        self.distance
112            .partial_cmp(&other.distance)
113            .unwrap_or(Ordering::Equal)
114    }
115}
116
117/// Navigable Small World (NSW) implementation
118pub struct NSWGraph {
119    /// Graph structure
120    graph: Graph<usize, f32>,
121    /// Node index mapping
122    node_map: HashMap<usize, NodeIndex>,
123    /// Data storage
124    data: Vec<(String, Vector)>,
125    /// Configuration
126    config: GraphIndexConfig,
127    /// Entry points for search
128    entry_points: Vec<NodeIndex>,
129}
130
131impl NSWGraph {
132    pub fn new(config: GraphIndexConfig) -> Self {
133        Self {
134            graph: Graph::new(),
135            node_map: HashMap::new(),
136            data: Vec::new(),
137            config,
138            entry_points: Vec::new(),
139        }
140    }
141
142    /// Build the graph from data
143    pub fn build(&mut self) -> Result<()> {
144        if self.data.is_empty() {
145            return Ok(());
146        }
147
148        // Create nodes
149        for (idx, _) in self.data.iter().enumerate() {
150            let node = self.graph.add_node(idx);
151            self.node_map.insert(idx, node);
152        }
153
154        // Select random entry points
155        let num_entry_points = (self.data.len() as f32).sqrt() as usize;
156        let mut rng = if let Some(seed) = self.config.random_seed {
157            Random::seed(seed)
158        } else {
159            Random::seed(42)
160        };
161
162        // Note: Using manual random selection instead of SliceRandom
163        let mut indices: Vec<usize> = (0..self.data.len()).collect();
164        // Manually shuffle using Fisher-Yates algorithm
165        for i in (1..indices.len()).rev() {
166            let j = rng.random_range(0, i + 1);
167            indices.swap(i, j);
168        }
169
170        self.entry_points = indices[..num_entry_points.min(self.data.len())]
171            .iter()
172            .map(|&idx| self.node_map[&idx])
173            .collect();
174
175        // Build graph structure
176        if self.config.parallel_construction && self.data.len() > 1000 {
177            self.build_parallel()?;
178        } else {
179            self.build_sequential()?;
180        }
181
182        Ok(())
183    }
184
185    fn build_sequential(&mut self) -> Result<()> {
186        for idx in 0..self.data.len() {
187            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
188            let node = self.node_map[&idx];
189
190            for (neighbor_idx, distance) in neighbors {
191                let neighbor_node = self.node_map[&neighbor_idx];
192                if !self.graph.contains_edge(node, neighbor_node) {
193                    self.graph.add_edge(node, neighbor_node, distance);
194                }
195            }
196        }
197
198        Ok(())
199    }
200
201    fn build_parallel(&mut self) -> Result<()> {
202        let _chunk_size = (self.data.len() / num_threads()).max(100);
203
204        // Pre-compute all edges that need to be added
205        let mut all_edges = Vec::new();
206        for idx in 0..self.data.len() {
207            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
208            let node = self.node_map[&idx];
209
210            for (neighbor_idx, distance) in neighbors {
211                let neighbor_node = self.node_map[&neighbor_idx];
212                all_edges.push((node, neighbor_node, distance));
213            }
214        }
215
216        // Now add all edges to the graph
217        for (from, to, weight) in all_edges {
218            if !self.graph.contains_edge(from, to) {
219                self.graph.add_edge(from, to, weight);
220            }
221        }
222
223        Ok(())
224    }
225
226    fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
227        let query = &self.data[idx].1.as_f32();
228        let mut heap = BinaryHeap::new();
229
230        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
231            if other_idx == idx {
232                continue;
233            }
234
235            let other = vector.as_f32();
236            let distance = self.config.distance_metric.distance(query, &other);
237
238            if heap.len() < k {
239                heap.push(SearchResult {
240                    index: other_idx,
241                    distance,
242                });
243            } else if distance < heap.peek().unwrap().distance {
244                heap.pop();
245                heap.push(SearchResult {
246                    index: other_idx,
247                    distance,
248                });
249            }
250        }
251
252        Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
253    }
254
255    /// Search for k nearest neighbors
256    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
257        if self.entry_points.is_empty() {
258            return Vec::new();
259        }
260
261        let mut visited = HashSet::new();
262        let mut candidates = BinaryHeap::new();
263        let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
264
265        // Initialize with entry points
266        for &entry in &self.entry_points {
267            let idx = self.graph[entry];
268            let distance = self
269                .config
270                .distance_metric
271                .distance(query, &self.data[idx].1.as_f32());
272            candidates.push(std::cmp::Reverse(SearchResult {
273                index: idx,
274                distance,
275            }));
276            visited.insert(idx);
277        }
278
279        // Search expansion
280        let max_candidates = (k as f32 * self.config.search_expansion) as usize;
281
282        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
283            // Only apply early termination if we have k results
284            if results.len() >= k && current.distance > results.peek().unwrap().distance {
285                break;
286            }
287
288            // Update results
289            if results.len() < k {
290                results.push(current.clone());
291            } else if current.distance < results.peek().unwrap().distance {
292                results.pop();
293                results.push(current.clone());
294            }
295
296            // Explore neighbors
297            let node = self.node_map[&current.index];
298            for neighbor in self.graph.neighbors(node) {
299                let neighbor_idx = self.graph[neighbor];
300
301                if visited.contains(&neighbor_idx) {
302                    continue;
303                }
304
305                visited.insert(neighbor_idx);
306                let distance = self
307                    .config
308                    .distance_metric
309                    .distance(query, &self.data[neighbor_idx].1.as_f32());
310
311                if candidates.len() < max_candidates
312                    || distance < candidates.peek().unwrap().0.distance
313                {
314                    candidates.push(std::cmp::Reverse(SearchResult {
315                        index: neighbor_idx,
316                        distance,
317                    }));
318                }
319            }
320        }
321
322        let mut results: Vec<(usize, f32)> =
323            results.into_iter().map(|r| (r.index, r.distance)).collect();
324
325        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
326        results
327    }
328}
329
330/// Optimized Nearest Neighbor Graph (ONNG) implementation
331pub struct ONNGGraph {
332    /// Adjacency list representation
333    adjacency: Vec<Vec<(usize, f32)>>,
334    /// Data storage
335    data: Vec<(String, Vector)>,
336    /// Configuration
337    config: GraphIndexConfig,
338}
339
340impl ONNGGraph {
341    pub fn new(config: GraphIndexConfig) -> Self {
342        Self {
343            adjacency: Vec::new(),
344            data: Vec::new(),
345            config,
346        }
347    }
348
349    pub fn build(&mut self) -> Result<()> {
350        if self.data.is_empty() {
351            return Ok(());
352        }
353
354        // Initialize adjacency lists
355        self.adjacency = vec![Vec::new(); self.data.len()];
356
357        // Build initial k-NN graph
358        self.build_knn_graph()?;
359
360        // Optimize graph structure
361        self.optimize_graph()?;
362
363        Ok(())
364    }
365
366    fn build_knn_graph(&mut self) -> Result<()> {
367        for idx in 0..self.data.len() {
368            let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
369            self.adjacency[idx] = neighbors;
370        }
371
372        Ok(())
373    }
374
375    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
376        let query = &self.data[idx].1.as_f32();
377        let mut neighbors = Vec::new();
378
379        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
380            if other_idx == idx {
381                continue;
382            }
383
384            let distance = self
385                .config
386                .distance_metric
387                .distance(query, &vector.as_f32());
388            neighbors.push((other_idx, distance));
389        }
390
391        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
392        neighbors.truncate(k);
393
394        Ok(neighbors)
395    }
396
397    fn optimize_graph(&mut self) -> Result<()> {
398        // Add reverse edges for better connectivity
399        let mut reverse_edges = vec![Vec::new(); self.data.len()];
400
401        for (idx, neighbors) in self.adjacency.iter().enumerate() {
402            for &(neighbor_idx, distance) in neighbors {
403                reverse_edges[neighbor_idx].push((idx, distance));
404            }
405        }
406
407        // Merge and optimize
408        for (idx, reverse) in reverse_edges.into_iter().enumerate() {
409            let mut all_neighbors = self.adjacency[idx].clone();
410            all_neighbors.extend(reverse);
411
412            // Remove duplicates and sort
413            all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
414            all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
415            all_neighbors.truncate(self.config.num_neighbors);
416
417            self.adjacency[idx] = all_neighbors;
418        }
419
420        Ok(())
421    }
422
423    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
424        if self.data.is_empty() {
425            return Vec::new();
426        }
427
428        // Start from multiple random points
429        let start_points = self.select_start_points();
430        let mut visited = HashSet::new();
431        let mut heap = BinaryHeap::new();
432
433        // Initialize with start points
434        for start in start_points {
435            let distance = self
436                .config
437                .distance_metric
438                .distance(query, &self.data[start].1.as_f32());
439            heap.push(std::cmp::Reverse(SearchResult {
440                index: start,
441                distance,
442            }));
443            visited.insert(start);
444        }
445
446        let mut results = Vec::new();
447
448        while let Some(std::cmp::Reverse(current)) = heap.pop() {
449            results.push((current.index, current.distance));
450
451            if results.len() >= k {
452                break;
453            }
454
455            // Explore neighbors
456            for &(neighbor_idx, _) in &self.adjacency[current.index] {
457                if visited.contains(&neighbor_idx) {
458                    continue;
459                }
460
461                visited.insert(neighbor_idx);
462                let distance = self
463                    .config
464                    .distance_metric
465                    .distance(query, &self.data[neighbor_idx].1.as_f32());
466                heap.push(std::cmp::Reverse(SearchResult {
467                    index: neighbor_idx,
468                    distance,
469                }));
470            }
471        }
472
473        results.truncate(k);
474        results
475    }
476
477    fn select_start_points(&self) -> Vec<usize> {
478        // Simple strategy: select sqrt(n) random points
479        let num_points = (self.data.len() as f32).sqrt() as usize;
480        let mut indices: Vec<usize> = (0..self.data.len()).collect();
481
482        let mut rng = if let Some(seed) = self.config.random_seed {
483            Random::seed(seed)
484        } else {
485            Random::seed(42)
486        };
487
488        // Note: Using manual random selection instead of SliceRandom
489        // Manually shuffle using Fisher-Yates algorithm
490        for i in (1..indices.len()).rev() {
491            let j = rng.random_range(0, i + 1);
492            indices.swap(i, j);
493        }
494        indices.truncate(num_points.max(1));
495
496        indices
497    }
498}
499
500/// Pruned Approximate Nearest Neighbor Graph (PANNG) implementation
501pub struct PANNGGraph {
502    /// Pruned adjacency list
503    adjacency: Vec<Vec<(usize, f32)>>,
504    /// Data storage
505    data: Vec<(String, Vector)>,
506    /// Configuration
507    config: GraphIndexConfig,
508    /// Pruning threshold
509    pruning_threshold: f32,
510}
511
512impl PANNGGraph {
513    pub fn new(config: GraphIndexConfig) -> Self {
514        Self {
515            adjacency: Vec::new(),
516            data: Vec::new(),
517            config,
518            pruning_threshold: 0.9, // Angle-based pruning threshold
519        }
520    }
521
522    pub fn build(&mut self) -> Result<()> {
523        if self.data.is_empty() {
524            return Ok(());
525        }
526
527        // Build initial k-NN graph
528        self.adjacency = vec![Vec::new(); self.data.len()];
529        self.build_initial_graph()?;
530
531        // Apply pruning
532        if self.config.enable_pruning {
533            self.prune_graph()?;
534        }
535
536        Ok(())
537    }
538
539    fn build_initial_graph(&mut self) -> Result<()> {
540        // Build with more neighbors initially for pruning
541        let initial_neighbors = self.config.num_neighbors * 2;
542
543        for idx in 0..self.data.len() {
544            let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
545            self.adjacency[idx] = neighbors;
546        }
547
548        Ok(())
549    }
550
551    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
552        let query = &self.data[idx].1.as_f32();
553        let mut heap = BinaryHeap::new();
554
555        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
556            if other_idx == idx {
557                continue;
558            }
559
560            let distance = self
561                .config
562                .distance_metric
563                .distance(query, &vector.as_f32());
564
565            if heap.len() < k {
566                heap.push(SearchResult {
567                    index: other_idx,
568                    distance,
569                });
570            } else if distance < heap.peek().unwrap().distance {
571                heap.pop();
572                heap.push(SearchResult {
573                    index: other_idx,
574                    distance,
575                });
576            }
577        }
578
579        Ok(heap
580            .into_sorted_vec()
581            .into_iter()
582            .map(|r| (r.index, r.distance))
583            .collect())
584    }
585
586    fn prune_graph(&mut self) -> Result<()> {
587        for idx in 0..self.data.len() {
588            let pruned = self.prune_neighbors(idx)?;
589            self.adjacency[idx] = pruned;
590        }
591
592        Ok(())
593    }
594
595    fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
596        let neighbors = &self.adjacency[idx];
597        if neighbors.len() <= self.config.num_neighbors {
598            return Ok(neighbors.clone());
599        }
600
601        let mut pruned = Vec::new();
602        let (_, vector) = &self.data[idx];
603        let query = vector.as_f32();
604
605        for &(neighbor_idx, distance) in neighbors {
606            let (_, vector) = &self.data[neighbor_idx];
607            let neighbor = vector.as_f32();
608            let mut keep = true;
609
610            // Check angle with already selected neighbors
611            for &(selected_idx, _) in &pruned {
612                let (_id, vector): &(String, Vector) = &self.data[selected_idx];
613                let selected = vector.as_f32();
614
615                // Calculate angle between neighbor and selected
616                let angle = self.calculate_angle(&query, &neighbor, &selected);
617
618                if angle < self.pruning_threshold {
619                    keep = false;
620                    break;
621                }
622            }
623
624            if keep {
625                pruned.push((neighbor_idx, distance));
626
627                if pruned.len() >= self.config.num_neighbors {
628                    break;
629                }
630            }
631        }
632
633        Ok(pruned)
634    }
635
636    fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
637        // Calculate vectors from origin
638        let va: Vec<f32> = a
639            .iter()
640            .zip(origin.iter())
641            .map(|(ai, oi)| ai - oi)
642            .collect();
643        let vb: Vec<f32> = b
644            .iter()
645            .zip(origin.iter())
646            .map(|(bi, oi)| bi - oi)
647            .collect();
648
649        // Calculate cosine of angle
650        let dot = f32::dot(&va, &vb);
651        let norm_a = f32::norm(&va);
652        let norm_b = f32::norm(&vb);
653
654        if norm_a == 0.0 || norm_b == 0.0 {
655            return 0.0;
656        }
657
658        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
659    }
660
661    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
662        if self.data.is_empty() {
663            return Vec::new();
664        }
665
666        let mut visited = HashSet::new();
667        let mut candidates = VecDeque::new();
668        let mut results = Vec::new();
669
670        // Start from closest point
671        let start = self.find_closest_point(query);
672        candidates.push_back(start);
673        visited.insert(start);
674
675        while let Some(current) = candidates.pop_front() {
676            let distance = self
677                .config
678                .distance_metric
679                .distance(query, &self.data[current].1.as_f32());
680            results.push((current, distance));
681
682            // Explore neighbors
683            for &(neighbor_idx, _) in &self.adjacency[current] {
684                if !visited.contains(&neighbor_idx) {
685                    visited.insert(neighbor_idx);
686                    candidates.push_back(neighbor_idx);
687                }
688            }
689
690            if results.len() >= k * 2 {
691                break;
692            }
693        }
694
695        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
696        results.truncate(k);
697        results
698    }
699
700    fn find_closest_point(&self, query: &[f32]) -> usize {
701        let mut min_dist = f32::INFINITY;
702        let mut closest = 0;
703
704        // Sample a few random points
705        let sample_size = (self.data.len() as f32).sqrt() as usize;
706        let step = self.data.len() / sample_size.max(1);
707
708        for idx in (0..self.data.len()).step_by(step.max(1)) {
709            let distance = self
710                .config
711                .distance_metric
712                .distance(query, &self.data[idx].1.as_f32());
713            if distance < min_dist {
714                min_dist = distance;
715                closest = idx;
716            }
717        }
718
719        closest
720    }
721}
722
723/// Delaunay Graph approximation for high dimensions
724pub struct DelaunayGraph {
725    /// Approximate Delaunay edges
726    edges: Vec<Vec<(usize, f32)>>,
727    /// Data storage
728    data: Vec<(String, Vector)>,
729    /// Configuration
730    config: GraphIndexConfig,
731}
732
733impl DelaunayGraph {
734    pub fn new(config: GraphIndexConfig) -> Self {
735        Self {
736            edges: Vec::new(),
737            data: Vec::new(),
738            config,
739        }
740    }
741
742    pub fn build(&mut self) -> Result<()> {
743        if self.data.is_empty() {
744            return Ok(());
745        }
746
747        self.edges = vec![Vec::new(); self.data.len()];
748
749        // For high dimensions, we approximate Delaunay by local criteria
750        for idx in 0..self.data.len() {
751            let neighbors = self.find_delaunay_neighbors(idx)?;
752            self.edges[idx] = neighbors;
753        }
754
755        // Make edges bidirectional
756        self.symmetrize_edges();
757
758        Ok(())
759    }
760
761    fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
762        let point = &self.data[idx].1.as_f32();
763        let mut candidates = Vec::new();
764
765        // Find potential neighbors
766        for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
767            if other_idx == idx {
768                continue;
769            }
770
771            let other = other_vec.as_f32();
772            let distance = self.config.distance_metric.distance(point, &other);
773            candidates.push((other_idx, distance));
774        }
775
776        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
777
778        // Apply Delaunay criterion approximation
779        let mut neighbors = Vec::new();
780
781        for &(candidate_idx, distance) in &candidates {
782            if neighbors.len() >= self.config.num_neighbors {
783                break;
784            }
785
786            let candidate = &self.data[candidate_idx].1.as_f32();
787            let mut is_neighbor = true;
788
789            // Check if any existing neighbor violates the empty circumsphere property
790            for &(neighbor_idx, _) in &neighbors {
791                let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
792                let neighbor = vector.as_f32();
793
794                // Approximate check: if candidate is closer to neighbor than to point
795                let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
796                if dist_to_neighbor < distance * 0.9 {
797                    is_neighbor = false;
798                    break;
799                }
800            }
801
802            if is_neighbor {
803                neighbors.push((candidate_idx, distance));
804            }
805        }
806
807        Ok(neighbors)
808    }
809
810    fn symmetrize_edges(&mut self) {
811        let mut symmetric_edges = vec![Vec::new(); self.data.len()];
812
813        // Collect all edges
814        for (idx, neighbors) in self.edges.iter().enumerate() {
815            for &(neighbor_idx, distance) in neighbors {
816                symmetric_edges[idx].push((neighbor_idx, distance));
817                symmetric_edges[neighbor_idx].push((idx, distance));
818            }
819        }
820
821        // Remove duplicates and sort
822        for edges in &mut symmetric_edges {
823            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
824            edges.dedup_by_key(|&mut (idx, _)| idx);
825            edges.truncate(self.config.num_neighbors);
826        }
827
828        self.edges = symmetric_edges;
829    }
830
831    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
832        if self.data.is_empty() {
833            return Vec::new();
834        }
835
836        let mut visited = HashSet::new();
837        let mut heap = BinaryHeap::new();
838        let mut results = Vec::new();
839
840        // Start from a random point
841        let start = 0;
842        let distance = self
843            .config
844            .distance_metric
845            .distance(query, &self.data[start].1.as_f32());
846        heap.push(std::cmp::Reverse(SearchResult {
847            index: start,
848            distance,
849        }));
850        visited.insert(start);
851
852        while let Some(std::cmp::Reverse(current)) = heap.pop() {
853            results.push((current.index, current.distance));
854
855            if results.len() >= k {
856                break;
857            }
858
859            // Explore neighbors
860            for &(neighbor_idx, _) in &self.edges[current.index] {
861                if !visited.contains(&neighbor_idx) {
862                    visited.insert(neighbor_idx);
863                    let distance = self
864                        .config
865                        .distance_metric
866                        .distance(query, &self.data[neighbor_idx].1.as_f32());
867                    heap.push(std::cmp::Reverse(SearchResult {
868                        index: neighbor_idx,
869                        distance,
870                    }));
871                }
872            }
873        }
874
875        results
876    }
877}
878
879/// Relative Neighborhood Graph (RNG) implementation
880pub struct RNGGraph {
881    /// RNG edges
882    edges: Vec<Vec<(usize, f32)>>,
883    /// Data storage
884    data: Vec<(String, Vector)>,
885    /// Configuration
886    config: GraphIndexConfig,
887}
888
889impl RNGGraph {
890    pub fn new(config: GraphIndexConfig) -> Self {
891        Self {
892            edges: Vec::new(),
893            data: Vec::new(),
894            config,
895        }
896    }
897
898    pub fn build(&mut self) -> Result<()> {
899        if self.data.is_empty() {
900            return Ok(());
901        }
902
903        self.edges = vec![Vec::new(); self.data.len()];
904
905        // Build RNG by checking the RNG criterion for each pair
906        for i in 0..self.data.len() {
907            for j in i + 1..self.data.len() {
908                if self.is_rng_edge(i, j)? {
909                    let distance = self
910                        .config
911                        .distance_metric
912                        .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
913
914                    self.edges[i].push((j, distance));
915                    self.edges[j].push((i, distance));
916                }
917            }
918        }
919
920        // Sort edges by distance
921        for edges in &mut self.edges {
922            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
923        }
924
925        Ok(())
926    }
927
928    fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
929        let pi = &self.data[i].1.as_f32();
930        let pj = &self.data[j].1.as_f32();
931        let dist_ij = self.config.distance_metric.distance(pi, pj);
932
933        // Check RNG criterion: no other point k exists such that
934        // max(dist(i,k), dist(j,k)) < dist(i,j)
935        for k in 0..self.data.len() {
936            if k == i || k == j {
937                continue;
938            }
939
940            let pk = &self.data[k].1.as_f32();
941            let dist_ik = self.config.distance_metric.distance(pi, pk);
942            let dist_jk = self.config.distance_metric.distance(pj, pk);
943
944            if dist_ik.max(dist_jk) < dist_ij {
945                return Ok(false);
946            }
947        }
948
949        Ok(true)
950    }
951
952    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
953        if self.data.is_empty() {
954            return Vec::new();
955        }
956
957        let mut visited = HashSet::new();
958        let mut candidates = BinaryHeap::new();
959        let mut results = Vec::new();
960
961        // Start from the closest sampled point
962        let start = self.find_start_point(query);
963        let distance = self
964            .config
965            .distance_metric
966            .distance(query, &self.data[start].1.as_f32());
967        candidates.push(std::cmp::Reverse(SearchResult {
968            index: start,
969            distance,
970        }));
971        visited.insert(start);
972
973        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
974            results.push((current.index, current.distance));
975
976            if results.len() >= k {
977                break;
978            }
979
980            // Explore neighbors
981            for &(neighbor_idx, _) in &self.edges[current.index] {
982                if !visited.contains(&neighbor_idx) {
983                    visited.insert(neighbor_idx);
984                    let distance = self
985                        .config
986                        .distance_metric
987                        .distance(query, &self.data[neighbor_idx].1.as_f32());
988                    candidates.push(std::cmp::Reverse(SearchResult {
989                        index: neighbor_idx,
990                        distance,
991                    }));
992                }
993            }
994        }
995
996        results
997    }
998
999    fn find_start_point(&self, query: &[f32]) -> usize {
1000        // Sample a subset of points
1001        let sample_size = (self.data.len() as f32).sqrt() as usize;
1002        let mut min_dist = f32::INFINITY;
1003        let mut best = 0;
1004
1005        for i in 0..sample_size.min(self.data.len()) {
1006            let idx = (i * self.data.len()) / sample_size;
1007            let distance = self
1008                .config
1009                .distance_metric
1010                .distance(query, &self.data[idx].1.as_f32());
1011
1012            if distance < min_dist {
1013                min_dist = distance;
1014                best = idx;
1015            }
1016        }
1017
1018        best
1019    }
1020}
1021
1022/// Unified graph index interface
1023pub struct GraphIndex {
1024    graph_type: GraphType,
1025    nsw: Option<NSWGraph>,
1026    onng: Option<ONNGGraph>,
1027    panng: Option<PANNGGraph>,
1028    delaunay: Option<DelaunayGraph>,
1029    rng: Option<RNGGraph>,
1030}
1031
1032impl GraphIndex {
1033    pub fn new(config: GraphIndexConfig) -> Self {
1034        let graph_type = config.graph_type;
1035
1036        let (nsw, onng, panng, delaunay, rng) = match graph_type {
1037            GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1038            GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1039            GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1040            GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1041            GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1042        };
1043
1044        Self {
1045            graph_type,
1046            nsw,
1047            onng,
1048            panng,
1049            delaunay,
1050            rng,
1051        }
1052    }
1053
1054    fn build(&mut self) -> Result<()> {
1055        match self.graph_type {
1056            GraphType::NSW => self.nsw.as_mut().unwrap().build(),
1057            GraphType::ONNG => self.onng.as_mut().unwrap().build(),
1058            GraphType::PANNG => self.panng.as_mut().unwrap().build(),
1059            GraphType::Delaunay => self.delaunay.as_mut().unwrap().build(),
1060            GraphType::RNG => self.rng.as_mut().unwrap().build(),
1061        }
1062    }
1063
1064    fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1065        match self.graph_type {
1066            GraphType::NSW => self.nsw.as_ref().unwrap().search(query, k),
1067            GraphType::ONNG => self.onng.as_ref().unwrap().search(query, k),
1068            GraphType::PANNG => self.panng.as_ref().unwrap().search(query, k),
1069            GraphType::Delaunay => self.delaunay.as_ref().unwrap().search(query, k),
1070            GraphType::RNG => self.rng.as_ref().unwrap().search(query, k),
1071        }
1072    }
1073}
1074
1075impl VectorIndex for GraphIndex {
1076    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1077        let data = match self.graph_type {
1078            GraphType::NSW => &mut self.nsw.as_mut().unwrap().data,
1079            GraphType::ONNG => &mut self.onng.as_mut().unwrap().data,
1080            GraphType::PANNG => &mut self.panng.as_mut().unwrap().data,
1081            GraphType::Delaunay => &mut self.delaunay.as_mut().unwrap().data,
1082            GraphType::RNG => &mut self.rng.as_mut().unwrap().data,
1083        };
1084
1085        data.push((uri, vector));
1086        Ok(())
1087    }
1088
1089    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1090        let query_f32 = query.as_f32();
1091        let results = self.search_internal(&query_f32, k);
1092
1093        let data = match self.graph_type {
1094            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1095            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1096            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1097            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1098            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1099        };
1100
1101        Ok(results
1102            .into_iter()
1103            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1104            .collect())
1105    }
1106
1107    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1108        let query_f32 = query.as_f32();
1109        let all_results = self.search_internal(&query_f32, 1000);
1110
1111        let data = match self.graph_type {
1112            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1113            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1114            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1115            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1116            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1117        };
1118
1119        Ok(all_results
1120            .into_iter()
1121            .filter(|(_, dist)| *dist <= threshold)
1122            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1123            .collect())
1124    }
1125
1126    fn get_vector(&self, uri: &str) -> Option<&Vector> {
1127        let data = match self.graph_type {
1128            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1129            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1130            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1131            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1132            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1133        };
1134
1135        data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1136    }
1137}
1138
1139// Add dependencies
1140use petgraph;
1141// Note: Replaced with scirs2_core::random
1142
1143#[cfg(test)]
1144mod tests {
1145    use super::*;
1146
1147    #[test]
1148    fn test_nsw_graph() {
1149        let config = GraphIndexConfig {
1150            graph_type: GraphType::NSW,
1151            num_neighbors: 10,
1152            ..Default::default()
1153        };
1154
1155        let mut index = GraphIndex::new(config);
1156
1157        // Insert test vectors
1158        for i in 0..50 {
1159            let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1160            index.insert(format!("vec_{i}"), vector).unwrap();
1161        }
1162
1163        index.build().unwrap();
1164
1165        // Search for nearest neighbors
1166        let query = Vector::new(vec![25.0, 50.0, 75.0]);
1167        let results = index.search_knn(&query, 5).unwrap();
1168
1169        assert_eq!(results.len(), 5);
1170        assert_eq!(results[0].0, "vec_25"); // Exact match
1171    }
1172
1173    #[test]
1174    fn test_onng_graph() {
1175        let config = GraphIndexConfig {
1176            graph_type: GraphType::ONNG,
1177            num_neighbors: 8,
1178            ..Default::default()
1179        };
1180
1181        let mut index = GraphIndex::new(config);
1182
1183        // Insert test vectors in a circle
1184        for i in 0..20 {
1185            let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1186            let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1187            index.insert(format!("vec_{i}"), vector).unwrap();
1188        }
1189
1190        index.build().unwrap();
1191
1192        // Search for nearest neighbors
1193        let query = Vector::new(vec![1.0, 0.0]);
1194        let results = index.search_knn(&query, 3).unwrap();
1195
1196        assert_eq!(results.len(), 3);
1197    }
1198
1199    #[test]
1200    fn test_panng_graph() {
1201        let config = GraphIndexConfig {
1202            graph_type: GraphType::PANNG,
1203            num_neighbors: 5,
1204            enable_pruning: true,
1205            ..Default::default()
1206        };
1207
1208        let mut index = GraphIndex::new(config);
1209
1210        // Insert test vectors
1211        for i in 0..30 {
1212            let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1213            index.insert(format!("vec_{i}"), vector).unwrap();
1214        }
1215
1216        index.build().unwrap();
1217
1218        // Search for nearest neighbors
1219        let query = Vector::new(vec![0.0, 1.0, 0.0]);
1220        let results = index.search_knn(&query, 5).unwrap();
1221
1222        assert_eq!(results.len(), 5);
1223    }
1224}