1use crate::{Vector, VectorIndex};
26use anyhow::Result;
27use oxirs_core::simd::SimdOps;
28use scirs2_core::random::{Random, Rng};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32#[derive(Debug, Clone)]
34pub struct TreeIndexConfig {
35 pub tree_type: TreeType,
37 pub max_leaf_size: usize,
39 pub random_seed: Option<u64>,
41 pub parallel_construction: bool,
43 pub distance_metric: DistanceMetric,
45}
46
47impl Default for TreeIndexConfig {
48 fn default() -> Self {
49 Self {
50 tree_type: TreeType::BallTree,
51 max_leaf_size: 16, random_seed: None,
53 parallel_construction: true,
54 distance_metric: DistanceMetric::Euclidean,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61pub enum TreeType {
62 BallTree,
63 KdTree,
64 VpTree,
65 CoverTree,
66 RandomProjectionTree,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum DistanceMetric {
72 Euclidean,
73 Manhattan,
74 Cosine,
75 Minkowski(f32),
76}
77
78impl DistanceMetric {
79 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
80 match self {
81 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
82 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
83 DistanceMetric::Cosine => f32::cosine_distance(a, b),
84 DistanceMetric::Minkowski(p) => a
85 .iter()
86 .zip(b.iter())
87 .map(|(x, y)| (x - y).abs().powf(*p))
88 .sum::<f32>()
89 .powf(1.0 / p),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct SearchResult {
97 index: usize,
98 distance: f32,
99}
100
101impl PartialEq for SearchResult {
102 fn eq(&self, other: &Self) -> bool {
103 self.distance == other.distance
104 }
105}
106
107impl Eq for SearchResult {}
108
109impl PartialOrd for SearchResult {
110 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111 Some(self.cmp(other))
112 }
113}
114
115impl Ord for SearchResult {
116 fn cmp(&self, other: &Self) -> Ordering {
117 self.partial_cmp(other).unwrap_or(Ordering::Equal)
118 }
119}
120
121pub struct BallTree {
123 root: Option<Box<BallNode>>,
124 data: Vec<(String, Vector)>,
125 config: TreeIndexConfig,
126}
127
128#[derive(Clone)]
129struct BallNode {
130 center: Vec<f32>,
132 radius: f32,
134 left: Option<Box<BallNode>>,
136 right: Option<Box<BallNode>>,
138 indices: Vec<usize>,
140}
141
142impl BallTree {
143 pub fn new(config: TreeIndexConfig) -> Self {
144 Self {
145 root: None,
146 data: Vec::new(),
147 config,
148 }
149 }
150
151 pub fn build(&mut self) -> Result<()> {
156 if self.data.is_empty() {
157 return Ok(());
158 }
159
160 let indices: Vec<usize> = (0..self.data.len()).collect();
161 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
162
163 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
164 Ok(())
165 }
166
167 fn build_node_safe(
169 &self,
170 points: &[Vec<f32>],
171 indices: Vec<usize>,
172 depth: usize,
173 ) -> Result<BallNode> {
174 const MAX_DEPTH: usize = 20;
177
178 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
183 let center = self.compute_centroid(points, &indices);
184 let radius = self.compute_radius(points, &indices, ¢er);
185 return Ok(BallNode {
186 center,
187 radius,
188 left: None,
189 right: None,
190 indices,
191 });
192 }
193
194 let split_dim = self.find_split_dimension(points, &indices);
196 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
197
198 if left_indices.is_empty() || right_indices.is_empty() {
200 let center = self.compute_centroid(points, &indices);
201 let radius = self.compute_radius(points, &indices, ¢er);
202 return Ok(BallNode {
203 center,
204 radius,
205 left: None,
206 right: None,
207 indices,
208 });
209 }
210
211 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
213 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
214
215 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
217 let center = self.compute_centroid_of_centers(&all_centers);
218 let radius = left_node.radius.max(right_node.radius)
219 + self
220 .config
221 .distance_metric
222 .distance(¢er, &left_node.center);
223
224 Ok(BallNode {
225 center,
226 radius,
227 left: Some(Box::new(left_node)),
228 right: Some(Box::new(right_node)),
229 indices: Vec::new(),
230 })
231 }
232
233 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
234 let dim = points[0].len();
235 let mut centroid = vec![0.0; dim];
236
237 for &idx in indices {
238 for (i, &val) in points[idx].iter().enumerate() {
239 centroid[i] += val;
240 }
241 }
242
243 let n = indices.len() as f32;
244 for val in &mut centroid {
245 *val /= n;
246 }
247
248 centroid
249 }
250
251 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
252 indices
253 .iter()
254 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
255 .fold(0.0f32, f32::max)
256 }
257
258 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
259 let dim = points[0].len();
260 let mut max_spread = 0.0;
261 let mut split_dim = 0;
262
263 #[allow(clippy::needless_range_loop)]
265 for d in 0..dim {
266 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
267
268 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270 let spread = max_val - min_val;
271
272 if spread > max_spread {
273 max_spread = spread;
274 split_dim = d;
275 }
276 }
277
278 split_dim
279 }
280
281 fn partition_indices(
282 &self,
283 points: &[Vec<f32>],
284 indices: &[usize],
285 dim: usize,
286 ) -> (Vec<usize>, Vec<usize>) {
287 let mut values: Vec<(f32, usize)> =
288 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
289
290 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
291
292 let mid = values.len() / 2;
293 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
294 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
295
296 (left_indices, right_indices)
297 }
298
299 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
300 let dim = centers[0].len();
301 let mut centroid = vec![0.0; dim];
302
303 for center in centers {
304 for (i, &val) in center.iter().enumerate() {
305 centroid[i] += val;
306 }
307 }
308
309 let n = centers.len() as f32;
310 for val in &mut centroid {
311 *val /= n;
312 }
313
314 centroid
315 }
316
317 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
319 if self.root.is_none() {
320 return Vec::new();
321 }
322
323 let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
324 let mut stack: Vec<&BallNode> = vec![self.root.as_ref().unwrap()];
325
326 while let Some(node) = stack.pop() {
327 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
329
330 if heap.len() >= k {
331 let worst_dist = heap.peek().unwrap().distance;
332 if dist_to_center - node.radius > worst_dist {
333 continue; }
335 }
336
337 if node.indices.is_empty() {
338 if let (Some(left), Some(right)) = (&node.left, &node.right) {
340 let left_dist = self.config.distance_metric.distance(query, &left.center);
341 let right_dist = self.config.distance_metric.distance(query, &right.center);
342
343 if left_dist < right_dist {
345 stack.push(right);
346 stack.push(left);
347 } else {
348 stack.push(left);
349 stack.push(right);
350 }
351 }
352 } else {
353 for &idx in &node.indices {
355 let point = &self.data[idx].1.as_f32();
356 let dist = self.config.distance_metric.distance(query, point);
357
358 if heap.len() < k {
359 heap.push(SearchResult {
360 index: idx,
361 distance: dist,
362 });
363 } else if dist < heap.peek().unwrap().distance {
364 heap.pop();
365 heap.push(SearchResult {
366 index: idx,
367 distance: dist,
368 });
369 }
370 }
371 }
372 }
373
374 let mut results: Vec<(usize, f32)> =
375 heap.into_iter().map(|r| (r.index, r.distance)).collect();
376
377 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
378 results
379 }
380}
381
382pub struct KdTree {
384 root: Option<Box<KdNode>>,
385 data: Vec<(String, Vector)>,
386 config: TreeIndexConfig,
387}
388
389struct KdNode {
390 split_dim: usize,
392 split_value: f32,
394 left: Option<Box<KdNode>>,
396 right: Option<Box<KdNode>>,
398 indices: Vec<usize>,
400}
401
402impl KdTree {
403 pub fn new(config: TreeIndexConfig) -> Self {
404 Self {
405 root: None,
406 data: Vec::new(),
407 config,
408 }
409 }
410
411 pub fn build(&mut self) -> Result<()> {
412 if self.data.is_empty() {
413 return Ok(());
414 }
415
416 let indices: Vec<usize> = (0..self.data.len()).collect();
417 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
418
419 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
420 Ok(())
421 }
422
423 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
424 let max_depth = if !self.data.is_empty() {
426 ((self.data.len() as f32).log2() * 2.0) as usize + 10
427 } else {
428 50
429 };
430
431 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
432 return Ok(KdNode {
433 split_dim: 0,
434 split_value: 0.0,
435 left: None,
436 right: None,
437 indices,
438 });
439 }
440
441 let dimensions = points[0].len();
442 let split_dim = depth % dimensions;
443
444 let mut values: Vec<(f32, usize)> = indices
446 .iter()
447 .map(|&idx| (points[idx][split_dim], idx))
448 .collect();
449
450 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
451
452 let median_idx = values.len() / 2;
453 let split_value = values[median_idx].0;
454
455 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
456
457 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
458
459 if left_indices.is_empty() || right_indices.is_empty() {
461 return Ok(KdNode {
462 split_dim: 0,
463 split_value: 0.0,
464 left: None,
465 right: None,
466 indices,
467 });
468 }
469
470 let left = Some(Box::new(self.build_node(
471 points,
472 left_indices,
473 depth + 1,
474 )?));
475
476 let right = Some(Box::new(self.build_node(
477 points,
478 right_indices,
479 depth + 1,
480 )?));
481
482 Ok(KdNode {
483 split_dim,
484 split_value,
485 left,
486 right,
487 indices: Vec::new(),
488 })
489 }
490
491 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
492 if self.root.is_none() {
493 return Vec::new();
494 }
495
496 let mut heap = BinaryHeap::new();
497 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
498
499 let mut results: Vec<(usize, f32)> =
500 heap.into_iter().map(|r| (r.index, r.distance)).collect();
501
502 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
503 results
504 }
505
506 fn search_node(
507 &self,
508 node: &KdNode,
509 query: &[f32],
510 k: usize,
511 heap: &mut BinaryHeap<SearchResult>,
512 ) {
513 if !node.indices.is_empty() {
514 for &idx in &node.indices {
516 let point = &self.data[idx].1.as_f32();
517 let dist = self.config.distance_metric.distance(query, point);
518
519 if heap.len() < k {
520 heap.push(SearchResult {
521 index: idx,
522 distance: dist,
523 });
524 } else if dist < heap.peek().unwrap().distance {
525 heap.pop();
526 heap.push(SearchResult {
527 index: idx,
528 distance: dist,
529 });
530 }
531 }
532 return;
533 }
534
535 let go_left = query[node.split_dim] <= node.split_value;
537
538 let (first, second) = if go_left {
539 (&node.left, &node.right)
540 } else {
541 (&node.right, &node.left)
542 };
543
544 if let Some(child) = first {
546 self.search_node(child, query, k, heap);
547 }
548
549 if heap.len() < k || {
551 let split_dist = (query[node.split_dim] - node.split_value).abs();
552 split_dist < heap.peek().unwrap().distance
553 } {
554 if let Some(child) = second {
555 self.search_node(child, query, k, heap);
556 }
557 }
558 }
559}
560
561pub struct VpTree {
563 root: Option<Box<VpNode>>,
564 data: Vec<(String, Vector)>,
565 config: TreeIndexConfig,
566}
567
568struct VpNode {
569 vantage_point: usize,
571 median_distance: f32,
573 inside: Option<Box<VpNode>>,
575 outside: Option<Box<VpNode>>,
577 indices: Vec<usize>,
579}
580
581impl VpTree {
582 pub fn new(config: TreeIndexConfig) -> Self {
583 Self {
584 root: None,
585 data: Vec::new(),
586 config,
587 }
588 }
589
590 pub fn build(&mut self) -> Result<()> {
591 if self.data.is_empty() {
592 return Ok(());
593 }
594
595 let indices: Vec<usize> = (0..self.data.len()).collect();
596 let mut rng = if let Some(seed) = self.config.random_seed {
597 Random::seed(seed)
598 } else {
599 Random::seed(42)
600 };
601
602 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
603 Ok(())
604 }
605
606 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
607 self.build_node_safe(indices, rng, 0)
608 }
609
610 #[allow(deprecated)]
611 fn build_node_safe<R: Rng>(
612 &self,
613 mut indices: Vec<usize>,
614 rng: &mut R,
615 depth: usize,
616 ) -> Result<VpNode> {
617 let max_depth = 30; if indices.len() <= self.config.max_leaf_size
625 || indices.len() <= 2 || depth >= max_depth
627 {
628 return Ok(VpNode {
629 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
630 median_distance: 0.0,
631 inside: None,
632 outside: None,
633 indices,
634 });
635 }
636
637 let vp_idx = if indices.len() > 1 {
639 rng.gen_range(0..indices.len())
640 } else {
641 0
642 };
643 let vantage_point = indices[vp_idx];
644 indices.remove(vp_idx);
645
646 let vp_data = &self.data[vantage_point].1.as_f32();
648 let mut distances: Vec<(f32, usize)> = indices
649 .iter()
650 .map(|&idx| {
651 let point = &self.data[idx].1.as_f32();
652 let dist = self.config.distance_metric.distance(vp_data, point);
653 (dist, idx)
654 })
655 .collect();
656
657 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
658
659 let median_idx = distances.len() / 2;
660 let median_distance = distances[median_idx].0;
661
662 let inside_indices: Vec<usize> = distances[..median_idx]
663 .iter()
664 .map(|(_, idx)| *idx)
665 .collect();
666
667 let outside_indices: Vec<usize> = distances[median_idx..]
668 .iter()
669 .map(|(_, idx)| *idx)
670 .collect();
671
672 if inside_indices.is_empty() || outside_indices.is_empty() {
674 return Ok(VpNode {
675 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
676 median_distance: 0.0,
677 inside: None,
678 outside: None,
679 indices,
680 });
681 }
682
683 let inside = Some(Box::new(self.build_node_safe(
684 inside_indices,
685 rng,
686 depth + 1,
687 )?));
688 let outside = Some(Box::new(self.build_node_safe(
689 outside_indices,
690 rng,
691 depth + 1,
692 )?));
693
694 Ok(VpNode {
695 vantage_point,
696 median_distance,
697 inside,
698 outside,
699 indices: Vec::new(),
700 })
701 }
702
703 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
704 if self.root.is_none() {
705 return Vec::new();
706 }
707
708 let mut heap = BinaryHeap::new();
709 self.search_node(
710 self.root.as_ref().unwrap(),
711 query,
712 k,
713 &mut heap,
714 f32::INFINITY,
715 );
716
717 let mut results: Vec<(usize, f32)> =
718 heap.into_iter().map(|r| (r.index, r.distance)).collect();
719
720 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
721 results
722 }
723
724 fn search_node(
725 &self,
726 node: &VpNode,
727 query: &[f32],
728 k: usize,
729 heap: &mut BinaryHeap<SearchResult>,
730 tau: f32,
731 ) -> f32 {
732 let mut tau = tau;
733
734 if !node.indices.is_empty() {
735 for &idx in &node.indices {
737 let point = &self.data[idx].1.as_f32();
738 let dist = self.config.distance_metric.distance(query, point);
739
740 if dist < tau {
741 if heap.len() < k {
742 heap.push(SearchResult {
743 index: idx,
744 distance: dist,
745 });
746 } else if dist < heap.peek().unwrap().distance {
747 heap.pop();
748 heap.push(SearchResult {
749 index: idx,
750 distance: dist,
751 });
752 }
753
754 if heap.len() >= k {
755 tau = heap.peek().unwrap().distance;
756 }
757 }
758 }
759 return tau;
760 }
761
762 let vp_data = &self.data[node.vantage_point].1.as_f32();
764 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
765
766 if dist_to_vp < tau {
768 if heap.len() < k {
769 heap.push(SearchResult {
770 index: node.vantage_point,
771 distance: dist_to_vp,
772 });
773 } else if dist_to_vp < heap.peek().unwrap().distance {
774 heap.pop();
775 heap.push(SearchResult {
776 index: node.vantage_point,
777 distance: dist_to_vp,
778 });
779 }
780
781 if heap.len() >= k {
782 tau = heap.peek().unwrap().distance;
783 }
784 }
785
786 if dist_to_vp < node.median_distance {
788 if let Some(inside) = &node.inside {
790 tau = self.search_node(inside, query, k, heap, tau);
791 }
792
793 if dist_to_vp + tau >= node.median_distance {
795 if let Some(outside) = &node.outside {
796 tau = self.search_node(outside, query, k, heap, tau);
797 }
798 }
799 } else {
800 if let Some(outside) = &node.outside {
802 tau = self.search_node(outside, query, k, heap, tau);
803 }
804
805 if dist_to_vp - tau <= node.median_distance {
807 if let Some(inside) = &node.inside {
808 tau = self.search_node(inside, query, k, heap, tau);
809 }
810 }
811 }
812
813 tau
814 }
815}
816
817pub struct CoverTree {
819 root: Option<Box<CoverNode>>,
820 data: Vec<(String, Vector)>,
821 config: TreeIndexConfig,
822 base: f32,
823}
824
825struct CoverNode {
826 point: usize,
828 level: i32,
830 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
833}
834
835impl CoverTree {
836 pub fn new(config: TreeIndexConfig) -> Self {
837 Self {
838 root: None,
839 data: Vec::new(),
840 config,
841 base: 2.0, }
843 }
844
845 pub fn build(&mut self) -> Result<()> {
846 if self.data.is_empty() {
847 return Ok(());
848 }
849
850 self.root = Some(Box::new(CoverNode {
852 point: 0,
853 level: self.get_level(0),
854 children: Vec::new(),
855 }));
856
857 for idx in 1..self.data.len() {
859 self.insert(idx)?;
860 }
861
862 Ok(())
863 }
864
865 fn get_level(&self, _point_idx: usize) -> i32 {
866 ((self.data.len() as f32).log2() as i32).max(0)
868 }
869
870 fn insert(&mut self, point_idx: usize) -> Result<()> {
871 let level = self.get_level(point_idx);
874 if let Some(root) = &mut self.root {
875 root.children.push(Box::new(CoverNode {
876 point: point_idx,
877 level,
878 children: Vec::new(),
879 }));
880 }
881 Ok(())
882 }
883
884 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
885 if self.root.is_none() {
886 return Vec::new();
887 }
888
889 let mut results = Vec::new();
890 self.search_node(self.root.as_ref().unwrap(), query, k, &mut results);
891
892 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
893 results.truncate(k);
894 results
895 }
896
897 #[allow(clippy::only_used_in_recursion)]
898 fn search_node(
899 &self,
900 node: &CoverNode,
901 query: &[f32],
902 k: usize,
903 results: &mut Vec<(usize, f32)>,
904 ) {
905 if results.len() >= k * 10 {
907 return;
908 }
909
910 let point_data = &self.data[node.point].1.as_f32();
911 let dist = self.config.distance_metric.distance(query, point_data);
912
913 results.push((node.point, dist));
914
915 for child in &node.children {
917 self.search_node(child, query, k, results);
918 }
919 }
920}
921
922pub struct RandomProjectionTree {
924 root: Option<Box<RpNode>>,
925 data: Vec<(String, Vector)>,
926 config: TreeIndexConfig,
927}
928
929struct RpNode {
930 projection: Vec<f32>,
932 threshold: f32,
934 left: Option<Box<RpNode>>,
936 right: Option<Box<RpNode>>,
938 indices: Vec<usize>,
940}
941
942impl RandomProjectionTree {
943 pub fn new(config: TreeIndexConfig) -> Self {
944 Self {
945 root: None,
946 data: Vec::new(),
947 config,
948 }
949 }
950
951 pub fn build(&mut self) -> Result<()> {
952 if self.data.is_empty() {
953 return Ok(());
954 }
955
956 let indices: Vec<usize> = (0..self.data.len()).collect();
957 let dimensions = self.data[0].1.dimensions;
958
959 let mut rng = if let Some(seed) = self.config.random_seed {
960 Random::seed(seed)
961 } else {
962 Random::seed(42)
963 };
964
965 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
966 Ok(())
967 }
968
969 fn build_node<R: Rng>(
970 &self,
971 indices: Vec<usize>,
972 dimensions: usize,
973 rng: &mut R,
974 ) -> Result<RpNode> {
975 self.build_node_safe(indices, dimensions, rng, 0)
976 }
977
978 #[allow(deprecated)]
979 fn build_node_safe<R: Rng>(
980 &self,
981 indices: Vec<usize>,
982 dimensions: usize,
983 rng: &mut R,
984 depth: usize,
985 ) -> Result<RpNode> {
986 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
988 return Ok(RpNode {
989 projection: Vec::new(),
990 threshold: 0.0,
991 left: None,
992 right: None,
993 indices,
994 });
995 }
996
997 let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
999
1000 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
1002 let projection: Vec<f32> = if norm > 0.0 {
1003 projection.iter().map(|&x| x / norm).collect()
1004 } else {
1005 projection
1006 };
1007
1008 let mut projections: Vec<(f32, usize)> = indices
1010 .iter()
1011 .map(|&idx| {
1012 let point = &self.data[idx].1.as_f32();
1013 let proj_val = f32::dot(point, &projection);
1014 (proj_val, idx)
1015 })
1016 .collect();
1017
1018 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
1019
1020 let median_idx = projections.len() / 2;
1022 let threshold = projections[median_idx].0;
1023
1024 let left_indices: Vec<usize> = projections[..median_idx]
1025 .iter()
1026 .map(|(_, idx)| *idx)
1027 .collect();
1028
1029 let right_indices: Vec<usize> = projections[median_idx..]
1030 .iter()
1031 .map(|(_, idx)| *idx)
1032 .collect();
1033
1034 if left_indices.is_empty() || right_indices.is_empty() {
1036 return Ok(RpNode {
1037 projection: Vec::new(),
1038 threshold: 0.0,
1039 left: None,
1040 right: None,
1041 indices,
1042 });
1043 }
1044
1045 let left = Some(Box::new(self.build_node_safe(
1046 left_indices,
1047 dimensions,
1048 rng,
1049 depth + 1,
1050 )?));
1051 let right = Some(Box::new(self.build_node_safe(
1052 right_indices,
1053 dimensions,
1054 rng,
1055 depth + 1,
1056 )?));
1057
1058 Ok(RpNode {
1059 projection,
1060 threshold,
1061 left,
1062 right,
1063 indices: Vec::new(),
1064 })
1065 }
1066
1067 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1068 if self.root.is_none() {
1069 return Vec::new();
1070 }
1071
1072 let mut heap = BinaryHeap::new();
1073 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
1074
1075 let mut results: Vec<(usize, f32)> =
1076 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1077
1078 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1079 results
1080 }
1081
1082 fn search_node(
1083 &self,
1084 node: &RpNode,
1085 query: &[f32],
1086 k: usize,
1087 heap: &mut BinaryHeap<SearchResult>,
1088 ) {
1089 if !node.indices.is_empty() {
1090 for &idx in &node.indices {
1092 let point = &self.data[idx].1.as_f32();
1093 let dist = self.config.distance_metric.distance(query, point);
1094
1095 if heap.len() < k {
1096 heap.push(SearchResult {
1097 index: idx,
1098 distance: dist,
1099 });
1100 } else if dist < heap.peek().unwrap().distance {
1101 heap.pop();
1102 heap.push(SearchResult {
1103 index: idx,
1104 distance: dist,
1105 });
1106 }
1107 }
1108 return;
1109 }
1110
1111 let query_projection = f32::dot(query, &node.projection);
1113
1114 let go_left = query_projection <= node.threshold;
1116
1117 let (first, second) = if go_left {
1118 (&node.left, &node.right)
1119 } else {
1120 (&node.right, &node.left)
1121 };
1122
1123 if let Some(child) = first {
1125 self.search_node(child, query, k, heap);
1126 }
1127
1128 if let Some(child) = second {
1129 self.search_node(child, query, k, heap);
1130 }
1131 }
1132}
1133
1134pub struct TreeIndex {
1136 tree_type: TreeType,
1137 ball_tree: Option<BallTree>,
1138 kd_tree: Option<KdTree>,
1139 vp_tree: Option<VpTree>,
1140 cover_tree: Option<CoverTree>,
1141 rp_tree: Option<RandomProjectionTree>,
1142}
1143
1144impl TreeIndex {
1145 pub fn new(config: TreeIndexConfig) -> Self {
1146 let tree_type = config.tree_type;
1147
1148 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1149 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1150 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1151 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1152 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1153 TreeType::RandomProjectionTree => (
1154 None,
1155 None,
1156 None,
1157 None,
1158 Some(RandomProjectionTree::new(config)),
1159 ),
1160 };
1161
1162 Self {
1163 tree_type,
1164 ball_tree,
1165 kd_tree,
1166 vp_tree,
1167 cover_tree,
1168 rp_tree,
1169 }
1170 }
1171
1172 pub fn build(&mut self) -> Result<()> {
1173 match self.tree_type {
1174 TreeType::BallTree => self.ball_tree.as_mut().unwrap().build(),
1175 TreeType::KdTree => self.kd_tree.as_mut().unwrap().build(),
1176 TreeType::VpTree => self.vp_tree.as_mut().unwrap().build(),
1177 TreeType::CoverTree => self.cover_tree.as_mut().unwrap().build(),
1178 TreeType::RandomProjectionTree => self.rp_tree.as_mut().unwrap().build(),
1179 }
1180 }
1181
1182 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1183 match self.tree_type {
1184 TreeType::BallTree => self.ball_tree.as_ref().unwrap().search(query, k),
1185 TreeType::KdTree => self.kd_tree.as_ref().unwrap().search(query, k),
1186 TreeType::VpTree => self.vp_tree.as_ref().unwrap().search(query, k),
1187 TreeType::CoverTree => self.cover_tree.as_ref().unwrap().search(query, k),
1188 TreeType::RandomProjectionTree => self.rp_tree.as_ref().unwrap().search(query, k),
1189 }
1190 }
1191}
1192
1193impl VectorIndex for TreeIndex {
1194 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1195 let data = match self.tree_type {
1196 TreeType::BallTree => &mut self.ball_tree.as_mut().unwrap().data,
1197 TreeType::KdTree => &mut self.kd_tree.as_mut().unwrap().data,
1198 TreeType::VpTree => &mut self.vp_tree.as_mut().unwrap().data,
1199 TreeType::CoverTree => &mut self.cover_tree.as_mut().unwrap().data,
1200 TreeType::RandomProjectionTree => &mut self.rp_tree.as_mut().unwrap().data,
1201 };
1202
1203 data.push((uri, vector));
1204 Ok(())
1205 }
1206
1207 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1208 let query_f32 = query.as_f32();
1209 let results = self.search_internal(&query_f32, k);
1210
1211 let data = match self.tree_type {
1212 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1213 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1214 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1215 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1216 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1217 };
1218
1219 Ok(results
1220 .into_iter()
1221 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1222 .collect())
1223 }
1224
1225 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1226 let query_f32 = query.as_f32();
1227 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1230 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1231 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1232 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1233 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1234 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1235 };
1236
1237 Ok(all_results
1238 .into_iter()
1239 .filter(|(_, dist)| *dist <= threshold)
1240 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1241 .collect())
1242 }
1243
1244 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1245 let data = match self.tree_type {
1246 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1247 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1248 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1249 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1250 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1251 };
1252
1253 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1254 }
1255}
1256
1257async fn spawn_task<F, T>(f: F) -> T
1262where
1263 F: FnOnce() -> T + Send + 'static,
1264 T: Send + 'static,
1265{
1266 f()
1268}
1269
1270#[cfg(test)]
1271mod tests {
1272 use super::*;
1273
1274 #[test]
1275 #[ignore = "Tree indices are experimental - see module documentation for alternatives"]
1276 fn test_ball_tree() {
1277 let config = TreeIndexConfig {
1278 tree_type: TreeType::BallTree,
1279 max_leaf_size: 10,
1280 ..Default::default()
1281 };
1282
1283 let mut ball_tree = BallTree::new(config);
1284
1285 for i in 0..100 {
1287 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1288 ball_tree.data.push((format!("vec_{i}"), vector));
1289 }
1290
1291 ball_tree.build().unwrap();
1293 assert!(ball_tree.root.is_some());
1294
1295 let query = vec![50.0, 100.0];
1296 let results = ball_tree.search(&query, 5);
1297
1298 assert!(results.len() <= 5);
1299 assert!(!results.is_empty());
1300 }
1301
1302 #[test]
1303 #[ignore = "Investigating stack overflow with recursive tree construction"]
1304 fn test_kd_tree() {
1305 let config = TreeIndexConfig {
1306 tree_type: TreeType::KdTree,
1307 max_leaf_size: 50, ..Default::default()
1309 };
1310
1311 let mut index = TreeIndex::new(config);
1312
1313 for i in 0..3 {
1315 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1316 index.insert(format!("vec_{i}"), vector).unwrap();
1317 }
1318
1319 index.build().unwrap();
1320
1321 let query = Vector::new(vec![1.0, 2.0]);
1323 let results = index.search_knn(&query, 2).unwrap();
1324
1325 assert_eq!(results.len(), 2);
1326 }
1327
1328 #[test]
1329 #[ignore = "Investigating stack overflow with recursive tree construction"]
1330 fn test_vp_tree() {
1331 let config = TreeIndexConfig {
1332 tree_type: TreeType::VpTree,
1333 random_seed: Some(42),
1334 max_leaf_size: 50, ..Default::default()
1336 };
1337
1338 let mut index = TreeIndex::new(config);
1339
1340 for i in 0..3 {
1342 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1343 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1344 index.insert(format!("vec_{i}"), vector).unwrap();
1345 }
1346
1347 index.build().unwrap();
1348
1349 let query = Vector::new(vec![1.0, 0.0]);
1351 let results = index.search_knn(&query, 2).unwrap();
1352
1353 assert_eq!(results.len(), 2);
1354 }
1355}