1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::simd::SimdOps;
14use scirs2_core::random::{Random, Rng};
15use std::cmp::Ordering;
16use std::collections::BinaryHeap;
17
18#[derive(Debug, Clone)]
20pub struct TreeIndexConfig {
21 pub tree_type: TreeType,
23 pub max_leaf_size: usize,
25 pub random_seed: Option<u64>,
27 pub parallel_construction: bool,
29 pub distance_metric: DistanceMetric,
31}
32
33impl Default for TreeIndexConfig {
34 fn default() -> Self {
35 Self {
36 tree_type: TreeType::BallTree,
37 max_leaf_size: 16, random_seed: None,
39 parallel_construction: true,
40 distance_metric: DistanceMetric::Euclidean,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum TreeType {
48 BallTree,
49 KdTree,
50 VpTree,
51 CoverTree,
52 RandomProjectionTree,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum DistanceMetric {
58 Euclidean,
59 Manhattan,
60 Cosine,
61 Minkowski(f32),
62}
63
64impl DistanceMetric {
65 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
66 match self {
67 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
68 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
69 DistanceMetric::Cosine => f32::cosine_distance(a, b),
70 DistanceMetric::Minkowski(p) => a
71 .iter()
72 .zip(b.iter())
73 .map(|(x, y)| (x - y).abs().powf(*p))
74 .sum::<f32>()
75 .powf(1.0 / p),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82struct SearchResult {
83 index: usize,
84 distance: f32,
85}
86
87impl PartialEq for SearchResult {
88 fn eq(&self, other: &Self) -> bool {
89 self.distance == other.distance
90 }
91}
92
93impl Eq for SearchResult {}
94
95impl PartialOrd for SearchResult {
96 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
97 Some(self.cmp(other))
98 }
99}
100
101impl Ord for SearchResult {
102 fn cmp(&self, other: &Self) -> Ordering {
103 self.partial_cmp(other).unwrap_or(Ordering::Equal)
104 }
105}
106
107pub struct BallTree {
109 root: Option<Box<BallNode>>,
110 data: Vec<(String, Vector)>,
111 config: TreeIndexConfig,
112}
113
114struct BallNode {
115 center: Vec<f32>,
117 radius: f32,
119 left: Option<Box<BallNode>>,
121 right: Option<Box<BallNode>>,
123 indices: Vec<usize>,
125}
126
127impl BallTree {
128 pub fn new(config: TreeIndexConfig) -> Self {
129 Self {
130 root: None,
131 data: Vec::new(),
132 config,
133 }
134 }
135
136 pub fn build(&mut self) -> Result<()> {
138 if self.data.is_empty() {
139 return Ok(());
140 }
141
142 let indices: Vec<usize> = (0..self.data.len()).collect();
143 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
144
145 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
146 Ok(())
147 }
148
149 fn build_node_safe(
150 &self,
151 points: &[Vec<f32>],
152 indices: Vec<usize>,
153 depth: usize,
154 ) -> Result<BallNode> {
155 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
157 let center = self.compute_centroid(points, &indices);
159 let radius = self.compute_radius(points, &indices, ¢er);
160
161 return Ok(BallNode {
162 center,
163 radius,
164 left: None,
165 right: None,
166 indices,
167 });
168 }
169
170 let split_dim = self.find_split_dimension(points, &indices);
172
173 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
175
176 if left_indices.is_empty() || right_indices.is_empty() {
178 let center = self.compute_centroid(points, &indices);
180 let radius = self.compute_radius(points, &indices, ¢er);
181 return Ok(BallNode {
182 center,
183 radius,
184 left: None,
185 right: None,
186 indices,
187 });
188 }
189
190 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
192 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
193
194 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
196 let center = self.compute_centroid_of_centers(&all_centers);
197 let radius = left_node.radius.max(right_node.radius)
198 + self
199 .config
200 .distance_metric
201 .distance(¢er, &left_node.center);
202
203 Ok(BallNode {
204 center,
205 radius,
206 left: Some(Box::new(left_node)),
207 right: Some(Box::new(right_node)),
208 indices: Vec::new(),
209 })
210 }
211
212 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
213 let dim = points[0].len();
214 let mut centroid = vec![0.0; dim];
215
216 for &idx in indices {
217 for (i, &val) in points[idx].iter().enumerate() {
218 centroid[i] += val;
219 }
220 }
221
222 let n = indices.len() as f32;
223 for val in &mut centroid {
224 *val /= n;
225 }
226
227 centroid
228 }
229
230 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
231 indices
232 .iter()
233 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
234 .fold(0.0f32, f32::max)
235 }
236
237 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
238 let dim = points[0].len();
239 let mut max_spread = 0.0;
240 let mut split_dim = 0;
241
242 for d in 0..dim {
243 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
244
245 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
246 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
247 let spread = max_val - min_val;
248
249 if spread > max_spread {
250 max_spread = spread;
251 split_dim = d;
252 }
253 }
254
255 split_dim
256 }
257
258 fn partition_indices(
259 &self,
260 points: &[Vec<f32>],
261 indices: &[usize],
262 dim: usize,
263 ) -> (Vec<usize>, Vec<usize>) {
264 let mut values: Vec<(f32, usize)> =
265 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
266
267 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
268
269 let mid = values.len() / 2;
270 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
271 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
272
273 (left_indices, right_indices)
274 }
275
276 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
277 let dim = centers[0].len();
278 let mut centroid = vec![0.0; dim];
279
280 for center in centers {
281 for (i, &val) in center.iter().enumerate() {
282 centroid[i] += val;
283 }
284 }
285
286 let n = centers.len() as f32;
287 for val in &mut centroid {
288 *val /= n;
289 }
290
291 centroid
292 }
293
294 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
296 if self.root.is_none() {
297 return Vec::new();
298 }
299
300 let mut heap = BinaryHeap::new();
301 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
302
303 let mut results: Vec<(usize, f32)> =
304 heap.into_iter().map(|r| (r.index, r.distance)).collect();
305
306 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
307 results
308 }
309
310 fn search_node(
311 &self,
312 node: &BallNode,
313 query: &[f32],
314 k: usize,
315 heap: &mut BinaryHeap<SearchResult>,
316 ) {
317 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
319
320 if heap.len() >= k {
321 let worst_dist = heap.peek().unwrap().distance;
322 if dist_to_center - node.radius > worst_dist {
323 return; }
325 }
326
327 if node.indices.is_empty() {
328 if let (Some(left), Some(right)) = (&node.left, &node.right) {
330 let left_dist = self.config.distance_metric.distance(query, &left.center);
331 let right_dist = self.config.distance_metric.distance(query, &right.center);
332
333 if left_dist < right_dist {
334 self.search_node(left, query, k, heap);
335 self.search_node(right, query, k, heap);
336 } else {
337 self.search_node(right, query, k, heap);
338 self.search_node(left, query, k, heap);
339 }
340 }
341 } else {
342 for &idx in &node.indices {
344 let point = &self.data[idx].1.as_f32();
345 let dist = self.config.distance_metric.distance(query, point);
346
347 if heap.len() < k {
348 heap.push(SearchResult {
349 index: idx,
350 distance: dist,
351 });
352 } else if dist < heap.peek().unwrap().distance {
353 heap.pop();
354 heap.push(SearchResult {
355 index: idx,
356 distance: dist,
357 });
358 }
359 }
360 }
361 }
362}
363
364pub struct KdTree {
366 root: Option<Box<KdNode>>,
367 data: Vec<(String, Vector)>,
368 config: TreeIndexConfig,
369}
370
371struct KdNode {
372 split_dim: usize,
374 split_value: f32,
376 left: Option<Box<KdNode>>,
378 right: Option<Box<KdNode>>,
380 indices: Vec<usize>,
382}
383
384impl KdTree {
385 pub fn new(config: TreeIndexConfig) -> Self {
386 Self {
387 root: None,
388 data: Vec::new(),
389 config,
390 }
391 }
392
393 pub fn build(&mut self) -> Result<()> {
394 if self.data.is_empty() {
395 return Ok(());
396 }
397
398 let indices: Vec<usize> = (0..self.data.len()).collect();
399 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
400
401 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
402 Ok(())
403 }
404
405 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
406 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
408 return Ok(KdNode {
409 split_dim: 0,
410 split_value: 0.0,
411 left: None,
412 right: None,
413 indices,
414 });
415 }
416
417 let dimensions = points[0].len();
418 let split_dim = depth % dimensions;
419
420 let mut values: Vec<(f32, usize)> = indices
422 .iter()
423 .map(|&idx| (points[idx][split_dim], idx))
424 .collect();
425
426 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
427
428 let median_idx = values.len() / 2;
429 let split_value = values[median_idx].0;
430
431 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
432
433 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
434
435 if left_indices.is_empty() || right_indices.is_empty() {
437 return Ok(KdNode {
438 split_dim: 0,
439 split_value: 0.0,
440 left: None,
441 right: None,
442 indices,
443 });
444 }
445
446 let left = Some(Box::new(self.build_node(
447 points,
448 left_indices,
449 depth + 1,
450 )?));
451
452 let right = Some(Box::new(self.build_node(
453 points,
454 right_indices,
455 depth + 1,
456 )?));
457
458 Ok(KdNode {
459 split_dim,
460 split_value,
461 left,
462 right,
463 indices: Vec::new(),
464 })
465 }
466
467 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
468 if self.root.is_none() {
469 return Vec::new();
470 }
471
472 let mut heap = BinaryHeap::new();
473 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
474
475 let mut results: Vec<(usize, f32)> =
476 heap.into_iter().map(|r| (r.index, r.distance)).collect();
477
478 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
479 results
480 }
481
482 fn search_node(
483 &self,
484 node: &KdNode,
485 query: &[f32],
486 k: usize,
487 heap: &mut BinaryHeap<SearchResult>,
488 ) {
489 if !node.indices.is_empty() {
490 for &idx in &node.indices {
492 let point = &self.data[idx].1.as_f32();
493 let dist = self.config.distance_metric.distance(query, point);
494
495 if heap.len() < k {
496 heap.push(SearchResult {
497 index: idx,
498 distance: dist,
499 });
500 } else if dist < heap.peek().unwrap().distance {
501 heap.pop();
502 heap.push(SearchResult {
503 index: idx,
504 distance: dist,
505 });
506 }
507 }
508 return;
509 }
510
511 let go_left = query[node.split_dim] <= node.split_value;
513
514 let (first, second) = if go_left {
515 (&node.left, &node.right)
516 } else {
517 (&node.right, &node.left)
518 };
519
520 if let Some(child) = first {
522 self.search_node(child, query, k, heap);
523 }
524
525 if heap.len() < k || {
527 let split_dist = (query[node.split_dim] - node.split_value).abs();
528 split_dist < heap.peek().unwrap().distance
529 } {
530 if let Some(child) = second {
531 self.search_node(child, query, k, heap);
532 }
533 }
534 }
535}
536
537pub struct VpTree {
539 root: Option<Box<VpNode>>,
540 data: Vec<(String, Vector)>,
541 config: TreeIndexConfig,
542}
543
544struct VpNode {
545 vantage_point: usize,
547 median_distance: f32,
549 inside: Option<Box<VpNode>>,
551 outside: Option<Box<VpNode>>,
553 indices: Vec<usize>,
555}
556
557impl VpTree {
558 pub fn new(config: TreeIndexConfig) -> Self {
559 Self {
560 root: None,
561 data: Vec::new(),
562 config,
563 }
564 }
565
566 pub fn build(&mut self) -> Result<()> {
567 if self.data.is_empty() {
568 return Ok(());
569 }
570
571 let indices: Vec<usize> = (0..self.data.len()).collect();
572 let mut rng = if let Some(seed) = self.config.random_seed {
573 Random::seed(seed)
574 } else {
575 Random::seed(42)
576 };
577
578 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
579 Ok(())
580 }
581
582 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
583 self.build_node_safe(indices, rng, 0)
584 }
585
586 #[allow(deprecated)]
587 fn build_node_safe<R: Rng>(
588 &self,
589 mut indices: Vec<usize>,
590 rng: &mut R,
591 depth: usize,
592 ) -> Result<VpNode> {
593 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
597 return Ok(VpNode {
598 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
599 median_distance: 0.0,
600 inside: None,
601 outside: None,
602 indices,
603 });
604 }
605
606 let vp_idx = indices.len() - 1;
608 for i in (1..indices.len()).rev() {
610 let j = rng.gen_range(0..=i);
611 indices.swap(i, j);
612 }
613 let vantage_point = indices[vp_idx];
614 indices.truncate(vp_idx);
615
616 let vp_data = &self.data[vantage_point].1.as_f32();
618 let mut distances: Vec<(f32, usize)> = indices
619 .iter()
620 .map(|&idx| {
621 let point = &self.data[idx].1.as_f32();
622 let dist = self.config.distance_metric.distance(vp_data, point);
623 (dist, idx)
624 })
625 .collect();
626
627 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
628
629 let median_idx = distances.len() / 2;
630 let median_distance = distances[median_idx].0;
631
632 let inside_indices: Vec<usize> = distances[..median_idx]
633 .iter()
634 .map(|(_, idx)| *idx)
635 .collect();
636
637 let outside_indices: Vec<usize> = distances[median_idx..]
638 .iter()
639 .map(|(_, idx)| *idx)
640 .collect();
641
642 if inside_indices.is_empty() || outside_indices.is_empty() {
644 return Ok(VpNode {
645 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
646 median_distance: 0.0,
647 inside: None,
648 outside: None,
649 indices,
650 });
651 }
652
653 let inside = Some(Box::new(self.build_node_safe(
654 inside_indices,
655 rng,
656 depth + 1,
657 )?));
658 let outside = Some(Box::new(self.build_node_safe(
659 outside_indices,
660 rng,
661 depth + 1,
662 )?));
663
664 Ok(VpNode {
665 vantage_point,
666 median_distance,
667 inside,
668 outside,
669 indices: Vec::new(),
670 })
671 }
672
673 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
674 if self.root.is_none() {
675 return Vec::new();
676 }
677
678 let mut heap = BinaryHeap::new();
679 self.search_node(
680 self.root.as_ref().unwrap(),
681 query,
682 k,
683 &mut heap,
684 f32::INFINITY,
685 );
686
687 let mut results: Vec<(usize, f32)> =
688 heap.into_iter().map(|r| (r.index, r.distance)).collect();
689
690 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
691 results
692 }
693
694 fn search_node(
695 &self,
696 node: &VpNode,
697 query: &[f32],
698 k: usize,
699 heap: &mut BinaryHeap<SearchResult>,
700 tau: f32,
701 ) -> f32 {
702 let mut tau = tau;
703
704 if !node.indices.is_empty() {
705 for &idx in &node.indices {
707 let point = &self.data[idx].1.as_f32();
708 let dist = self.config.distance_metric.distance(query, point);
709
710 if dist < tau {
711 if heap.len() < k {
712 heap.push(SearchResult {
713 index: idx,
714 distance: dist,
715 });
716 } else if dist < heap.peek().unwrap().distance {
717 heap.pop();
718 heap.push(SearchResult {
719 index: idx,
720 distance: dist,
721 });
722 }
723
724 if heap.len() >= k {
725 tau = heap.peek().unwrap().distance;
726 }
727 }
728 }
729 return tau;
730 }
731
732 let vp_data = &self.data[node.vantage_point].1.as_f32();
734 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
735
736 if dist_to_vp < tau {
738 if heap.len() < k {
739 heap.push(SearchResult {
740 index: node.vantage_point,
741 distance: dist_to_vp,
742 });
743 } else if dist_to_vp < heap.peek().unwrap().distance {
744 heap.pop();
745 heap.push(SearchResult {
746 index: node.vantage_point,
747 distance: dist_to_vp,
748 });
749 }
750
751 if heap.len() >= k {
752 tau = heap.peek().unwrap().distance;
753 }
754 }
755
756 if dist_to_vp < node.median_distance {
758 if let Some(inside) = &node.inside {
760 tau = self.search_node(inside, query, k, heap, tau);
761 }
762
763 if dist_to_vp + tau >= node.median_distance {
765 if let Some(outside) = &node.outside {
766 tau = self.search_node(outside, query, k, heap, tau);
767 }
768 }
769 } else {
770 if let Some(outside) = &node.outside {
772 tau = self.search_node(outside, query, k, heap, tau);
773 }
774
775 if dist_to_vp - tau <= node.median_distance {
777 if let Some(inside) = &node.inside {
778 tau = self.search_node(inside, query, k, heap, tau);
779 }
780 }
781 }
782
783 tau
784 }
785}
786
787pub struct CoverTree {
789 root: Option<Box<CoverNode>>,
790 data: Vec<(String, Vector)>,
791 config: TreeIndexConfig,
792 base: f32,
793}
794
795struct CoverNode {
796 point: usize,
798 level: i32,
800 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
803}
804
805impl CoverTree {
806 pub fn new(config: TreeIndexConfig) -> Self {
807 Self {
808 root: None,
809 data: Vec::new(),
810 config,
811 base: 2.0, }
813 }
814
815 pub fn build(&mut self) -> Result<()> {
816 if self.data.is_empty() {
817 return Ok(());
818 }
819
820 self.root = Some(Box::new(CoverNode {
822 point: 0,
823 level: self.get_level(0),
824 children: Vec::new(),
825 }));
826
827 for idx in 1..self.data.len() {
829 self.insert(idx)?;
830 }
831
832 Ok(())
833 }
834
835 fn get_level(&self, _point_idx: usize) -> i32 {
836 ((self.data.len() as f32).log2() as i32).max(0)
838 }
839
840 fn insert(&mut self, point_idx: usize) -> Result<()> {
841 let level = self.get_level(point_idx);
844 if let Some(root) = &mut self.root {
845 root.children.push(Box::new(CoverNode {
846 point: point_idx,
847 level,
848 children: Vec::new(),
849 }));
850 }
851 Ok(())
852 }
853
854 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
855 if self.root.is_none() {
856 return Vec::new();
857 }
858
859 let mut results = Vec::new();
860 self.search_node(self.root.as_ref().unwrap(), query, k, &mut results);
861
862 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
863 results.truncate(k);
864 results
865 }
866
867 #[allow(clippy::only_used_in_recursion)]
868 fn search_node(
869 &self,
870 node: &CoverNode,
871 query: &[f32],
872 k: usize,
873 results: &mut Vec<(usize, f32)>,
874 ) {
875 if results.len() >= k * 10 {
877 return;
878 }
879
880 let point_data = &self.data[node.point].1.as_f32();
881 let dist = self.config.distance_metric.distance(query, point_data);
882
883 results.push((node.point, dist));
884
885 for child in &node.children {
887 self.search_node(child, query, k, results);
888 }
889 }
890}
891
892pub struct RandomProjectionTree {
894 root: Option<Box<RpNode>>,
895 data: Vec<(String, Vector)>,
896 config: TreeIndexConfig,
897}
898
899struct RpNode {
900 projection: Vec<f32>,
902 threshold: f32,
904 left: Option<Box<RpNode>>,
906 right: Option<Box<RpNode>>,
908 indices: Vec<usize>,
910}
911
912impl RandomProjectionTree {
913 pub fn new(config: TreeIndexConfig) -> Self {
914 Self {
915 root: None,
916 data: Vec::new(),
917 config,
918 }
919 }
920
921 pub fn build(&mut self) -> Result<()> {
922 if self.data.is_empty() {
923 return Ok(());
924 }
925
926 let indices: Vec<usize> = (0..self.data.len()).collect();
927 let dimensions = self.data[0].1.dimensions;
928
929 let mut rng = if let Some(seed) = self.config.random_seed {
930 Random::seed(seed)
931 } else {
932 Random::seed(42)
933 };
934
935 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
936 Ok(())
937 }
938
939 fn build_node<R: Rng>(
940 &self,
941 indices: Vec<usize>,
942 dimensions: usize,
943 rng: &mut R,
944 ) -> Result<RpNode> {
945 self.build_node_safe(indices, dimensions, rng, 0)
946 }
947
948 #[allow(deprecated)]
949 fn build_node_safe<R: Rng>(
950 &self,
951 indices: Vec<usize>,
952 dimensions: usize,
953 rng: &mut R,
954 depth: usize,
955 ) -> Result<RpNode> {
956 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
958 return Ok(RpNode {
959 projection: Vec::new(),
960 threshold: 0.0,
961 left: None,
962 right: None,
963 indices,
964 });
965 }
966
967 let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
969
970 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
972 let projection: Vec<f32> = if norm > 0.0 {
973 projection.iter().map(|&x| x / norm).collect()
974 } else {
975 projection
976 };
977
978 let mut projections: Vec<(f32, usize)> = indices
980 .iter()
981 .map(|&idx| {
982 let point = &self.data[idx].1.as_f32();
983 let proj_val = f32::dot(point, &projection);
984 (proj_val, idx)
985 })
986 .collect();
987
988 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
989
990 let median_idx = projections.len() / 2;
992 let threshold = projections[median_idx].0;
993
994 let left_indices: Vec<usize> = projections[..median_idx]
995 .iter()
996 .map(|(_, idx)| *idx)
997 .collect();
998
999 let right_indices: Vec<usize> = projections[median_idx..]
1000 .iter()
1001 .map(|(_, idx)| *idx)
1002 .collect();
1003
1004 if left_indices.is_empty() || right_indices.is_empty() {
1006 return Ok(RpNode {
1007 projection: Vec::new(),
1008 threshold: 0.0,
1009 left: None,
1010 right: None,
1011 indices,
1012 });
1013 }
1014
1015 let left = Some(Box::new(self.build_node_safe(
1016 left_indices,
1017 dimensions,
1018 rng,
1019 depth + 1,
1020 )?));
1021 let right = Some(Box::new(self.build_node_safe(
1022 right_indices,
1023 dimensions,
1024 rng,
1025 depth + 1,
1026 )?));
1027
1028 Ok(RpNode {
1029 projection,
1030 threshold,
1031 left,
1032 right,
1033 indices: Vec::new(),
1034 })
1035 }
1036
1037 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1038 if self.root.is_none() {
1039 return Vec::new();
1040 }
1041
1042 let mut heap = BinaryHeap::new();
1043 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
1044
1045 let mut results: Vec<(usize, f32)> =
1046 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1047
1048 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1049 results
1050 }
1051
1052 fn search_node(
1053 &self,
1054 node: &RpNode,
1055 query: &[f32],
1056 k: usize,
1057 heap: &mut BinaryHeap<SearchResult>,
1058 ) {
1059 if !node.indices.is_empty() {
1060 for &idx in &node.indices {
1062 let point = &self.data[idx].1.as_f32();
1063 let dist = self.config.distance_metric.distance(query, point);
1064
1065 if heap.len() < k {
1066 heap.push(SearchResult {
1067 index: idx,
1068 distance: dist,
1069 });
1070 } else if dist < heap.peek().unwrap().distance {
1071 heap.pop();
1072 heap.push(SearchResult {
1073 index: idx,
1074 distance: dist,
1075 });
1076 }
1077 }
1078 return;
1079 }
1080
1081 let query_projection = f32::dot(query, &node.projection);
1083
1084 let go_left = query_projection <= node.threshold;
1086
1087 let (first, second) = if go_left {
1088 (&node.left, &node.right)
1089 } else {
1090 (&node.right, &node.left)
1091 };
1092
1093 if let Some(child) = first {
1095 self.search_node(child, query, k, heap);
1096 }
1097
1098 if let Some(child) = second {
1099 self.search_node(child, query, k, heap);
1100 }
1101 }
1102}
1103
1104pub struct TreeIndex {
1106 tree_type: TreeType,
1107 ball_tree: Option<BallTree>,
1108 kd_tree: Option<KdTree>,
1109 vp_tree: Option<VpTree>,
1110 cover_tree: Option<CoverTree>,
1111 rp_tree: Option<RandomProjectionTree>,
1112}
1113
1114impl TreeIndex {
1115 pub fn new(config: TreeIndexConfig) -> Self {
1116 let tree_type = config.tree_type;
1117
1118 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1119 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1120 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1121 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1122 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1123 TreeType::RandomProjectionTree => (
1124 None,
1125 None,
1126 None,
1127 None,
1128 Some(RandomProjectionTree::new(config)),
1129 ),
1130 };
1131
1132 Self {
1133 tree_type,
1134 ball_tree,
1135 kd_tree,
1136 vp_tree,
1137 cover_tree,
1138 rp_tree,
1139 }
1140 }
1141
1142 fn build(&mut self) -> Result<()> {
1143 match self.tree_type {
1144 TreeType::BallTree => self.ball_tree.as_mut().unwrap().build(),
1145 TreeType::KdTree => self.kd_tree.as_mut().unwrap().build(),
1146 TreeType::VpTree => self.vp_tree.as_mut().unwrap().build(),
1147 TreeType::CoverTree => self.cover_tree.as_mut().unwrap().build(),
1148 TreeType::RandomProjectionTree => self.rp_tree.as_mut().unwrap().build(),
1149 }
1150 }
1151
1152 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1153 match self.tree_type {
1154 TreeType::BallTree => self.ball_tree.as_ref().unwrap().search(query, k),
1155 TreeType::KdTree => self.kd_tree.as_ref().unwrap().search(query, k),
1156 TreeType::VpTree => self.vp_tree.as_ref().unwrap().search(query, k),
1157 TreeType::CoverTree => self.cover_tree.as_ref().unwrap().search(query, k),
1158 TreeType::RandomProjectionTree => self.rp_tree.as_ref().unwrap().search(query, k),
1159 }
1160 }
1161}
1162
1163impl VectorIndex for TreeIndex {
1164 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1165 let data = match self.tree_type {
1166 TreeType::BallTree => &mut self.ball_tree.as_mut().unwrap().data,
1167 TreeType::KdTree => &mut self.kd_tree.as_mut().unwrap().data,
1168 TreeType::VpTree => &mut self.vp_tree.as_mut().unwrap().data,
1169 TreeType::CoverTree => &mut self.cover_tree.as_mut().unwrap().data,
1170 TreeType::RandomProjectionTree => &mut self.rp_tree.as_mut().unwrap().data,
1171 };
1172
1173 data.push((uri, vector));
1174 Ok(())
1175 }
1176
1177 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1178 let query_f32 = query.as_f32();
1179 let results = self.search_internal(&query_f32, k);
1180
1181 let data = match self.tree_type {
1182 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1183 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1184 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1185 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1186 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1187 };
1188
1189 Ok(results
1190 .into_iter()
1191 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1192 .collect())
1193 }
1194
1195 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1196 let query_f32 = query.as_f32();
1197 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1200 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1201 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1202 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1203 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1204 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1205 };
1206
1207 Ok(all_results
1208 .into_iter()
1209 .filter(|(_, dist)| *dist <= threshold)
1210 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1211 .collect())
1212 }
1213
1214 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1215 let data = match self.tree_type {
1216 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1217 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1218 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1219 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1220 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1221 };
1222
1223 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1224 }
1225}
1226
1227async fn spawn_task<F, T>(f: F) -> T
1232where
1233 F: FnOnce() -> T + Send + 'static,
1234 T: Send + 'static,
1235{
1236 f()
1238}
1239
1240#[cfg(test)]
1241mod tests {
1242 use super::*;
1243
1244 #[test]
1245 #[ignore = "Stack overflow issue - being investigated"]
1246 fn test_ball_tree() {
1247 let config = TreeIndexConfig {
1248 tree_type: TreeType::BallTree,
1249 max_leaf_size: 50, ..Default::default()
1251 };
1252
1253 let mut index = TreeIndex::new(config);
1254
1255 for i in 0..3 {
1257 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1258 index.insert(format!("vec_{i}"), vector).unwrap();
1259 }
1260
1261 index.build().unwrap();
1262
1263 let query = Vector::new(vec![1.0, 2.0]);
1265 let results = index.search_knn(&query, 2).unwrap();
1266
1267 assert_eq!(results.len(), 2);
1268 assert_eq!(results[0].0, "vec_1"); }
1270
1271 #[test]
1272 #[ignore = "Stack overflow issue - being investigated"]
1273 fn test_kd_tree() {
1274 let config = TreeIndexConfig {
1275 tree_type: TreeType::KdTree,
1276 max_leaf_size: 50, ..Default::default()
1278 };
1279
1280 let mut index = TreeIndex::new(config);
1281
1282 for i in 0..3 {
1284 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1285 index.insert(format!("vec_{i}"), vector).unwrap();
1286 }
1287
1288 index.build().unwrap();
1289
1290 let query = Vector::new(vec![1.0, 2.0]);
1292 let results = index.search_knn(&query, 2).unwrap();
1293
1294 assert_eq!(results.len(), 2);
1295 }
1296
1297 #[test]
1298 #[ignore = "Stack overflow issue - being investigated"]
1299 fn test_vp_tree() {
1300 let config = TreeIndexConfig {
1301 tree_type: TreeType::VpTree,
1302 random_seed: Some(42),
1303 max_leaf_size: 50, ..Default::default()
1305 };
1306
1307 let mut index = TreeIndex::new(config);
1308
1309 for i in 0..3 {
1311 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1312 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1313 index.insert(format!("vec_{i}"), vector).unwrap();
1314 }
1315
1316 index.build().unwrap();
1317
1318 let query = Vector::new(vec![1.0, 0.0]);
1320 let results = index.search_knn(&query, 2).unwrap();
1321
1322 assert_eq!(results.len(), 2);
1323 }
1324}