1use crate::{Vector, VectorIndex};
26use anyhow::Result;
27use oxirs_core::simd::SimdOps;
28use scirs2_core::random::{Random, Rng, RngExt};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32#[derive(Debug, Clone)]
34pub struct TreeIndexConfig {
35 pub tree_type: TreeType,
37 pub max_leaf_size: usize,
39 pub random_seed: Option<u64>,
41 pub parallel_construction: bool,
43 pub distance_metric: DistanceMetric,
45}
46
47impl Default for TreeIndexConfig {
48 fn default() -> Self {
49 Self {
50 tree_type: TreeType::BallTree,
51 max_leaf_size: 16, random_seed: None,
53 parallel_construction: true,
54 distance_metric: DistanceMetric::Euclidean,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61pub enum TreeType {
62 BallTree,
63 KdTree,
64 VpTree,
65 CoverTree,
66 RandomProjectionTree,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum DistanceMetric {
72 Euclidean,
73 Manhattan,
74 Cosine,
75 Minkowski(f32),
76}
77
78impl DistanceMetric {
79 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
80 match self {
81 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
82 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
83 DistanceMetric::Cosine => f32::cosine_distance(a, b),
84 DistanceMetric::Minkowski(p) => a
85 .iter()
86 .zip(b.iter())
87 .map(|(x, y)| (x - y).abs().powf(*p))
88 .sum::<f32>()
89 .powf(1.0 / p),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct SearchResult {
97 index: usize,
98 distance: f32,
99}
100
101impl PartialEq for SearchResult {
102 fn eq(&self, other: &Self) -> bool {
103 self.distance == other.distance
104 }
105}
106
107impl Eq for SearchResult {}
108
109impl PartialOrd for SearchResult {
110 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111 Some(self.cmp(other))
112 }
113}
114
115impl Ord for SearchResult {
116 fn cmp(&self, other: &Self) -> Ordering {
117 self.partial_cmp(other).unwrap_or(Ordering::Equal)
118 }
119}
120
121pub struct BallTree {
123 root: Option<Box<BallNode>>,
124 data: Vec<(String, Vector)>,
125 config: TreeIndexConfig,
126}
127
128#[derive(Clone)]
129struct BallNode {
130 center: Vec<f32>,
132 radius: f32,
134 left: Option<Box<BallNode>>,
136 right: Option<Box<BallNode>>,
138 indices: Vec<usize>,
140}
141
142impl BallTree {
143 pub fn new(config: TreeIndexConfig) -> Self {
144 Self {
145 root: None,
146 data: Vec::new(),
147 config,
148 }
149 }
150
151 pub fn build(&mut self) -> Result<()> {
156 if self.data.is_empty() {
157 return Ok(());
158 }
159
160 let indices: Vec<usize> = (0..self.data.len()).collect();
161 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
162
163 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
164 Ok(())
165 }
166
167 fn build_node_safe(
169 &self,
170 points: &[Vec<f32>],
171 indices: Vec<usize>,
172 depth: usize,
173 ) -> Result<BallNode> {
174 const MAX_DEPTH: usize = 20;
177
178 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
183 let center = self.compute_centroid(points, &indices);
184 let radius = self.compute_radius(points, &indices, ¢er);
185 return Ok(BallNode {
186 center,
187 radius,
188 left: None,
189 right: None,
190 indices,
191 });
192 }
193
194 let split_dim = self.find_split_dimension(points, &indices);
196 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
197
198 if left_indices.is_empty() || right_indices.is_empty() {
200 let center = self.compute_centroid(points, &indices);
201 let radius = self.compute_radius(points, &indices, ¢er);
202 return Ok(BallNode {
203 center,
204 radius,
205 left: None,
206 right: None,
207 indices,
208 });
209 }
210
211 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
213 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
214
215 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
217 let center = self.compute_centroid_of_centers(&all_centers);
218 let radius = left_node.radius.max(right_node.radius)
219 + self
220 .config
221 .distance_metric
222 .distance(¢er, &left_node.center);
223
224 Ok(BallNode {
225 center,
226 radius,
227 left: Some(Box::new(left_node)),
228 right: Some(Box::new(right_node)),
229 indices: Vec::new(),
230 })
231 }
232
233 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
234 let dim = points[0].len();
235 let mut centroid = vec![0.0; dim];
236
237 for &idx in indices {
238 for (i, &val) in points[idx].iter().enumerate() {
239 centroid[i] += val;
240 }
241 }
242
243 let n = indices.len() as f32;
244 for val in &mut centroid {
245 *val /= n;
246 }
247
248 centroid
249 }
250
251 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
252 indices
253 .iter()
254 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
255 .fold(0.0f32, f32::max)
256 }
257
258 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
259 let dim = points[0].len();
260 let mut max_spread = 0.0;
261 let mut split_dim = 0;
262
263 #[allow(clippy::needless_range_loop)]
265 for d in 0..dim {
266 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
267
268 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270 let spread = max_val - min_val;
271
272 if spread > max_spread {
273 max_spread = spread;
274 split_dim = d;
275 }
276 }
277
278 split_dim
279 }
280
281 fn partition_indices(
282 &self,
283 points: &[Vec<f32>],
284 indices: &[usize],
285 dim: usize,
286 ) -> (Vec<usize>, Vec<usize>) {
287 let mut values: Vec<(f32, usize)> =
288 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
289
290 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
291
292 let mid = values.len() / 2;
293 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
294 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
295
296 (left_indices, right_indices)
297 }
298
299 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
300 let dim = centers[0].len();
301 let mut centroid = vec![0.0; dim];
302
303 for center in centers {
304 for (i, &val) in center.iter().enumerate() {
305 centroid[i] += val;
306 }
307 }
308
309 let n = centers.len() as f32;
310 for val in &mut centroid {
311 *val /= n;
312 }
313
314 centroid
315 }
316
317 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
319 if self.root.is_none() {
320 return Vec::new();
321 }
322
323 let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
324 let mut stack: Vec<&BallNode> = vec![self
325 .root
326 .as_ref()
327 .expect("tree should have root after build")];
328
329 while let Some(node) = stack.pop() {
330 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
332
333 if heap.len() >= k {
334 let worst_dist = heap.peek().expect("heap should have k elements").distance;
335 if dist_to_center - node.radius > worst_dist {
336 continue; }
338 }
339
340 if node.indices.is_empty() {
341 if let (Some(left), Some(right)) = (&node.left, &node.right) {
343 let left_dist = self.config.distance_metric.distance(query, &left.center);
344 let right_dist = self.config.distance_metric.distance(query, &right.center);
345
346 if left_dist < right_dist {
348 stack.push(right);
349 stack.push(left);
350 } else {
351 stack.push(left);
352 stack.push(right);
353 }
354 }
355 } else {
356 for &idx in &node.indices {
358 let point = &self.data[idx].1.as_f32();
359 let dist = self.config.distance_metric.distance(query, point);
360
361 if heap.len() < k {
362 heap.push(SearchResult {
363 index: idx,
364 distance: dist,
365 });
366 } else if dist < heap.peek().expect("heap should have k elements").distance {
367 heap.pop();
368 heap.push(SearchResult {
369 index: idx,
370 distance: dist,
371 });
372 }
373 }
374 }
375 }
376
377 let mut results: Vec<(usize, f32)> =
378 heap.into_iter().map(|r| (r.index, r.distance)).collect();
379
380 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
381 results
382 }
383}
384
385pub struct KdTree {
387 root: Option<Box<KdNode>>,
388 data: Vec<(String, Vector)>,
389 config: TreeIndexConfig,
390}
391
392struct KdNode {
393 split_dim: usize,
395 split_value: f32,
397 left: Option<Box<KdNode>>,
399 right: Option<Box<KdNode>>,
401 indices: Vec<usize>,
403}
404
405impl KdTree {
406 pub fn new(config: TreeIndexConfig) -> Self {
407 Self {
408 root: None,
409 data: Vec::new(),
410 config,
411 }
412 }
413
414 pub fn build(&mut self) -> Result<()> {
415 if self.data.is_empty() {
416 return Ok(());
417 }
418
419 let indices: Vec<usize> = (0..self.data.len()).collect();
420 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
421
422 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
423 Ok(())
424 }
425
426 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
427 let max_depth = if !self.data.is_empty() {
429 ((self.data.len() as f32).log2() * 2.0) as usize + 10
430 } else {
431 50
432 };
433
434 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
435 return Ok(KdNode {
436 split_dim: 0,
437 split_value: 0.0,
438 left: None,
439 right: None,
440 indices,
441 });
442 }
443
444 let dimensions = points[0].len();
445 let split_dim = depth % dimensions;
446
447 let mut values: Vec<(f32, usize)> = indices
449 .iter()
450 .map(|&idx| (points[idx][split_dim], idx))
451 .collect();
452
453 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
454
455 let median_idx = values.len() / 2;
456 let split_value = values[median_idx].0;
457
458 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
459
460 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
461
462 if left_indices.is_empty() || right_indices.is_empty() {
464 return Ok(KdNode {
465 split_dim: 0,
466 split_value: 0.0,
467 left: None,
468 right: None,
469 indices,
470 });
471 }
472
473 let left = Some(Box::new(self.build_node(
474 points,
475 left_indices,
476 depth + 1,
477 )?));
478
479 let right = Some(Box::new(self.build_node(
480 points,
481 right_indices,
482 depth + 1,
483 )?));
484
485 Ok(KdNode {
486 split_dim,
487 split_value,
488 left,
489 right,
490 indices: Vec::new(),
491 })
492 }
493
494 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
495 if self.root.is_none() {
496 return Vec::new();
497 }
498
499 let mut heap = BinaryHeap::new();
500 self.search_node(
501 self.root
502 .as_ref()
503 .expect("tree should have root after build"),
504 query,
505 k,
506 &mut heap,
507 );
508
509 let mut results: Vec<(usize, f32)> =
510 heap.into_iter().map(|r| (r.index, r.distance)).collect();
511
512 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
513 results
514 }
515
516 fn search_node(
517 &self,
518 node: &KdNode,
519 query: &[f32],
520 k: usize,
521 heap: &mut BinaryHeap<SearchResult>,
522 ) {
523 if !node.indices.is_empty() {
524 for &idx in &node.indices {
526 let point = &self.data[idx].1.as_f32();
527 let dist = self.config.distance_metric.distance(query, point);
528
529 if heap.len() < k {
530 heap.push(SearchResult {
531 index: idx,
532 distance: dist,
533 });
534 } else if dist < heap.peek().expect("heap should have k elements").distance {
535 heap.pop();
536 heap.push(SearchResult {
537 index: idx,
538 distance: dist,
539 });
540 }
541 }
542 return;
543 }
544
545 let go_left = query[node.split_dim] <= node.split_value;
547
548 let (first, second) = if go_left {
549 (&node.left, &node.right)
550 } else {
551 (&node.right, &node.left)
552 };
553
554 if let Some(child) = first {
556 self.search_node(child, query, k, heap);
557 }
558
559 if heap.len() < k || {
561 let split_dist = (query[node.split_dim] - node.split_value).abs();
562 split_dist < heap.peek().expect("heap should have k elements").distance
563 } {
564 if let Some(child) = second {
565 self.search_node(child, query, k, heap);
566 }
567 }
568 }
569}
570
571pub struct VpTree {
573 root: Option<Box<VpNode>>,
574 data: Vec<(String, Vector)>,
575 config: TreeIndexConfig,
576}
577
578struct VpNode {
579 vantage_point: usize,
581 median_distance: f32,
583 inside: Option<Box<VpNode>>,
585 outside: Option<Box<VpNode>>,
587 indices: Vec<usize>,
589}
590
591impl VpTree {
592 pub fn new(config: TreeIndexConfig) -> Self {
593 Self {
594 root: None,
595 data: Vec::new(),
596 config,
597 }
598 }
599
600 pub fn build(&mut self) -> Result<()> {
601 if self.data.is_empty() {
602 return Ok(());
603 }
604
605 let indices: Vec<usize> = (0..self.data.len()).collect();
606 let mut rng = if let Some(seed) = self.config.random_seed {
607 Random::seed(seed)
608 } else {
609 Random::seed(42)
610 };
611
612 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
613 Ok(())
614 }
615
616 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
617 self.build_node_safe(indices, rng, 0)
618 }
619
620 #[allow(deprecated)]
621 fn build_node_safe<R: Rng>(
622 &self,
623 mut indices: Vec<usize>,
624 rng: &mut R,
625 depth: usize,
626 ) -> Result<VpNode> {
627 let max_depth = 30; if indices.len() <= self.config.max_leaf_size
635 || indices.len() <= 2 || depth >= max_depth
637 {
638 return Ok(VpNode {
639 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
640 median_distance: 0.0,
641 inside: None,
642 outside: None,
643 indices,
644 });
645 }
646
647 let vp_idx = if indices.len() > 1 {
649 rng.random_range(0..indices.len())
650 } else {
651 0
652 };
653 let vantage_point = indices[vp_idx];
654 indices.remove(vp_idx);
655
656 let vp_data = &self.data[vantage_point].1.as_f32();
658 let mut distances: Vec<(f32, usize)> = indices
659 .iter()
660 .map(|&idx| {
661 let point = &self.data[idx].1.as_f32();
662 let dist = self.config.distance_metric.distance(vp_data, point);
663 (dist, idx)
664 })
665 .collect();
666
667 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
668
669 let median_idx = distances.len() / 2;
670 let median_distance = distances[median_idx].0;
671
672 let inside_indices: Vec<usize> = distances[..median_idx]
673 .iter()
674 .map(|(_, idx)| *idx)
675 .collect();
676
677 let outside_indices: Vec<usize> = distances[median_idx..]
678 .iter()
679 .map(|(_, idx)| *idx)
680 .collect();
681
682 if inside_indices.is_empty() || outside_indices.is_empty() {
684 return Ok(VpNode {
685 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
686 median_distance: 0.0,
687 inside: None,
688 outside: None,
689 indices,
690 });
691 }
692
693 let inside = Some(Box::new(self.build_node_safe(
694 inside_indices,
695 rng,
696 depth + 1,
697 )?));
698 let outside = Some(Box::new(self.build_node_safe(
699 outside_indices,
700 rng,
701 depth + 1,
702 )?));
703
704 Ok(VpNode {
705 vantage_point,
706 median_distance,
707 inside,
708 outside,
709 indices: Vec::new(),
710 })
711 }
712
713 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
714 if self.root.is_none() {
715 return Vec::new();
716 }
717
718 let mut heap = BinaryHeap::new();
719 self.search_node(
720 self.root
721 .as_ref()
722 .expect("tree should have root after build"),
723 query,
724 k,
725 &mut heap,
726 f32::INFINITY,
727 );
728
729 let mut results: Vec<(usize, f32)> =
730 heap.into_iter().map(|r| (r.index, r.distance)).collect();
731
732 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
733 results
734 }
735
736 fn search_node(
737 &self,
738 node: &VpNode,
739 query: &[f32],
740 k: usize,
741 heap: &mut BinaryHeap<SearchResult>,
742 tau: f32,
743 ) -> f32 {
744 let mut tau = tau;
745
746 if !node.indices.is_empty() {
747 for &idx in &node.indices {
749 let point = &self.data[idx].1.as_f32();
750 let dist = self.config.distance_metric.distance(query, point);
751
752 if dist < tau {
753 if heap.len() < k {
754 heap.push(SearchResult {
755 index: idx,
756 distance: dist,
757 });
758 } else if dist < heap.peek().expect("heap should have k elements").distance {
759 heap.pop();
760 heap.push(SearchResult {
761 index: idx,
762 distance: dist,
763 });
764 }
765
766 if heap.len() >= k {
767 tau = heap.peek().expect("heap should have k elements").distance;
768 }
769 }
770 }
771 return tau;
772 }
773
774 let vp_data = &self.data[node.vantage_point].1.as_f32();
776 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
777
778 if dist_to_vp < tau {
780 if heap.len() < k {
781 heap.push(SearchResult {
782 index: node.vantage_point,
783 distance: dist_to_vp,
784 });
785 } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
786 heap.pop();
787 heap.push(SearchResult {
788 index: node.vantage_point,
789 distance: dist_to_vp,
790 });
791 }
792
793 if heap.len() >= k {
794 tau = heap.peek().expect("heap should have k elements").distance;
795 }
796 }
797
798 if dist_to_vp < node.median_distance {
800 if let Some(inside) = &node.inside {
802 tau = self.search_node(inside, query, k, heap, tau);
803 }
804
805 if dist_to_vp + tau >= node.median_distance {
807 if let Some(outside) = &node.outside {
808 tau = self.search_node(outside, query, k, heap, tau);
809 }
810 }
811 } else {
812 if let Some(outside) = &node.outside {
814 tau = self.search_node(outside, query, k, heap, tau);
815 }
816
817 if dist_to_vp - tau <= node.median_distance {
819 if let Some(inside) = &node.inside {
820 tau = self.search_node(inside, query, k, heap, tau);
821 }
822 }
823 }
824
825 tau
826 }
827}
828
829pub struct CoverTree {
831 root: Option<Box<CoverNode>>,
832 data: Vec<(String, Vector)>,
833 config: TreeIndexConfig,
834 base: f32,
835}
836
837struct CoverNode {
838 point: usize,
840 level: i32,
842 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
845}
846
847impl CoverTree {
848 pub fn new(config: TreeIndexConfig) -> Self {
849 Self {
850 root: None,
851 data: Vec::new(),
852 config,
853 base: 2.0, }
855 }
856
857 pub fn build(&mut self) -> Result<()> {
858 if self.data.is_empty() {
859 return Ok(());
860 }
861
862 self.root = Some(Box::new(CoverNode {
864 point: 0,
865 level: self.get_level(0),
866 children: Vec::new(),
867 }));
868
869 for idx in 1..self.data.len() {
871 self.insert(idx)?;
872 }
873
874 Ok(())
875 }
876
877 fn get_level(&self, _point_idx: usize) -> i32 {
878 ((self.data.len() as f32).log2() as i32).max(0)
880 }
881
882 fn insert(&mut self, point_idx: usize) -> Result<()> {
883 let level = self.get_level(point_idx);
886 if let Some(root) = &mut self.root {
887 root.children.push(Box::new(CoverNode {
888 point: point_idx,
889 level,
890 children: Vec::new(),
891 }));
892 }
893 Ok(())
894 }
895
896 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
897 if self.root.is_none() {
898 return Vec::new();
899 }
900
901 let mut results = Vec::new();
902 self.search_node(
903 self.root
904 .as_ref()
905 .expect("tree should have root after build"),
906 query,
907 k,
908 &mut results,
909 );
910
911 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
912 results.truncate(k);
913 results
914 }
915
916 #[allow(clippy::only_used_in_recursion)]
917 fn search_node(
918 &self,
919 node: &CoverNode,
920 query: &[f32],
921 k: usize,
922 results: &mut Vec<(usize, f32)>,
923 ) {
924 if results.len() >= k * 10 {
926 return;
927 }
928
929 let point_data = &self.data[node.point].1.as_f32();
930 let dist = self.config.distance_metric.distance(query, point_data);
931
932 results.push((node.point, dist));
933
934 for child in &node.children {
936 self.search_node(child, query, k, results);
937 }
938 }
939}
940
941pub struct RandomProjectionTree {
943 root: Option<Box<RpNode>>,
944 data: Vec<(String, Vector)>,
945 config: TreeIndexConfig,
946}
947
948struct RpNode {
949 projection: Vec<f32>,
951 threshold: f32,
953 left: Option<Box<RpNode>>,
955 right: Option<Box<RpNode>>,
957 indices: Vec<usize>,
959}
960
961impl RandomProjectionTree {
962 pub fn new(config: TreeIndexConfig) -> Self {
963 Self {
964 root: None,
965 data: Vec::new(),
966 config,
967 }
968 }
969
970 pub fn build(&mut self) -> Result<()> {
971 if self.data.is_empty() {
972 return Ok(());
973 }
974
975 let indices: Vec<usize> = (0..self.data.len()).collect();
976 let dimensions = self.data[0].1.dimensions;
977
978 let mut rng = if let Some(seed) = self.config.random_seed {
979 Random::seed(seed)
980 } else {
981 Random::seed(42)
982 };
983
984 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
985 Ok(())
986 }
987
988 fn build_node<R: Rng>(
989 &self,
990 indices: Vec<usize>,
991 dimensions: usize,
992 rng: &mut R,
993 ) -> Result<RpNode> {
994 self.build_node_safe(indices, dimensions, rng, 0)
995 }
996
997 #[allow(deprecated)]
998 fn build_node_safe<R: Rng>(
999 &self,
1000 indices: Vec<usize>,
1001 dimensions: usize,
1002 rng: &mut R,
1003 depth: usize,
1004 ) -> Result<RpNode> {
1005 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
1007 return Ok(RpNode {
1008 projection: Vec::new(),
1009 threshold: 0.0,
1010 left: None,
1011 right: None,
1012 indices,
1013 });
1014 }
1015
1016 let projection: Vec<f32> = (0..dimensions)
1018 .map(|_| rng.random_range(-1.0..1.0))
1019 .collect();
1020
1021 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
1023 let projection: Vec<f32> = if norm > 0.0 {
1024 projection.iter().map(|&x| x / norm).collect()
1025 } else {
1026 projection
1027 };
1028
1029 let mut projections: Vec<(f32, usize)> = indices
1031 .iter()
1032 .map(|&idx| {
1033 let point = &self.data[idx].1.as_f32();
1034 let proj_val = f32::dot(point, &projection);
1035 (proj_val, idx)
1036 })
1037 .collect();
1038
1039 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
1040
1041 let median_idx = projections.len() / 2;
1043 let threshold = projections[median_idx].0;
1044
1045 let left_indices: Vec<usize> = projections[..median_idx]
1046 .iter()
1047 .map(|(_, idx)| *idx)
1048 .collect();
1049
1050 let right_indices: Vec<usize> = projections[median_idx..]
1051 .iter()
1052 .map(|(_, idx)| *idx)
1053 .collect();
1054
1055 if left_indices.is_empty() || right_indices.is_empty() {
1057 return Ok(RpNode {
1058 projection: Vec::new(),
1059 threshold: 0.0,
1060 left: None,
1061 right: None,
1062 indices,
1063 });
1064 }
1065
1066 let left = Some(Box::new(self.build_node_safe(
1067 left_indices,
1068 dimensions,
1069 rng,
1070 depth + 1,
1071 )?));
1072 let right = Some(Box::new(self.build_node_safe(
1073 right_indices,
1074 dimensions,
1075 rng,
1076 depth + 1,
1077 )?));
1078
1079 Ok(RpNode {
1080 projection,
1081 threshold,
1082 left,
1083 right,
1084 indices: Vec::new(),
1085 })
1086 }
1087
1088 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1089 if self.root.is_none() {
1090 return Vec::new();
1091 }
1092
1093 let mut heap = BinaryHeap::new();
1094 self.search_node(
1095 self.root
1096 .as_ref()
1097 .expect("tree should have root after build"),
1098 query,
1099 k,
1100 &mut heap,
1101 );
1102
1103 let mut results: Vec<(usize, f32)> =
1104 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1105
1106 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1107 results
1108 }
1109
1110 fn search_node(
1111 &self,
1112 node: &RpNode,
1113 query: &[f32],
1114 k: usize,
1115 heap: &mut BinaryHeap<SearchResult>,
1116 ) {
1117 if !node.indices.is_empty() {
1118 for &idx in &node.indices {
1120 let point = &self.data[idx].1.as_f32();
1121 let dist = self.config.distance_metric.distance(query, point);
1122
1123 if heap.len() < k {
1124 heap.push(SearchResult {
1125 index: idx,
1126 distance: dist,
1127 });
1128 } else if dist < heap.peek().expect("heap should have k elements").distance {
1129 heap.pop();
1130 heap.push(SearchResult {
1131 index: idx,
1132 distance: dist,
1133 });
1134 }
1135 }
1136 return;
1137 }
1138
1139 let query_projection = f32::dot(query, &node.projection);
1141
1142 let go_left = query_projection <= node.threshold;
1144
1145 let (first, second) = if go_left {
1146 (&node.left, &node.right)
1147 } else {
1148 (&node.right, &node.left)
1149 };
1150
1151 if let Some(child) = first {
1153 self.search_node(child, query, k, heap);
1154 }
1155
1156 if let Some(child) = second {
1157 self.search_node(child, query, k, heap);
1158 }
1159 }
1160}
1161
1162pub struct TreeIndex {
1164 tree_type: TreeType,
1165 ball_tree: Option<BallTree>,
1166 kd_tree: Option<KdTree>,
1167 vp_tree: Option<VpTree>,
1168 cover_tree: Option<CoverTree>,
1169 rp_tree: Option<RandomProjectionTree>,
1170}
1171
1172impl TreeIndex {
1173 pub fn new(config: TreeIndexConfig) -> Self {
1174 let tree_type = config.tree_type;
1175
1176 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1177 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1178 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1179 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1180 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1181 TreeType::RandomProjectionTree => (
1182 None,
1183 None,
1184 None,
1185 None,
1186 Some(RandomProjectionTree::new(config)),
1187 ),
1188 };
1189
1190 Self {
1191 tree_type,
1192 ball_tree,
1193 kd_tree,
1194 vp_tree,
1195 cover_tree,
1196 rp_tree,
1197 }
1198 }
1199
1200 pub fn build(&mut self) -> Result<()> {
1201 match self.tree_type {
1202 TreeType::BallTree => self
1203 .ball_tree
1204 .as_mut()
1205 .expect("ball_tree should be initialized for BallTree type")
1206 .build(),
1207 TreeType::KdTree => self
1208 .kd_tree
1209 .as_mut()
1210 .expect("kd_tree should be initialized for KdTree type")
1211 .build(),
1212 TreeType::VpTree => self
1213 .vp_tree
1214 .as_mut()
1215 .expect("vp_tree should be initialized for VpTree type")
1216 .build(),
1217 TreeType::CoverTree => self
1218 .cover_tree
1219 .as_mut()
1220 .expect("cover_tree should be initialized for CoverTree type")
1221 .build(),
1222 TreeType::RandomProjectionTree => self
1223 .rp_tree
1224 .as_mut()
1225 .expect("rp_tree should be initialized for RandomProjectionTree type")
1226 .build(),
1227 }
1228 }
1229
1230 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1231 match self.tree_type {
1232 TreeType::BallTree => self
1233 .ball_tree
1234 .as_ref()
1235 .expect("ball_tree should be initialized for BallTree type")
1236 .search(query, k),
1237 TreeType::KdTree => self
1238 .kd_tree
1239 .as_ref()
1240 .expect("kd_tree should be initialized for KdTree type")
1241 .search(query, k),
1242 TreeType::VpTree => self
1243 .vp_tree
1244 .as_ref()
1245 .expect("vp_tree should be initialized for VpTree type")
1246 .search(query, k),
1247 TreeType::CoverTree => self
1248 .cover_tree
1249 .as_ref()
1250 .expect("cover_tree should be initialized for CoverTree type")
1251 .search(query, k),
1252 TreeType::RandomProjectionTree => self
1253 .rp_tree
1254 .as_ref()
1255 .expect("rp_tree should be initialized for RandomProjectionTree type")
1256 .search(query, k),
1257 }
1258 }
1259}
1260
1261impl VectorIndex for TreeIndex {
1262 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1263 let data = match self.tree_type {
1264 TreeType::BallTree => {
1265 &mut self
1266 .ball_tree
1267 .as_mut()
1268 .expect("ball_tree should be initialized for BallTree type")
1269 .data
1270 }
1271 TreeType::KdTree => {
1272 &mut self
1273 .kd_tree
1274 .as_mut()
1275 .expect("kd_tree should be initialized for KdTree type")
1276 .data
1277 }
1278 TreeType::VpTree => {
1279 &mut self
1280 .vp_tree
1281 .as_mut()
1282 .expect("vp_tree should be initialized for VpTree type")
1283 .data
1284 }
1285 TreeType::CoverTree => {
1286 &mut self
1287 .cover_tree
1288 .as_mut()
1289 .expect("cover_tree should be initialized for CoverTree type")
1290 .data
1291 }
1292 TreeType::RandomProjectionTree => {
1293 &mut self
1294 .rp_tree
1295 .as_mut()
1296 .expect("rp_tree should be initialized for RandomProjectionTree type")
1297 .data
1298 }
1299 };
1300
1301 data.push((uri, vector));
1302 Ok(())
1303 }
1304
1305 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1306 let query_f32 = query.as_f32();
1307 let results = self.search_internal(&query_f32, k);
1308
1309 let data = match self.tree_type {
1310 TreeType::BallTree => {
1311 &self
1312 .ball_tree
1313 .as_ref()
1314 .expect("ball_tree should be initialized for BallTree type")
1315 .data
1316 }
1317 TreeType::KdTree => {
1318 &self
1319 .kd_tree
1320 .as_ref()
1321 .expect("kd_tree should be initialized for KdTree type")
1322 .data
1323 }
1324 TreeType::VpTree => {
1325 &self
1326 .vp_tree
1327 .as_ref()
1328 .expect("vp_tree should be initialized for VpTree type")
1329 .data
1330 }
1331 TreeType::CoverTree => {
1332 &self
1333 .cover_tree
1334 .as_ref()
1335 .expect("cover_tree should be initialized for CoverTree type")
1336 .data
1337 }
1338 TreeType::RandomProjectionTree => {
1339 &self
1340 .rp_tree
1341 .as_ref()
1342 .expect("rp_tree should be initialized for RandomProjectionTree type")
1343 .data
1344 }
1345 };
1346
1347 Ok(results
1348 .into_iter()
1349 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1350 .collect())
1351 }
1352
1353 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1354 let query_f32 = query.as_f32();
1355 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1358 TreeType::BallTree => {
1359 &self
1360 .ball_tree
1361 .as_ref()
1362 .expect("ball_tree should be initialized for BallTree type")
1363 .data
1364 }
1365 TreeType::KdTree => {
1366 &self
1367 .kd_tree
1368 .as_ref()
1369 .expect("kd_tree should be initialized for KdTree type")
1370 .data
1371 }
1372 TreeType::VpTree => {
1373 &self
1374 .vp_tree
1375 .as_ref()
1376 .expect("vp_tree should be initialized for VpTree type")
1377 .data
1378 }
1379 TreeType::CoverTree => {
1380 &self
1381 .cover_tree
1382 .as_ref()
1383 .expect("cover_tree should be initialized for CoverTree type")
1384 .data
1385 }
1386 TreeType::RandomProjectionTree => {
1387 &self
1388 .rp_tree
1389 .as_ref()
1390 .expect("rp_tree should be initialized for RandomProjectionTree type")
1391 .data
1392 }
1393 };
1394
1395 Ok(all_results
1396 .into_iter()
1397 .filter(|(_, dist)| *dist <= threshold)
1398 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1399 .collect())
1400 }
1401
1402 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1403 let data = match self.tree_type {
1404 TreeType::BallTree => {
1405 &self
1406 .ball_tree
1407 .as_ref()
1408 .expect("ball_tree should be initialized for BallTree type")
1409 .data
1410 }
1411 TreeType::KdTree => {
1412 &self
1413 .kd_tree
1414 .as_ref()
1415 .expect("kd_tree should be initialized for KdTree type")
1416 .data
1417 }
1418 TreeType::VpTree => {
1419 &self
1420 .vp_tree
1421 .as_ref()
1422 .expect("vp_tree should be initialized for VpTree type")
1423 .data
1424 }
1425 TreeType::CoverTree => {
1426 &self
1427 .cover_tree
1428 .as_ref()
1429 .expect("cover_tree should be initialized for CoverTree type")
1430 .data
1431 }
1432 TreeType::RandomProjectionTree => {
1433 &self
1434 .rp_tree
1435 .as_ref()
1436 .expect("rp_tree should be initialized for RandomProjectionTree type")
1437 .data
1438 }
1439 };
1440
1441 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1442 }
1443}
1444
1445async fn spawn_task<F, T>(f: F) -> T
1450where
1451 F: FnOnce() -> T + Send + 'static,
1452 T: Send + 'static,
1453{
1454 f()
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460 use super::*;
1461
1462 #[test]
1463 #[ignore = "Tree indices are experimental - see module documentation for alternatives"]
1464 fn test_ball_tree() -> Result<()> {
1465 let config = TreeIndexConfig {
1466 tree_type: TreeType::BallTree,
1467 max_leaf_size: 10,
1468 ..Default::default()
1469 };
1470
1471 let mut ball_tree = BallTree::new(config);
1472
1473 for i in 0..100 {
1475 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1476 ball_tree.data.push((format!("vec_{i}"), vector));
1477 }
1478
1479 ball_tree.build()?;
1481 assert!(ball_tree.root.is_some());
1482
1483 let query = vec![50.0, 100.0];
1484 let results = ball_tree.search(&query, 5);
1485
1486 assert!(results.len() <= 5);
1487 assert!(!results.is_empty());
1488 Ok(())
1489 }
1490
1491 #[test]
1492 #[ignore = "Investigating stack overflow with recursive tree construction"]
1493 fn test_kd_tree() -> Result<()> {
1494 let config = TreeIndexConfig {
1495 tree_type: TreeType::KdTree,
1496 max_leaf_size: 50, ..Default::default()
1498 };
1499
1500 let mut index = TreeIndex::new(config);
1501
1502 for i in 0..3 {
1504 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1505 index.insert(format!("vec_{i}"), vector)?;
1506 }
1507
1508 index.build()?;
1509
1510 let query = Vector::new(vec![1.0, 2.0]);
1512 let results = index.search_knn(&query, 2)?;
1513
1514 assert_eq!(results.len(), 2);
1515 Ok(())
1516 }
1517
1518 #[test]
1519 #[ignore = "Investigating stack overflow with recursive tree construction"]
1520 fn test_vp_tree() -> Result<()> {
1521 let config = TreeIndexConfig {
1522 tree_type: TreeType::VpTree,
1523 random_seed: Some(42),
1524 max_leaf_size: 50, ..Default::default()
1526 };
1527
1528 let mut index = TreeIndex::new(config);
1529
1530 for i in 0..3 {
1532 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1533 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1534 index.insert(format!("vec_{i}"), vector)?;
1535 }
1536
1537 index.build()?;
1538
1539 let query = Vector::new(vec![1.0, 0.0]);
1541 let results = index.search_knn(&query, 2)?;
1542
1543 assert_eq!(results.len(), 2);
1544 Ok(())
1545 }
1546}