1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16use scirs2_core::random::{Random, Rng};
17use std::cmp::Ordering;
18use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
19
20#[derive(Debug, Clone)]
22pub struct GraphIndexConfig {
23 pub graph_type: GraphType,
25 pub num_neighbors: usize,
27 pub random_seed: Option<u64>,
29 pub parallel_construction: bool,
31 pub distance_metric: DistanceMetric,
33 pub enable_pruning: bool,
35 pub search_expansion: f32,
37}
38
39impl Default for GraphIndexConfig {
40 fn default() -> Self {
41 Self {
42 graph_type: GraphType::NSW,
43 num_neighbors: 32,
44 random_seed: None,
45 parallel_construction: true,
46 distance_metric: DistanceMetric::Euclidean,
47 enable_pruning: true,
48 search_expansion: 1.5,
49 }
50 }
51}
52
53#[derive(Debug, Clone, Copy)]
55pub enum GraphType {
56 NSW, ONNG, PANNG, Delaunay, RNG, }
62
63#[derive(Debug, Clone, Copy)]
65pub enum DistanceMetric {
66 Euclidean,
67 Manhattan,
68 Cosine,
69 Angular,
70}
71
72impl DistanceMetric {
73 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
74 match self {
75 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
76 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
77 DistanceMetric::Cosine => f32::cosine_distance(a, b),
78 DistanceMetric::Angular => {
79 let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
81 cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
82 }
83 }
84 }
85}
86
87#[derive(Debug, Clone)]
89struct SearchResult {
90 index: usize,
91 distance: f32,
92}
93
94impl PartialEq for SearchResult {
95 fn eq(&self, other: &Self) -> bool {
96 self.distance == other.distance
97 }
98}
99
100impl Eq for SearchResult {}
101
102impl PartialOrd for SearchResult {
103 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
104 Some(self.cmp(other))
105 }
106}
107
108impl Ord for SearchResult {
109 fn cmp(&self, other: &Self) -> Ordering {
110 self.distance
111 .partial_cmp(&other.distance)
112 .unwrap_or(Ordering::Equal)
113 }
114}
115
116pub struct NSWGraph {
118 graph: Graph<usize, f32>,
120 node_map: HashMap<usize, NodeIndex>,
122 data: Vec<(String, Vector)>,
124 config: GraphIndexConfig,
126 entry_points: Vec<NodeIndex>,
128}
129
130impl NSWGraph {
131 pub fn new(config: GraphIndexConfig) -> Self {
132 Self {
133 graph: Graph::new(),
134 node_map: HashMap::new(),
135 data: Vec::new(),
136 config,
137 entry_points: Vec::new(),
138 }
139 }
140
141 pub fn build(&mut self) -> Result<()> {
143 if self.data.is_empty() {
144 return Ok(());
145 }
146
147 for (idx, _) in self.data.iter().enumerate() {
149 let node = self.graph.add_node(idx);
150 self.node_map.insert(idx, node);
151 }
152
153 let num_entry_points = (self.data.len() as f32).sqrt() as usize;
155 let mut rng = if let Some(seed) = self.config.random_seed {
156 Random::seed(seed)
157 } else {
158 Random::seed(42)
159 };
160
161 let mut indices: Vec<usize> = (0..self.data.len()).collect();
163 for i in (1..indices.len()).rev() {
165 let j = rng.gen_range(0..=i);
166 indices.swap(i, j);
167 }
168
169 self.entry_points = indices[..num_entry_points.min(self.data.len())]
170 .iter()
171 .map(|&idx| self.node_map[&idx])
172 .collect();
173
174 if self.config.parallel_construction && self.data.len() > 1000 {
176 self.build_parallel()?;
177 } else {
178 self.build_sequential()?;
179 }
180
181 Ok(())
182 }
183
184 fn build_sequential(&mut self) -> Result<()> {
185 for idx in 0..self.data.len() {
186 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
187 let node = self.node_map[&idx];
188
189 for (neighbor_idx, distance) in neighbors {
190 let neighbor_node = self.node_map[&neighbor_idx];
191 if !self.graph.contains_edge(node, neighbor_node) {
192 self.graph.add_edge(node, neighbor_node, distance);
193 }
194 }
195 }
196
197 Ok(())
198 }
199
200 fn build_parallel(&mut self) -> Result<()> {
201 let _chunk_size = (self.data.len() / num_threads()).max(100);
202
203 let mut all_edges = Vec::new();
205 for idx in 0..self.data.len() {
206 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
207 let node = self.node_map[&idx];
208
209 for (neighbor_idx, distance) in neighbors {
210 let neighbor_node = self.node_map[&neighbor_idx];
211 all_edges.push((node, neighbor_node, distance));
212 }
213 }
214
215 for (from, to, weight) in all_edges {
217 if !self.graph.contains_edge(from, to) {
218 self.graph.add_edge(from, to, weight);
219 }
220 }
221
222 Ok(())
223 }
224
225 fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
226 let query = &self.data[idx].1.as_f32();
227 let mut heap = BinaryHeap::new();
228
229 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
230 if other_idx == idx {
231 continue;
232 }
233
234 let other = vector.as_f32();
235 let distance = self.config.distance_metric.distance(query, &other);
236
237 if heap.len() < k {
238 heap.push(SearchResult {
239 index: other_idx,
240 distance,
241 });
242 } else if distance < heap.peek().unwrap().distance {
243 heap.pop();
244 heap.push(SearchResult {
245 index: other_idx,
246 distance,
247 });
248 }
249 }
250
251 Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
252 }
253
254 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
256 if self.entry_points.is_empty() {
257 return Vec::new();
258 }
259
260 let mut visited = HashSet::new();
261 let mut candidates = BinaryHeap::new();
262 let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
263
264 for &entry in &self.entry_points {
266 let idx = self.graph[entry];
267 let distance = self
268 .config
269 .distance_metric
270 .distance(query, &self.data[idx].1.as_f32());
271 candidates.push(std::cmp::Reverse(SearchResult {
272 index: idx,
273 distance,
274 }));
275 visited.insert(idx);
276 }
277
278 let max_candidates = (k as f32 * self.config.search_expansion) as usize;
280
281 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
282 if results.len() >= k && current.distance > results.peek().unwrap().distance {
284 break;
285 }
286
287 if results.len() < k {
289 results.push(current.clone());
290 } else if current.distance < results.peek().unwrap().distance {
291 results.pop();
292 results.push(current.clone());
293 }
294
295 let node = self.node_map[¤t.index];
297 for neighbor in self.graph.neighbors(node) {
298 let neighbor_idx = self.graph[neighbor];
299
300 if visited.contains(&neighbor_idx) {
301 continue;
302 }
303
304 visited.insert(neighbor_idx);
305 let distance = self
306 .config
307 .distance_metric
308 .distance(query, &self.data[neighbor_idx].1.as_f32());
309
310 if candidates.len() < max_candidates
311 || distance < candidates.peek().unwrap().0.distance
312 {
313 candidates.push(std::cmp::Reverse(SearchResult {
314 index: neighbor_idx,
315 distance,
316 }));
317 }
318 }
319 }
320
321 let mut results: Vec<(usize, f32)> =
322 results.into_iter().map(|r| (r.index, r.distance)).collect();
323
324 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
325 results
326 }
327}
328
329pub struct ONNGGraph {
331 adjacency: Vec<Vec<(usize, f32)>>,
333 data: Vec<(String, Vector)>,
335 config: GraphIndexConfig,
337}
338
339impl ONNGGraph {
340 pub fn new(config: GraphIndexConfig) -> Self {
341 Self {
342 adjacency: Vec::new(),
343 data: Vec::new(),
344 config,
345 }
346 }
347
348 pub fn build(&mut self) -> Result<()> {
349 if self.data.is_empty() {
350 return Ok(());
351 }
352
353 self.adjacency = vec![Vec::new(); self.data.len()];
355
356 self.build_knn_graph()?;
358
359 self.optimize_graph()?;
361
362 Ok(())
363 }
364
365 fn build_knn_graph(&mut self) -> Result<()> {
366 for idx in 0..self.data.len() {
367 let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
368 self.adjacency[idx] = neighbors;
369 }
370
371 Ok(())
372 }
373
374 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
375 let query = &self.data[idx].1.as_f32();
376 let mut neighbors = Vec::new();
377
378 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
379 if other_idx == idx {
380 continue;
381 }
382
383 let distance = self
384 .config
385 .distance_metric
386 .distance(query, &vector.as_f32());
387 neighbors.push((other_idx, distance));
388 }
389
390 neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
391 neighbors.truncate(k);
392
393 Ok(neighbors)
394 }
395
396 fn optimize_graph(&mut self) -> Result<()> {
397 let mut reverse_edges = vec![Vec::new(); self.data.len()];
399
400 for (idx, neighbors) in self.adjacency.iter().enumerate() {
401 for &(neighbor_idx, distance) in neighbors {
402 reverse_edges[neighbor_idx].push((idx, distance));
403 }
404 }
405
406 for (idx, reverse) in reverse_edges.into_iter().enumerate() {
408 let mut all_neighbors = self.adjacency[idx].clone();
409 all_neighbors.extend(reverse);
410
411 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
413 all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
414 all_neighbors.truncate(self.config.num_neighbors);
415
416 self.adjacency[idx] = all_neighbors;
417 }
418
419 Ok(())
420 }
421
422 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
423 if self.data.is_empty() {
424 return Vec::new();
425 }
426
427 let start_points = self.select_start_points();
429 let mut visited = HashSet::new();
430 let mut heap = BinaryHeap::new();
431
432 for start in start_points {
434 let distance = self
435 .config
436 .distance_metric
437 .distance(query, &self.data[start].1.as_f32());
438 heap.push(std::cmp::Reverse(SearchResult {
439 index: start,
440 distance,
441 }));
442 visited.insert(start);
443 }
444
445 let mut results = Vec::new();
446
447 while let Some(std::cmp::Reverse(current)) = heap.pop() {
448 results.push((current.index, current.distance));
449
450 if results.len() >= k {
451 break;
452 }
453
454 for &(neighbor_idx, _) in &self.adjacency[current.index] {
456 if visited.contains(&neighbor_idx) {
457 continue;
458 }
459
460 visited.insert(neighbor_idx);
461 let distance = self
462 .config
463 .distance_metric
464 .distance(query, &self.data[neighbor_idx].1.as_f32());
465 heap.push(std::cmp::Reverse(SearchResult {
466 index: neighbor_idx,
467 distance,
468 }));
469 }
470 }
471
472 results.truncate(k);
473 results
474 }
475
476 fn select_start_points(&self) -> Vec<usize> {
477 let num_points = (self.data.len() as f32).sqrt() as usize;
479 let mut indices: Vec<usize> = (0..self.data.len()).collect();
480
481 let mut rng = if let Some(seed) = self.config.random_seed {
482 Random::seed(seed)
483 } else {
484 Random::seed(42)
485 };
486
487 for i in (1..indices.len()).rev() {
490 let j = rng.gen_range(0..=i);
491 indices.swap(i, j);
492 }
493 indices.truncate(num_points.max(1));
494
495 indices
496 }
497}
498
499pub struct PANNGGraph {
501 adjacency: Vec<Vec<(usize, f32)>>,
503 data: Vec<(String, Vector)>,
505 config: GraphIndexConfig,
507 pruning_threshold: f32,
509}
510
511impl PANNGGraph {
512 pub fn new(config: GraphIndexConfig) -> Self {
513 Self {
514 adjacency: Vec::new(),
515 data: Vec::new(),
516 config,
517 pruning_threshold: 0.9, }
519 }
520
521 pub fn build(&mut self) -> Result<()> {
522 if self.data.is_empty() {
523 return Ok(());
524 }
525
526 self.adjacency = vec![Vec::new(); self.data.len()];
528 self.build_initial_graph()?;
529
530 if self.config.enable_pruning {
532 self.prune_graph()?;
533 }
534
535 Ok(())
536 }
537
538 fn build_initial_graph(&mut self) -> Result<()> {
539 let initial_neighbors = self.config.num_neighbors * 2;
541
542 for idx in 0..self.data.len() {
543 let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
544 self.adjacency[idx] = neighbors;
545 }
546
547 Ok(())
548 }
549
550 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
551 let query = &self.data[idx].1.as_f32();
552 let mut heap = BinaryHeap::new();
553
554 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
555 if other_idx == idx {
556 continue;
557 }
558
559 let distance = self
560 .config
561 .distance_metric
562 .distance(query, &vector.as_f32());
563
564 if heap.len() < k {
565 heap.push(SearchResult {
566 index: other_idx,
567 distance,
568 });
569 } else if distance < heap.peek().unwrap().distance {
570 heap.pop();
571 heap.push(SearchResult {
572 index: other_idx,
573 distance,
574 });
575 }
576 }
577
578 Ok(heap
579 .into_sorted_vec()
580 .into_iter()
581 .map(|r| (r.index, r.distance))
582 .collect())
583 }
584
585 fn prune_graph(&mut self) -> Result<()> {
586 for idx in 0..self.data.len() {
587 let pruned = self.prune_neighbors(idx)?;
588 self.adjacency[idx] = pruned;
589 }
590
591 Ok(())
592 }
593
594 fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
595 let neighbors = &self.adjacency[idx];
596 if neighbors.len() <= self.config.num_neighbors {
597 return Ok(neighbors.clone());
598 }
599
600 let mut pruned = Vec::new();
601 let (_, vector) = &self.data[idx];
602 let query = vector.as_f32();
603
604 for &(neighbor_idx, distance) in neighbors {
605 let (_, vector) = &self.data[neighbor_idx];
606 let neighbor = vector.as_f32();
607 let mut keep = true;
608
609 for &(selected_idx, _) in &pruned {
611 let (_id, vector): &(String, Vector) = &self.data[selected_idx];
612 let selected = vector.as_f32();
613
614 let angle = self.calculate_angle(&query, &neighbor, &selected);
616
617 if angle < self.pruning_threshold {
618 keep = false;
619 break;
620 }
621 }
622
623 if keep {
624 pruned.push((neighbor_idx, distance));
625
626 if pruned.len() >= self.config.num_neighbors {
627 break;
628 }
629 }
630 }
631
632 Ok(pruned)
633 }
634
635 fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
636 let va: Vec<f32> = a
638 .iter()
639 .zip(origin.iter())
640 .map(|(ai, oi)| ai - oi)
641 .collect();
642 let vb: Vec<f32> = b
643 .iter()
644 .zip(origin.iter())
645 .map(|(bi, oi)| bi - oi)
646 .collect();
647
648 let dot = f32::dot(&va, &vb);
650 let norm_a = f32::norm(&va);
651 let norm_b = f32::norm(&vb);
652
653 if norm_a == 0.0 || norm_b == 0.0 {
654 return 0.0;
655 }
656
657 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
658 }
659
660 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
661 if self.data.is_empty() {
662 return Vec::new();
663 }
664
665 let mut visited = HashSet::new();
666 let mut candidates = VecDeque::new();
667 let mut results = Vec::new();
668
669 let start = self.find_closest_point(query);
671 candidates.push_back(start);
672 visited.insert(start);
673
674 while let Some(current) = candidates.pop_front() {
675 let distance = self
676 .config
677 .distance_metric
678 .distance(query, &self.data[current].1.as_f32());
679 results.push((current, distance));
680
681 for &(neighbor_idx, _) in &self.adjacency[current] {
683 if !visited.contains(&neighbor_idx) {
684 visited.insert(neighbor_idx);
685 candidates.push_back(neighbor_idx);
686 }
687 }
688
689 if results.len() >= k * 2 {
690 break;
691 }
692 }
693
694 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
695 results.truncate(k);
696 results
697 }
698
699 fn find_closest_point(&self, query: &[f32]) -> usize {
700 let mut min_dist = f32::INFINITY;
701 let mut closest = 0;
702
703 let sample_size = (self.data.len() as f32).sqrt() as usize;
705 let step = self.data.len() / sample_size.max(1);
706
707 for idx in (0..self.data.len()).step_by(step.max(1)) {
708 let distance = self
709 .config
710 .distance_metric
711 .distance(query, &self.data[idx].1.as_f32());
712 if distance < min_dist {
713 min_dist = distance;
714 closest = idx;
715 }
716 }
717
718 closest
719 }
720}
721
722pub struct DelaunayGraph {
724 edges: Vec<Vec<(usize, f32)>>,
726 data: Vec<(String, Vector)>,
728 config: GraphIndexConfig,
730}
731
732impl DelaunayGraph {
733 pub fn new(config: GraphIndexConfig) -> Self {
734 Self {
735 edges: Vec::new(),
736 data: Vec::new(),
737 config,
738 }
739 }
740
741 pub fn build(&mut self) -> Result<()> {
742 if self.data.is_empty() {
743 return Ok(());
744 }
745
746 self.edges = vec![Vec::new(); self.data.len()];
747
748 for idx in 0..self.data.len() {
750 let neighbors = self.find_delaunay_neighbors(idx)?;
751 self.edges[idx] = neighbors;
752 }
753
754 self.symmetrize_edges();
756
757 Ok(())
758 }
759
760 fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
761 let point = &self.data[idx].1.as_f32();
762 let mut candidates = Vec::new();
763
764 for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
766 if other_idx == idx {
767 continue;
768 }
769
770 let other = other_vec.as_f32();
771 let distance = self.config.distance_metric.distance(point, &other);
772 candidates.push((other_idx, distance));
773 }
774
775 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
776
777 let mut neighbors = Vec::new();
779
780 for &(candidate_idx, distance) in &candidates {
781 if neighbors.len() >= self.config.num_neighbors {
782 break;
783 }
784
785 let candidate = &self.data[candidate_idx].1.as_f32();
786 let mut is_neighbor = true;
787
788 for &(neighbor_idx, _) in &neighbors {
790 let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
791 let neighbor = vector.as_f32();
792
793 let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
795 if dist_to_neighbor < distance * 0.9 {
796 is_neighbor = false;
797 break;
798 }
799 }
800
801 if is_neighbor {
802 neighbors.push((candidate_idx, distance));
803 }
804 }
805
806 Ok(neighbors)
807 }
808
809 fn symmetrize_edges(&mut self) {
810 let mut symmetric_edges = vec![Vec::new(); self.data.len()];
811
812 for (idx, neighbors) in self.edges.iter().enumerate() {
814 for &(neighbor_idx, distance) in neighbors {
815 symmetric_edges[idx].push((neighbor_idx, distance));
816 symmetric_edges[neighbor_idx].push((idx, distance));
817 }
818 }
819
820 for edges in &mut symmetric_edges {
822 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
823 edges.dedup_by_key(|&mut (idx, _)| idx);
824 edges.truncate(self.config.num_neighbors);
825 }
826
827 self.edges = symmetric_edges;
828 }
829
830 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
831 if self.data.is_empty() {
832 return Vec::new();
833 }
834
835 let mut visited = HashSet::new();
836 let mut heap = BinaryHeap::new();
837 let mut results = Vec::new();
838
839 let start = 0;
841 let distance = self
842 .config
843 .distance_metric
844 .distance(query, &self.data[start].1.as_f32());
845 heap.push(std::cmp::Reverse(SearchResult {
846 index: start,
847 distance,
848 }));
849 visited.insert(start);
850
851 while let Some(std::cmp::Reverse(current)) = heap.pop() {
852 results.push((current.index, current.distance));
853
854 if results.len() >= k {
855 break;
856 }
857
858 for &(neighbor_idx, _) in &self.edges[current.index] {
860 if !visited.contains(&neighbor_idx) {
861 visited.insert(neighbor_idx);
862 let distance = self
863 .config
864 .distance_metric
865 .distance(query, &self.data[neighbor_idx].1.as_f32());
866 heap.push(std::cmp::Reverse(SearchResult {
867 index: neighbor_idx,
868 distance,
869 }));
870 }
871 }
872 }
873
874 results
875 }
876}
877
878pub struct RNGGraph {
880 edges: Vec<Vec<(usize, f32)>>,
882 data: Vec<(String, Vector)>,
884 config: GraphIndexConfig,
886}
887
888impl RNGGraph {
889 pub fn new(config: GraphIndexConfig) -> Self {
890 Self {
891 edges: Vec::new(),
892 data: Vec::new(),
893 config,
894 }
895 }
896
897 pub fn build(&mut self) -> Result<()> {
898 if self.data.is_empty() {
899 return Ok(());
900 }
901
902 self.edges = vec![Vec::new(); self.data.len()];
903
904 for i in 0..self.data.len() {
906 for j in i + 1..self.data.len() {
907 if self.is_rng_edge(i, j)? {
908 let distance = self
909 .config
910 .distance_metric
911 .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
912
913 self.edges[i].push((j, distance));
914 self.edges[j].push((i, distance));
915 }
916 }
917 }
918
919 for edges in &mut self.edges {
921 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
922 }
923
924 Ok(())
925 }
926
927 fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
928 let pi = &self.data[i].1.as_f32();
929 let pj = &self.data[j].1.as_f32();
930 let dist_ij = self.config.distance_metric.distance(pi, pj);
931
932 for k in 0..self.data.len() {
935 if k == i || k == j {
936 continue;
937 }
938
939 let pk = &self.data[k].1.as_f32();
940 let dist_ik = self.config.distance_metric.distance(pi, pk);
941 let dist_jk = self.config.distance_metric.distance(pj, pk);
942
943 if dist_ik.max(dist_jk) < dist_ij {
944 return Ok(false);
945 }
946 }
947
948 Ok(true)
949 }
950
951 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
952 if self.data.is_empty() {
953 return Vec::new();
954 }
955
956 let mut visited = HashSet::new();
957 let mut candidates = BinaryHeap::new();
958 let mut results = Vec::new();
959
960 let start = self.find_start_point(query);
962 let distance = self
963 .config
964 .distance_metric
965 .distance(query, &self.data[start].1.as_f32());
966 candidates.push(std::cmp::Reverse(SearchResult {
967 index: start,
968 distance,
969 }));
970 visited.insert(start);
971
972 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
973 results.push((current.index, current.distance));
974
975 if results.len() >= k {
976 break;
977 }
978
979 for &(neighbor_idx, _) in &self.edges[current.index] {
981 if !visited.contains(&neighbor_idx) {
982 visited.insert(neighbor_idx);
983 let distance = self
984 .config
985 .distance_metric
986 .distance(query, &self.data[neighbor_idx].1.as_f32());
987 candidates.push(std::cmp::Reverse(SearchResult {
988 index: neighbor_idx,
989 distance,
990 }));
991 }
992 }
993 }
994
995 results
996 }
997
998 fn find_start_point(&self, query: &[f32]) -> usize {
999 let sample_size = (self.data.len() as f32).sqrt() as usize;
1001 let mut min_dist = f32::INFINITY;
1002 let mut best = 0;
1003
1004 for i in 0..sample_size.min(self.data.len()) {
1005 let idx = (i * self.data.len()) / sample_size;
1006 let distance = self
1007 .config
1008 .distance_metric
1009 .distance(query, &self.data[idx].1.as_f32());
1010
1011 if distance < min_dist {
1012 min_dist = distance;
1013 best = idx;
1014 }
1015 }
1016
1017 best
1018 }
1019}
1020
1021pub struct GraphIndex {
1023 graph_type: GraphType,
1024 nsw: Option<NSWGraph>,
1025 onng: Option<ONNGGraph>,
1026 panng: Option<PANNGGraph>,
1027 delaunay: Option<DelaunayGraph>,
1028 rng: Option<RNGGraph>,
1029}
1030
1031impl GraphIndex {
1032 pub fn new(config: GraphIndexConfig) -> Self {
1033 let graph_type = config.graph_type;
1034
1035 let (nsw, onng, panng, delaunay, rng) = match graph_type {
1036 GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1037 GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1038 GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1039 GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1040 GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1041 };
1042
1043 Self {
1044 graph_type,
1045 nsw,
1046 onng,
1047 panng,
1048 delaunay,
1049 rng,
1050 }
1051 }
1052
1053 fn build(&mut self) -> Result<()> {
1054 match self.graph_type {
1055 GraphType::NSW => self.nsw.as_mut().unwrap().build(),
1056 GraphType::ONNG => self.onng.as_mut().unwrap().build(),
1057 GraphType::PANNG => self.panng.as_mut().unwrap().build(),
1058 GraphType::Delaunay => self.delaunay.as_mut().unwrap().build(),
1059 GraphType::RNG => self.rng.as_mut().unwrap().build(),
1060 }
1061 }
1062
1063 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1064 match self.graph_type {
1065 GraphType::NSW => self.nsw.as_ref().unwrap().search(query, k),
1066 GraphType::ONNG => self.onng.as_ref().unwrap().search(query, k),
1067 GraphType::PANNG => self.panng.as_ref().unwrap().search(query, k),
1068 GraphType::Delaunay => self.delaunay.as_ref().unwrap().search(query, k),
1069 GraphType::RNG => self.rng.as_ref().unwrap().search(query, k),
1070 }
1071 }
1072}
1073
1074impl VectorIndex for GraphIndex {
1075 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1076 let data = match self.graph_type {
1077 GraphType::NSW => &mut self.nsw.as_mut().unwrap().data,
1078 GraphType::ONNG => &mut self.onng.as_mut().unwrap().data,
1079 GraphType::PANNG => &mut self.panng.as_mut().unwrap().data,
1080 GraphType::Delaunay => &mut self.delaunay.as_mut().unwrap().data,
1081 GraphType::RNG => &mut self.rng.as_mut().unwrap().data,
1082 };
1083
1084 data.push((uri, vector));
1085 Ok(())
1086 }
1087
1088 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1089 let query_f32 = query.as_f32();
1090 let results = self.search_internal(&query_f32, k);
1091
1092 let data = match self.graph_type {
1093 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1094 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1095 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1096 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1097 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1098 };
1099
1100 Ok(results
1101 .into_iter()
1102 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1103 .collect())
1104 }
1105
1106 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1107 let query_f32 = query.as_f32();
1108 let all_results = self.search_internal(&query_f32, 1000);
1109
1110 let data = match self.graph_type {
1111 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1112 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1113 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1114 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1115 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1116 };
1117
1118 Ok(all_results
1119 .into_iter()
1120 .filter(|(_, dist)| *dist <= threshold)
1121 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1122 .collect())
1123 }
1124
1125 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1126 let data = match self.graph_type {
1127 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1128 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1129 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1130 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1131 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1132 };
1133
1134 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1135 }
1136}
1137
1138use petgraph;
1140#[cfg(test)]
1143mod tests {
1144 use super::*;
1145
1146 #[test]
1147 fn test_nsw_graph() {
1148 let config = GraphIndexConfig {
1149 graph_type: GraphType::NSW,
1150 num_neighbors: 10,
1151 ..Default::default()
1152 };
1153
1154 let mut index = GraphIndex::new(config);
1155
1156 for i in 0..50 {
1158 let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1159 index.insert(format!("vec_{i}"), vector).unwrap();
1160 }
1161
1162 index.build().unwrap();
1163
1164 let query = Vector::new(vec![25.0, 50.0, 75.0]);
1166 let results = index.search_knn(&query, 5).unwrap();
1167
1168 assert_eq!(results.len(), 5);
1169 assert_eq!(results[0].0, "vec_25"); }
1171
1172 #[test]
1173 fn test_onng_graph() {
1174 let config = GraphIndexConfig {
1175 graph_type: GraphType::ONNG,
1176 num_neighbors: 8,
1177 ..Default::default()
1178 };
1179
1180 let mut index = GraphIndex::new(config);
1181
1182 for i in 0..20 {
1184 let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1185 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1186 index.insert(format!("vec_{i}"), vector).unwrap();
1187 }
1188
1189 index.build().unwrap();
1190
1191 let query = Vector::new(vec![1.0, 0.0]);
1193 let results = index.search_knn(&query, 3).unwrap();
1194
1195 assert_eq!(results.len(), 3);
1196 }
1197
1198 #[test]
1199 fn test_panng_graph() {
1200 let config = GraphIndexConfig {
1201 graph_type: GraphType::PANNG,
1202 num_neighbors: 5,
1203 enable_pruning: true,
1204 ..Default::default()
1205 };
1206
1207 let mut index = GraphIndex::new(config);
1208
1209 for i in 0..30 {
1211 let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1212 index.insert(format!("vec_{i}"), vector).unwrap();
1213 }
1214
1215 index.build().unwrap();
1216
1217 let query = Vector::new(vec![0.0, 1.0, 0.0]);
1219 let results = index.search_knn(&query, 5).unwrap();
1220
1221 assert_eq!(results.len(), 5);
1222 }
1223}