1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::parallel::*;
14use oxirs_core::simd::SimdOps;
15use petgraph::graph::{Graph, NodeIndex};
16#[allow(unused_imports)]
17use scirs2_core::random::{Random, Rng};
18use std::cmp::Ordering;
19use std::collections::{BinaryHeap, HashMap, HashSet, VecDeque};
20
21#[derive(Debug, Clone)]
23pub struct GraphIndexConfig {
24 pub graph_type: GraphType,
26 pub num_neighbors: usize,
28 pub random_seed: Option<u64>,
30 pub parallel_construction: bool,
32 pub distance_metric: DistanceMetric,
34 pub enable_pruning: bool,
36 pub search_expansion: f32,
38}
39
40impl Default for GraphIndexConfig {
41 fn default() -> Self {
42 Self {
43 graph_type: GraphType::NSW,
44 num_neighbors: 32,
45 random_seed: None,
46 parallel_construction: true,
47 distance_metric: DistanceMetric::Euclidean,
48 enable_pruning: true,
49 search_expansion: 1.5,
50 }
51 }
52}
53
54#[derive(Debug, Clone, Copy)]
56pub enum GraphType {
57 NSW, ONNG, PANNG, Delaunay, RNG, }
63
64#[derive(Debug, Clone, Copy)]
66pub enum DistanceMetric {
67 Euclidean,
68 Manhattan,
69 Cosine,
70 Angular,
71}
72
73impl DistanceMetric {
74 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
75 match self {
76 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
77 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
78 DistanceMetric::Cosine => f32::cosine_distance(a, b),
79 DistanceMetric::Angular => {
80 let cos_sim: f32 = 1.0 - f32::cosine_distance(a, b);
82 cos_sim.clamp(-1.0, 1.0).acos() / std::f32::consts::PI
83 }
84 }
85 }
86}
87
88#[derive(Debug, Clone)]
90struct SearchResult {
91 index: usize,
92 distance: f32,
93}
94
95impl PartialEq for SearchResult {
96 fn eq(&self, other: &Self) -> bool {
97 self.distance == other.distance
98 }
99}
100
101impl Eq for SearchResult {}
102
103impl PartialOrd for SearchResult {
104 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
105 Some(self.cmp(other))
106 }
107}
108
109impl Ord for SearchResult {
110 fn cmp(&self, other: &Self) -> Ordering {
111 self.distance
112 .partial_cmp(&other.distance)
113 .unwrap_or(Ordering::Equal)
114 }
115}
116
117pub struct NSWGraph {
119 graph: Graph<usize, f32>,
121 node_map: HashMap<usize, NodeIndex>,
123 data: Vec<(String, Vector)>,
125 config: GraphIndexConfig,
127 entry_points: Vec<NodeIndex>,
129}
130
131impl NSWGraph {
132 pub fn new(config: GraphIndexConfig) -> Self {
133 Self {
134 graph: Graph::new(),
135 node_map: HashMap::new(),
136 data: Vec::new(),
137 config,
138 entry_points: Vec::new(),
139 }
140 }
141
142 pub fn build(&mut self) -> Result<()> {
144 if self.data.is_empty() {
145 return Ok(());
146 }
147
148 for (idx, _) in self.data.iter().enumerate() {
150 let node = self.graph.add_node(idx);
151 self.node_map.insert(idx, node);
152 }
153
154 let num_entry_points = (self.data.len() as f32).sqrt() as usize;
156 let mut rng = if let Some(seed) = self.config.random_seed {
157 Random::seed(seed)
158 } else {
159 Random::seed(42)
160 };
161
162 let mut indices: Vec<usize> = (0..self.data.len()).collect();
164 for i in (1..indices.len()).rev() {
166 let j = rng.random_range(0, i + 1);
167 indices.swap(i, j);
168 }
169
170 self.entry_points = indices[..num_entry_points.min(self.data.len())]
171 .iter()
172 .map(|&idx| self.node_map[&idx])
173 .collect();
174
175 if self.config.parallel_construction && self.data.len() > 1000 {
177 self.build_parallel()?;
178 } else {
179 self.build_sequential()?;
180 }
181
182 Ok(())
183 }
184
185 fn build_sequential(&mut self) -> Result<()> {
186 for idx in 0..self.data.len() {
187 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
188 let node = self.node_map[&idx];
189
190 for (neighbor_idx, distance) in neighbors {
191 let neighbor_node = self.node_map[&neighbor_idx];
192 if !self.graph.contains_edge(node, neighbor_node) {
193 self.graph.add_edge(node, neighbor_node, distance);
194 }
195 }
196 }
197
198 Ok(())
199 }
200
201 fn build_parallel(&mut self) -> Result<()> {
202 let _chunk_size = (self.data.len() / num_threads()).max(100);
203
204 let mut all_edges = Vec::new();
206 for idx in 0..self.data.len() {
207 let neighbors = self.find_neighbors(idx, self.config.num_neighbors)?;
208 let node = self.node_map[&idx];
209
210 for (neighbor_idx, distance) in neighbors {
211 let neighbor_node = self.node_map[&neighbor_idx];
212 all_edges.push((node, neighbor_node, distance));
213 }
214 }
215
216 for (from, to, weight) in all_edges {
218 if !self.graph.contains_edge(from, to) {
219 self.graph.add_edge(from, to, weight);
220 }
221 }
222
223 Ok(())
224 }
225
226 fn find_neighbors(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
227 let query = &self.data[idx].1.as_f32();
228 let mut heap = BinaryHeap::new();
229
230 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
231 if other_idx == idx {
232 continue;
233 }
234
235 let other = vector.as_f32();
236 let distance = self.config.distance_metric.distance(query, &other);
237
238 if heap.len() < k {
239 heap.push(SearchResult {
240 index: other_idx,
241 distance,
242 });
243 } else if distance < heap.peek().unwrap().distance {
244 heap.pop();
245 heap.push(SearchResult {
246 index: other_idx,
247 distance,
248 });
249 }
250 }
251
252 Ok(heap.into_iter().map(|r| (r.index, r.distance)).collect())
253 }
254
255 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
257 if self.entry_points.is_empty() {
258 return Vec::new();
259 }
260
261 let mut visited = HashSet::new();
262 let mut candidates = BinaryHeap::new();
263 let mut results: BinaryHeap<SearchResult> = BinaryHeap::new();
264
265 for &entry in &self.entry_points {
267 let idx = self.graph[entry];
268 let distance = self
269 .config
270 .distance_metric
271 .distance(query, &self.data[idx].1.as_f32());
272 candidates.push(std::cmp::Reverse(SearchResult {
273 index: idx,
274 distance,
275 }));
276 visited.insert(idx);
277 }
278
279 let max_candidates = (k as f32 * self.config.search_expansion) as usize;
281
282 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
283 if results.len() >= k && current.distance > results.peek().unwrap().distance {
285 break;
286 }
287
288 if results.len() < k {
290 results.push(current.clone());
291 } else if current.distance < results.peek().unwrap().distance {
292 results.pop();
293 results.push(current.clone());
294 }
295
296 let node = self.node_map[¤t.index];
298 for neighbor in self.graph.neighbors(node) {
299 let neighbor_idx = self.graph[neighbor];
300
301 if visited.contains(&neighbor_idx) {
302 continue;
303 }
304
305 visited.insert(neighbor_idx);
306 let distance = self
307 .config
308 .distance_metric
309 .distance(query, &self.data[neighbor_idx].1.as_f32());
310
311 if candidates.len() < max_candidates
312 || distance < candidates.peek().unwrap().0.distance
313 {
314 candidates.push(std::cmp::Reverse(SearchResult {
315 index: neighbor_idx,
316 distance,
317 }));
318 }
319 }
320 }
321
322 let mut results: Vec<(usize, f32)> =
323 results.into_iter().map(|r| (r.index, r.distance)).collect();
324
325 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
326 results
327 }
328}
329
330pub struct ONNGGraph {
332 adjacency: Vec<Vec<(usize, f32)>>,
334 data: Vec<(String, Vector)>,
336 config: GraphIndexConfig,
338}
339
340impl ONNGGraph {
341 pub fn new(config: GraphIndexConfig) -> Self {
342 Self {
343 adjacency: Vec::new(),
344 data: Vec::new(),
345 config,
346 }
347 }
348
349 pub fn build(&mut self) -> Result<()> {
350 if self.data.is_empty() {
351 return Ok(());
352 }
353
354 self.adjacency = vec![Vec::new(); self.data.len()];
356
357 self.build_knn_graph()?;
359
360 self.optimize_graph()?;
362
363 Ok(())
364 }
365
366 fn build_knn_graph(&mut self) -> Result<()> {
367 for idx in 0..self.data.len() {
368 let neighbors = self.find_k_nearest(idx, self.config.num_neighbors)?;
369 self.adjacency[idx] = neighbors;
370 }
371
372 Ok(())
373 }
374
375 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
376 let query = &self.data[idx].1.as_f32();
377 let mut neighbors = Vec::new();
378
379 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
380 if other_idx == idx {
381 continue;
382 }
383
384 let distance = self
385 .config
386 .distance_metric
387 .distance(query, &vector.as_f32());
388 neighbors.push((other_idx, distance));
389 }
390
391 neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
392 neighbors.truncate(k);
393
394 Ok(neighbors)
395 }
396
397 fn optimize_graph(&mut self) -> Result<()> {
398 let mut reverse_edges = vec![Vec::new(); self.data.len()];
400
401 for (idx, neighbors) in self.adjacency.iter().enumerate() {
402 for &(neighbor_idx, distance) in neighbors {
403 reverse_edges[neighbor_idx].push((idx, distance));
404 }
405 }
406
407 for (idx, reverse) in reverse_edges.into_iter().enumerate() {
409 let mut all_neighbors = self.adjacency[idx].clone();
410 all_neighbors.extend(reverse);
411
412 all_neighbors.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
414 all_neighbors.dedup_by_key(|&mut (idx, _)| idx);
415 all_neighbors.truncate(self.config.num_neighbors);
416
417 self.adjacency[idx] = all_neighbors;
418 }
419
420 Ok(())
421 }
422
423 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
424 if self.data.is_empty() {
425 return Vec::new();
426 }
427
428 let start_points = self.select_start_points();
430 let mut visited = HashSet::new();
431 let mut heap = BinaryHeap::new();
432
433 for start in start_points {
435 let distance = self
436 .config
437 .distance_metric
438 .distance(query, &self.data[start].1.as_f32());
439 heap.push(std::cmp::Reverse(SearchResult {
440 index: start,
441 distance,
442 }));
443 visited.insert(start);
444 }
445
446 let mut results = Vec::new();
447
448 while let Some(std::cmp::Reverse(current)) = heap.pop() {
449 results.push((current.index, current.distance));
450
451 if results.len() >= k {
452 break;
453 }
454
455 for &(neighbor_idx, _) in &self.adjacency[current.index] {
457 if visited.contains(&neighbor_idx) {
458 continue;
459 }
460
461 visited.insert(neighbor_idx);
462 let distance = self
463 .config
464 .distance_metric
465 .distance(query, &self.data[neighbor_idx].1.as_f32());
466 heap.push(std::cmp::Reverse(SearchResult {
467 index: neighbor_idx,
468 distance,
469 }));
470 }
471 }
472
473 results.truncate(k);
474 results
475 }
476
477 fn select_start_points(&self) -> Vec<usize> {
478 let num_points = (self.data.len() as f32).sqrt() as usize;
480 let mut indices: Vec<usize> = (0..self.data.len()).collect();
481
482 let mut rng = if let Some(seed) = self.config.random_seed {
483 Random::seed(seed)
484 } else {
485 Random::seed(42)
486 };
487
488 for i in (1..indices.len()).rev() {
491 let j = rng.random_range(0, i + 1);
492 indices.swap(i, j);
493 }
494 indices.truncate(num_points.max(1));
495
496 indices
497 }
498}
499
500pub struct PANNGGraph {
502 adjacency: Vec<Vec<(usize, f32)>>,
504 data: Vec<(String, Vector)>,
506 config: GraphIndexConfig,
508 pruning_threshold: f32,
510}
511
512impl PANNGGraph {
513 pub fn new(config: GraphIndexConfig) -> Self {
514 Self {
515 adjacency: Vec::new(),
516 data: Vec::new(),
517 config,
518 pruning_threshold: 0.9, }
520 }
521
522 pub fn build(&mut self) -> Result<()> {
523 if self.data.is_empty() {
524 return Ok(());
525 }
526
527 self.adjacency = vec![Vec::new(); self.data.len()];
529 self.build_initial_graph()?;
530
531 if self.config.enable_pruning {
533 self.prune_graph()?;
534 }
535
536 Ok(())
537 }
538
539 fn build_initial_graph(&mut self) -> Result<()> {
540 let initial_neighbors = self.config.num_neighbors * 2;
542
543 for idx in 0..self.data.len() {
544 let neighbors = self.find_k_nearest(idx, initial_neighbors)?;
545 self.adjacency[idx] = neighbors;
546 }
547
548 Ok(())
549 }
550
551 fn find_k_nearest(&self, idx: usize, k: usize) -> Result<Vec<(usize, f32)>> {
552 let query = &self.data[idx].1.as_f32();
553 let mut heap = BinaryHeap::new();
554
555 for (other_idx, (_, vector)) in self.data.iter().enumerate() {
556 if other_idx == idx {
557 continue;
558 }
559
560 let distance = self
561 .config
562 .distance_metric
563 .distance(query, &vector.as_f32());
564
565 if heap.len() < k {
566 heap.push(SearchResult {
567 index: other_idx,
568 distance,
569 });
570 } else if distance < heap.peek().unwrap().distance {
571 heap.pop();
572 heap.push(SearchResult {
573 index: other_idx,
574 distance,
575 });
576 }
577 }
578
579 Ok(heap
580 .into_sorted_vec()
581 .into_iter()
582 .map(|r| (r.index, r.distance))
583 .collect())
584 }
585
586 fn prune_graph(&mut self) -> Result<()> {
587 for idx in 0..self.data.len() {
588 let pruned = self.prune_neighbors(idx)?;
589 self.adjacency[idx] = pruned;
590 }
591
592 Ok(())
593 }
594
595 fn prune_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
596 let neighbors = &self.adjacency[idx];
597 if neighbors.len() <= self.config.num_neighbors {
598 return Ok(neighbors.clone());
599 }
600
601 let mut pruned = Vec::new();
602 let (_, vector) = &self.data[idx];
603 let query = vector.as_f32();
604
605 for &(neighbor_idx, distance) in neighbors {
606 let (_, vector) = &self.data[neighbor_idx];
607 let neighbor = vector.as_f32();
608 let mut keep = true;
609
610 for &(selected_idx, _) in &pruned {
612 let (_id, vector): &(String, Vector) = &self.data[selected_idx];
613 let selected = vector.as_f32();
614
615 let angle = self.calculate_angle(&query, &neighbor, &selected);
617
618 if angle < self.pruning_threshold {
619 keep = false;
620 break;
621 }
622 }
623
624 if keep {
625 pruned.push((neighbor_idx, distance));
626
627 if pruned.len() >= self.config.num_neighbors {
628 break;
629 }
630 }
631 }
632
633 Ok(pruned)
634 }
635
636 fn calculate_angle(&self, origin: &[f32], a: &[f32], b: &[f32]) -> f32 {
637 let va: Vec<f32> = a
639 .iter()
640 .zip(origin.iter())
641 .map(|(ai, oi)| ai - oi)
642 .collect();
643 let vb: Vec<f32> = b
644 .iter()
645 .zip(origin.iter())
646 .map(|(bi, oi)| bi - oi)
647 .collect();
648
649 let dot = f32::dot(&va, &vb);
651 let norm_a = f32::norm(&va);
652 let norm_b = f32::norm(&vb);
653
654 if norm_a == 0.0 || norm_b == 0.0 {
655 return 0.0;
656 }
657
658 (dot / (norm_a * norm_b)).clamp(-1.0, 1.0).acos()
659 }
660
661 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
662 if self.data.is_empty() {
663 return Vec::new();
664 }
665
666 let mut visited = HashSet::new();
667 let mut candidates = VecDeque::new();
668 let mut results = Vec::new();
669
670 let start = self.find_closest_point(query);
672 candidates.push_back(start);
673 visited.insert(start);
674
675 while let Some(current) = candidates.pop_front() {
676 let distance = self
677 .config
678 .distance_metric
679 .distance(query, &self.data[current].1.as_f32());
680 results.push((current, distance));
681
682 for &(neighbor_idx, _) in &self.adjacency[current] {
684 if !visited.contains(&neighbor_idx) {
685 visited.insert(neighbor_idx);
686 candidates.push_back(neighbor_idx);
687 }
688 }
689
690 if results.len() >= k * 2 {
691 break;
692 }
693 }
694
695 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
696 results.truncate(k);
697 results
698 }
699
700 fn find_closest_point(&self, query: &[f32]) -> usize {
701 let mut min_dist = f32::INFINITY;
702 let mut closest = 0;
703
704 let sample_size = (self.data.len() as f32).sqrt() as usize;
706 let step = self.data.len() / sample_size.max(1);
707
708 for idx in (0..self.data.len()).step_by(step.max(1)) {
709 let distance = self
710 .config
711 .distance_metric
712 .distance(query, &self.data[idx].1.as_f32());
713 if distance < min_dist {
714 min_dist = distance;
715 closest = idx;
716 }
717 }
718
719 closest
720 }
721}
722
723pub struct DelaunayGraph {
725 edges: Vec<Vec<(usize, f32)>>,
727 data: Vec<(String, Vector)>,
729 config: GraphIndexConfig,
731}
732
733impl DelaunayGraph {
734 pub fn new(config: GraphIndexConfig) -> Self {
735 Self {
736 edges: Vec::new(),
737 data: Vec::new(),
738 config,
739 }
740 }
741
742 pub fn build(&mut self) -> Result<()> {
743 if self.data.is_empty() {
744 return Ok(());
745 }
746
747 self.edges = vec![Vec::new(); self.data.len()];
748
749 for idx in 0..self.data.len() {
751 let neighbors = self.find_delaunay_neighbors(idx)?;
752 self.edges[idx] = neighbors;
753 }
754
755 self.symmetrize_edges();
757
758 Ok(())
759 }
760
761 fn find_delaunay_neighbors(&self, idx: usize) -> Result<Vec<(usize, f32)>> {
762 let point = &self.data[idx].1.as_f32();
763 let mut candidates = Vec::new();
764
765 for (other_idx, (_, other_vec)) in self.data.iter().enumerate() {
767 if other_idx == idx {
768 continue;
769 }
770
771 let other = other_vec.as_f32();
772 let distance = self.config.distance_metric.distance(point, &other);
773 candidates.push((other_idx, distance));
774 }
775
776 candidates.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
777
778 let mut neighbors = Vec::new();
780
781 for &(candidate_idx, distance) in &candidates {
782 if neighbors.len() >= self.config.num_neighbors {
783 break;
784 }
785
786 let candidate = &self.data[candidate_idx].1.as_f32();
787 let mut is_neighbor = true;
788
789 for &(neighbor_idx, _) in &neighbors {
791 let (_id, vector): &(String, Vector) = &self.data[neighbor_idx];
792 let neighbor = vector.as_f32();
793
794 let dist_to_neighbor = self.config.distance_metric.distance(candidate, &neighbor);
796 if dist_to_neighbor < distance * 0.9 {
797 is_neighbor = false;
798 break;
799 }
800 }
801
802 if is_neighbor {
803 neighbors.push((candidate_idx, distance));
804 }
805 }
806
807 Ok(neighbors)
808 }
809
810 fn symmetrize_edges(&mut self) {
811 let mut symmetric_edges = vec![Vec::new(); self.data.len()];
812
813 for (idx, neighbors) in self.edges.iter().enumerate() {
815 for &(neighbor_idx, distance) in neighbors {
816 symmetric_edges[idx].push((neighbor_idx, distance));
817 symmetric_edges[neighbor_idx].push((idx, distance));
818 }
819 }
820
821 for edges in &mut symmetric_edges {
823 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
824 edges.dedup_by_key(|&mut (idx, _)| idx);
825 edges.truncate(self.config.num_neighbors);
826 }
827
828 self.edges = symmetric_edges;
829 }
830
831 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
832 if self.data.is_empty() {
833 return Vec::new();
834 }
835
836 let mut visited = HashSet::new();
837 let mut heap = BinaryHeap::new();
838 let mut results = Vec::new();
839
840 let start = 0;
842 let distance = self
843 .config
844 .distance_metric
845 .distance(query, &self.data[start].1.as_f32());
846 heap.push(std::cmp::Reverse(SearchResult {
847 index: start,
848 distance,
849 }));
850 visited.insert(start);
851
852 while let Some(std::cmp::Reverse(current)) = heap.pop() {
853 results.push((current.index, current.distance));
854
855 if results.len() >= k {
856 break;
857 }
858
859 for &(neighbor_idx, _) in &self.edges[current.index] {
861 if !visited.contains(&neighbor_idx) {
862 visited.insert(neighbor_idx);
863 let distance = self
864 .config
865 .distance_metric
866 .distance(query, &self.data[neighbor_idx].1.as_f32());
867 heap.push(std::cmp::Reverse(SearchResult {
868 index: neighbor_idx,
869 distance,
870 }));
871 }
872 }
873 }
874
875 results
876 }
877}
878
879pub struct RNGGraph {
881 edges: Vec<Vec<(usize, f32)>>,
883 data: Vec<(String, Vector)>,
885 config: GraphIndexConfig,
887}
888
889impl RNGGraph {
890 pub fn new(config: GraphIndexConfig) -> Self {
891 Self {
892 edges: Vec::new(),
893 data: Vec::new(),
894 config,
895 }
896 }
897
898 pub fn build(&mut self) -> Result<()> {
899 if self.data.is_empty() {
900 return Ok(());
901 }
902
903 self.edges = vec![Vec::new(); self.data.len()];
904
905 for i in 0..self.data.len() {
907 for j in i + 1..self.data.len() {
908 if self.is_rng_edge(i, j)? {
909 let distance = self
910 .config
911 .distance_metric
912 .distance(&self.data[i].1.as_f32(), &self.data[j].1.as_f32());
913
914 self.edges[i].push((j, distance));
915 self.edges[j].push((i, distance));
916 }
917 }
918 }
919
920 for edges in &mut self.edges {
922 edges.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
923 }
924
925 Ok(())
926 }
927
928 fn is_rng_edge(&self, i: usize, j: usize) -> Result<bool> {
929 let pi = &self.data[i].1.as_f32();
930 let pj = &self.data[j].1.as_f32();
931 let dist_ij = self.config.distance_metric.distance(pi, pj);
932
933 for k in 0..self.data.len() {
936 if k == i || k == j {
937 continue;
938 }
939
940 let pk = &self.data[k].1.as_f32();
941 let dist_ik = self.config.distance_metric.distance(pi, pk);
942 let dist_jk = self.config.distance_metric.distance(pj, pk);
943
944 if dist_ik.max(dist_jk) < dist_ij {
945 return Ok(false);
946 }
947 }
948
949 Ok(true)
950 }
951
952 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
953 if self.data.is_empty() {
954 return Vec::new();
955 }
956
957 let mut visited = HashSet::new();
958 let mut candidates = BinaryHeap::new();
959 let mut results = Vec::new();
960
961 let start = self.find_start_point(query);
963 let distance = self
964 .config
965 .distance_metric
966 .distance(query, &self.data[start].1.as_f32());
967 candidates.push(std::cmp::Reverse(SearchResult {
968 index: start,
969 distance,
970 }));
971 visited.insert(start);
972
973 while let Some(std::cmp::Reverse(current)) = candidates.pop() {
974 results.push((current.index, current.distance));
975
976 if results.len() >= k {
977 break;
978 }
979
980 for &(neighbor_idx, _) in &self.edges[current.index] {
982 if !visited.contains(&neighbor_idx) {
983 visited.insert(neighbor_idx);
984 let distance = self
985 .config
986 .distance_metric
987 .distance(query, &self.data[neighbor_idx].1.as_f32());
988 candidates.push(std::cmp::Reverse(SearchResult {
989 index: neighbor_idx,
990 distance,
991 }));
992 }
993 }
994 }
995
996 results
997 }
998
999 fn find_start_point(&self, query: &[f32]) -> usize {
1000 let sample_size = (self.data.len() as f32).sqrt() as usize;
1002 let mut min_dist = f32::INFINITY;
1003 let mut best = 0;
1004
1005 for i in 0..sample_size.min(self.data.len()) {
1006 let idx = (i * self.data.len()) / sample_size;
1007 let distance = self
1008 .config
1009 .distance_metric
1010 .distance(query, &self.data[idx].1.as_f32());
1011
1012 if distance < min_dist {
1013 min_dist = distance;
1014 best = idx;
1015 }
1016 }
1017
1018 best
1019 }
1020}
1021
1022pub struct GraphIndex {
1024 graph_type: GraphType,
1025 nsw: Option<NSWGraph>,
1026 onng: Option<ONNGGraph>,
1027 panng: Option<PANNGGraph>,
1028 delaunay: Option<DelaunayGraph>,
1029 rng: Option<RNGGraph>,
1030}
1031
1032impl GraphIndex {
1033 pub fn new(config: GraphIndexConfig) -> Self {
1034 let graph_type = config.graph_type;
1035
1036 let (nsw, onng, panng, delaunay, rng) = match graph_type {
1037 GraphType::NSW => (Some(NSWGraph::new(config)), None, None, None, None),
1038 GraphType::ONNG => (None, Some(ONNGGraph::new(config)), None, None, None),
1039 GraphType::PANNG => (None, None, Some(PANNGGraph::new(config)), None, None),
1040 GraphType::Delaunay => (None, None, None, Some(DelaunayGraph::new(config)), None),
1041 GraphType::RNG => (None, None, None, None, Some(RNGGraph::new(config))),
1042 };
1043
1044 Self {
1045 graph_type,
1046 nsw,
1047 onng,
1048 panng,
1049 delaunay,
1050 rng,
1051 }
1052 }
1053
1054 fn build(&mut self) -> Result<()> {
1055 match self.graph_type {
1056 GraphType::NSW => self.nsw.as_mut().unwrap().build(),
1057 GraphType::ONNG => self.onng.as_mut().unwrap().build(),
1058 GraphType::PANNG => self.panng.as_mut().unwrap().build(),
1059 GraphType::Delaunay => self.delaunay.as_mut().unwrap().build(),
1060 GraphType::RNG => self.rng.as_mut().unwrap().build(),
1061 }
1062 }
1063
1064 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1065 match self.graph_type {
1066 GraphType::NSW => self.nsw.as_ref().unwrap().search(query, k),
1067 GraphType::ONNG => self.onng.as_ref().unwrap().search(query, k),
1068 GraphType::PANNG => self.panng.as_ref().unwrap().search(query, k),
1069 GraphType::Delaunay => self.delaunay.as_ref().unwrap().search(query, k),
1070 GraphType::RNG => self.rng.as_ref().unwrap().search(query, k),
1071 }
1072 }
1073}
1074
1075impl VectorIndex for GraphIndex {
1076 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1077 let data = match self.graph_type {
1078 GraphType::NSW => &mut self.nsw.as_mut().unwrap().data,
1079 GraphType::ONNG => &mut self.onng.as_mut().unwrap().data,
1080 GraphType::PANNG => &mut self.panng.as_mut().unwrap().data,
1081 GraphType::Delaunay => &mut self.delaunay.as_mut().unwrap().data,
1082 GraphType::RNG => &mut self.rng.as_mut().unwrap().data,
1083 };
1084
1085 data.push((uri, vector));
1086 Ok(())
1087 }
1088
1089 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1090 let query_f32 = query.as_f32();
1091 let results = self.search_internal(&query_f32, k);
1092
1093 let data = match self.graph_type {
1094 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1095 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1096 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1097 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1098 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1099 };
1100
1101 Ok(results
1102 .into_iter()
1103 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1104 .collect())
1105 }
1106
1107 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1108 let query_f32 = query.as_f32();
1109 let all_results = self.search_internal(&query_f32, 1000);
1110
1111 let data = match self.graph_type {
1112 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1113 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1114 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1115 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1116 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1117 };
1118
1119 Ok(all_results
1120 .into_iter()
1121 .filter(|(_, dist)| *dist <= threshold)
1122 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1123 .collect())
1124 }
1125
1126 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1127 let data = match self.graph_type {
1128 GraphType::NSW => &self.nsw.as_ref().unwrap().data,
1129 GraphType::ONNG => &self.onng.as_ref().unwrap().data,
1130 GraphType::PANNG => &self.panng.as_ref().unwrap().data,
1131 GraphType::Delaunay => &self.delaunay.as_ref().unwrap().data,
1132 GraphType::RNG => &self.rng.as_ref().unwrap().data,
1133 };
1134
1135 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1136 }
1137}
1138
1139use petgraph;
1141#[cfg(test)]
1144mod tests {
1145 use super::*;
1146
1147 #[test]
1148 fn test_nsw_graph() {
1149 let config = GraphIndexConfig {
1150 graph_type: GraphType::NSW,
1151 num_neighbors: 10,
1152 ..Default::default()
1153 };
1154
1155 let mut index = GraphIndex::new(config);
1156
1157 for i in 0..50 {
1159 let vector = Vector::new(vec![i as f32, (i * 2) as f32, (i * 3) as f32]);
1160 index.insert(format!("vec_{i}"), vector).unwrap();
1161 }
1162
1163 index.build().unwrap();
1164
1165 let query = Vector::new(vec![25.0, 50.0, 75.0]);
1167 let results = index.search_knn(&query, 5).unwrap();
1168
1169 assert_eq!(results.len(), 5);
1170 assert_eq!(results[0].0, "vec_25"); }
1172
1173 #[test]
1174 fn test_onng_graph() {
1175 let config = GraphIndexConfig {
1176 graph_type: GraphType::ONNG,
1177 num_neighbors: 8,
1178 ..Default::default()
1179 };
1180
1181 let mut index = GraphIndex::new(config);
1182
1183 for i in 0..20 {
1185 let angle = (i as f32) * 2.0 * std::f32::consts::PI / 20.0;
1186 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1187 index.insert(format!("vec_{i}"), vector).unwrap();
1188 }
1189
1190 index.build().unwrap();
1191
1192 let query = Vector::new(vec![1.0, 0.0]);
1194 let results = index.search_knn(&query, 3).unwrap();
1195
1196 assert_eq!(results.len(), 3);
1197 }
1198
1199 #[test]
1200 fn test_panng_graph() {
1201 let config = GraphIndexConfig {
1202 graph_type: GraphType::PANNG,
1203 num_neighbors: 5,
1204 enable_pruning: true,
1205 ..Default::default()
1206 };
1207
1208 let mut index = GraphIndex::new(config);
1209
1210 for i in 0..30 {
1212 let vector = Vector::new(vec![(i as f32).sin(), (i as f32).cos(), (i as f32) / 10.0]);
1213 index.insert(format!("vec_{i}"), vector).unwrap();
1214 }
1215
1216 index.build().unwrap();
1217
1218 let query = Vector::new(vec![0.0, 1.0, 0.0]);
1220 let results = index.search_knn(&query, 5).unwrap();
1221
1222 assert_eq!(results.len(), 5);
1223 }
1224}