1use crate::{Vector, VectorIndex};
12use anyhow::Result;
13use oxirs_core::simd::SimdOps;
14use scirs2_core::random::{Random, Rng};
15use std::cmp::Ordering;
16use std::collections::BinaryHeap;
17
18#[derive(Debug, Clone)]
20pub struct TreeIndexConfig {
21 pub tree_type: TreeType,
23 pub max_leaf_size: usize,
25 pub random_seed: Option<u64>,
27 pub parallel_construction: bool,
29 pub distance_metric: DistanceMetric,
31}
32
33impl Default for TreeIndexConfig {
34 fn default() -> Self {
35 Self {
36 tree_type: TreeType::BallTree,
37 max_leaf_size: 16, random_seed: None,
39 parallel_construction: true,
40 distance_metric: DistanceMetric::Euclidean,
41 }
42 }
43}
44
45#[derive(Debug, Clone, Copy)]
47pub enum TreeType {
48 BallTree,
49 KdTree,
50 VpTree,
51 CoverTree,
52 RandomProjectionTree,
53}
54
55#[derive(Debug, Clone, Copy)]
57pub enum DistanceMetric {
58 Euclidean,
59 Manhattan,
60 Cosine,
61 Minkowski(f32),
62}
63
64impl DistanceMetric {
65 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
66 match self {
67 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
68 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
69 DistanceMetric::Cosine => f32::cosine_distance(a, b),
70 DistanceMetric::Minkowski(p) => a
71 .iter()
72 .zip(b.iter())
73 .map(|(x, y)| (x - y).abs().powf(*p))
74 .sum::<f32>()
75 .powf(1.0 / p),
76 }
77 }
78}
79
80#[derive(Debug, Clone)]
82struct SearchResult {
83 index: usize,
84 distance: f32,
85}
86
87impl PartialEq for SearchResult {
88 fn eq(&self, other: &Self) -> bool {
89 self.distance == other.distance
90 }
91}
92
93impl Eq for SearchResult {}
94
95impl PartialOrd for SearchResult {
96 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
97 Some(self.cmp(other))
98 }
99}
100
101impl Ord for SearchResult {
102 fn cmp(&self, other: &Self) -> Ordering {
103 self.partial_cmp(other).unwrap_or(Ordering::Equal)
104 }
105}
106
107pub struct BallTree {
109 root: Option<Box<BallNode>>,
110 data: Vec<(String, Vector)>,
111 config: TreeIndexConfig,
112}
113
114struct BallNode {
115 center: Vec<f32>,
117 radius: f32,
119 left: Option<Box<BallNode>>,
121 right: Option<Box<BallNode>>,
123 indices: Vec<usize>,
125}
126
127impl BallTree {
128 pub fn new(config: TreeIndexConfig) -> Self {
129 Self {
130 root: None,
131 data: Vec::new(),
132 config,
133 }
134 }
135
136 pub fn build(&mut self) -> Result<()> {
138 if self.data.is_empty() {
139 return Ok(());
140 }
141
142 let indices: Vec<usize> = (0..self.data.len()).collect();
143 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
144
145 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
146 Ok(())
147 }
148
149 fn build_node_safe(
150 &self,
151 points: &[Vec<f32>],
152 indices: Vec<usize>,
153 depth: usize,
154 ) -> Result<BallNode> {
155 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
157 let center = self.compute_centroid(points, &indices);
159 let radius = self.compute_radius(points, &indices, ¢er);
160
161 return Ok(BallNode {
162 center,
163 radius,
164 left: None,
165 right: None,
166 indices,
167 });
168 }
169
170 let split_dim = self.find_split_dimension(points, &indices);
172
173 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
175
176 if left_indices.is_empty() || right_indices.is_empty() {
178 let center = self.compute_centroid(points, &indices);
180 let radius = self.compute_radius(points, &indices, ¢er);
181 return Ok(BallNode {
182 center,
183 radius,
184 left: None,
185 right: None,
186 indices,
187 });
188 }
189
190 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
192 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
193
194 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
196 let center = self.compute_centroid_of_centers(&all_centers);
197 let radius = left_node.radius.max(right_node.radius)
198 + self
199 .config
200 .distance_metric
201 .distance(¢er, &left_node.center);
202
203 Ok(BallNode {
204 center,
205 radius,
206 left: Some(Box::new(left_node)),
207 right: Some(Box::new(right_node)),
208 indices: Vec::new(),
209 })
210 }
211
212 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
213 let dim = points[0].len();
214 let mut centroid = vec![0.0; dim];
215
216 for &idx in indices {
217 for (i, &val) in points[idx].iter().enumerate() {
218 centroid[i] += val;
219 }
220 }
221
222 let n = indices.len() as f32;
223 for val in &mut centroid {
224 *val /= n;
225 }
226
227 centroid
228 }
229
230 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
231 indices
232 .iter()
233 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
234 .fold(0.0f32, f32::max)
235 }
236
237 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
238 let dim = points[0].len();
239 let mut max_spread = 0.0;
240 let mut split_dim = 0;
241
242 for d in 0..dim {
243 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
244
245 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
246 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
247 let spread = max_val - min_val;
248
249 if spread > max_spread {
250 max_spread = spread;
251 split_dim = d;
252 }
253 }
254
255 split_dim
256 }
257
258 fn partition_indices(
259 &self,
260 points: &[Vec<f32>],
261 indices: &[usize],
262 dim: usize,
263 ) -> (Vec<usize>, Vec<usize>) {
264 let mut values: Vec<(f32, usize)> =
265 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
266
267 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
268
269 let mid = values.len() / 2;
270 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
271 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
272
273 (left_indices, right_indices)
274 }
275
276 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
277 let dim = centers[0].len();
278 let mut centroid = vec![0.0; dim];
279
280 for center in centers {
281 for (i, &val) in center.iter().enumerate() {
282 centroid[i] += val;
283 }
284 }
285
286 let n = centers.len() as f32;
287 for val in &mut centroid {
288 *val /= n;
289 }
290
291 centroid
292 }
293
294 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
296 if self.root.is_none() {
297 return Vec::new();
298 }
299
300 let mut heap = BinaryHeap::new();
301 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
302
303 let mut results: Vec<(usize, f32)> =
304 heap.into_iter().map(|r| (r.index, r.distance)).collect();
305
306 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
307 results
308 }
309
310 fn search_node(
311 &self,
312 node: &BallNode,
313 query: &[f32],
314 k: usize,
315 heap: &mut BinaryHeap<SearchResult>,
316 ) {
317 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
319
320 if heap.len() >= k {
321 let worst_dist = heap.peek().unwrap().distance;
322 if dist_to_center - node.radius > worst_dist {
323 return; }
325 }
326
327 if node.indices.is_empty() {
328 if let (Some(left), Some(right)) = (&node.left, &node.right) {
330 let left_dist = self.config.distance_metric.distance(query, &left.center);
331 let right_dist = self.config.distance_metric.distance(query, &right.center);
332
333 if left_dist < right_dist {
334 self.search_node(left, query, k, heap);
335 self.search_node(right, query, k, heap);
336 } else {
337 self.search_node(right, query, k, heap);
338 self.search_node(left, query, k, heap);
339 }
340 }
341 } else {
342 for &idx in &node.indices {
344 let point = &self.data[idx].1.as_f32();
345 let dist = self.config.distance_metric.distance(query, point);
346
347 if heap.len() < k {
348 heap.push(SearchResult {
349 index: idx,
350 distance: dist,
351 });
352 } else if dist < heap.peek().unwrap().distance {
353 heap.pop();
354 heap.push(SearchResult {
355 index: idx,
356 distance: dist,
357 });
358 }
359 }
360 }
361 }
362}
363
364pub struct KdTree {
366 root: Option<Box<KdNode>>,
367 data: Vec<(String, Vector)>,
368 config: TreeIndexConfig,
369}
370
371struct KdNode {
372 split_dim: usize,
374 split_value: f32,
376 left: Option<Box<KdNode>>,
378 right: Option<Box<KdNode>>,
380 indices: Vec<usize>,
382}
383
384impl KdTree {
385 pub fn new(config: TreeIndexConfig) -> Self {
386 Self {
387 root: None,
388 data: Vec::new(),
389 config,
390 }
391 }
392
393 pub fn build(&mut self) -> Result<()> {
394 if self.data.is_empty() {
395 return Ok(());
396 }
397
398 let indices: Vec<usize> = (0..self.data.len()).collect();
399 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
400
401 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
402 Ok(())
403 }
404
405 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
406 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
408 return Ok(KdNode {
409 split_dim: 0,
410 split_value: 0.0,
411 left: None,
412 right: None,
413 indices,
414 });
415 }
416
417 let dimensions = points[0].len();
418 let split_dim = depth % dimensions;
419
420 let mut values: Vec<(f32, usize)> = indices
422 .iter()
423 .map(|&idx| (points[idx][split_dim], idx))
424 .collect();
425
426 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
427
428 let median_idx = values.len() / 2;
429 let split_value = values[median_idx].0;
430
431 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
432
433 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
434
435 if left_indices.is_empty() || right_indices.is_empty() {
437 return Ok(KdNode {
438 split_dim: 0,
439 split_value: 0.0,
440 left: None,
441 right: None,
442 indices,
443 });
444 }
445
446 let left = Some(Box::new(self.build_node(
447 points,
448 left_indices,
449 depth + 1,
450 )?));
451
452 let right = Some(Box::new(self.build_node(
453 points,
454 right_indices,
455 depth + 1,
456 )?));
457
458 Ok(KdNode {
459 split_dim,
460 split_value,
461 left,
462 right,
463 indices: Vec::new(),
464 })
465 }
466
467 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
468 if self.root.is_none() {
469 return Vec::new();
470 }
471
472 let mut heap = BinaryHeap::new();
473 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
474
475 let mut results: Vec<(usize, f32)> =
476 heap.into_iter().map(|r| (r.index, r.distance)).collect();
477
478 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
479 results
480 }
481
482 fn search_node(
483 &self,
484 node: &KdNode,
485 query: &[f32],
486 k: usize,
487 heap: &mut BinaryHeap<SearchResult>,
488 ) {
489 if !node.indices.is_empty() {
490 for &idx in &node.indices {
492 let point = &self.data[idx].1.as_f32();
493 let dist = self.config.distance_metric.distance(query, point);
494
495 if heap.len() < k {
496 heap.push(SearchResult {
497 index: idx,
498 distance: dist,
499 });
500 } else if dist < heap.peek().unwrap().distance {
501 heap.pop();
502 heap.push(SearchResult {
503 index: idx,
504 distance: dist,
505 });
506 }
507 }
508 return;
509 }
510
511 let go_left = query[node.split_dim] <= node.split_value;
513
514 let (first, second) = if go_left {
515 (&node.left, &node.right)
516 } else {
517 (&node.right, &node.left)
518 };
519
520 if let Some(child) = first {
522 self.search_node(child, query, k, heap);
523 }
524
525 if heap.len() < k || {
527 let split_dist = (query[node.split_dim] - node.split_value).abs();
528 split_dist < heap.peek().unwrap().distance
529 } {
530 if let Some(child) = second {
531 self.search_node(child, query, k, heap);
532 }
533 }
534 }
535}
536
537pub struct VpTree {
539 root: Option<Box<VpNode>>,
540 data: Vec<(String, Vector)>,
541 config: TreeIndexConfig,
542}
543
544struct VpNode {
545 vantage_point: usize,
547 median_distance: f32,
549 inside: Option<Box<VpNode>>,
551 outside: Option<Box<VpNode>>,
553 indices: Vec<usize>,
555}
556
557impl VpTree {
558 pub fn new(config: TreeIndexConfig) -> Self {
559 Self {
560 root: None,
561 data: Vec::new(),
562 config,
563 }
564 }
565
566 pub fn build(&mut self) -> Result<()> {
567 if self.data.is_empty() {
568 return Ok(());
569 }
570
571 let indices: Vec<usize> = (0..self.data.len()).collect();
572 let mut rng = if let Some(seed) = self.config.random_seed {
573 Random::seed(seed)
574 } else {
575 Random::seed(42)
576 };
577
578 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
579 Ok(())
580 }
581
582 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
583 self.build_node_safe(indices, rng, 0)
584 }
585
586 fn build_node_safe<R: Rng>(
587 &self,
588 mut indices: Vec<usize>,
589 rng: &mut R,
590 depth: usize,
591 ) -> Result<VpNode> {
592 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= 3 {
596 return Ok(VpNode {
597 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
598 median_distance: 0.0,
599 inside: None,
600 outside: None,
601 indices,
602 });
603 }
604
605 let vp_idx = indices.len() - 1;
607 for i in (1..indices.len()).rev() {
609 let j = rng.gen_range(0..=i);
610 indices.swap(i, j);
611 }
612 let vantage_point = indices[vp_idx];
613 indices.truncate(vp_idx);
614
615 let vp_data = &self.data[vantage_point].1.as_f32();
617 let mut distances: Vec<(f32, usize)> = indices
618 .iter()
619 .map(|&idx| {
620 let point = &self.data[idx].1.as_f32();
621 let dist = self.config.distance_metric.distance(vp_data, point);
622 (dist, idx)
623 })
624 .collect();
625
626 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
627
628 let median_idx = distances.len() / 2;
629 let median_distance = distances[median_idx].0;
630
631 let inside_indices: Vec<usize> = distances[..median_idx]
632 .iter()
633 .map(|(_, idx)| *idx)
634 .collect();
635
636 let outside_indices: Vec<usize> = distances[median_idx..]
637 .iter()
638 .map(|(_, idx)| *idx)
639 .collect();
640
641 if inside_indices.is_empty() || outside_indices.is_empty() {
643 return Ok(VpNode {
644 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
645 median_distance: 0.0,
646 inside: None,
647 outside: None,
648 indices,
649 });
650 }
651
652 let inside = Some(Box::new(self.build_node_safe(
653 inside_indices,
654 rng,
655 depth + 1,
656 )?));
657 let outside = Some(Box::new(self.build_node_safe(
658 outside_indices,
659 rng,
660 depth + 1,
661 )?));
662
663 Ok(VpNode {
664 vantage_point,
665 median_distance,
666 inside,
667 outside,
668 indices: Vec::new(),
669 })
670 }
671
672 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
673 if self.root.is_none() {
674 return Vec::new();
675 }
676
677 let mut heap = BinaryHeap::new();
678 self.search_node(
679 self.root.as_ref().unwrap(),
680 query,
681 k,
682 &mut heap,
683 f32::INFINITY,
684 );
685
686 let mut results: Vec<(usize, f32)> =
687 heap.into_iter().map(|r| (r.index, r.distance)).collect();
688
689 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
690 results
691 }
692
693 fn search_node(
694 &self,
695 node: &VpNode,
696 query: &[f32],
697 k: usize,
698 heap: &mut BinaryHeap<SearchResult>,
699 tau: f32,
700 ) -> f32 {
701 let mut tau = tau;
702
703 if !node.indices.is_empty() {
704 for &idx in &node.indices {
706 let point = &self.data[idx].1.as_f32();
707 let dist = self.config.distance_metric.distance(query, point);
708
709 if dist < tau {
710 if heap.len() < k {
711 heap.push(SearchResult {
712 index: idx,
713 distance: dist,
714 });
715 } else if dist < heap.peek().unwrap().distance {
716 heap.pop();
717 heap.push(SearchResult {
718 index: idx,
719 distance: dist,
720 });
721 }
722
723 if heap.len() >= k {
724 tau = heap.peek().unwrap().distance;
725 }
726 }
727 }
728 return tau;
729 }
730
731 let vp_data = &self.data[node.vantage_point].1.as_f32();
733 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
734
735 if dist_to_vp < tau {
737 if heap.len() < k {
738 heap.push(SearchResult {
739 index: node.vantage_point,
740 distance: dist_to_vp,
741 });
742 } else if dist_to_vp < heap.peek().unwrap().distance {
743 heap.pop();
744 heap.push(SearchResult {
745 index: node.vantage_point,
746 distance: dist_to_vp,
747 });
748 }
749
750 if heap.len() >= k {
751 tau = heap.peek().unwrap().distance;
752 }
753 }
754
755 if dist_to_vp < node.median_distance {
757 if let Some(inside) = &node.inside {
759 tau = self.search_node(inside, query, k, heap, tau);
760 }
761
762 if dist_to_vp + tau >= node.median_distance {
764 if let Some(outside) = &node.outside {
765 tau = self.search_node(outside, query, k, heap, tau);
766 }
767 }
768 } else {
769 if let Some(outside) = &node.outside {
771 tau = self.search_node(outside, query, k, heap, tau);
772 }
773
774 if dist_to_vp - tau <= node.median_distance {
776 if let Some(inside) = &node.inside {
777 tau = self.search_node(inside, query, k, heap, tau);
778 }
779 }
780 }
781
782 tau
783 }
784}
785
786pub struct CoverTree {
788 root: Option<Box<CoverNode>>,
789 data: Vec<(String, Vector)>,
790 config: TreeIndexConfig,
791 base: f32,
792}
793
794struct CoverNode {
795 point: usize,
797 level: i32,
799 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
802}
803
804impl CoverTree {
805 pub fn new(config: TreeIndexConfig) -> Self {
806 Self {
807 root: None,
808 data: Vec::new(),
809 config,
810 base: 2.0, }
812 }
813
814 pub fn build(&mut self) -> Result<()> {
815 if self.data.is_empty() {
816 return Ok(());
817 }
818
819 self.root = Some(Box::new(CoverNode {
821 point: 0,
822 level: self.get_level(0),
823 children: Vec::new(),
824 }));
825
826 for idx in 1..self.data.len() {
828 self.insert(idx)?;
829 }
830
831 Ok(())
832 }
833
834 fn get_level(&self, _point_idx: usize) -> i32 {
835 ((self.data.len() as f32).log2() as i32).max(0)
837 }
838
839 fn insert(&mut self, point_idx: usize) -> Result<()> {
840 let level = self.get_level(point_idx);
843 if let Some(root) = &mut self.root {
844 root.children.push(Box::new(CoverNode {
845 point: point_idx,
846 level,
847 children: Vec::new(),
848 }));
849 }
850 Ok(())
851 }
852
853 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
854 if self.root.is_none() {
855 return Vec::new();
856 }
857
858 let mut results = Vec::new();
859 self.search_node(self.root.as_ref().unwrap(), query, k, &mut results);
860
861 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
862 results.truncate(k);
863 results
864 }
865
866 #[allow(clippy::only_used_in_recursion)]
867 fn search_node(
868 &self,
869 node: &CoverNode,
870 query: &[f32],
871 k: usize,
872 results: &mut Vec<(usize, f32)>,
873 ) {
874 if results.len() >= k * 10 {
876 return;
877 }
878
879 let point_data = &self.data[node.point].1.as_f32();
880 let dist = self.config.distance_metric.distance(query, point_data);
881
882 results.push((node.point, dist));
883
884 for child in &node.children {
886 self.search_node(child, query, k, results);
887 }
888 }
889}
890
891pub struct RandomProjectionTree {
893 root: Option<Box<RpNode>>,
894 data: Vec<(String, Vector)>,
895 config: TreeIndexConfig,
896}
897
898struct RpNode {
899 projection: Vec<f32>,
901 threshold: f32,
903 left: Option<Box<RpNode>>,
905 right: Option<Box<RpNode>>,
907 indices: Vec<usize>,
909}
910
911impl RandomProjectionTree {
912 pub fn new(config: TreeIndexConfig) -> Self {
913 Self {
914 root: None,
915 data: Vec::new(),
916 config,
917 }
918 }
919
920 pub fn build(&mut self) -> Result<()> {
921 if self.data.is_empty() {
922 return Ok(());
923 }
924
925 let indices: Vec<usize> = (0..self.data.len()).collect();
926 let dimensions = self.data[0].1.dimensions;
927
928 let mut rng = if let Some(seed) = self.config.random_seed {
929 Random::seed(seed)
930 } else {
931 Random::seed(42)
932 };
933
934 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
935 Ok(())
936 }
937
938 fn build_node<R: Rng>(
939 &self,
940 indices: Vec<usize>,
941 dimensions: usize,
942 rng: &mut R,
943 ) -> Result<RpNode> {
944 self.build_node_safe(indices, dimensions, rng, 0)
945 }
946
947 fn build_node_safe<R: Rng>(
948 &self,
949 indices: Vec<usize>,
950 dimensions: usize,
951 rng: &mut R,
952 depth: usize,
953 ) -> Result<RpNode> {
954 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
956 return Ok(RpNode {
957 projection: Vec::new(),
958 threshold: 0.0,
959 left: None,
960 right: None,
961 indices,
962 });
963 }
964
965 let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
967
968 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
970 let projection: Vec<f32> = if norm > 0.0 {
971 projection.iter().map(|&x| x / norm).collect()
972 } else {
973 projection
974 };
975
976 let mut projections: Vec<(f32, usize)> = indices
978 .iter()
979 .map(|&idx| {
980 let point = &self.data[idx].1.as_f32();
981 let proj_val = f32::dot(point, &projection);
982 (proj_val, idx)
983 })
984 .collect();
985
986 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
987
988 let median_idx = projections.len() / 2;
990 let threshold = projections[median_idx].0;
991
992 let left_indices: Vec<usize> = projections[..median_idx]
993 .iter()
994 .map(|(_, idx)| *idx)
995 .collect();
996
997 let right_indices: Vec<usize> = projections[median_idx..]
998 .iter()
999 .map(|(_, idx)| *idx)
1000 .collect();
1001
1002 if left_indices.is_empty() || right_indices.is_empty() {
1004 return Ok(RpNode {
1005 projection: Vec::new(),
1006 threshold: 0.0,
1007 left: None,
1008 right: None,
1009 indices,
1010 });
1011 }
1012
1013 let left = Some(Box::new(self.build_node_safe(
1014 left_indices,
1015 dimensions,
1016 rng,
1017 depth + 1,
1018 )?));
1019 let right = Some(Box::new(self.build_node_safe(
1020 right_indices,
1021 dimensions,
1022 rng,
1023 depth + 1,
1024 )?));
1025
1026 Ok(RpNode {
1027 projection,
1028 threshold,
1029 left,
1030 right,
1031 indices: Vec::new(),
1032 })
1033 }
1034
1035 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1036 if self.root.is_none() {
1037 return Vec::new();
1038 }
1039
1040 let mut heap = BinaryHeap::new();
1041 self.search_node(self.root.as_ref().unwrap(), query, k, &mut heap);
1042
1043 let mut results: Vec<(usize, f32)> =
1044 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1045
1046 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1047 results
1048 }
1049
1050 fn search_node(
1051 &self,
1052 node: &RpNode,
1053 query: &[f32],
1054 k: usize,
1055 heap: &mut BinaryHeap<SearchResult>,
1056 ) {
1057 if !node.indices.is_empty() {
1058 for &idx in &node.indices {
1060 let point = &self.data[idx].1.as_f32();
1061 let dist = self.config.distance_metric.distance(query, point);
1062
1063 if heap.len() < k {
1064 heap.push(SearchResult {
1065 index: idx,
1066 distance: dist,
1067 });
1068 } else if dist < heap.peek().unwrap().distance {
1069 heap.pop();
1070 heap.push(SearchResult {
1071 index: idx,
1072 distance: dist,
1073 });
1074 }
1075 }
1076 return;
1077 }
1078
1079 let query_projection = f32::dot(query, &node.projection);
1081
1082 let go_left = query_projection <= node.threshold;
1084
1085 let (first, second) = if go_left {
1086 (&node.left, &node.right)
1087 } else {
1088 (&node.right, &node.left)
1089 };
1090
1091 if let Some(child) = first {
1093 self.search_node(child, query, k, heap);
1094 }
1095
1096 if let Some(child) = second {
1097 self.search_node(child, query, k, heap);
1098 }
1099 }
1100}
1101
1102pub struct TreeIndex {
1104 tree_type: TreeType,
1105 ball_tree: Option<BallTree>,
1106 kd_tree: Option<KdTree>,
1107 vp_tree: Option<VpTree>,
1108 cover_tree: Option<CoverTree>,
1109 rp_tree: Option<RandomProjectionTree>,
1110}
1111
1112impl TreeIndex {
1113 pub fn new(config: TreeIndexConfig) -> Self {
1114 let tree_type = config.tree_type;
1115
1116 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1117 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1118 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1119 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1120 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1121 TreeType::RandomProjectionTree => (
1122 None,
1123 None,
1124 None,
1125 None,
1126 Some(RandomProjectionTree::new(config)),
1127 ),
1128 };
1129
1130 Self {
1131 tree_type,
1132 ball_tree,
1133 kd_tree,
1134 vp_tree,
1135 cover_tree,
1136 rp_tree,
1137 }
1138 }
1139
1140 fn build(&mut self) -> Result<()> {
1141 match self.tree_type {
1142 TreeType::BallTree => self.ball_tree.as_mut().unwrap().build(),
1143 TreeType::KdTree => self.kd_tree.as_mut().unwrap().build(),
1144 TreeType::VpTree => self.vp_tree.as_mut().unwrap().build(),
1145 TreeType::CoverTree => self.cover_tree.as_mut().unwrap().build(),
1146 TreeType::RandomProjectionTree => self.rp_tree.as_mut().unwrap().build(),
1147 }
1148 }
1149
1150 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1151 match self.tree_type {
1152 TreeType::BallTree => self.ball_tree.as_ref().unwrap().search(query, k),
1153 TreeType::KdTree => self.kd_tree.as_ref().unwrap().search(query, k),
1154 TreeType::VpTree => self.vp_tree.as_ref().unwrap().search(query, k),
1155 TreeType::CoverTree => self.cover_tree.as_ref().unwrap().search(query, k),
1156 TreeType::RandomProjectionTree => self.rp_tree.as_ref().unwrap().search(query, k),
1157 }
1158 }
1159}
1160
1161impl VectorIndex for TreeIndex {
1162 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1163 let data = match self.tree_type {
1164 TreeType::BallTree => &mut self.ball_tree.as_mut().unwrap().data,
1165 TreeType::KdTree => &mut self.kd_tree.as_mut().unwrap().data,
1166 TreeType::VpTree => &mut self.vp_tree.as_mut().unwrap().data,
1167 TreeType::CoverTree => &mut self.cover_tree.as_mut().unwrap().data,
1168 TreeType::RandomProjectionTree => &mut self.rp_tree.as_mut().unwrap().data,
1169 };
1170
1171 data.push((uri, vector));
1172 Ok(())
1173 }
1174
1175 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1176 let query_f32 = query.as_f32();
1177 let results = self.search_internal(&query_f32, k);
1178
1179 let data = match self.tree_type {
1180 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1181 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1182 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1183 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1184 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1185 };
1186
1187 Ok(results
1188 .into_iter()
1189 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1190 .collect())
1191 }
1192
1193 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1194 let query_f32 = query.as_f32();
1195 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1198 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1199 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1200 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1201 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1202 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1203 };
1204
1205 Ok(all_results
1206 .into_iter()
1207 .filter(|(_, dist)| *dist <= threshold)
1208 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1209 .collect())
1210 }
1211
1212 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1213 let data = match self.tree_type {
1214 TreeType::BallTree => &self.ball_tree.as_ref().unwrap().data,
1215 TreeType::KdTree => &self.kd_tree.as_ref().unwrap().data,
1216 TreeType::VpTree => &self.vp_tree.as_ref().unwrap().data,
1217 TreeType::CoverTree => &self.cover_tree.as_ref().unwrap().data,
1218 TreeType::RandomProjectionTree => &self.rp_tree.as_ref().unwrap().data,
1219 };
1220
1221 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1222 }
1223}
1224
1225async fn spawn_task<F, T>(f: F) -> T
1230where
1231 F: FnOnce() -> T + Send + 'static,
1232 T: Send + 'static,
1233{
1234 f()
1236}
1237
1238#[cfg(test)]
1239mod tests {
1240 use super::*;
1241
1242 #[test]
1243 #[ignore = "Stack overflow issue - being investigated"]
1244 fn test_ball_tree() {
1245 let config = TreeIndexConfig {
1246 tree_type: TreeType::BallTree,
1247 max_leaf_size: 50, ..Default::default()
1249 };
1250
1251 let mut index = TreeIndex::new(config);
1252
1253 for i in 0..3 {
1255 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1256 index.insert(format!("vec_{i}"), vector).unwrap();
1257 }
1258
1259 index.build().unwrap();
1260
1261 let query = Vector::new(vec![1.0, 2.0]);
1263 let results = index.search_knn(&query, 2).unwrap();
1264
1265 assert_eq!(results.len(), 2);
1266 assert_eq!(results[0].0, "vec_1"); }
1268
1269 #[test]
1270 #[ignore = "Stack overflow issue - being investigated"]
1271 fn test_kd_tree() {
1272 let config = TreeIndexConfig {
1273 tree_type: TreeType::KdTree,
1274 max_leaf_size: 50, ..Default::default()
1276 };
1277
1278 let mut index = TreeIndex::new(config);
1279
1280 for i in 0..3 {
1282 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1283 index.insert(format!("vec_{i}"), vector).unwrap();
1284 }
1285
1286 index.build().unwrap();
1287
1288 let query = Vector::new(vec![1.0, 2.0]);
1290 let results = index.search_knn(&query, 2).unwrap();
1291
1292 assert_eq!(results.len(), 2);
1293 }
1294
1295 #[test]
1296 #[ignore = "Stack overflow issue - being investigated"]
1297 fn test_vp_tree() {
1298 let config = TreeIndexConfig {
1299 tree_type: TreeType::VpTree,
1300 random_seed: Some(42),
1301 max_leaf_size: 50, ..Default::default()
1303 };
1304
1305 let mut index = TreeIndex::new(config);
1306
1307 for i in 0..3 {
1309 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1310 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1311 index.insert(format!("vec_{i}"), vector).unwrap();
1312 }
1313
1314 index.build().unwrap();
1315
1316 let query = Vector::new(vec![1.0, 0.0]);
1318 let results = index.search_knn(&query, 2).unwrap();
1319
1320 assert_eq!(results.len(), 2);
1321 }
1322}