1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::simd::SimdOps;
14use scirs2_core::random::{Random, Rng};
15use std::cmp::Ordering;
16use std::collections::BinaryHeap;
17
18#[derive(Debug, Clone)]
20pub struct TreeIndexConfig {
21 pub tree_type: TreeType,
23 pub max_leaf_size: usize,
25 pub random_seed: Option<u64>,
27 pub parallel_construction: bool,
29 pub distance_metric: DistanceMetric,
31}
32
33impl Default for TreeIndexConfig {
34 fn default() -> Self {
35 Self {
36 tree_type: TreeType::BallTree,
37 max_leaf_size: 16, random_seed: None,
39 parallel_construction: true,
40 distance_metric: DistanceMetric::Euclidean,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum TreeType {
48 BallTree,
49 KdTree,
50 VpTree,
51 CoverTree,
52 RandomProjectionTree,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum DistanceMetric {
58 Euclidean,
59 Manhattan,
60 Cosine,
61 Minkowski(f32),
62}
63
64impl DistanceMetric {
65 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
66 match self {
67 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
68 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
69 DistanceMetric::Cosine => f32::cosine_distance(a, b),
70 DistanceMetric::Minkowski(p) => a
71 .iter()
72 .zip(b.iter())
73 .map(|(x, y)| (x - y).abs().powf(*p))
74 .sum::<f32>()
75 .powf(1.0 / p),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82struct SearchResult {
83 index: usize,
84 distance: f32,
85}
86
87impl PartialEq for SearchResult {
88 fn eq(&self, other: &Self) -> bool {
89 self.distance == other.distance
90 }
91}
92
93impl Eq for SearchResult {}
94
95impl PartialOrd for SearchResult {
96 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
97 Some(self.cmp(other))
98 }
99}
100
101impl Ord for SearchResult {
102 fn cmp(&self, other: &Self) -> Ordering {
103 self.partial_cmp(other).unwrap_or(Ordering::Equal)
104 }
105}
106
107pub struct BallTree {
109 root: Option<Box<BallNode>>,
110 data: Vec<(String, Vector)>,
111 config: TreeIndexConfig,
112}
113
114struct BallNode {
115 center: Vec<f32>,
117 radius: f32,
119 left: Option<Box<BallNode>>,
121 right: Option<Box<BallNode>>,
123 indices: Vec<usize>,
125}
126
127impl BallTree {
128 pub fn new(config: TreeIndexConfig) -> Self {
129 Self {
130 root: None,
131 data: Vec::new(),
132 config,
133 }
134 }
135
136 pub fn build(&mut self) -> Result<()> {
138 if self.data.is_empty() {
139 return Ok(());
140 }
141
142 let indices: Vec<usize> = (0..self.data.len()).collect();
143 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
144
145 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
146 Ok(())
147 }
148
149 fn build_node_safe(
150 &self,
151 points: &[Vec<f32>],
152 indices: Vec<usize>,
153 depth: usize,
154 ) -> Result<BallNode> {
155 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
157 let center = self.compute_centroid(points, &indices);
159 let radius = self.compute_radius(points, &indices, ¢er);
160
161 return Ok(BallNode {
162 center,
163 radius,
164 left: None,
165 right: None,
166 indices,
167 });
168 }
169
170 let split_dim = self.find_split_dimension(points, &indices);
172
173 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
175
176 if left_indices.is_empty() || right_indices.is_empty() {
178 let center = self.compute_centroid(points, &indices);
180 let radius = self.compute_radius(points, &indices, ¢er);
181 return Ok(BallNode {
182 center,
183 radius,
184 left: None,
185 right: None,
186 indices,
187 });
188 }
189
190 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
192 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
193
194 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
196 let center = self.compute_centroid_of_centers(&all_centers);
197 let radius = left_node.radius.max(right_node.radius)
198 + self
199 .config
200 .distance_metric
201 .distance(¢er, &left_node.center);
202
203 Ok(BallNode {
204 center,
205 radius,
206 left: Some(Box::new(left_node)),
207 right: Some(Box::new(right_node)),
208 indices: Vec::new(),
209 })
210 }
211
212 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
213 let dim = points[0].len();
214 let mut centroid = vec![0.0; dim];
215
216 for &idx in indices {
217 for (i, &val) in points[idx].iter().enumerate() {
218 centroid[i] += val;
219 }
220 }
221
222 let n = indices.len() as f32;
223 for val in &mut centroid {
224 *val /= n;
225 }
226
227 centroid
228 }
229
230 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
231 indices
232 .iter()
233 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
234 .fold(0.0f32, f32::max)
235 }
236
237 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
238 let dim = points[0].len();
239 let mut max_spread = 0.0;
240 let mut split_dim = 0;
241
242 #[allow(clippy::needless_range_loop)]
244 for d in 0..dim {
245 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
246
247 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
248 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
249 let spread = max_val - min_val;
250
251 if spread > max_spread {
252 max_spread = spread;
253 split_dim = d;
254 }
255 }
256
257 split_dim
258 }
259
260 fn partition_indices(
261 &self,
262 points: &[Vec<f32>],
263 indices: &[usize],
264 dim: usize,
265 ) -> (Vec<usize>, Vec<usize>) {
266 let mut values: Vec<(f32, usize)> =
267 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
268
269 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
270
271 let mid = values.len() / 2;
272 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
273 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
274
275 (left_indices, right_indices)
276 }
277
278 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
279 let dim = centers[0].len();
280 let mut centroid = vec![0.0; dim];
281
282 for center in centers {
283 for (i, &val) in center.iter().enumerate() {
284 centroid[i] += val;
285 }
286 }
287
288 let n = centers.len() as f32;
289 for val in &mut centroid {
290 *val /= n;
291 }
292
293 centroid
294 }
295
296 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
298 if self.root.is_none() {
299 return Vec::new();
300 }
301
302 let mut heap = BinaryHeap::new();
303 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
304
305 let mut results: Vec<(usize, f32)> =
306 heap.into_iter().map(|r| (r.index, r.distance)).collect();
307
308 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
309 results
310 }
311
312 fn search_node(
313 &self,
314 node: &BallNode,
315 query: &[f32],
316 k: usize,
317 heap: &mut BinaryHeap<SearchResult>,
318 ) {
319 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
321
322 if heap.len() >= k {
323 let worst_dist = heap.peek().unwrap().distance;
324 if dist_to_center - node.radius > worst_dist {
325 return; }
327 }
328
329 if node.indices.is_empty() {
330 if let (Some(left), Some(right)) = (&node.left, &node.right) {
332 let left_dist = self.config.distance_metric.distance(query, &left.center);
333 let right_dist = self.config.distance_metric.distance(query, &right.center);
334
335 if left_dist < right_dist {
336 self.search_node(left, query, k, heap);
337 self.search_node(right, query, k, heap);
338 } else {
339 self.search_node(right, query, k, heap);
340 self.search_node(left, query, k, heap);
341 }
342 }
343 } else {
344 for &idx in &node.indices {
346 let point = &self.data[idx].1.as_f32();
347 let dist = self.config.distance_metric.distance(query, point);
348
349 if heap.len() < k {
350 heap.push(SearchResult {
351 index: idx,
352 distance: dist,
353 });
354 } else if dist < heap.peek().unwrap().distance {
355 heap.pop();
356 heap.push(SearchResult {
357 index: idx,
358 distance: dist,
359 });
360 }
361 }
362 }
363 }
364}
365
366pub struct KdTree {
368 root: Option<Box<KdNode>>,
369 data: Vec<(String, Vector)>,
370 config: TreeIndexConfig,
371}
372
373struct KdNode {
374 split_dim: usize,
376 split_value: f32,
378 left: Option<Box<KdNode>>,
380 right: Option<Box<KdNode>>,
382 indices: Vec<usize>,
384}
385
386impl KdTree {
387 pub fn new(config: TreeIndexConfig) -> Self {
388 Self {
389 root: None,
390 data: Vec::new(),
391 config,
392 }
393 }
394
395 pub fn build(&mut self) -> Result<()> {
396 if self.data.is_empty() {
397 return Ok(());
398 }
399
400 let indices: Vec<usize> = (0..self.data.len()).collect();
401 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
402
403 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
404 Ok(())
405 }
406
407 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
408 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
410 return Ok(KdNode {
411 split_dim: 0,
412 split_value: 0.0,
413 left: None,
414 right: None,
415 indices,
416 });
417 }
418
419 let dimensions = points[0].len();
420 let split_dim = depth % dimensions;
421
422 let mut values: Vec<(f32, usize)> = indices
424 .iter()
425 .map(|&idx| (points[idx][split_dim], idx))
426 .collect();
427
428 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
429
430 let median_idx = values.len() / 2;
431 let split_value = values[median_idx].0;
432
433 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
434
435 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
436
437 if left_indices.is_empty() || right_indices.is_empty() {
439 return Ok(KdNode {
440 split_dim: 0,
441 split_value: 0.0,
442 left: None,
443 right: None,
444 indices,
445 });
446 }
447
448 let left = Some(Box::new(self.build_node(
449 points,
450 left_indices,
451 depth + 1,
452 )?));
453
454 let right = Some(Box::new(self.build_node(
455 points,
456 right_indices,
457 depth + 1,
458 )?));
459
460 Ok(KdNode {
461 split_dim,
462 split_value,
463 left,
464 right,
465 indices: Vec::new(),
466 })
467 }
468
469 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
470 if self.root.is_none() {
471 return Vec::new();
472 }
473
474 let mut heap = BinaryHeap::new();
475 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
476
477 let mut results: Vec<(usize, f32)> =
478 heap.into_iter().map(|r| (r.index, r.distance)).collect();
479
480 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
481 results
482 }
483
484 fn search_node(
485 &self,
486 node: &KdNode,
487 query: &[f32],
488 k: usize,
489 heap: &mut BinaryHeap<SearchResult>,
490 ) {
491 if !node.indices.is_empty() {
492 for &idx in &node.indices {
494 let point = &self.data[idx].1.as_f32();
495 let dist = self.config.distance_metric.distance(query, point);
496
497 if heap.len() < k {
498 heap.push(SearchResult {
499 index: idx,
500 distance: dist,
501 });
502 } else if dist < heap.peek().unwrap().distance {
503 heap.pop();
504 heap.push(SearchResult {
505 index: idx,
506 distance: dist,
507 });
508 }
509 }
510 return;
511 }
512
513 let go_left = query[node.split_dim] <= node.split_value;
515
516 let (first, second) = if go_left {
517 (&node.left, &node.right)
518 } else {
519 (&node.right, &node.left)
520 };
521
522 if let Some(child) = first {
524 self.search_node(child, query, k, heap);
525 }
526
527 if heap.len() < k || {
529 let split_dist = (query[node.split_dim] - node.split_value).abs();
530 split_dist < heap.peek().unwrap().distance
531 } {
532 if let Some(child) = second {
533 self.search_node(child, query, k, heap);
534 }
535 }
536 }
537}
538
539pub struct VpTree {
541 root: Option<Box<VpNode>>,
542 data: Vec<(String, Vector)>,
543 config: TreeIndexConfig,
544}
545
546struct VpNode {
547 vantage_point: usize,
549 median_distance: f32,
551 inside: Option<Box<VpNode>>,
553 outside: Option<Box<VpNode>>,
555 indices: Vec<usize>,
557}
558
559impl VpTree {
560 pub fn new(config: TreeIndexConfig) -> Self {
561 Self {
562 root: None,
563 data: Vec::new(),
564 config,
565 }
566 }
567
568 pub fn build(&mut self) -> Result<()> {
569 if self.data.is_empty() {
570 return Ok(());
571 }
572
573 let indices: Vec<usize> = (0..self.data.len()).collect();
574 let mut rng = if let Some(seed) = self.config.random_seed {
575 Random::seed(seed)
576 } else {
577 Random::seed(42)
578 };
579
580 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
581 Ok(())
582 }
583
584 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
585 self.build_node_safe(indices, rng, 0)
586 }
587
588 #[allow(deprecated)]
589 fn build_node_safe<R: Rng>(
590 &self,
591 mut indices: Vec<usize>,
592 rng: &mut R,
593 depth: usize,
594 ) -> Result<VpNode> {
595 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
599 return Ok(VpNode {
600 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
601 median_distance: 0.0,
602 inside: None,
603 outside: None,
604 indices,
605 });
606 }
607
608 let vp_idx = indices.len() - 1;
610 for i in (1..indices.len()).rev() {
612 let j = rng.gen_range(0..=i);
613 indices.swap(i, j);
614 }
615 let vantage_point = indices[vp_idx];
616 indices.truncate(vp_idx);
617
618 let vp_data = &self.data[vantage_point].1.as_f32();
620 let mut distances: Vec<(f32, usize)> = indices
621 .iter()
622 .map(|&idx| {
623 let point = &self.data[idx].1.as_f32();
624 let dist = self.config.distance_metric.distance(vp_data, point);
625 (dist, idx)
626 })
627 .collect();
628
629 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
630
631 let median_idx = distances.len() / 2;
632 let median_distance = distances[median_idx].0;
633
634 let inside_indices: Vec<usize> = distances[..median_idx]
635 .iter()
636 .map(|(_, idx)| *idx)
637 .collect();
638
639 let outside_indices: Vec<usize> = distances[median_idx..]
640 .iter()
641 .map(|(_, idx)| *idx)
642 .collect();
643
644 if inside_indices.is_empty() || outside_indices.is_empty() {
646 return Ok(VpNode {
647 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
648 median_distance: 0.0,
649 inside: None,
650 outside: None,
651 indices,
652 });
653 }
654
655 let inside = Some(Box::new(self.build_node_safe(
656 inside_indices,
657 rng,
658 depth + 1,
659 )?));
660 let outside = Some(Box::new(self.build_node_safe(
661 outside_indices,
662 rng,
663 depth + 1,
664 )?));
665
666 Ok(VpNode {
667 vantage_point,
668 median_distance,
669 inside,
670 outside,
671 indices: Vec::new(),
672 })
673 }
674
675 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
676 if self.root.is_none() {
677 return Vec::new();
678 }
679
680 let mut heap = BinaryHeap::new();
681 self.search_node(
682 self.root.as_ref().unwrap(),
683 query,
684 k,
685 &mut heap,
686 f32::INFINITY,
687 );
688
689 let mut results: Vec<(usize, f32)> =
690 heap.into_iter().map(|r| (r.index, r.distance)).collect();
691
692 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
693 results
694 }
695
696 fn search_node(
697 &self,
698 node: &VpNode,
699 query: &[f32],
700 k: usize,
701 heap: &mut BinaryHeap<SearchResult>,
702 tau: f32,
703 ) -> f32 {
704 let mut tau = tau;
705
706 if !node.indices.is_empty() {
707 for &idx in &node.indices {
709 let point = &self.data[idx].1.as_f32();
710 let dist = self.config.distance_metric.distance(query, point);
711
712 if dist < tau {
713 if heap.len() < k {
714 heap.push(SearchResult {
715 index: idx,
716 distance: dist,
717 });
718 } else if dist < heap.peek().unwrap().distance {
719 heap.pop();
720 heap.push(SearchResult {
721 index: idx,
722 distance: dist,
723 });
724 }
725
726 if heap.len() >= k {
727 tau = heap.peek().unwrap().distance;
728 }
729 }
730 }
731 return tau;
732 }
733
734 let vp_data = &self.data[node.vantage_point].1.as_f32();
736 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
737
738 if dist_to_vp < tau {
740 if heap.len() < k {
741 heap.push(SearchResult {
742 index: node.vantage_point,
743 distance: dist_to_vp,
744 });
745 } else if dist_to_vp < heap.peek().unwrap().distance {
746 heap.pop();
747 heap.push(SearchResult {
748 index: node.vantage_point,
749 distance: dist_to_vp,
750 });
751 }
752
753 if heap.len() >= k {
754 tau = heap.peek().unwrap().distance;
755 }
756 }
757
758 if dist_to_vp < node.median_distance {
760 if let Some(inside) = &node.inside {
762 tau = self.search_node(inside, query, k, heap, tau);
763 }
764
765 if dist_to_vp + tau >= node.median_distance {
767 if let Some(outside) = &node.outside {
768 tau = self.search_node(outside, query, k, heap, tau);
769 }
770 }
771 } else {
772 if let Some(outside) = &node.outside {
774 tau = self.search_node(outside, query, k, heap, tau);
775 }
776
777 if dist_to_vp - tau <= node.median_distance {
779 if let Some(inside) = &node.inside {
780 tau = self.search_node(inside, query, k, heap, tau);
781 }
782 }
783 }
784
785 tau
786 }
787}
788
789pub struct CoverTree {
791 root: Option<Box<CoverNode>>,
792 data: Vec<(String, Vector)>,
793 config: TreeIndexConfig,
794 base: f32,
795}
796
797struct CoverNode {
798 point: usize,
800 level: i32,
802 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
805}
806
807impl CoverTree {
808 pub fn new(config: TreeIndexConfig) -> Self {
809 Self {
810 root: None,
811 data: Vec::new(),
812 config,
813 base: 2.0, }
815 }
816
817 pub fn build(&mut self) -> Result<()> {
818 if self.data.is_empty() {
819 return Ok(());
820 }
821
822 self.root = Some(Box::new(CoverNode {
824 point: 0,
825 level: self.get_level(0),
826 children: Vec::new(),
827 }));
828
829 for idx in 1..self.data.len() {
831 self.insert(idx)?;
832 }
833
834 Ok(())
835 }
836
837 fn get_level(&self, _point_idx: usize) -> i32 {
838 ((self.data.len() as f32).log2() as i32).max(0)
840 }
841
842 fn insert(&mut self, point_idx: usize) -> Result<()> {
843 let level = self.get_level(point_idx);
846 if let Some(root) = &mut self.root {
847 root.children.push(Box::new(CoverNode {
848 point: point_idx,
849 level,
850 children: Vec::new(),
851 }));
852 }
853 Ok(())
854 }
855
856 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
857 if self.root.is_none() {
858 return Vec::new();
859 }
860
861 let mut results = Vec::new();
862 self.search_node(self.root.as_ref().unwrap(), query, k, &mut results);
863
864 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
865 results.truncate(k);
866 results
867 }
868
869 #[allow(clippy::only_used_in_recursion)]
870 fn search_node(
871 &self,
872 node: &CoverNode,
873 query: &[f32],
874 k: usize,
875 results: &mut Vec<(usize, f32)>,
876 ) {
877 if results.len() >= k * 10 {
879 return;
880 }
881
882 let point_data = &self.data[node.point].1.as_f32();
883 let dist = self.config.distance_metric.distance(query, point_data);
884
885 results.push((node.point, dist));
886
887 for child in &node.children {
889 self.search_node(child, query, k, results);
890 }
891 }
892}
893
894pub struct RandomProjectionTree {
896 root: Option<Box<RpNode>>,
897 data: Vec<(String, Vector)>,
898 config: TreeIndexConfig,
899}
900
901struct RpNode {
902 projection: Vec<f32>,
904 threshold: f32,
906 left: Option<Box<RpNode>>,
908 right: Option<Box<RpNode>>,
910 indices: Vec<usize>,
912}
913
914impl RandomProjectionTree {
915 pub fn new(config: TreeIndexConfig) -> Self {
916 Self {
917 root: None,
918 data: Vec::new(),
919 config,
920 }
921 }
922
923 pub fn build(&mut self) -> Result<()> {
924 if self.data.is_empty() {
925 return Ok(());
926 }
927
928 let indices: Vec<usize> = (0..self.data.len()).collect();
929 let dimensions = self.data[0].1.dimensions;
930
931 let mut rng = if let Some(seed) = self.config.random_seed {
932 Random::seed(seed)
933 } else {
934 Random::seed(42)
935 };
936
937 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
938 Ok(())
939 }
940
941 fn build_node<R: Rng>(
942 &self,
943 indices: Vec<usize>,
944 dimensions: usize,
945 rng: &mut R,
946 ) -> Result<RpNode> {
947 self.build_node_safe(indices, dimensions, rng, 0)
948 }
949
950 #[allow(deprecated)]
951 fn build_node_safe<R: Rng>(
952 &self,
953 indices: Vec<usize>,
954 dimensions: usize,
955 rng: &mut R,
956 depth: usize,
957 ) -> Result<RpNode> {
958 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
960 return Ok(RpNode {
961 projection: Vec::new(),
962 threshold: 0.0,
963 left: None,
964 right: None,
965 indices,
966 });
967 }
968
969 let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
971
972 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
974 let projection: Vec<f32> = if norm > 0.0 {
975 projection.iter().map(|&x| x / norm).collect()
976 } else {
977 projection
978 };
979
980 let mut projections: Vec<(f32, usize)> = indices
982 .iter()
983 .map(|&idx| {
984 let point = &self.data[idx].1.as_f32();
985 let proj_val = f32::dot(point, &projection);
986 (proj_val, idx)
987 })
988 .collect();
989
990 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
991
992 let median_idx = projections.len() / 2;
994 let threshold = projections[median_idx].0;
995
996 let left_indices: Vec<usize> = projections[..median_idx]
997 .iter()
998 .map(|(_, idx)| *idx)
999 .collect();
1000
1001 let right_indices: Vec<usize> = projections[median_idx..]
1002 .iter()
1003 .map(|(_, idx)| *idx)
1004 .collect();
1005
1006 if left_indices.is_empty() || right_indices.is_empty() {
1008 return Ok(RpNode {
1009 projection: Vec::new(),
1010 threshold: 0.0,
1011 left: None,
1012 right: None,
1013 indices,
1014 });
1015 }
1016
1017 let left = Some(Box::new(self.build_node_safe(
1018 left_indices,
1019 dimensions,
1020 rng,
1021 depth + 1,
1022 )?));
1023 let right = Some(Box::new(self.build_node_safe(
1024 right_indices,
1025 dimensions,
1026 rng,
1027 depth + 1,
1028 )?));
1029
1030 Ok(RpNode {
1031 projection,
1032 threshold,
1033 left,
1034 right,
1035 indices: Vec::new(),
1036 })
1037 }
1038
1039 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1040 if self.root.is_none() {
1041 return Vec::new();
1042 }
1043
1044 let mut heap = BinaryHeap::new();
1045 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
1046
1047 let mut results: Vec<(usize, f32)> =
1048 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1049
1050 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1051 results
1052 }
1053
1054 fn search_node(
1055 &self,
1056 node: &RpNode,
1057 query: &[f32],
1058 k: usize,
1059 heap: &mut BinaryHeap<SearchResult>,
1060 ) {
1061 if !node.indices.is_empty() {
1062 for &idx in &node.indices {
1064 let point = &self.data[idx].1.as_f32();
1065 let dist = self.config.distance_metric.distance(query, point);
1066
1067 if heap.len() < k {
1068 heap.push(SearchResult {
1069 index: idx,
1070 distance: dist,
1071 });
1072 } else if dist < heap.peek().unwrap().distance {
1073 heap.pop();
1074 heap.push(SearchResult {
1075 index: idx,
1076 distance: dist,
1077 });
1078 }
1079 }
1080 return;
1081 }
1082
1083 let query_projection = f32::dot(query, &node.projection);
1085
1086 let go_left = query_projection <= node.threshold;
1088
1089 let (first, second) = if go_left {
1090 (&node.left, &node.right)
1091 } else {
1092 (&node.right, &node.left)
1093 };
1094
1095 if let Some(child) = first {
1097 self.search_node(child, query, k, heap);
1098 }
1099
1100 if let Some(child) = second {
1101 self.search_node(child, query, k, heap);
1102 }
1103 }
1104}
1105
1106pub struct TreeIndex {
1108 tree_type: TreeType,
1109 ball_tree: Option<BallTree>,
1110 kd_tree: Option<KdTree>,
1111 vp_tree: Option<VpTree>,
1112 cover_tree: Option<CoverTree>,
1113 rp_tree: Option<RandomProjectionTree>,
1114}
1115
1116impl TreeIndex {
1117 pub fn new(config: TreeIndexConfig) -> Self {
1118 let tree_type = config.tree_type;
1119
1120 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1121 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1122 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1123 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1124 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1125 TreeType::RandomProjectionTree => (
1126 None,
1127 None,
1128 None,
1129 None,
1130 Some(RandomProjectionTree::new(config)),
1131 ),
1132 };
1133
1134 Self {
1135 tree_type,
1136 ball_tree,
1137 kd_tree,
1138 vp_tree,
1139 cover_tree,
1140 rp_tree,
1141 }
1142 }
1143
1144 fn build(&mut self) -> Result<()> {
1145 match self.tree_type {
1146 TreeType::BallTree => self.ball_tree.as_mut().unwrap().build(),
1147 TreeType::KdTree => self.kd_tree.as_mut().unwrap().build(),
1148 TreeType::VpTree => self.vp_tree.as_mut().unwrap().build(),
1149 TreeType::CoverTree => self.cover_tree.as_mut().unwrap().build(),
1150 TreeType::RandomProjectionTree => self.rp_tree.as_mut().unwrap().build(),
1151 }
1152 }
1153
1154 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1155 match self.tree_type {
1156 TreeType::BallTree => self.ball_tree.as_ref().unwrap().search(query, k),
1157 TreeType::KdTree => self.kd_tree.as_ref().unwrap().search(query, k),
1158 TreeType::VpTree => self.vp_tree.as_ref().unwrap().search(query, k),
1159 TreeType::CoverTree => self.cover_tree.as_ref().unwrap().search(query, k),
1160 TreeType::RandomProjectionTree => self.rp_tree.as_ref().unwrap().search(query, k),
1161 }
1162 }
1163}
1164
1165impl VectorIndex for TreeIndex {
1166 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1167 let data = match self.tree_type {
1168 TreeType::BallTree => &mut self.ball_tree.as_mut().unwrap().data,
1169 TreeType::KdTree => &mut self.kd_tree.as_mut().unwrap().data,
1170 TreeType::VpTree => &mut self.vp_tree.as_mut().unwrap().data,
1171 TreeType::CoverTree => &mut self.cover_tree.as_mut().unwrap().data,
1172 TreeType::RandomProjectionTree => &mut self.rp_tree.as_mut().unwrap().data,
1173 };
1174
1175 data.push((uri, vector));
1176 Ok(())
1177 }
1178
1179 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1180 let query_f32 = query.as_f32();
1181 let results = self.search_internal(&query_f32, k);
1182
1183 let data = match self.tree_type {
1184 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1185 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1186 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1187 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1188 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1189 };
1190
1191 Ok(results
1192 .into_iter()
1193 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1194 .collect())
1195 }
1196
1197 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1198 let query_f32 = query.as_f32();
1199 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1202 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1203 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1204 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1205 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1206 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1207 };
1208
1209 Ok(all_results
1210 .into_iter()
1211 .filter(|(_, dist)| *dist <= threshold)
1212 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1213 .collect())
1214 }
1215
1216 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1217 let data = match self.tree_type {
1218 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1219 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1220 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1221 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1222 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1223 };
1224
1225 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1226 }
1227}
1228
1229async fn spawn_task<F, T>(f: F) -> T
1234where
1235 F: FnOnce() -> T + Send + 'static,
1236 T: Send + 'static,
1237{
1238 f()
1240}
1241
1242#[cfg(test)]
1243mod tests {
1244 use super::*;
1245
1246 #[test]
1247 #[ignore = "Stack overflow issue - being investigated"]
1248 fn test_ball_tree() {
1249 let config = TreeIndexConfig {
1250 tree_type: TreeType::BallTree,
1251 max_leaf_size: 50, ..Default::default()
1253 };
1254
1255 let mut index = TreeIndex::new(config);
1256
1257 for i in 0..3 {
1259 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1260 index.insert(format!("vec_{i}"), vector).unwrap();
1261 }
1262
1263 index.build().unwrap();
1264
1265 let query = Vector::new(vec![1.0, 2.0]);
1267 let results = index.search_knn(&query, 2).unwrap();
1268
1269 assert_eq!(results.len(), 2);
1270 assert_eq!(results[0].0, "vec_1"); }
1272
1273 #[test]
1274 #[ignore = "Stack overflow issue - being investigated"]
1275 fn test_kd_tree() {
1276 let config = TreeIndexConfig {
1277 tree_type: TreeType::KdTree,
1278 max_leaf_size: 50, ..Default::default()
1280 };
1281
1282 let mut index = TreeIndex::new(config);
1283
1284 for i in 0..3 {
1286 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1287 index.insert(format!("vec_{i}"), vector).unwrap();
1288 }
1289
1290 index.build().unwrap();
1291
1292 let query = Vector::new(vec![1.0, 2.0]);
1294 let results = index.search_knn(&query, 2).unwrap();
1295
1296 assert_eq!(results.len(), 2);
1297 }
1298
1299 #[test]
1300 #[ignore = "Stack overflow issue - being investigated"]
1301 fn test_vp_tree() {
1302 let config = TreeIndexConfig {
1303 tree_type: TreeType::VpTree,
1304 random_seed: Some(42),
1305 max_leaf_size: 50, ..Default::default()
1307 };
1308
1309 let mut index = TreeIndex::new(config);
1310
1311 for i in 0..3 {
1313 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1314 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1315 index.insert(format!("vec_{i}"), vector).unwrap();
1316 }
1317
1318 index.build().unwrap();
1319
1320 let query = Vector::new(vec![1.0, 0.0]);
1322 let results = index.search_knn(&query, 2).unwrap();
1323
1324 assert_eq!(results.len(), 2);
1325 }
1326}