oxirs_vec/
graph_indices.rs

1//! Graph-based indices for efficient nearest neighbor search
2//!
3//! This module implements various graph-based data structures optimized for
4//! nearest neighbor search:
5//! - NSW: Navigable Small World
6//! - ONNG: Optimized Nearest Neighbor Graph
7//! - PANNG: Pruned Approximate Nearest Neighbor Graph
8//! - Delaunay Graph: Approximation for high-dimensional space
9//! - RNG: Relative Neighborhood Graph
10
11use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16use scirs2_core::random::{Random, Rng};
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
19
20/// Configuration for graph-based indices
21#[derive(Debug, Clone)]
22pub struct GraphIndexConfig {
23    /// Type of graph to use
24    pub graph_type: GraphType,
25    /// Number of neighbors per node
26    pub num_neighbors: usize,
27    /// Random seed for reproducibility
28    pub random_seed: Option<u64>,
29    /// Enable parallel construction
30    pub parallel_construction: bool,
31    /// Distance metric
32    pub distance_metric: DistanceMetric,
33    /// Enable pruning for better quality
34    pub enable_pruning: bool,
35    /// Search depth multiplier
36    pub search_expansion: f32,
37}
38
39impl Default for GraphIndexConfig {
40    fn default() -> Self {
41        Self {
42            graph_type: GraphType::NSW,
43            num_neighbors: 32,
44            random_seed: None,
45            parallel_construction: true,
46            distance_metric: DistanceMetric::Euclidean,
47            enable_pruning: true,
48            search_expansion: 1.5,
49        }
50    }
51}
52
53/// Available graph types
54#[derive(Debug, Clone, Copy)]
55pub enum GraphType {
56    NSW,      // Navigable Small World
57    ONNG,     // Optimized Nearest Neighbor Graph
58    PANNG,    // Pruned Approximate Nearest Neighbor Graph
59    Delaunay, // Delaunay Graph approximation
60    RNG,      // Relative Neighborhood Graph
61}
62
63/// Distance metrics
64#[derive(Debug, Clone, Copy)]
65pub enum DistanceMetric {
66    Euclidean,
67    Manhattan,
68    Cosine,
69    Angular,
70}
71
72impl DistanceMetric {
73    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
74        match self {
75            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
76            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
77            DistanceMetric::Cosine => f32::cosine_distance(a, b),
78            DistanceMetric::Angular => {
79                // Angular distance = arccos(cosine_similarity) / pi
80                let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
81                cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
82            }
83        }
84    }
85}
86
87/// Search result with distance
88#[derive(Debug, Clone)]
89struct SearchResult {
90    index: usize,
91    distance: f32,
92}
93
94impl PartialEq for SearchResult {
95    fn eq(&self, other: &Self) -> bool {
96        self.distance == other.distance
97    }
98}
99
100impl Eq for SearchResult {}
101
102impl PartialOrd for SearchResult {
103    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
104        Some(self.cmp(other))
105    }
106}
107
108impl Ord for SearchResult {
109    fn cmp(&self, other: &Self) -> Ordering {
110        self.distance
111            .partial_cmp(&other.distance)
112            .unwrap_or(Ordering::Equal)
113    }
114}
115
116/// Navigable Small World (NSW) implementation
117pub struct NSWGraph {
118    /// Graph structure
119    graph: Graph<usize, f32>,
120    /// Node index mapping
121    node_map: HashMap<usize, NodeIndex>,
122    /// Data storage
123    data: Vec<(String, Vector)>,
124    /// Configuration
125    config: GraphIndexConfig,
126    /// Entry points for search
127    entry_points: Vec<NodeIndex>,
128}
129
130impl NSWGraph {
131    pub fn new(config: GraphIndexConfig) -> Self {
132        Self {
133            graph: Graph::new(),
134            node_map: HashMap::new(),
135            data: Vec::new(),
136            config,
137            entry_points: Vec::new(),
138        }
139    }
140
141    /// Build the graph from data
142    pub fn build(&mut self) -> Result<()> {
143        if self.data.is_empty() {
144            return Ok(());
145        }
146
147        // Create nodes
148        for (idx, _) in self.data.iter().enumerate() {
149            let node = self.graph.add_node(idx);
150            self.node_map.insert(idx, node);
151        }
152
153        // Select random entry points
154        let num_entry_points = (self.data.len() as f32).sqrt() as usize;
155        let mut rng = if let Some(seed) = self.config.random_seed {
156            Random::seed(seed)
157        } else {
158            Random::seed(42)
159        };
160
161        // Note: Using manual random selection instead of SliceRandom
162        let mut indices: Vec<usize> = (0..self.data.len()).collect();
163        // Manually shuffle using Fisher-Yates algorithm
164        for i in (1..indices.len()).rev() {
165            let j = rng.gen_range(0..=i);
166            indices.swap(i, j);
167        }
168
169        self.entry_points = indices[..num_entry_points.min(self.data.len())]
170            .iter()
171            .map(|&idx| self.node_map[&idx])
172            .collect();
173
174        // Build graph structure
175        if self.config.parallel_construction && self.data.len() > 1000 {
176            self.build_parallel()?;
177        } else {
178            self.build_sequential()?;
179        }
180
181        Ok(())
182    }
183
184    fn build_sequential(&mut self) -> Result<()> {
185        for idx in 0..self.data.len() {
186            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
187            let node = self.node_map[&idx];
188
189            for (neighbor_idx, distance) in neighbors {
190                let neighbor_node = self.node_map[&neighbor_idx];
191                if !self.graph.contains_edge(node, neighbor_node) {
192                    self.graph.add_edge(node, neighbor_node, distance);
193                }
194            }
195        }
196
197        Ok(())
198    }
199
200    fn build_parallel(&mut self) -> Result<()> {
201        let _chunk_size = (self.data.len() / num_threads()).max(100);
202
203        // Pre-compute all edges that need to be added
204        let mut all_edges = Vec::new();
205        for idx in 0..self.data.len() {
206            let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
207            let node = self.node_map[&idx];
208
209            for (neighbor_idx, distance) in neighbors {
210                let neighbor_node = self.node_map[&neighbor_idx];
211                all_edges.push((node, neighbor_node, distance));
212            }
213        }
214
215        // Now add all edges to the graph
216        for (from, to, weight) in all_edges {
217            if !self.graph.contains_edge(from, to) {
218                self.graph.add_edge(from, to, weight);
219            }
220        }
221
222        Ok(())
223    }
224
225    fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
226        let query = &self.data[idx].1.as_f32();
227        let mut heap = BinaryHeap::new();
228
229        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
230            if other_idx == idx {
231                continue;
232            }
233
234            let other = vector.as_f32();
235            let distance = self.config.distance_metric.distance(query, &other);
236
237            if heap.len() < k {
238                heap.push(SearchResult {
239                    index: other_idx,
240                    distance,
241                });
242            } else if distance < heap.peek().unwrap().distance {
243                heap.pop();
244                heap.push(SearchResult {
245                    index: other_idx,
246                    distance,
247                });
248            }
249        }
250
251        Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
252    }
253
254    /// Search for k nearest neighbors
255    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
256        if self.entry_points.is_empty() {
257            return Vec::new();
258        }
259
260        let mut visited = HashSet::new();
261        let mut candidates = BinaryHeap::new();
262        let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
263
264        // Initialize with entry points
265        for &entry in &self.entry_points {
266            let idx = self.graph[entry];
267            let distance = self
268                .config
269                .distance_metric
270                .distance(query, &self.data[idx].1.as_f32());
271            candidates.push(std::cmp::Reverse(SearchResult {
272                index: idx,
273                distance,
274            }));
275            visited.insert(idx);
276        }
277
278        // Search expansion
279        let max_candidates = (k as f32 * self.config.search_expansion) as usize;
280
281        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
282            // Only apply early termination if we have k results
283            if results.len() >= k && current.distance > results.peek().unwrap().distance {
284                break;
285            }
286
287            // Update results
288            if results.len() < k {
289                results.push(current.clone());
290            } else if current.distance < results.peek().unwrap().distance {
291                results.pop();
292                results.push(current.clone());
293            }
294
295            // Explore neighbors
296            let node = self.node_map[&current.index];
297            for neighbor in self.graph.neighbors(node) {
298                let neighbor_idx = self.graph[neighbor];
299
300                if visited.contains(&neighbor_idx) {
301                    continue;
302                }
303
304                visited.insert(neighbor_idx);
305                let distance = self
306                    .config
307                    .distance_metric
308                    .distance(query, &self.data[neighbor_idx].1.as_f32());
309
310                if candidates.len() < max_candidates
311                    || distance < candidates.peek().unwrap().0.distance
312                {
313                    candidates.push(std::cmp::Reverse(SearchResult {
314                        index: neighbor_idx,
315                        distance,
316                    }));
317                }
318            }
319        }
320
321        let mut results: Vec<(usize, f32)> =
322            results.into_iter().map(|r| (r.index, r.distance)).collect();
323
324        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
325        results
326    }
327}
328
329/// Optimized Nearest Neighbor Graph (ONNG) implementation
330pub struct ONNGGraph {
331    /// Adjacency list representation
332    adjacency: Vec<Vec<(usize, f32)>>,
333    /// Data storage
334    data: Vec<(String, Vector)>,
335    /// Configuration
336    config: GraphIndexConfig,
337}
338
339impl ONNGGraph {
340    pub fn new(config: GraphIndexConfig) -> Self {
341        Self {
342            adjacency: Vec::new(),
343            data: Vec::new(),
344            config,
345        }
346    }
347
348    pub fn build(&mut self) -> Result<()> {
349        if self.data.is_empty() {
350            return Ok(());
351        }
352
353        // Initialize adjacency lists
354        self.adjacency = vec![Vec::new(); self.data.len()];
355
356        // Build initial k-NN graph
357        self.build_knn_graph()?;
358
359        // Optimize graph structure
360        self.optimize_graph()?;
361
362        Ok(())
363    }
364
365    fn build_knn_graph(&mut self) -> Result<()> {
366        for idx in 0..self.data.len() {
367            let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
368            self.adjacency[idx] = neighbors;
369        }
370
371        Ok(())
372    }
373
374    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
375        let query = &self.data[idx].1.as_f32();
376        let mut neighbors = Vec::new();
377
378        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
379            if other_idx == idx {
380                continue;
381            }
382
383            let distance = self
384                .config
385                .distance_metric
386                .distance(query, &vector.as_f32());
387            neighbors.push((other_idx, distance));
388        }
389
390        neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
391        neighbors.truncate(k);
392
393        Ok(neighbors)
394    }
395
396    fn optimize_graph(&mut self) -> Result<()> {
397        // Add reverse edges for better connectivity
398        let mut reverse_edges = vec![Vec::new(); self.data.len()];
399
400        for (idx, neighbors) in self.adjacency.iter().enumerate() {
401            for &(neighbor_idx, distance) in neighbors {
402                reverse_edges[neighbor_idx].push((idx, distance));
403            }
404        }
405
406        // Merge and optimize
407        for (idx, reverse) in reverse_edges.into_iter().enumerate() {
408            let mut all_neighbors = self.adjacency[idx].clone();
409            all_neighbors.extend(reverse);
410
411            // Remove duplicates and sort
412            all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
413            all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
414            all_neighbors.truncate(self.config.num_neighbors);
415
416            self.adjacency[idx] = all_neighbors;
417        }
418
419        Ok(())
420    }
421
422    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
423        if self.data.is_empty() {
424            return Vec::new();
425        }
426
427        // Start from multiple random points
428        let start_points = self.select_start_points();
429        let mut visited = HashSet::new();
430        let mut heap = BinaryHeap::new();
431
432        // Initialize with start points
433        for start in start_points {
434            let distance = self
435                .config
436                .distance_metric
437                .distance(query, &self.data[start].1.as_f32());
438            heap.push(std::cmp::Reverse(SearchResult {
439                index: start,
440                distance,
441            }));
442            visited.insert(start);
443        }
444
445        let mut results = Vec::new();
446
447        while let Some(std::cmp::Reverse(current)) = heap.pop() {
448            results.push((current.index, current.distance));
449
450            if results.len() >= k {
451                break;
452            }
453
454            // Explore neighbors
455            for &(neighbor_idx, _) in &self.adjacency[current.index] {
456                if visited.contains(&neighbor_idx) {
457                    continue;
458                }
459
460                visited.insert(neighbor_idx);
461                let distance = self
462                    .config
463                    .distance_metric
464                    .distance(query, &self.data[neighbor_idx].1.as_f32());
465                heap.push(std::cmp::Reverse(SearchResult {
466                    index: neighbor_idx,
467                    distance,
468                }));
469            }
470        }
471
472        results.truncate(k);
473        results
474    }
475
476    fn select_start_points(&self) -> Vec<usize> {
477        // Simple strategy: select sqrt(n) random points
478        let num_points = (self.data.len() as f32).sqrt() as usize;
479        let mut indices: Vec<usize> = (0..self.data.len()).collect();
480
481        let mut rng = if let Some(seed) = self.config.random_seed {
482            Random::seed(seed)
483        } else {
484            Random::seed(42)
485        };
486
487        // Note: Using manual random selection instead of SliceRandom
488        // Manually shuffle using Fisher-Yates algorithm
489        for i in (1..indices.len()).rev() {
490            let j = rng.gen_range(0..=i);
491            indices.swap(i, j);
492        }
493        indices.truncate(num_points.max(1));
494
495        indices
496    }
497}
498
499/// Pruned Approximate Nearest Neighbor Graph (PANNG) implementation
500pub struct PANNGGraph {
501    /// Pruned adjacency list
502    adjacency: Vec<Vec<(usize, f32)>>,
503    /// Data storage
504    data: Vec<(String, Vector)>,
505    /// Configuration
506    config: GraphIndexConfig,
507    /// Pruning threshold
508    pruning_threshold: f32,
509}
510
511impl PANNGGraph {
512    pub fn new(config: GraphIndexConfig) -> Self {
513        Self {
514            adjacency: Vec::new(),
515            data: Vec::new(),
516            config,
517            pruning_threshold: 0.9, // Angle-based pruning threshold
518        }
519    }
520
521    pub fn build(&mut self) -> Result<()> {
522        if self.data.is_empty() {
523            return Ok(());
524        }
525
526        // Build initial k-NN graph
527        self.adjacency = vec![Vec::new(); self.data.len()];
528        self.build_initial_graph()?;
529
530        // Apply pruning
531        if self.config.enable_pruning {
532            self.prune_graph()?;
533        }
534
535        Ok(())
536    }
537
538    fn build_initial_graph(&mut self) -> Result<()> {
539        // Build with more neighbors initially for pruning
540        let initial_neighbors = self.config.num_neighbors * 2;
541
542        for idx in 0..self.data.len() {
543            let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
544            self.adjacency[idx] = neighbors;
545        }
546
547        Ok(())
548    }
549
550    fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
551        let query = &self.data[idx].1.as_f32();
552        let mut heap = BinaryHeap::new();
553
554        for (other_idx, (_, vector)) in self.data.iter().enumerate() {
555            if other_idx == idx {
556                continue;
557            }
558
559            let distance = self
560                .config
561                .distance_metric
562                .distance(query, &vector.as_f32());
563
564            if heap.len() < k {
565                heap.push(SearchResult {
566                    index: other_idx,
567                    distance,
568                });
569            } else if distance < heap.peek().unwrap().distance {
570                heap.pop();
571                heap.push(SearchResult {
572                    index: other_idx,
573                    distance,
574                });
575            }
576        }
577
578        Ok(heap
579            .into_sorted_vec()
580            .into_iter()
581            .map(|r| (r.index, r.distance))
582            .collect())
583    }
584
585    fn prune_graph(&mut self) -> Result<()> {
586        for idx in 0..self.data.len() {
587            let pruned = self.prune_neighbors(idx)?;
588            self.adjacency[idx] = pruned;
589        }
590
591        Ok(())
592    }
593
594    fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
595        let neighbors = &self.adjacency[idx];
596        if neighbors.len() <= self.config.num_neighbors {
597            return Ok(neighbors.clone());
598        }
599
600        let mut pruned = Vec::new();
601        let (_, vector) = &self.data[idx];
602        let query = vector.as_f32();
603
604        for &(neighbor_idx, distance) in neighbors {
605            let (_, vector) = &self.data[neighbor_idx];
606            let neighbor = vector.as_f32();
607            let mut keep = true;
608
609            // Check angle with already selected neighbors
610            for &(selected_idx, _) in &pruned {
611                let (_id, vector): &(String, Vector) = &self.data[selected_idx];
612                let selected = vector.as_f32();
613
614                // Calculate angle between neighbor and selected
615                let angle = self.calculate_angle(&query, &neighbor, &selected);
616
617                if angle < self.pruning_threshold {
618                    keep = false;
619                    break;
620                }
621            }
622
623            if keep {
624                pruned.push((neighbor_idx, distance));
625
626                if pruned.len() >= self.config.num_neighbors {
627                    break;
628                }
629            }
630        }
631
632        Ok(pruned)
633    }
634
635    fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
636        // Calculate vectors from origin
637        let va: Vec<f32> = a
638            .iter()
639            .zip(origin.iter())
640            .map(|(ai, oi)| ai - oi)
641            .collect();
642        let vb: Vec<f32> = b
643            .iter()
644            .zip(origin.iter())
645            .map(|(bi, oi)| bi - oi)
646            .collect();
647
648        // Calculate cosine of angle
649        let dot = f32::dot(&va, &vb);
650        let norm_a = f32::norm(&va);
651        let norm_b = f32::norm(&vb);
652
653        if norm_a == 0.0 || norm_b == 0.0 {
654            return 0.0;
655        }
656
657        (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
658    }
659
660    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
661        if self.data.is_empty() {
662            return Vec::new();
663        }
664
665        let mut visited = HashSet::new();
666        let mut candidates = VecDeque::new();
667        let mut results = Vec::new();
668
669        // Start from closest point
670        let start = self.find_closest_point(query);
671        candidates.push_back(start);
672        visited.insert(start);
673
674        while let Some(current) = candidates.pop_front() {
675            let distance = self
676                .config
677                .distance_metric
678                .distance(query, &self.data[current].1.as_f32());
679            results.push((current, distance));
680
681            // Explore neighbors
682            for &(neighbor_idx, _) in &self.adjacency[current] {
683                if !visited.contains(&neighbor_idx) {
684                    visited.insert(neighbor_idx);
685                    candidates.push_back(neighbor_idx);
686                }
687            }
688
689            if results.len() >= k * 2 {
690                break;
691            }
692        }
693
694        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
695        results.truncate(k);
696        results
697    }
698
699    fn find_closest_point(&self, query: &[f32]) -> usize {
700        let mut min_dist = f32::INFINITY;
701        let mut closest = 0;
702
703        // Sample a few random points
704        let sample_size = (self.data.len() as f32).sqrt() as usize;
705        let step = self.data.len() / sample_size.max(1);
706
707        for idx in (0..self.data.len()).step_by(step.max(1)) {
708            let distance = self
709                .config
710                .distance_metric
711                .distance(query, &self.data[idx].1.as_f32());
712            if distance < min_dist {
713                min_dist = distance;
714                closest = idx;
715            }
716        }
717
718        closest
719    }
720}
721
722/// Delaunay Graph approximation for high dimensions
723pub struct DelaunayGraph {
724    /// Approximate Delaunay edges
725    edges: Vec<Vec<(usize, f32)>>,
726    /// Data storage
727    data: Vec<(String, Vector)>,
728    /// Configuration
729    config: GraphIndexConfig,
730}
731
732impl DelaunayGraph {
733    pub fn new(config: GraphIndexConfig) -> Self {
734        Self {
735            edges: Vec::new(),
736            data: Vec::new(),
737            config,
738        }
739    }
740
741    pub fn build(&mut self) -> Result<()> {
742        if self.data.is_empty() {
743            return Ok(());
744        }
745
746        self.edges = vec![Vec::new(); self.data.len()];
747
748        // For high dimensions, we approximate Delaunay by local criteria
749        for idx in 0..self.data.len() {
750            let neighbors = self.find_delaunay_neighbors(idx)?;
751            self.edges[idx] = neighbors;
752        }
753
754        // Make edges bidirectional
755        self.symmetrize_edges();
756
757        Ok(())
758    }
759
760    fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
761        let point = &self.data[idx].1.as_f32();
762        let mut candidates = Vec::new();
763
764        // Find potential neighbors
765        for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
766            if other_idx == idx {
767                continue;
768            }
769
770            let other = other_vec.as_f32();
771            let distance = self.config.distance_metric.distance(point, &other);
772            candidates.push((other_idx, distance));
773        }
774
775        candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
776
777        // Apply Delaunay criterion approximation
778        let mut neighbors = Vec::new();
779
780        for &(candidate_idx, distance) in &candidates {
781            if neighbors.len() >= self.config.num_neighbors {
782                break;
783            }
784
785            let candidate = &self.data[candidate_idx].1.as_f32();
786            let mut is_neighbor = true;
787
788            // Check if any existing neighbor violates the empty circumsphere property
789            for &(neighbor_idx, _) in &neighbors {
790                let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
791                let neighbor = vector.as_f32();
792
793                // Approximate check: if candidate is closer to neighbor than to point
794                let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
795                if dist_to_neighbor < distance * 0.9 {
796                    is_neighbor = false;
797                    break;
798                }
799            }
800
801            if is_neighbor {
802                neighbors.push((candidate_idx, distance));
803            }
804        }
805
806        Ok(neighbors)
807    }
808
809    fn symmetrize_edges(&mut self) {
810        let mut symmetric_edges = vec![Vec::new(); self.data.len()];
811
812        // Collect all edges
813        for (idx, neighbors) in self.edges.iter().enumerate() {
814            for &(neighbor_idx, distance) in neighbors {
815                symmetric_edges[idx].push((neighbor_idx, distance));
816                symmetric_edges[neighbor_idx].push((idx, distance));
817            }
818        }
819
820        // Remove duplicates and sort
821        for edges in &mut symmetric_edges {
822            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
823            edges.dedup_by_key(|&mut (idx, _)| idx);
824            edges.truncate(self.config.num_neighbors);
825        }
826
827        self.edges = symmetric_edges;
828    }
829
830    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
831        if self.data.is_empty() {
832            return Vec::new();
833        }
834
835        let mut visited = HashSet::new();
836        let mut heap = BinaryHeap::new();
837        let mut results = Vec::new();
838
839        // Start from a random point
840        let start = 0;
841        let distance = self
842            .config
843            .distance_metric
844            .distance(query, &self.data[start].1.as_f32());
845        heap.push(std::cmp::Reverse(SearchResult {
846            index: start,
847            distance,
848        }));
849        visited.insert(start);
850
851        while let Some(std::cmp::Reverse(current)) = heap.pop() {
852            results.push((current.index, current.distance));
853
854            if results.len() >= k {
855                break;
856            }
857
858            // Explore neighbors
859            for &(neighbor_idx, _) in &self.edges[current.index] {
860                if !visited.contains(&neighbor_idx) {
861                    visited.insert(neighbor_idx);
862                    let distance = self
863                        .config
864                        .distance_metric
865                        .distance(query, &self.data[neighbor_idx].1.as_f32());
866                    heap.push(std::cmp::Reverse(SearchResult {
867                        index: neighbor_idx,
868                        distance,
869                    }));
870                }
871            }
872        }
873
874        results
875    }
876}
877
878/// Relative Neighborhood Graph (RNG) implementation
879pub struct RNGGraph {
880    /// RNG edges
881    edges: Vec<Vec<(usize, f32)>>,
882    /// Data storage
883    data: Vec<(String, Vector)>,
884    /// Configuration
885    config: GraphIndexConfig,
886}
887
888impl RNGGraph {
889    pub fn new(config: GraphIndexConfig) -> Self {
890        Self {
891            edges: Vec::new(),
892            data: Vec::new(),
893            config,
894        }
895    }
896
897    pub fn build(&mut self) -> Result<()> {
898        if self.data.is_empty() {
899            return Ok(());
900        }
901
902        self.edges = vec![Vec::new(); self.data.len()];
903
904        // Build RNG by checking the RNG criterion for each pair
905        for i in 0..self.data.len() {
906            for j in i + 1..self.data.len() {
907                if self.is_rng_edge(i, j)? {
908                    let distance = self
909                        .config
910                        .distance_metric
911                        .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
912
913                    self.edges[i].push((j, distance));
914                    self.edges[j].push((i, distance));
915                }
916            }
917        }
918
919        // Sort edges by distance
920        for edges in &mut self.edges {
921            edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
922        }
923
924        Ok(())
925    }
926
927    fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
928        let pi = &self.data[i].1.as_f32();
929        let pj = &self.data[j].1.as_f32();
930        let dist_ij = self.config.distance_metric.distance(pi, pj);
931
932        // Check RNG criterion: no other point k exists such that
933        // max(dist(i,k), dist(j,k)) < dist(i,j)
934        for k in 0..self.data.len() {
935            if k == i || k == j {
936                continue;
937            }
938
939            let pk = &self.data[k].1.as_f32();
940            let dist_ik = self.config.distance_metric.distance(pi, pk);
941            let dist_jk = self.config.distance_metric.distance(pj, pk);
942
943            if dist_ik.max(dist_jk) < dist_ij {
944                return Ok(false);
945            }
946        }
947
948        Ok(true)
949    }
950
951    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
952        if self.data.is_empty() {
953            return Vec::new();
954        }
955
956        let mut visited = HashSet::new();
957        let mut candidates = BinaryHeap::new();
958        let mut results = Vec::new();
959
960        // Start from the closest sampled point
961        let start = self.find_start_point(query);
962        let distance = self
963            .config
964            .distance_metric
965            .distance(query, &self.data[start].1.as_f32());
966        candidates.push(std::cmp::Reverse(SearchResult {
967            index: start,
968            distance,
969        }));
970        visited.insert(start);
971
972        while let Some(std::cmp::Reverse(current)) = candidates.pop() {
973            results.push((current.index, current.distance));
974
975            if results.len() >= k {
976                break;
977            }
978
979            // Explore neighbors
980            for &(neighbor_idx, _) in &self.edges[current.index] {
981                if !visited.contains(&neighbor_idx) {
982                    visited.insert(neighbor_idx);
983                    let distance = self
984                        .config
985                        .distance_metric
986                        .distance(query, &self.data[neighbor_idx].1.as_f32());
987                    candidates.push(std::cmp::Reverse(SearchResult {
988                        index: neighbor_idx,
989                        distance,
990                    }));
991                }
992            }
993        }
994
995        results
996    }
997
998    fn find_start_point(&self, query: &[f32]) -> usize {
999        // Sample a subset of points
1000        let sample_size = (self.data.len() as f32).sqrt() as usize;
1001        let mut min_dist = f32::INFINITY;
1002        let mut best = 0;
1003
1004        for i in 0..sample_size.min(self.data.len()) {
1005            let idx = (i * self.data.len()) / sample_size;
1006            let distance = self
1007                .config
1008                .distance_metric
1009                .distance(query, &self.data[idx].1.as_f32());
1010
1011            if distance < min_dist {
1012                min_dist = distance;
1013                best = idx;
1014            }
1015        }
1016
1017        best
1018    }
1019}
1020
1021/// Unified graph index interface
1022pub struct GraphIndex {
1023    graph_type: GraphType,
1024    nsw: Option<NSWGraph>,
1025    onng: Option<ONNGGraph>,
1026    panng: Option<PANNGGraph>,
1027    delaunay: Option<DelaunayGraph>,
1028    rng: Option<RNGGraph>,
1029}
1030
1031impl GraphIndex {
1032    pub fn new(config: GraphIndexConfig) -> Self {
1033        let graph_type = config.graph_type;
1034
1035        let (nsw, onng, panng, delaunay, rng) = match graph_type {
1036            GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1037            GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1038            GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1039            GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1040            GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1041        };
1042
1043        Self {
1044            graph_type,
1045            nsw,
1046            onng,
1047            panng,
1048            delaunay,
1049            rng,
1050        }
1051    }
1052
1053    fn build(&mut self) -> Result<()> {
1054        match self.graph_type {
1055            GraphType::NSW => self.nsw.as_mut().unwrap().build(),
1056            GraphType::ONNG => self.onng.as_mut().unwrap().build(),
1057            GraphType::PANNG => self.panng.as_mut().unwrap().build(),
1058            GraphType::Delaunay => self.delaunay.as_mut().unwrap().build(),
1059            GraphType::RNG => self.rng.as_mut().unwrap().build(),
1060        }
1061    }
1062
1063    fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1064        match self.graph_type {
1065            GraphType::NSW => self.nsw.as_ref().unwrap().search(query, k),
1066            GraphType::ONNG => self.onng.as_ref().unwrap().search(query, k),
1067            GraphType::PANNG => self.panng.as_ref().unwrap().search(query, k),
1068            GraphType::Delaunay => self.delaunay.as_ref().unwrap().search(query, k),
1069            GraphType::RNG => self.rng.as_ref().unwrap().search(query, k),
1070        }
1071    }
1072}
1073
1074impl VectorIndex for GraphIndex {
1075    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1076        let data = match self.graph_type {
1077            GraphType::NSW => &mut self.nsw.as_mut().unwrap().data,
1078            GraphType::ONNG => &mut self.onng.as_mut().unwrap().data,
1079            GraphType::PANNG => &mut self.panng.as_mut().unwrap().data,
1080            GraphType::Delaunay => &mut self.delaunay.as_mut().unwrap().data,
1081            GraphType::RNG => &mut self.rng.as_mut().unwrap().data,
1082        };
1083
1084        data.push((uri, vector));
1085        Ok(())
1086    }
1087
1088    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1089        let query_f32 = query.as_f32();
1090        let results = self.search_internal(&query_f32, k);
1091
1092        let data = match self.graph_type {
1093            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1094            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1095            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1096            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1097            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1098        };
1099
1100        Ok(results
1101            .into_iter()
1102            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1103            .collect())
1104    }
1105
1106    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1107        let query_f32 = query.as_f32();
1108        let all_results = self.search_internal(&query_f32, 1000);
1109
1110        let data = match self.graph_type {
1111            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1112            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1113            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1114            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1115            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1116        };
1117
1118        Ok(all_results
1119            .into_iter()
1120            .filter(|(_, dist)| *dist <= threshold)
1121            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1122            .collect())
1123    }
1124
1125    fn get_vector(&self, uri: &str) -> Option<&Vector> {
1126        let data = match self.graph_type {
1127            GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1128            GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1129            GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1130            GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1131            GraphType::RNG => &self.rng.as_ref().unwrap().data,
1132        };
1133
1134        data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1135    }
1136}
1137
1138// Add dependencies
1139use petgraph;
1140// Note: Replaced with scirs2_core::random
1141
1142#[cfg(test)]
1143mod tests {
1144    use super::*;
1145
1146    #[test]
1147    fn test_nsw_graph() {
1148        let config = GraphIndexConfig {
1149            graph_type: GraphType::NSW,
1150            num_neighbors: 10,
1151            ..Default::default()
1152        };
1153
1154        let mut index = GraphIndex::new(config);
1155
1156        // Insert test vectors
1157        for i in 0..50 {
1158            let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1159            index.insert(format!("vec_{i}"), vector).unwrap();
1160        }
1161
1162        index.build().unwrap();
1163
1164        // Search for nearest neighbors
1165        let query = Vector::new(vec![25.0, 50.0, 75.0]);
1166        let results = index.search_knn(&query, 5).unwrap();
1167
1168        assert_eq!(results.len(), 5);
1169        assert_eq!(results[0].0, "vec_25"); // Exact match
1170    }
1171
1172    #[test]
1173    fn test_onng_graph() {
1174        let config = GraphIndexConfig {
1175            graph_type: GraphType::ONNG,
1176            num_neighbors: 8,
1177            ..Default::default()
1178        };
1179
1180        let mut index = GraphIndex::new(config);
1181
1182        // Insert test vectors in a circle
1183        for i in 0..20 {
1184            let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1185            let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1186            index.insert(format!("vec_{i}"), vector).unwrap();
1187        }
1188
1189        index.build().unwrap();
1190
1191        // Search for nearest neighbors
1192        let query = Vector::new(vec![1.0, 0.0]);
1193        let results = index.search_knn(&query, 3).unwrap();
1194
1195        assert_eq!(results.len(), 3);
1196    }
1197
1198    #[test]
1199    fn test_panng_graph() {
1200        let config = GraphIndexConfig {
1201            graph_type: GraphType::PANNG,
1202            num_neighbors: 5,
1203            enable_pruning: true,
1204            ..Default::default()
1205        };
1206
1207        let mut index = GraphIndex::new(config);
1208
1209        // Insert test vectors
1210        for i in 0..30 {
1211            let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1212            index.insert(format!("vec_{i}"), vector).unwrap();
1213        }
1214
1215        index.build().unwrap();
1216
1217        // Search for nearest neighbors
1218        let query = Vector::new(vec![0.0, 1.0, 0.0]);
1219        let results = index.search_knn(&query, 5).unwrap();
1220
1221        assert_eq!(results.len(), 5);
1222    }
1223}