1use crate::{Vector, VectorIndex};
26use anyhow::Result;
27use oxirs_core::simd::SimdOps;
28use scirs2_core::random::{Random, Rng};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32#[derive(Debug, Clone)]
34pub struct TreeIndexConfig {
35 pub tree_type: TreeType,
37 pub max_leaf_size: usize,
39 pub random_seed: Option<u64>,
41 pub parallel_construction: bool,
43 pub distance_metric: DistanceMetric,
45}
46
47impl Default for TreeIndexConfig {
48 fn default() -> Self {
49 Self {
50 tree_type: TreeType::BallTree,
51 max_leaf_size: 16, random_seed: None,
53 parallel_construction: true,
54 distance_metric: DistanceMetric::Euclidean,
55 }
56 }
57}
58
59#[derive(Debug, Clone, Copy)]
61pub enum TreeType {
62 BallTree,
63 KdTree,
64 VpTree,
65 CoverTree,
66 RandomProjectionTree,
67}
68
69#[derive(Debug, Clone, Copy)]
71pub enum DistanceMetric {
72 Euclidean,
73 Manhattan,
74 Cosine,
75 Minkowski(f32),
76}
77
78impl DistanceMetric {
79 fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
80 match self {
81 DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
82 DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
83 DistanceMetric::Cosine => f32::cosine_distance(a, b),
84 DistanceMetric::Minkowski(p) => a
85 .iter()
86 .zip(b.iter())
87 .map(|(x, y)| (x - y).abs().powf(*p))
88 .sum::<f32>()
89 .powf(1.0 / p),
90 }
91 }
92}
93
94#[derive(Debug, Clone)]
96struct SearchResult {
97 index: usize,
98 distance: f32,
99}
100
101impl PartialEq for SearchResult {
102 fn eq(&self, other: &Self) -> bool {
103 self.distance == other.distance
104 }
105}
106
107impl Eq for SearchResult {}
108
109impl PartialOrd for SearchResult {
110 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111 Some(self.cmp(other))
112 }
113}
114
115impl Ord for SearchResult {
116 fn cmp(&self, other: &Self) -> Ordering {
117 self.partial_cmp(other).unwrap_or(Ordering::Equal)
118 }
119}
120
121pub struct BallTree {
123 root: Option<Box<BallNode>>,
124 data: Vec<(String, Vector)>,
125 config: TreeIndexConfig,
126}
127
128#[derive(Clone)]
129struct BallNode {
130 center: Vec<f32>,
132 radius: f32,
134 left: Option<Box<BallNode>>,
136 right: Option<Box<BallNode>>,
138 indices: Vec<usize>,
140}
141
142impl BallTree {
143 pub fn new(config: TreeIndexConfig) -> Self {
144 Self {
145 root: None,
146 data: Vec::new(),
147 config,
148 }
149 }
150
151 pub fn build(&mut self) -> Result<()> {
156 if self.data.is_empty() {
157 return Ok(());
158 }
159
160 let indices: Vec<usize> = (0..self.data.len()).collect();
161 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
162
163 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
164 Ok(())
165 }
166
167 fn build_node_safe(
169 &self,
170 points: &[Vec<f32>],
171 indices: Vec<usize>,
172 depth: usize,
173 ) -> Result<BallNode> {
174 const MAX_DEPTH: usize = 20;
177
178 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
183 let center = self.compute_centroid(points, &indices);
184 let radius = self.compute_radius(points, &indices, ¢er);
185 return Ok(BallNode {
186 center,
187 radius,
188 left: None,
189 right: None,
190 indices,
191 });
192 }
193
194 let split_dim = self.find_split_dimension(points, &indices);
196 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
197
198 if left_indices.is_empty() || right_indices.is_empty() {
200 let center = self.compute_centroid(points, &indices);
201 let radius = self.compute_radius(points, &indices, ¢er);
202 return Ok(BallNode {
203 center,
204 radius,
205 left: None,
206 right: None,
207 indices,
208 });
209 }
210
211 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
213 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
214
215 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
217 let center = self.compute_centroid_of_centers(&all_centers);
218 let radius = left_node.radius.max(right_node.radius)
219 + self
220 .config
221 .distance_metric
222 .distance(¢er, &left_node.center);
223
224 Ok(BallNode {
225 center,
226 radius,
227 left: Some(Box::new(left_node)),
228 right: Some(Box::new(right_node)),
229 indices: Vec::new(),
230 })
231 }
232
233 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
234 let dim = points[0].len();
235 let mut centroid = vec![0.0; dim];
236
237 for &idx in indices {
238 for (i, &val) in points[idx].iter().enumerate() {
239 centroid[i] += val;
240 }
241 }
242
243 let n = indices.len() as f32;
244 for val in &mut centroid {
245 *val /= n;
246 }
247
248 centroid
249 }
250
251 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
252 indices
253 .iter()
254 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
255 .fold(0.0f32, f32::max)
256 }
257
258 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
259 let dim = points[0].len();
260 let mut max_spread = 0.0;
261 let mut split_dim = 0;
262
263 #[allow(clippy::needless_range_loop)]
265 for d in 0..dim {
266 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
267
268 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270 let spread = max_val - min_val;
271
272 if spread > max_spread {
273 max_spread = spread;
274 split_dim = d;
275 }
276 }
277
278 split_dim
279 }
280
281 fn partition_indices(
282 &self,
283 points: &[Vec<f32>],
284 indices: &[usize],
285 dim: usize,
286 ) -> (Vec<usize>, Vec<usize>) {
287 let mut values: Vec<(f32, usize)> =
288 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
289
290 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
291
292 let mid = values.len() / 2;
293 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
294 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
295
296 (left_indices, right_indices)
297 }
298
299 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
300 let dim = centers[0].len();
301 let mut centroid = vec![0.0; dim];
302
303 for center in centers {
304 for (i, &val) in center.iter().enumerate() {
305 centroid[i] += val;
306 }
307 }
308
309 let n = centers.len() as f32;
310 for val in &mut centroid {
311 *val /= n;
312 }
313
314 centroid
315 }
316
317 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
319 if self.root.is_none() {
320 return Vec::new();
321 }
322
323 let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
324 let mut stack: Vec<&BallNode> = vec![self
325 .root
326 .as_ref()
327 .expect("tree should have root after build")];
328
329 while let Some(node) = stack.pop() {
330 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
332
333 if heap.len() >= k {
334 let worst_dist = heap.peek().expect("heap should have k elements").distance;
335 if dist_to_center - node.radius > worst_dist {
336 continue; }
338 }
339
340 if node.indices.is_empty() {
341 if let (Some(left), Some(right)) = (&node.left, &node.right) {
343 let left_dist = self.config.distance_metric.distance(query, &left.center);
344 let right_dist = self.config.distance_metric.distance(query, &right.center);
345
346 if left_dist < right_dist {
348 stack.push(right);
349 stack.push(left);
350 } else {
351 stack.push(left);
352 stack.push(right);
353 }
354 }
355 } else {
356 for &idx in &node.indices {
358 let point = &self.data[idx].1.as_f32();
359 let dist = self.config.distance_metric.distance(query, point);
360
361 if heap.len() < k {
362 heap.push(SearchResult {
363 index: idx,
364 distance: dist,
365 });
366 } else if dist < heap.peek().expect("heap should have k elements").distance {
367 heap.pop();
368 heap.push(SearchResult {
369 index: idx,
370 distance: dist,
371 });
372 }
373 }
374 }
375 }
376
377 let mut results: Vec<(usize, f32)> =
378 heap.into_iter().map(|r| (r.index, r.distance)).collect();
379
380 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
381 results
382 }
383}
384
385pub struct KdTree {
387 root: Option<Box<KdNode>>,
388 data: Vec<(String, Vector)>,
389 config: TreeIndexConfig,
390}
391
392struct KdNode {
393 split_dim: usize,
395 split_value: f32,
397 left: Option<Box<KdNode>>,
399 right: Option<Box<KdNode>>,
401 indices: Vec<usize>,
403}
404
405impl KdTree {
406 pub fn new(config: TreeIndexConfig) -> Self {
407 Self {
408 root: None,
409 data: Vec::new(),
410 config,
411 }
412 }
413
414 pub fn build(&mut self) -> Result<()> {
415 if self.data.is_empty() {
416 return Ok(());
417 }
418
419 let indices: Vec<usize> = (0..self.data.len()).collect();
420 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
421
422 self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
423 Ok(())
424 }
425
426 fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
427 let max_depth = if !self.data.is_empty() {
429 ((self.data.len() as f32).log2() * 2.0) as usize + 10
430 } else {
431 50
432 };
433
434 if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
435 return Ok(KdNode {
436 split_dim: 0,
437 split_value: 0.0,
438 left: None,
439 right: None,
440 indices,
441 });
442 }
443
444 let dimensions = points[0].len();
445 let split_dim = depth % dimensions;
446
447 let mut values: Vec<(f32, usize)> = indices
449 .iter()
450 .map(|&idx| (points[idx][split_dim], idx))
451 .collect();
452
453 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
454
455 let median_idx = values.len() / 2;
456 let split_value = values[median_idx].0;
457
458 let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
459
460 let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
461
462 if left_indices.is_empty() || right_indices.is_empty() {
464 return Ok(KdNode {
465 split_dim: 0,
466 split_value: 0.0,
467 left: None,
468 right: None,
469 indices,
470 });
471 }
472
473 let left = Some(Box::new(self.build_node(
474 points,
475 left_indices,
476 depth + 1,
477 )?));
478
479 let right = Some(Box::new(self.build_node(
480 points,
481 right_indices,
482 depth + 1,
483 )?));
484
485 Ok(KdNode {
486 split_dim,
487 split_value,
488 left,
489 right,
490 indices: Vec::new(),
491 })
492 }
493
494 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
495 if self.root.is_none() {
496 return Vec::new();
497 }
498
499 let mut heap = BinaryHeap::new();
500 self.search_node(
501 self.root
502 .as_ref()
503 .expect("tree should have root after build"),
504 query,
505 k,
506 &mut heap,
507 );
508
509 let mut results: Vec<(usize, f32)> =
510 heap.into_iter().map(|r| (r.index, r.distance)).collect();
511
512 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
513 results
514 }
515
516 fn search_node(
517 &self,
518 node: &KdNode,
519 query: &[f32],
520 k: usize,
521 heap: &mut BinaryHeap<SearchResult>,
522 ) {
523 if !node.indices.is_empty() {
524 for &idx in &node.indices {
526 let point = &self.data[idx].1.as_f32();
527 let dist = self.config.distance_metric.distance(query, point);
528
529 if heap.len() < k {
530 heap.push(SearchResult {
531 index: idx,
532 distance: dist,
533 });
534 } else if dist < heap.peek().expect("heap should have k elements").distance {
535 heap.pop();
536 heap.push(SearchResult {
537 index: idx,
538 distance: dist,
539 });
540 }
541 }
542 return;
543 }
544
545 let go_left = query[node.split_dim] <= node.split_value;
547
548 let (first, second) = if go_left {
549 (&node.left, &node.right)
550 } else {
551 (&node.right, &node.left)
552 };
553
554 if let Some(child) = first {
556 self.search_node(child, query, k, heap);
557 }
558
559 if heap.len() < k || {
561 let split_dist = (query[node.split_dim] - node.split_value).abs();
562 split_dist < heap.peek().expect("heap should have k elements").distance
563 } {
564 if let Some(child) = second {
565 self.search_node(child, query, k, heap);
566 }
567 }
568 }
569}
570
571pub struct VpTree {
573 root: Option<Box<VpNode>>,
574 data: Vec<(String, Vector)>,
575 config: TreeIndexConfig,
576}
577
578struct VpNode {
579 vantage_point: usize,
581 median_distance: f32,
583 inside: Option<Box<VpNode>>,
585 outside: Option<Box<VpNode>>,
587 indices: Vec<usize>,
589}
590
591impl VpTree {
592 pub fn new(config: TreeIndexConfig) -> Self {
593 Self {
594 root: None,
595 data: Vec::new(),
596 config,
597 }
598 }
599
600 pub fn build(&mut self) -> Result<()> {
601 if self.data.is_empty() {
602 return Ok(());
603 }
604
605 let indices: Vec<usize> = (0..self.data.len()).collect();
606 let mut rng = if let Some(seed) = self.config.random_seed {
607 Random::seed(seed)
608 } else {
609 Random::seed(42)
610 };
611
612 self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
613 Ok(())
614 }
615
616 fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
617 self.build_node_safe(indices, rng, 0)
618 }
619
620 #[allow(deprecated)]
621 fn build_node_safe<R: Rng>(
622 &self,
623 mut indices: Vec<usize>,
624 rng: &mut R,
625 depth: usize,
626 ) -> Result<VpNode> {
627 let max_depth = 30; if indices.len() <= self.config.max_leaf_size
635 || indices.len() <= 2 || depth >= max_depth
637 {
638 return Ok(VpNode {
639 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
640 median_distance: 0.0,
641 inside: None,
642 outside: None,
643 indices,
644 });
645 }
646
647 let vp_idx = if indices.len() > 1 {
649 rng.gen_range(0..indices.len())
650 } else {
651 0
652 };
653 let vantage_point = indices[vp_idx];
654 indices.remove(vp_idx);
655
656 let vp_data = &self.data[vantage_point].1.as_f32();
658 let mut distances: Vec<(f32, usize)> = indices
659 .iter()
660 .map(|&idx| {
661 let point = &self.data[idx].1.as_f32();
662 let dist = self.config.distance_metric.distance(vp_data, point);
663 (dist, idx)
664 })
665 .collect();
666
667 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
668
669 let median_idx = distances.len() / 2;
670 let median_distance = distances[median_idx].0;
671
672 let inside_indices: Vec<usize> = distances[..median_idx]
673 .iter()
674 .map(|(_, idx)| *idx)
675 .collect();
676
677 let outside_indices: Vec<usize> = distances[median_idx..]
678 .iter()
679 .map(|(_, idx)| *idx)
680 .collect();
681
682 if inside_indices.is_empty() || outside_indices.is_empty() {
684 return Ok(VpNode {
685 vantage_point: if indices.is_empty() { 0 } else { indices[0] },
686 median_distance: 0.0,
687 inside: None,
688 outside: None,
689 indices,
690 });
691 }
692
693 let inside = Some(Box::new(self.build_node_safe(
694 inside_indices,
695 rng,
696 depth + 1,
697 )?));
698 let outside = Some(Box::new(self.build_node_safe(
699 outside_indices,
700 rng,
701 depth + 1,
702 )?));
703
704 Ok(VpNode {
705 vantage_point,
706 median_distance,
707 inside,
708 outside,
709 indices: Vec::new(),
710 })
711 }
712
713 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
714 if self.root.is_none() {
715 return Vec::new();
716 }
717
718 let mut heap = BinaryHeap::new();
719 self.search_node(
720 self.root
721 .as_ref()
722 .expect("tree should have root after build"),
723 query,
724 k,
725 &mut heap,
726 f32::INFINITY,
727 );
728
729 let mut results: Vec<(usize, f32)> =
730 heap.into_iter().map(|r| (r.index, r.distance)).collect();
731
732 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
733 results
734 }
735
736 fn search_node(
737 &self,
738 node: &VpNode,
739 query: &[f32],
740 k: usize,
741 heap: &mut BinaryHeap<SearchResult>,
742 tau: f32,
743 ) -> f32 {
744 let mut tau = tau;
745
746 if !node.indices.is_empty() {
747 for &idx in &node.indices {
749 let point = &self.data[idx].1.as_f32();
750 let dist = self.config.distance_metric.distance(query, point);
751
752 if dist < tau {
753 if heap.len() < k {
754 heap.push(SearchResult {
755 index: idx,
756 distance: dist,
757 });
758 } else if dist < heap.peek().expect("heap should have k elements").distance {
759 heap.pop();
760 heap.push(SearchResult {
761 index: idx,
762 distance: dist,
763 });
764 }
765
766 if heap.len() >= k {
767 tau = heap.peek().expect("heap should have k elements").distance;
768 }
769 }
770 }
771 return tau;
772 }
773
774 let vp_data = &self.data[node.vantage_point].1.as_f32();
776 let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
777
778 if dist_to_vp < tau {
780 if heap.len() < k {
781 heap.push(SearchResult {
782 index: node.vantage_point,
783 distance: dist_to_vp,
784 });
785 } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
786 heap.pop();
787 heap.push(SearchResult {
788 index: node.vantage_point,
789 distance: dist_to_vp,
790 });
791 }
792
793 if heap.len() >= k {
794 tau = heap.peek().expect("heap should have k elements").distance;
795 }
796 }
797
798 if dist_to_vp < node.median_distance {
800 if let Some(inside) = &node.inside {
802 tau = self.search_node(inside, query, k, heap, tau);
803 }
804
805 if dist_to_vp + tau >= node.median_distance {
807 if let Some(outside) = &node.outside {
808 tau = self.search_node(outside, query, k, heap, tau);
809 }
810 }
811 } else {
812 if let Some(outside) = &node.outside {
814 tau = self.search_node(outside, query, k, heap, tau);
815 }
816
817 if dist_to_vp - tau <= node.median_distance {
819 if let Some(inside) = &node.inside {
820 tau = self.search_node(inside, query, k, heap, tau);
821 }
822 }
823 }
824
825 tau
826 }
827}
828
829pub struct CoverTree {
831 root: Option<Box<CoverNode>>,
832 data: Vec<(String, Vector)>,
833 config: TreeIndexConfig,
834 base: f32,
835}
836
837struct CoverNode {
838 point: usize,
840 level: i32,
842 #[allow(clippy::vec_box)] children: Vec<Box<CoverNode>>,
845}
846
847impl CoverTree {
848 pub fn new(config: TreeIndexConfig) -> Self {
849 Self {
850 root: None,
851 data: Vec::new(),
852 config,
853 base: 2.0, }
855 }
856
857 pub fn build(&mut self) -> Result<()> {
858 if self.data.is_empty() {
859 return Ok(());
860 }
861
862 self.root = Some(Box::new(CoverNode {
864 point: 0,
865 level: self.get_level(0),
866 children: Vec::new(),
867 }));
868
869 for idx in 1..self.data.len() {
871 self.insert(idx)?;
872 }
873
874 Ok(())
875 }
876
877 fn get_level(&self, _point_idx: usize) -> i32 {
878 ((self.data.len() as f32).log2() as i32).max(0)
880 }
881
882 fn insert(&mut self, point_idx: usize) -> Result<()> {
883 let level = self.get_level(point_idx);
886 if let Some(root) = &mut self.root {
887 root.children.push(Box::new(CoverNode {
888 point: point_idx,
889 level,
890 children: Vec::new(),
891 }));
892 }
893 Ok(())
894 }
895
896 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
897 if self.root.is_none() {
898 return Vec::new();
899 }
900
901 let mut results = Vec::new();
902 self.search_node(
903 self.root
904 .as_ref()
905 .expect("tree should have root after build"),
906 query,
907 k,
908 &mut results,
909 );
910
911 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
912 results.truncate(k);
913 results
914 }
915
916 #[allow(clippy::only_used_in_recursion)]
917 fn search_node(
918 &self,
919 node: &CoverNode,
920 query: &[f32],
921 k: usize,
922 results: &mut Vec<(usize, f32)>,
923 ) {
924 if results.len() >= k * 10 {
926 return;
927 }
928
929 let point_data = &self.data[node.point].1.as_f32();
930 let dist = self.config.distance_metric.distance(query, point_data);
931
932 results.push((node.point, dist));
933
934 for child in &node.children {
936 self.search_node(child, query, k, results);
937 }
938 }
939}
940
941pub struct RandomProjectionTree {
943 root: Option<Box<RpNode>>,
944 data: Vec<(String, Vector)>,
945 config: TreeIndexConfig,
946}
947
948struct RpNode {
949 projection: Vec<f32>,
951 threshold: f32,
953 left: Option<Box<RpNode>>,
955 right: Option<Box<RpNode>>,
957 indices: Vec<usize>,
959}
960
961impl RandomProjectionTree {
962 pub fn new(config: TreeIndexConfig) -> Self {
963 Self {
964 root: None,
965 data: Vec::new(),
966 config,
967 }
968 }
969
970 pub fn build(&mut self) -> Result<()> {
971 if self.data.is_empty() {
972 return Ok(());
973 }
974
975 let indices: Vec<usize> = (0..self.data.len()).collect();
976 let dimensions = self.data[0].1.dimensions;
977
978 let mut rng = if let Some(seed) = self.config.random_seed {
979 Random::seed(seed)
980 } else {
981 Random::seed(42)
982 };
983
984 self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
985 Ok(())
986 }
987
988 fn build_node<R: Rng>(
989 &self,
990 indices: Vec<usize>,
991 dimensions: usize,
992 rng: &mut R,
993 ) -> Result<RpNode> {
994 self.build_node_safe(indices, dimensions, rng, 0)
995 }
996
997 #[allow(deprecated)]
998 fn build_node_safe<R: Rng>(
999 &self,
1000 indices: Vec<usize>,
1001 dimensions: usize,
1002 rng: &mut R,
1003 depth: usize,
1004 ) -> Result<RpNode> {
1005 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
1007 return Ok(RpNode {
1008 projection: Vec::new(),
1009 threshold: 0.0,
1010 left: None,
1011 right: None,
1012 indices,
1013 });
1014 }
1015
1016 let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
1018
1019 let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
1021 let projection: Vec<f32> = if norm > 0.0 {
1022 projection.iter().map(|&x| x / norm).collect()
1023 } else {
1024 projection
1025 };
1026
1027 let mut projections: Vec<(f32, usize)> = indices
1029 .iter()
1030 .map(|&idx| {
1031 let point = &self.data[idx].1.as_f32();
1032 let proj_val = f32::dot(point, &projection);
1033 (proj_val, idx)
1034 })
1035 .collect();
1036
1037 projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
1038
1039 let median_idx = projections.len() / 2;
1041 let threshold = projections[median_idx].0;
1042
1043 let left_indices: Vec<usize> = projections[..median_idx]
1044 .iter()
1045 .map(|(_, idx)| *idx)
1046 .collect();
1047
1048 let right_indices: Vec<usize> = projections[median_idx..]
1049 .iter()
1050 .map(|(_, idx)| *idx)
1051 .collect();
1052
1053 if left_indices.is_empty() || right_indices.is_empty() {
1055 return Ok(RpNode {
1056 projection: Vec::new(),
1057 threshold: 0.0,
1058 left: None,
1059 right: None,
1060 indices,
1061 });
1062 }
1063
1064 let left = Some(Box::new(self.build_node_safe(
1065 left_indices,
1066 dimensions,
1067 rng,
1068 depth + 1,
1069 )?));
1070 let right = Some(Box::new(self.build_node_safe(
1071 right_indices,
1072 dimensions,
1073 rng,
1074 depth + 1,
1075 )?));
1076
1077 Ok(RpNode {
1078 projection,
1079 threshold,
1080 left,
1081 right,
1082 indices: Vec::new(),
1083 })
1084 }
1085
1086 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1087 if self.root.is_none() {
1088 return Vec::new();
1089 }
1090
1091 let mut heap = BinaryHeap::new();
1092 self.search_node(
1093 self.root
1094 .as_ref()
1095 .expect("tree should have root after build"),
1096 query,
1097 k,
1098 &mut heap,
1099 );
1100
1101 let mut results: Vec<(usize, f32)> =
1102 heap.into_iter().map(|r| (r.index, r.distance)).collect();
1103
1104 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1105 results
1106 }
1107
1108 fn search_node(
1109 &self,
1110 node: &RpNode,
1111 query: &[f32],
1112 k: usize,
1113 heap: &mut BinaryHeap<SearchResult>,
1114 ) {
1115 if !node.indices.is_empty() {
1116 for &idx in &node.indices {
1118 let point = &self.data[idx].1.as_f32();
1119 let dist = self.config.distance_metric.distance(query, point);
1120
1121 if heap.len() < k {
1122 heap.push(SearchResult {
1123 index: idx,
1124 distance: dist,
1125 });
1126 } else if dist < heap.peek().expect("heap should have k elements").distance {
1127 heap.pop();
1128 heap.push(SearchResult {
1129 index: idx,
1130 distance: dist,
1131 });
1132 }
1133 }
1134 return;
1135 }
1136
1137 let query_projection = f32::dot(query, &node.projection);
1139
1140 let go_left = query_projection <= node.threshold;
1142
1143 let (first, second) = if go_left {
1144 (&node.left, &node.right)
1145 } else {
1146 (&node.right, &node.left)
1147 };
1148
1149 if let Some(child) = first {
1151 self.search_node(child, query, k, heap);
1152 }
1153
1154 if let Some(child) = second {
1155 self.search_node(child, query, k, heap);
1156 }
1157 }
1158}
1159
1160pub struct TreeIndex {
1162 tree_type: TreeType,
1163 ball_tree: Option<BallTree>,
1164 kd_tree: Option<KdTree>,
1165 vp_tree: Option<VpTree>,
1166 cover_tree: Option<CoverTree>,
1167 rp_tree: Option<RandomProjectionTree>,
1168}
1169
1170impl TreeIndex {
1171 pub fn new(config: TreeIndexConfig) -> Self {
1172 let tree_type = config.tree_type;
1173
1174 let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1175 TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1176 TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1177 TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1178 TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1179 TreeType::RandomProjectionTree => (
1180 None,
1181 None,
1182 None,
1183 None,
1184 Some(RandomProjectionTree::new(config)),
1185 ),
1186 };
1187
1188 Self {
1189 tree_type,
1190 ball_tree,
1191 kd_tree,
1192 vp_tree,
1193 cover_tree,
1194 rp_tree,
1195 }
1196 }
1197
1198 pub fn build(&mut self) -> Result<()> {
1199 match self.tree_type {
1200 TreeType::BallTree => self
1201 .ball_tree
1202 .as_mut()
1203 .expect("ball_tree should be initialized for BallTree type")
1204 .build(),
1205 TreeType::KdTree => self
1206 .kd_tree
1207 .as_mut()
1208 .expect("kd_tree should be initialized for KdTree type")
1209 .build(),
1210 TreeType::VpTree => self
1211 .vp_tree
1212 .as_mut()
1213 .expect("vp_tree should be initialized for VpTree type")
1214 .build(),
1215 TreeType::CoverTree => self
1216 .cover_tree
1217 .as_mut()
1218 .expect("cover_tree should be initialized for CoverTree type")
1219 .build(),
1220 TreeType::RandomProjectionTree => self
1221 .rp_tree
1222 .as_mut()
1223 .expect("rp_tree should be initialized for RandomProjectionTree type")
1224 .build(),
1225 }
1226 }
1227
1228 fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1229 match self.tree_type {
1230 TreeType::BallTree => self
1231 .ball_tree
1232 .as_ref()
1233 .expect("ball_tree should be initialized for BallTree type")
1234 .search(query, k),
1235 TreeType::KdTree => self
1236 .kd_tree
1237 .as_ref()
1238 .expect("kd_tree should be initialized for KdTree type")
1239 .search(query, k),
1240 TreeType::VpTree => self
1241 .vp_tree
1242 .as_ref()
1243 .expect("vp_tree should be initialized for VpTree type")
1244 .search(query, k),
1245 TreeType::CoverTree => self
1246 .cover_tree
1247 .as_ref()
1248 .expect("cover_tree should be initialized for CoverTree type")
1249 .search(query, k),
1250 TreeType::RandomProjectionTree => self
1251 .rp_tree
1252 .as_ref()
1253 .expect("rp_tree should be initialized for RandomProjectionTree type")
1254 .search(query, k),
1255 }
1256 }
1257}
1258
1259impl VectorIndex for TreeIndex {
1260 fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1261 let data = match self.tree_type {
1262 TreeType::BallTree => {
1263 &mut self
1264 .ball_tree
1265 .as_mut()
1266 .expect("ball_tree should be initialized for BallTree type")
1267 .data
1268 }
1269 TreeType::KdTree => {
1270 &mut self
1271 .kd_tree
1272 .as_mut()
1273 .expect("kd_tree should be initialized for KdTree type")
1274 .data
1275 }
1276 TreeType::VpTree => {
1277 &mut self
1278 .vp_tree
1279 .as_mut()
1280 .expect("vp_tree should be initialized for VpTree type")
1281 .data
1282 }
1283 TreeType::CoverTree => {
1284 &mut self
1285 .cover_tree
1286 .as_mut()
1287 .expect("cover_tree should be initialized for CoverTree type")
1288 .data
1289 }
1290 TreeType::RandomProjectionTree => {
1291 &mut self
1292 .rp_tree
1293 .as_mut()
1294 .expect("rp_tree should be initialized for RandomProjectionTree type")
1295 .data
1296 }
1297 };
1298
1299 data.push((uri, vector));
1300 Ok(())
1301 }
1302
1303 fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1304 let query_f32 = query.as_f32();
1305 let results = self.search_internal(&query_f32, k);
1306
1307 let data = match self.tree_type {
1308 TreeType::BallTree => {
1309 &self
1310 .ball_tree
1311 .as_ref()
1312 .expect("ball_tree should be initialized for BallTree type")
1313 .data
1314 }
1315 TreeType::KdTree => {
1316 &self
1317 .kd_tree
1318 .as_ref()
1319 .expect("kd_tree should be initialized for KdTree type")
1320 .data
1321 }
1322 TreeType::VpTree => {
1323 &self
1324 .vp_tree
1325 .as_ref()
1326 .expect("vp_tree should be initialized for VpTree type")
1327 .data
1328 }
1329 TreeType::CoverTree => {
1330 &self
1331 .cover_tree
1332 .as_ref()
1333 .expect("cover_tree should be initialized for CoverTree type")
1334 .data
1335 }
1336 TreeType::RandomProjectionTree => {
1337 &self
1338 .rp_tree
1339 .as_ref()
1340 .expect("rp_tree should be initialized for RandomProjectionTree type")
1341 .data
1342 }
1343 };
1344
1345 Ok(results
1346 .into_iter()
1347 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1348 .collect())
1349 }
1350
1351 fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1352 let query_f32 = query.as_f32();
1353 let all_results = self.search_internal(&query_f32, 1000); let data = match self.tree_type {
1356 TreeType::BallTree => {
1357 &self
1358 .ball_tree
1359 .as_ref()
1360 .expect("ball_tree should be initialized for BallTree type")
1361 .data
1362 }
1363 TreeType::KdTree => {
1364 &self
1365 .kd_tree
1366 .as_ref()
1367 .expect("kd_tree should be initialized for KdTree type")
1368 .data
1369 }
1370 TreeType::VpTree => {
1371 &self
1372 .vp_tree
1373 .as_ref()
1374 .expect("vp_tree should be initialized for VpTree type")
1375 .data
1376 }
1377 TreeType::CoverTree => {
1378 &self
1379 .cover_tree
1380 .as_ref()
1381 .expect("cover_tree should be initialized for CoverTree type")
1382 .data
1383 }
1384 TreeType::RandomProjectionTree => {
1385 &self
1386 .rp_tree
1387 .as_ref()
1388 .expect("rp_tree should be initialized for RandomProjectionTree type")
1389 .data
1390 }
1391 };
1392
1393 Ok(all_results
1394 .into_iter()
1395 .filter(|(_, dist)| *dist <= threshold)
1396 .map(|(idx, dist)| (data[idx].0.clone(), dist))
1397 .collect())
1398 }
1399
1400 fn get_vector(&self, uri: &str) -> Option<&Vector> {
1401 let data = match self.tree_type {
1402 TreeType::BallTree => {
1403 &self
1404 .ball_tree
1405 .as_ref()
1406 .expect("ball_tree should be initialized for BallTree type")
1407 .data
1408 }
1409 TreeType::KdTree => {
1410 &self
1411 .kd_tree
1412 .as_ref()
1413 .expect("kd_tree should be initialized for KdTree type")
1414 .data
1415 }
1416 TreeType::VpTree => {
1417 &self
1418 .vp_tree
1419 .as_ref()
1420 .expect("vp_tree should be initialized for VpTree type")
1421 .data
1422 }
1423 TreeType::CoverTree => {
1424 &self
1425 .cover_tree
1426 .as_ref()
1427 .expect("cover_tree should be initialized for CoverTree type")
1428 .data
1429 }
1430 TreeType::RandomProjectionTree => {
1431 &self
1432 .rp_tree
1433 .as_ref()
1434 .expect("rp_tree should be initialized for RandomProjectionTree type")
1435 .data
1436 }
1437 };
1438
1439 data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1440 }
1441}
1442
1443async fn spawn_task<F, T>(f: F) -> T
1448where
1449 F: FnOnce() -> T + Send + 'static,
1450 T: Send + 'static,
1451{
1452 f()
1454}
1455
1456#[cfg(test)]
1457mod tests {
1458 use super::*;
1459
1460 #[test]
1461 #[ignore = "Tree indices are experimental - see module documentation for alternatives"]
1462 fn test_ball_tree() {
1463 let config = TreeIndexConfig {
1464 tree_type: TreeType::BallTree,
1465 max_leaf_size: 10,
1466 ..Default::default()
1467 };
1468
1469 let mut ball_tree = BallTree::new(config);
1470
1471 for i in 0..100 {
1473 let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1474 ball_tree.data.push((format!("vec_{i}"), vector));
1475 }
1476
1477 ball_tree.build().unwrap();
1479 assert!(ball_tree.root.is_some());
1480
1481 let query = vec![50.0, 100.0];
1482 let results = ball_tree.search(&query, 5);
1483
1484 assert!(results.len() <= 5);
1485 assert!(!results.is_empty());
1486 }
1487
1488 #[test]
1489 #[ignore = "Investigating stack overflow with recursive tree construction"]
1490 fn test_kd_tree() {
1491 let config = TreeIndexConfig {
1492 tree_type: TreeType::KdTree,
1493 max_leaf_size: 50, ..Default::default()
1495 };
1496
1497 let mut index = TreeIndex::new(config);
1498
1499 for i in 0..3 {
1501 let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1502 index.insert(format!("vec_{i}"), vector).unwrap();
1503 }
1504
1505 index.build().unwrap();
1506
1507 let query = Vector::new(vec![1.0, 2.0]);
1509 let results = index.search_knn(&query, 2).unwrap();
1510
1511 assert_eq!(results.len(), 2);
1512 }
1513
1514 #[test]
1515 #[ignore = "Investigating stack overflow with recursive tree construction"]
1516 fn test_vp_tree() {
1517 let config = TreeIndexConfig {
1518 tree_type: TreeType::VpTree,
1519 random_seed: Some(42),
1520 max_leaf_size: 50, ..Default::default()
1522 };
1523
1524 let mut index = TreeIndex::new(config);
1525
1526 for i in 0..3 {
1528 let angle = (i as f32) * std::f32::consts::PI / 4.0;
1529 let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1530 index.insert(format!("vec_{i}"), vector).unwrap();
1531 }
1532
1533 index.build().unwrap();
1534
1535 let query = Vector::new(vec![1.0, 0.0]);
1537 let results = index.search_knn(&query, 2).unwrap();
1538
1539 assert_eq!(results.len(), 2);
1540 }
1541}