1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16#[allow(unused_imports)]
17use scirs2_core::random::{Random, Rng};
18use std::cmp::Ordering;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20
21#[derive(Debug, Clone)]
23pub struct GraphIndexConfig {
24 pub graph_type: GraphType,
26 pub num_neighbors: usize,
28 pub random_seed: Option<u64>,
30 pub parallel_construction: bool,
32 pub distance_metric: DistanceMetric,
34 pub enable_pruning: bool,
36 pub search_expansion: f32,
38}
39
40impl Default for GraphIndexConfig {
41 fn default() -> Self {
42 Self {
43 graph_type: GraphType::NSW,
44 num_neighbors: 32,
45 random_seed: None,
46 parallel_construction: true,
47 distance_metric: DistanceMetric::Euclidean,
48 enable_pruning: true,
49 search_expansion: 1.5,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy)]
56pub enum GraphType {
57 NSW, ONNG, PANNG, Delaunay, RNG, }
63
64#[derive(Debug, Clone, Copy)]
66pub enum DistanceMetric {
67 Euclidean,
68 Manhattan,
69 Cosine,
70 Angular,
71}
72
73impl DistanceMetric {
74 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
75 match self {
76 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
77 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
78 DistanceMetric::Cosine => f32::cosine_distance(a, b),
79 DistanceMetric::Angular => {
80 let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
82 cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90struct SearchResult {
91 index: usize,
92 distance: f32,
93}
94
95impl PartialEq for SearchResult {
96 fn eq(&self, other: &Self) -> bool {
97 self.distance == other.distance
98 }
99}
100
101impl Eq for SearchResult {}
102
103impl PartialOrd for SearchResult {
104 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105 Some(self.cmp(other))
106 }
107}
108
109impl Ord for SearchResult {
110 fn cmp(&self, other: &Self) -> Ordering {
111 self.distance
112 .partial_cmp(&other.distance)
113 .unwrap_or(Ordering::Equal)
114 }
115}
116
117pub struct NSWGraph {
119 graph: Graph<usize, f32>,
121 node_map: HashMap<usize, NodeIndex>,
123 data: Vec<(String, Vector)>,
125 config: GraphIndexConfig,
127 entry_points: Vec<NodeIndex>,
129}
130
131impl NSWGraph {
132 pub fn new(config: GraphIndexConfig) -> Self {
133 Self {
134 graph: Graph::new(),
135 node_map: HashMap::new(),
136 data: Vec::new(),
137 config,
138 entry_points: Vec::new(),
139 }
140 }
141
142 pub fn build(&mut self) -> Result<()> {
144 if self.data.is_empty() {
145 return Ok(());
146 }
147
148 for (idx, _) in self.data.iter().enumerate() {
150 let node = self.graph.add_node(idx);
151 self.node_map.insert(idx, node);
152 }
153
154 let num_entry_points = (self.data.len() as f32).sqrt() as usize;
156 let mut rng = if let Some(seed) = self.config.random_seed {
157 Random::seed(seed)
158 } else {
159 Random::seed(42)
160 };
161
162 let mut indices: Vec<usize> = (0..self.data.len()).collect();
164 for i in (1..indices.len()).rev() {
166 let j = rng.random_range(0..i + 1);
167 indices.swap(i, j);
168 }
169
170 self.entry_points = indices[..num_entry_points.min(self.data.len())]
171 .iter()
172 .map(|&idx| self.node_map[&idx])
173 .collect();
174
175 if self.config.parallel_construction && self.data.len() > 1000 {
177 self.build_parallel()?;
178 } else {
179 self.build_sequential()?;
180 }
181
182 Ok(())
183 }
184
185 fn build_sequential(&mut self) -> Result<()> {
186 for idx in 0..self.data.len() {
187 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
188 let node = self.node_map[&idx];
189
190 for (neighbor_idx, distance) in neighbors {
191 let neighbor_node = self.node_map[&neighbor_idx];
192 if !self.graph.contains_edge(node, neighbor_node) {
193 self.graph.add_edge(node, neighbor_node, distance);
194 }
195 }
196 }
197
198 Ok(())
199 }
200
201 fn build_parallel(&mut self) -> Result<()> {
202 let _chunk_size = (self.data.len() / num_threads()).max(100);
203
204 let mut all_edges = Vec::new();
206 for idx in 0..self.data.len() {
207 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
208 let node = self.node_map[&idx];
209
210 for (neighbor_idx, distance) in neighbors {
211 let neighbor_node = self.node_map[&neighbor_idx];
212 all_edges.push((node, neighbor_node, distance));
213 }
214 }
215
216 for (from, to, weight) in all_edges {
218 if !self.graph.contains_edge(from, to) {
219 self.graph.add_edge(from, to, weight);
220 }
221 }
222
223 Ok(())
224 }
225
226 fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
227 let query = &self.data[idx].1.as_f32();
228 let mut heap = BinaryHeap::new();
229
230 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
231 if other_idx == idx {
232 continue;
233 }
234
235 let other = vector.as_f32();
236 let distance = self.config.distance_metric.distance(query, &other);
237
238 if heap.len() < k {
239 heap.push(SearchResult {
240 index: other_idx,
241 distance,
242 });
243 } else if distance < heap.peek().expect("heap should have k elements").distance {
244 heap.pop();
245 heap.push(SearchResult {
246 index: other_idx,
247 distance,
248 });
249 }
250 }
251
252 Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
253 }
254
255 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
257 if self.entry_points.is_empty() {
258 return Vec::new();
259 }
260
261 let mut visited = HashSet::new();
262 let mut candidates = BinaryHeap::new();
263 let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
264
265 for &entry in &self.entry_points {
267 let idx = self.graph[entry];
268 let distance = self
269 .config
270 .distance_metric
271 .distance(query, &self.data[idx].1.as_f32());
272 candidates.push(std::cmp::Reverse(SearchResult {
273 index: idx,
274 distance,
275 }));
276 visited.insert(idx);
277 }
278
279 let max_candidates = (k as f32 * self.config.search_expansion) as usize;
281
282 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
283 if results.len() >= k
285 && current.distance
286 > results
287 .peek()
288 .expect("results should have k elements")
289 .distance
290 {
291 break;
292 }
293
294 if results.len() < k {
296 results.push(current.clone());
297 } else if current.distance
298 < results
299 .peek()
300 .expect("results should have k elements")
301 .distance
302 {
303 results.pop();
304 results.push(current.clone());
305 }
306
307 let node = self.node_map[¤t.index];
309 for neighbor in self.graph.neighbors(node) {
310 let neighbor_idx = self.graph[neighbor];
311
312 if visited.contains(&neighbor_idx) {
313 continue;
314 }
315
316 visited.insert(neighbor_idx);
317 let distance = self
318 .config
319 .distance_metric
320 .distance(query, &self.data[neighbor_idx].1.as_f32());
321
322 if candidates.len() < max_candidates
323 || distance
324 < candidates
325 .peek()
326 .expect("candidates should have elements")
327 .0
328 .distance
329 {
330 candidates.push(std::cmp::Reverse(SearchResult {
331 index: neighbor_idx,
332 distance,
333 }));
334 }
335 }
336 }
337
338 let mut results: Vec<(usize, f32)> =
339 results.into_iter().map(|r| (r.index, r.distance)).collect();
340
341 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
342 results
343 }
344}
345
346pub struct ONNGGraph {
348 adjacency: Vec<Vec<(usize, f32)>>,
350 data: Vec<(String, Vector)>,
352 config: GraphIndexConfig,
354}
355
356impl ONNGGraph {
357 pub fn new(config: GraphIndexConfig) -> Self {
358 Self {
359 adjacency: Vec::new(),
360 data: Vec::new(),
361 config,
362 }
363 }
364
365 pub fn build(&mut self) -> Result<()> {
366 if self.data.is_empty() {
367 return Ok(());
368 }
369
370 self.adjacency = vec![Vec::new(); self.data.len()];
372
373 self.build_knn_graph()?;
375
376 self.optimize_graph()?;
378
379 Ok(())
380 }
381
382 fn build_knn_graph(&mut self) -> Result<()> {
383 for idx in 0..self.data.len() {
384 let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
385 self.adjacency[idx] = neighbors;
386 }
387
388 Ok(())
389 }
390
391 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
392 let query = &self.data[idx].1.as_f32();
393 let mut neighbors = Vec::new();
394
395 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
396 if other_idx == idx {
397 continue;
398 }
399
400 let distance = self
401 .config
402 .distance_metric
403 .distance(query, &vector.as_f32());
404 neighbors.push((other_idx, distance));
405 }
406
407 neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
408 neighbors.truncate(k);
409
410 Ok(neighbors)
411 }
412
413 fn optimize_graph(&mut self) -> Result<()> {
414 let mut reverse_edges = vec![Vec::new(); self.data.len()];
416
417 for (idx, neighbors) in self.adjacency.iter().enumerate() {
418 for &(neighbor_idx, distance) in neighbors {
419 reverse_edges[neighbor_idx].push((idx, distance));
420 }
421 }
422
423 for (idx, reverse) in reverse_edges.into_iter().enumerate() {
425 let mut all_neighbors = self.adjacency[idx].clone();
426 all_neighbors.extend(reverse);
427
428 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
430 all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
431 all_neighbors.truncate(self.config.num_neighbors);
432
433 self.adjacency[idx] = all_neighbors;
434 }
435
436 Ok(())
437 }
438
439 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
440 if self.data.is_empty() {
441 return Vec::new();
442 }
443
444 let start_points = self.select_start_points();
446 let mut visited = HashSet::new();
447 let mut heap = BinaryHeap::new();
448
449 for start in start_points {
451 let distance = self
452 .config
453 .distance_metric
454 .distance(query, &self.data[start].1.as_f32());
455 heap.push(std::cmp::Reverse(SearchResult {
456 index: start,
457 distance,
458 }));
459 visited.insert(start);
460 }
461
462 let mut results = Vec::new();
463
464 while let Some(std::cmp::Reverse(current)) = heap.pop() {
465 results.push((current.index, current.distance));
466
467 if results.len() >= k {
468 break;
469 }
470
471 for &(neighbor_idx, _) in &self.adjacency[current.index] {
473 if visited.contains(&neighbor_idx) {
474 continue;
475 }
476
477 visited.insert(neighbor_idx);
478 let distance = self
479 .config
480 .distance_metric
481 .distance(query, &self.data[neighbor_idx].1.as_f32());
482 heap.push(std::cmp::Reverse(SearchResult {
483 index: neighbor_idx,
484 distance,
485 }));
486 }
487 }
488
489 results.truncate(k);
490 results
491 }
492
493 fn select_start_points(&self) -> Vec<usize> {
494 let num_points = (self.data.len() as f32).sqrt() as usize;
496 let mut indices: Vec<usize> = (0..self.data.len()).collect();
497
498 let mut rng = if let Some(seed) = self.config.random_seed {
499 Random::seed(seed)
500 } else {
501 Random::seed(42)
502 };
503
504 for i in (1..indices.len()).rev() {
507 let j = rng.random_range(0..i + 1);
508 indices.swap(i, j);
509 }
510 indices.truncate(num_points.max(1));
511
512 indices
513 }
514}
515
516pub struct PANNGGraph {
518 adjacency: Vec<Vec<(usize, f32)>>,
520 data: Vec<(String, Vector)>,
522 config: GraphIndexConfig,
524 pruning_threshold: f32,
526}
527
528impl PANNGGraph {
529 pub fn new(config: GraphIndexConfig) -> Self {
530 Self {
531 adjacency: Vec::new(),
532 data: Vec::new(),
533 config,
534 pruning_threshold: 0.9, }
536 }
537
538 pub fn build(&mut self) -> Result<()> {
539 if self.data.is_empty() {
540 return Ok(());
541 }
542
543 self.adjacency = vec![Vec::new(); self.data.len()];
545 self.build_initial_graph()?;
546
547 if self.config.enable_pruning {
549 self.prune_graph()?;
550 }
551
552 Ok(())
553 }
554
555 fn build_initial_graph(&mut self) -> Result<()> {
556 let initial_neighbors = self.config.num_neighbors * 2;
558
559 for idx in 0..self.data.len() {
560 let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
561 self.adjacency[idx] = neighbors;
562 }
563
564 Ok(())
565 }
566
567 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
568 let query = &self.data[idx].1.as_f32();
569 let mut heap = BinaryHeap::new();
570
571 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
572 if other_idx == idx {
573 continue;
574 }
575
576 let distance = self
577 .config
578 .distance_metric
579 .distance(query, &vector.as_f32());
580
581 if heap.len() < k {
582 heap.push(SearchResult {
583 index: other_idx,
584 distance,
585 });
586 } else if distance < heap.peek().expect("heap should have k elements").distance {
587 heap.pop();
588 heap.push(SearchResult {
589 index: other_idx,
590 distance,
591 });
592 }
593 }
594
595 Ok(heap
596 .into_sorted_vec()
597 .into_iter()
598 .map(|r| (r.index, r.distance))
599 .collect())
600 }
601
602 fn prune_graph(&mut self) -> Result<()> {
603 for idx in 0..self.data.len() {
604 let pruned = self.prune_neighbors(idx)?;
605 self.adjacency[idx] = pruned;
606 }
607
608 Ok(())
609 }
610
611 fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
612 let neighbors = &self.adjacency[idx];
613 if neighbors.len() <= self.config.num_neighbors {
614 return Ok(neighbors.clone());
615 }
616
617 let mut pruned = Vec::new();
618 let (_, vector) = &self.data[idx];
619 let query = vector.as_f32();
620
621 for &(neighbor_idx, distance) in neighbors {
622 let (_, vector) = &self.data[neighbor_idx];
623 let neighbor = vector.as_f32();
624 let mut keep = true;
625
626 for &(selected_idx, _) in &pruned {
628 let (_id, vector): &(String, Vector) = &self.data[selected_idx];
629 let selected = vector.as_f32();
630
631 let angle = self.calculate_angle(&query, &neighbor, &selected);
633
634 if angle < self.pruning_threshold {
635 keep = false;
636 break;
637 }
638 }
639
640 if keep {
641 pruned.push((neighbor_idx, distance));
642
643 if pruned.len() >= self.config.num_neighbors {
644 break;
645 }
646 }
647 }
648
649 Ok(pruned)
650 }
651
652 fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
653 let va: Vec<f32> = a
655 .iter()
656 .zip(origin.iter())
657 .map(|(ai, oi)| ai - oi)
658 .collect();
659 let vb: Vec<f32> = b
660 .iter()
661 .zip(origin.iter())
662 .map(|(bi, oi)| bi - oi)
663 .collect();
664
665 let dot = f32::dot(&va, &vb);
667 let norm_a = f32::norm(&va);
668 let norm_b = f32::norm(&vb);
669
670 if norm_a == 0.0 || norm_b == 0.0 {
671 return 0.0;
672 }
673
674 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
675 }
676
677 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
678 if self.data.is_empty() {
679 return Vec::new();
680 }
681
682 let mut visited = HashSet::new();
683 let mut candidates = VecDeque::new();
684 let mut results = Vec::new();
685
686 let start = self.find_closest_point(query);
688 candidates.push_back(start);
689 visited.insert(start);
690
691 while let Some(current) = candidates.pop_front() {
692 let distance = self
693 .config
694 .distance_metric
695 .distance(query, &self.data[current].1.as_f32());
696 results.push((current, distance));
697
698 for &(neighbor_idx, _) in &self.adjacency[current] {
700 if !visited.contains(&neighbor_idx) {
701 visited.insert(neighbor_idx);
702 candidates.push_back(neighbor_idx);
703 }
704 }
705
706 if results.len() >= k * 2 {
707 break;
708 }
709 }
710
711 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
712 results.truncate(k);
713 results
714 }
715
716 fn find_closest_point(&self, query: &[f32]) -> usize {
717 let mut min_dist = f32::INFINITY;
718 let mut closest = 0;
719
720 let sample_size = (self.data.len() as f32).sqrt() as usize;
722 let step = self.data.len() / sample_size.max(1);
723
724 for idx in (0..self.data.len()).step_by(step.max(1)) {
725 let distance = self
726 .config
727 .distance_metric
728 .distance(query, &self.data[idx].1.as_f32());
729 if distance < min_dist {
730 min_dist = distance;
731 closest = idx;
732 }
733 }
734
735 closest
736 }
737}
738
739pub struct DelaunayGraph {
741 edges: Vec<Vec<(usize, f32)>>,
743 data: Vec<(String, Vector)>,
745 config: GraphIndexConfig,
747}
748
749impl DelaunayGraph {
750 pub fn new(config: GraphIndexConfig) -> Self {
751 Self {
752 edges: Vec::new(),
753 data: Vec::new(),
754 config,
755 }
756 }
757
758 pub fn build(&mut self) -> Result<()> {
759 if self.data.is_empty() {
760 return Ok(());
761 }
762
763 self.edges = vec![Vec::new(); self.data.len()];
764
765 for idx in 0..self.data.len() {
767 let neighbors = self.find_delaunay_neighbors(idx)?;
768 self.edges[idx] = neighbors;
769 }
770
771 self.symmetrize_edges();
773
774 Ok(())
775 }
776
777 fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
778 let point = &self.data[idx].1.as_f32();
779 let mut candidates = Vec::new();
780
781 for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
783 if other_idx == idx {
784 continue;
785 }
786
787 let other = other_vec.as_f32();
788 let distance = self.config.distance_metric.distance(point, &other);
789 candidates.push((other_idx, distance));
790 }
791
792 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
793
794 let mut neighbors = Vec::new();
796
797 for &(candidate_idx, distance) in &candidates {
798 if neighbors.len() >= self.config.num_neighbors {
799 break;
800 }
801
802 let candidate = &self.data[candidate_idx].1.as_f32();
803 let mut is_neighbor = true;
804
805 for &(neighbor_idx, _) in &neighbors {
807 let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
808 let neighbor = vector.as_f32();
809
810 let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
812 if dist_to_neighbor < distance * 0.9 {
813 is_neighbor = false;
814 break;
815 }
816 }
817
818 if is_neighbor {
819 neighbors.push((candidate_idx, distance));
820 }
821 }
822
823 Ok(neighbors)
824 }
825
826 fn symmetrize_edges(&mut self) {
827 let mut symmetric_edges = vec![Vec::new(); self.data.len()];
828
829 for (idx, neighbors) in self.edges.iter().enumerate() {
831 for &(neighbor_idx, distance) in neighbors {
832 symmetric_edges[idx].push((neighbor_idx, distance));
833 symmetric_edges[neighbor_idx].push((idx, distance));
834 }
835 }
836
837 for edges in &mut symmetric_edges {
839 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
840 edges.dedup_by_key(|&mut (idx, _)| idx);
841 edges.truncate(self.config.num_neighbors);
842 }
843
844 self.edges = symmetric_edges;
845 }
846
847 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
848 if self.data.is_empty() {
849 return Vec::new();
850 }
851
852 let mut visited = HashSet::new();
853 let mut heap = BinaryHeap::new();
854 let mut results = Vec::new();
855
856 let start = 0;
858 let distance = self
859 .config
860 .distance_metric
861 .distance(query, &self.data[start].1.as_f32());
862 heap.push(std::cmp::Reverse(SearchResult {
863 index: start,
864 distance,
865 }));
866 visited.insert(start);
867
868 while let Some(std::cmp::Reverse(current)) = heap.pop() {
869 results.push((current.index, current.distance));
870
871 if results.len() >= k {
872 break;
873 }
874
875 for &(neighbor_idx, _) in &self.edges[current.index] {
877 if !visited.contains(&neighbor_idx) {
878 visited.insert(neighbor_idx);
879 let distance = self
880 .config
881 .distance_metric
882 .distance(query, &self.data[neighbor_idx].1.as_f32());
883 heap.push(std::cmp::Reverse(SearchResult {
884 index: neighbor_idx,
885 distance,
886 }));
887 }
888 }
889 }
890
891 results
892 }
893}
894
895pub struct RNGGraph {
897 edges: Vec<Vec<(usize, f32)>>,
899 data: Vec<(String, Vector)>,
901 config: GraphIndexConfig,
903}
904
905impl RNGGraph {
906 pub fn new(config: GraphIndexConfig) -> Self {
907 Self {
908 edges: Vec::new(),
909 data: Vec::new(),
910 config,
911 }
912 }
913
914 pub fn build(&mut self) -> Result<()> {
915 if self.data.is_empty() {
916 return Ok(());
917 }
918
919 self.edges = vec![Vec::new(); self.data.len()];
920
921 for i in 0..self.data.len() {
923 for j in i + 1..self.data.len() {
924 if self.is_rng_edge(i, j)? {
925 let distance = self
926 .config
927 .distance_metric
928 .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
929
930 self.edges[i].push((j, distance));
931 self.edges[j].push((i, distance));
932 }
933 }
934 }
935
936 for edges in &mut self.edges {
938 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
939 }
940
941 Ok(())
942 }
943
944 fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
945 let pi = &self.data[i].1.as_f32();
946 let pj = &self.data[j].1.as_f32();
947 let dist_ij = self.config.distance_metric.distance(pi, pj);
948
949 for k in 0..self.data.len() {
952 if k == i || k == j {
953 continue;
954 }
955
956 let pk = &self.data[k].1.as_f32();
957 let dist_ik = self.config.distance_metric.distance(pi, pk);
958 let dist_jk = self.config.distance_metric.distance(pj, pk);
959
960 if dist_ik.max(dist_jk) < dist_ij {
961 return Ok(false);
962 }
963 }
964
965 Ok(true)
966 }
967
968 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
969 if self.data.is_empty() {
970 return Vec::new();
971 }
972
973 let mut visited = HashSet::new();
974 let mut candidates = BinaryHeap::new();
975 let mut results = Vec::new();
976
977 let start = self.find_start_point(query);
979 let distance = self
980 .config
981 .distance_metric
982 .distance(query, &self.data[start].1.as_f32());
983 candidates.push(std::cmp::Reverse(SearchResult {
984 index: start,
985 distance,
986 }));
987 visited.insert(start);
988
989 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
990 results.push((current.index, current.distance));
991
992 if results.len() >= k {
993 break;
994 }
995
996 for &(neighbor_idx, _) in &self.edges[current.index] {
998 if !visited.contains(&neighbor_idx) {
999 visited.insert(neighbor_idx);
1000 let distance = self
1001 .config
1002 .distance_metric
1003 .distance(query, &self.data[neighbor_idx].1.as_f32());
1004 candidates.push(std::cmp::Reverse(SearchResult {
1005 index: neighbor_idx,
1006 distance,
1007 }));
1008 }
1009 }
1010 }
1011
1012 results
1013 }
1014
1015 fn find_start_point(&self, query: &[f32]) -> usize {
1016 let sample_size = (self.data.len() as f32).sqrt() as usize;
1018 let mut min_dist = f32::INFINITY;
1019 let mut best = 0;
1020
1021 for i in 0..sample_size.min(self.data.len()) {
1022 let idx = (i * self.data.len()) / sample_size;
1023 let distance = self
1024 .config
1025 .distance_metric
1026 .distance(query, &self.data[idx].1.as_f32());
1027
1028 if distance < min_dist {
1029 min_dist = distance;
1030 best = idx;
1031 }
1032 }
1033
1034 best
1035 }
1036}
1037
1038pub struct GraphIndex {
1040 graph_type: GraphType,
1041 nsw: Option<NSWGraph>,
1042 onng: Option<ONNGGraph>,
1043 panng: Option<PANNGGraph>,
1044 delaunay: Option<DelaunayGraph>,
1045 rng: Option<RNGGraph>,
1046}
1047
1048impl GraphIndex {
1049 pub fn new(config: GraphIndexConfig) -> Self {
1050 let graph_type = config.graph_type;
1051
1052 let (nsw, onng, panng, delaunay, rng) = match graph_type {
1053 GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1054 GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1055 GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1056 GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1057 GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1058 };
1059
1060 Self {
1061 graph_type,
1062 nsw,
1063 onng,
1064 panng,
1065 delaunay,
1066 rng,
1067 }
1068 }
1069
1070 fn build(&mut self) -> Result<()> {
1071 match self.graph_type {
1072 GraphType::NSW => self
1073 .nsw
1074 .as_mut()
1075 .expect("nsw should be initialized for NSW type")
1076 .build(),
1077 GraphType::ONNG => self
1078 .onng
1079 .as_mut()
1080 .expect("onng should be initialized for ONNG type")
1081 .build(),
1082 GraphType::PANNG => self
1083 .panng
1084 .as_mut()
1085 .expect("panng should be initialized for PANNG type")
1086 .build(),
1087 GraphType::Delaunay => self
1088 .delaunay
1089 .as_mut()
1090 .expect("delaunay should be initialized for Delaunay type")
1091 .build(),
1092 GraphType::RNG => self
1093 .rng
1094 .as_mut()
1095 .expect("rng should be initialized for RNG type")
1096 .build(),
1097 }
1098 }
1099
1100 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1101 match self.graph_type {
1102 GraphType::NSW => self
1103 .nsw
1104 .as_ref()
1105 .expect("nsw should be initialized for NSW type")
1106 .search(query, k),
1107 GraphType::ONNG => self
1108 .onng
1109 .as_ref()
1110 .expect("onng should be initialized for ONNG type")
1111 .search(query, k),
1112 GraphType::PANNG => self
1113 .panng
1114 .as_ref()
1115 .expect("panng should be initialized for PANNG type")
1116 .search(query, k),
1117 GraphType::Delaunay => self
1118 .delaunay
1119 .as_ref()
1120 .expect("delaunay should be initialized for Delaunay type")
1121 .search(query, k),
1122 GraphType::RNG => self
1123 .rng
1124 .as_ref()
1125 .expect("rng should be initialized for RNG type")
1126 .search(query, k),
1127 }
1128 }
1129}
1130
1131impl VectorIndex for GraphIndex {
1132 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1133 let data = match self.graph_type {
1134 GraphType::NSW => {
1135 &mut self
1136 .nsw
1137 .as_mut()
1138 .expect("nsw should be initialized for NSW type")
1139 .data
1140 }
1141 GraphType::ONNG => {
1142 &mut self
1143 .onng
1144 .as_mut()
1145 .expect("onng should be initialized for ONNG type")
1146 .data
1147 }
1148 GraphType::PANNG => {
1149 &mut self
1150 .panng
1151 .as_mut()
1152 .expect("panng should be initialized for PANNG type")
1153 .data
1154 }
1155 GraphType::Delaunay => {
1156 &mut self
1157 .delaunay
1158 .as_mut()
1159 .expect("delaunay should be initialized for Delaunay type")
1160 .data
1161 }
1162 GraphType::RNG => {
1163 &mut self
1164 .rng
1165 .as_mut()
1166 .expect("rng should be initialized for RNG type")
1167 .data
1168 }
1169 };
1170
1171 data.push((uri, vector));
1172 Ok(())
1173 }
1174
1175 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1176 let query_f32 = query.as_f32();
1177 let results = self.search_internal(&query_f32, k);
1178
1179 let data = match self.graph_type {
1180 GraphType::NSW => {
1181 &self
1182 .nsw
1183 .as_ref()
1184 .expect("nsw should be initialized for NSW type")
1185 .data
1186 }
1187 GraphType::ONNG => {
1188 &self
1189 .onng
1190 .as_ref()
1191 .expect("onng should be initialized for ONNG type")
1192 .data
1193 }
1194 GraphType::PANNG => {
1195 &self
1196 .panng
1197 .as_ref()
1198 .expect("panng should be initialized for PANNG type")
1199 .data
1200 }
1201 GraphType::Delaunay => {
1202 &self
1203 .delaunay
1204 .as_ref()
1205 .expect("delaunay should be initialized for Delaunay type")
1206 .data
1207 }
1208 GraphType::RNG => {
1209 &self
1210 .rng
1211 .as_ref()
1212 .expect("rng should be initialized for RNG type")
1213 .data
1214 }
1215 };
1216
1217 Ok(results
1218 .into_iter()
1219 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1220 .collect())
1221 }
1222
1223 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1224 let query_f32 = query.as_f32();
1225 let all_results = self.search_internal(&query_f32, 1000);
1226
1227 let data = match self.graph_type {
1228 GraphType::NSW => {
1229 &self
1230 .nsw
1231 .as_ref()
1232 .expect("nsw should be initialized for NSW type")
1233 .data
1234 }
1235 GraphType::ONNG => {
1236 &self
1237 .onng
1238 .as_ref()
1239 .expect("onng should be initialized for ONNG type")
1240 .data
1241 }
1242 GraphType::PANNG => {
1243 &self
1244 .panng
1245 .as_ref()
1246 .expect("panng should be initialized for PANNG type")
1247 .data
1248 }
1249 GraphType::Delaunay => {
1250 &self
1251 .delaunay
1252 .as_ref()
1253 .expect("delaunay should be initialized for Delaunay type")
1254 .data
1255 }
1256 GraphType::RNG => {
1257 &self
1258 .rng
1259 .as_ref()
1260 .expect("rng should be initialized for RNG type")
1261 .data
1262 }
1263 };
1264
1265 Ok(all_results
1266 .into_iter()
1267 .filter(|(_, dist)| *dist <= threshold)
1268 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1269 .collect())
1270 }
1271
1272 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1273 let data = match self.graph_type {
1274 GraphType::NSW => {
1275 &self
1276 .nsw
1277 .as_ref()
1278 .expect("nsw should be initialized for NSW type")
1279 .data
1280 }
1281 GraphType::ONNG => {
1282 &self
1283 .onng
1284 .as_ref()
1285 .expect("onng should be initialized for ONNG type")
1286 .data
1287 }
1288 GraphType::PANNG => {
1289 &self
1290 .panng
1291 .as_ref()
1292 .expect("panng should be initialized for PANNG type")
1293 .data
1294 }
1295 GraphType::Delaunay => {
1296 &self
1297 .delaunay
1298 .as_ref()
1299 .expect("delaunay should be initialized for Delaunay type")
1300 .data
1301 }
1302 GraphType::RNG => {
1303 &self
1304 .rng
1305 .as_ref()
1306 .expect("rng should be initialized for RNG type")
1307 .data
1308 }
1309 };
1310
1311 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1312 }
1313}
1314
1315use petgraph;
1317#[cfg(test)]
1320mod tests {
1321 use super::*;
1322
1323 #[test]
1324 fn test_nsw_graph() {
1325 let config = GraphIndexConfig {
1326 graph_type: GraphType::NSW,
1327 num_neighbors: 10,
1328 ..Default::default()
1329 };
1330
1331 let mut index = GraphIndex::new(config);
1332
1333 for i in 0..50 {
1335 let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1336 index.insert(format!("vec_{i}"), vector).unwrap();
1337 }
1338
1339 index.build().unwrap();
1340
1341 let query = Vector::new(vec![25.0, 50.0, 75.0]);
1343 let results = index.search_knn(&query, 5).unwrap();
1344
1345 assert_eq!(results.len(), 5);
1346 assert_eq!(results[0].0, "vec_25"); }
1348
1349 #[test]
1350 fn test_onng_graph() {
1351 let config = GraphIndexConfig {
1352 graph_type: GraphType::ONNG,
1353 num_neighbors: 8,
1354 ..Default::default()
1355 };
1356
1357 let mut index = GraphIndex::new(config);
1358
1359 for i in 0..20 {
1361 let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1362 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1363 index.insert(format!("vec_{i}"), vector).unwrap();
1364 }
1365
1366 index.build().unwrap();
1367
1368 let query = Vector::new(vec![1.0, 0.0]);
1370 let results = index.search_knn(&query, 3).unwrap();
1371
1372 assert_eq!(results.len(), 3);
1373 }
1374
1375 #[test]
1376 fn test_panng_graph() {
1377 let config = GraphIndexConfig {
1378 graph_type: GraphType::PANNG,
1379 num_neighbors: 5,
1380 enable_pruning: true,
1381 ..Default::default()
1382 };
1383
1384 let mut index = GraphIndex::new(config);
1385
1386 for i in 0..30 {
1388 let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1389 index.insert(format!("vec_{i}"), vector).unwrap();
1390 }
1391
1392 index.build().unwrap();
1393
1394 let query = Vector::new(vec![0.0, 1.0, 0.0]);
1396 let results = index.search_knn(&query, 5).unwrap();
1397
1398 assert_eq!(results.len(), 5);
1399 }
1400}