Skip to main content

oxirs_vec/
tree_indices.rs

1//! Tree-based indices for efficient nearest neighbor search
2//!
3//! **EXPERIMENTAL**: These tree implementations are currently experimental
4//! and have known limitations with large datasets or specific configurations.
5//! For production use, prefer HNSW, IVF, or LSH indices instead.
6//!
7//! ## Known Limitations
8//! - Tree construction uses recursion with conservative depth limits (20 levels)
9//! - Best suited for moderate-sized datasets (< 100K vectors)
10//! - May encounter stack overflow on systems with very small stack sizes
11//! - Performance degrades in high-dimensional spaces (> 128 dimensions)
12//!
13//! ## Recommended Alternatives
14//! - For most use cases: Use `HnswIndex` (Hierarchical Navigable Small World)
15//! - For very large datasets: Use `IVFIndex` (Inverted File Index)
16//! - For approximate search: Use `LSHIndex` (Locality Sensitive Hashing)
17//!
18//! This module implements various tree data structures:
19//! - Ball Tree: Efficient for arbitrary metrics
20//! - KD-Tree: Classic space partitioning tree
21//! - VP-Tree: Vantage point tree for metric spaces
22//! - Cover Tree: Navigating nets with provable bounds
23//! - Random Projection Trees: Randomized space partitioning
24
25use crate::{Vector, VectorIndex};
26use anyhow::Result;
27use oxirs_core::simd::SimdOps;
28use scirs2_core::random::{Random, Rng, RngExt};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32/// Configuration for tree-based indices
33#[derive(Debug, Clone)]
34pub struct TreeIndexConfig {
35    /// Type of tree to use
36    pub tree_type: TreeType,
37    /// Maximum leaf size before splitting
38    pub max_leaf_size: usize,
39    /// Random seed for reproducibility
40    pub random_seed: Option<u64>,
41    /// Enable parallel construction
42    pub parallel_construction: bool,
43    /// Distance metric
44    pub distance_metric: DistanceMetric,
45}
46
47impl Default for TreeIndexConfig {
48    fn default() -> Self {
49        Self {
50            tree_type: TreeType::BallTree,
51            max_leaf_size: 16, // Larger leaf size to prevent deep recursion and stack overflow
52            random_seed: None,
53            parallel_construction: true,
54            distance_metric: DistanceMetric::Euclidean,
55        }
56    }
57}
58
59/// Available tree types
60#[derive(Debug, Clone, Copy)]
61pub enum TreeType {
62    BallTree,
63    KdTree,
64    VpTree,
65    CoverTree,
66    RandomProjectionTree,
67}
68
69/// Distance metrics
70#[derive(Debug, Clone, Copy)]
71pub enum DistanceMetric {
72    Euclidean,
73    Manhattan,
74    Cosine,
75    Minkowski(f32),
76}
77
78impl DistanceMetric {
79    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
80        match self {
81            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
82            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
83            DistanceMetric::Cosine => f32::cosine_distance(a, b),
84            DistanceMetric::Minkowski(p) => a
85                .iter()
86                .zip(b.iter())
87                .map(|(x, y)| (x - y).abs().powf(*p))
88                .sum::<f32>()
89                .powf(1.0 / p),
90        }
91    }
92}
93
94/// Search result with distance
95#[derive(Debug, Clone)]
96struct SearchResult {
97    index: usize,
98    distance: f32,
99}
100
101impl PartialEq for SearchResult {
102    fn eq(&self, other: &Self) -> bool {
103        self.distance == other.distance
104    }
105}
106
107impl Eq for SearchResult {}
108
109impl PartialOrd for SearchResult {
110    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111        Some(self.cmp(other))
112    }
113}
114
115impl Ord for SearchResult {
116    fn cmp(&self, other: &Self) -> Ordering {
117        self.partial_cmp(other).unwrap_or(Ordering::Equal)
118    }
119}
120
121/// Ball Tree implementation
122pub struct BallTree {
123    root: Option<Box<BallNode>>,
124    data: Vec<(String, Vector)>,
125    config: TreeIndexConfig,
126}
127
128#[derive(Clone)]
129struct BallNode {
130    /// Center of the ball
131    center: Vec<f32>,
132    /// Radius of the ball
133    radius: f32,
134    /// Left child
135    left: Option<Box<BallNode>>,
136    /// Right child
137    right: Option<Box<BallNode>>,
138    /// Indices of points in this node (for leaf nodes)
139    indices: Vec<usize>,
140}
141
142impl BallTree {
143    pub fn new(config: TreeIndexConfig) -> Self {
144        Self {
145            root: None,
146            data: Vec::new(),
147            config,
148        }
149    }
150
151    /// Build the tree from data with conservative depth limits to prevent stack overflow
152    ///
153    /// Note: Tree indices work best with moderate dataset sizes (< 100K points).
154    /// For larger datasets, consider using HNSW, IVF, or LSH indices instead.
155    pub fn build(&mut self) -> Result<()> {
156        if self.data.is_empty() {
157            return Ok(());
158        }
159
160        let indices: Vec<usize> = (0..self.data.len()).collect();
161        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
162
163        self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
164        Ok(())
165    }
166
167    /// Conservative recursive construction with strict depth limits
168    fn build_node_safe(
169        &self,
170        points: &[Vec<f32>],
171        indices: Vec<usize>,
172        depth: usize,
173    ) -> Result<BallNode> {
174        // VERY conservative depth limit to prevent stack overflow
175        // Limit depth to 20 for safety (can handle ~1M points with leaf_size=10)
176        const MAX_DEPTH: usize = 20;
177
178        // Force leaf creation if:
179        // 1. At or below leaf size
180        // 2. Only 1 or 2 points left
181        // 3. Reached maximum safe depth
182        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
183            let center = self.compute_centroid(points, &indices);
184            let radius = self.compute_radius(points, &indices, &center);
185            return Ok(BallNode {
186                center,
187                radius,
188                left: None,
189                right: None,
190                indices,
191            });
192        }
193
194        // Find split dimension
195        let split_dim = self.find_split_dimension(points, &indices);
196        let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
197
198        // Prevent empty partitions - create leaf instead
199        if left_indices.is_empty() || right_indices.is_empty() {
200            let center = self.compute_centroid(points, &indices);
201            let radius = self.compute_radius(points, &indices, &center);
202            return Ok(BallNode {
203                center,
204                radius,
205                left: None,
206                right: None,
207                indices,
208            });
209        }
210
211        // Recursively build children (limited by MAX_DEPTH)
212        let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
213        let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
214
215        // Compute bounding ball
216        let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
217        let center = self.compute_centroid_of_centers(&all_centers);
218        let radius = left_node.radius.max(right_node.radius)
219            + self
220                .config
221                .distance_metric
222                .distance(&center, &left_node.center);
223
224        Ok(BallNode {
225            center,
226            radius,
227            left: Some(Box::new(left_node)),
228            right: Some(Box::new(right_node)),
229            indices: Vec::new(),
230        })
231    }
232
233    fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
234        let dim = points[0].len();
235        let mut centroid = vec![0.0; dim];
236
237        for &idx in indices {
238            for (i, &val) in points[idx].iter().enumerate() {
239                centroid[i] += val;
240            }
241        }
242
243        let n = indices.len() as f32;
244        for val in &mut centroid {
245            *val /= n;
246        }
247
248        centroid
249    }
250
251    fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
252        indices
253            .iter()
254            .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
255            .fold(0.0f32, f32::max)
256    }
257
258    fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
259        let dim = points[0].len();
260        let mut max_spread = 0.0;
261        let mut split_dim = 0;
262
263        // We need the dimension index `d` to access the d-th component of each point
264        #[allow(clippy::needless_range_loop)]
265        for d in 0..dim {
266            let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
267
268            let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269            let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270            let spread = max_val - min_val;
271
272            if spread > max_spread {
273                max_spread = spread;
274                split_dim = d;
275            }
276        }
277
278        split_dim
279    }
280
281    fn partition_indices(
282        &self,
283        points: &[Vec<f32>],
284        indices: &[usize],
285        dim: usize,
286    ) -> (Vec<usize>, Vec<usize>) {
287        let mut values: Vec<(f32, usize)> =
288            indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
289
290        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
291
292        let mid = values.len() / 2;
293        let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
294        let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
295
296        (left_indices, right_indices)
297    }
298
299    fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
300        let dim = centers[0].len();
301        let mut centroid = vec![0.0; dim];
302
303        for center in centers {
304            for (i, &val) in center.iter().enumerate() {
305                centroid[i] += val;
306            }
307        }
308
309        let n = centers.len() as f32;
310        for val in &mut centroid {
311            *val /= n;
312        }
313
314        centroid
315    }
316
317    /// Search for k nearest neighbors using iterative algorithm
318    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
319        if self.root.is_none() {
320            return Vec::new();
321        }
322
323        let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
324        let mut stack: Vec<&BallNode> = vec![self
325            .root
326            .as_ref()
327            .expect("tree should have root after build")];
328
329        while let Some(node) = stack.pop() {
330            // Check if we need to explore this node
331            let dist_to_center = self.config.distance_metric.distance(query, &node.center);
332
333            if heap.len() >= k {
334                let worst_dist = heap.peek().expect("heap should have k elements").distance;
335                if dist_to_center - node.radius > worst_dist {
336                    continue; // Prune this branch
337                }
338            }
339
340            if node.indices.is_empty() {
341                // Internal node - add children to stack
342                if let (Some(left), Some(right)) = (&node.left, &node.right) {
343                    let left_dist = self.config.distance_metric.distance(query, &left.center);
344                    let right_dist = self.config.distance_metric.distance(query, &right.center);
345
346                    // Add in order so closer one is processed first
347                    if left_dist < right_dist {
348                        stack.push(right);
349                        stack.push(left);
350                    } else {
351                        stack.push(left);
352                        stack.push(right);
353                    }
354                }
355            } else {
356                // Leaf node - check all points
357                for &idx in &node.indices {
358                    let point = &self.data[idx].1.as_f32();
359                    let dist = self.config.distance_metric.distance(query, point);
360
361                    if heap.len() < k {
362                        heap.push(SearchResult {
363                            index: idx,
364                            distance: dist,
365                        });
366                    } else if dist < heap.peek().expect("heap should have k elements").distance {
367                        heap.pop();
368                        heap.push(SearchResult {
369                            index: idx,
370                            distance: dist,
371                        });
372                    }
373                }
374            }
375        }
376
377        let mut results: Vec<(usize, f32)> =
378            heap.into_iter().map(|r| (r.index, r.distance)).collect();
379
380        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
381        results
382    }
383}
384
385/// KD-Tree implementation
386pub struct KdTree {
387    root: Option<Box<KdNode>>,
388    data: Vec<(String, Vector)>,
389    config: TreeIndexConfig,
390}
391
392struct KdNode {
393    /// Split dimension
394    split_dim: usize,
395    /// Split value
396    split_value: f32,
397    /// Left child (values <= split_value)
398    left: Option<Box<KdNode>>,
399    /// Right child (values > split_value)
400    right: Option<Box<KdNode>>,
401    /// Indices for leaf nodes
402    indices: Vec<usize>,
403}
404
405impl KdTree {
406    pub fn new(config: TreeIndexConfig) -> Self {
407        Self {
408            root: None,
409            data: Vec::new(),
410            config,
411        }
412    }
413
414    pub fn build(&mut self) -> Result<()> {
415        if self.data.is_empty() {
416            return Ok(());
417        }
418
419        let indices: Vec<usize> = (0..self.data.len()).collect();
420        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
421
422        self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
423        Ok(())
424    }
425
426    fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
427        // Reasonable stack overflow prevention with proper depth limit
428        let max_depth = if !self.data.is_empty() {
429            ((self.data.len() as f32).log2() * 2.0) as usize + 10
430        } else {
431            50
432        };
433
434        if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
435            return Ok(KdNode {
436                split_dim: 0,
437                split_value: 0.0,
438                left: None,
439                right: None,
440                indices,
441            });
442        }
443
444        let dimensions = points[0].len();
445        let split_dim = depth % dimensions;
446
447        // Find median along split dimension
448        let mut values: Vec<(f32, usize)> = indices
449            .iter()
450            .map(|&idx| (points[idx][split_dim], idx))
451            .collect();
452
453        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
454
455        let median_idx = values.len() / 2;
456        let split_value = values[median_idx].0;
457
458        let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
459
460        let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
461
462        // Prevent creating empty partitions - create leaf instead
463        if left_indices.is_empty() || right_indices.is_empty() {
464            return Ok(KdNode {
465                split_dim: 0,
466                split_value: 0.0,
467                left: None,
468                right: None,
469                indices,
470            });
471        }
472
473        let left = Some(Box::new(self.build_node(
474            points,
475            left_indices,
476            depth + 1,
477        )?));
478
479        let right = Some(Box::new(self.build_node(
480            points,
481            right_indices,
482            depth + 1,
483        )?));
484
485        Ok(KdNode {
486            split_dim,
487            split_value,
488            left,
489            right,
490            indices: Vec::new(),
491        })
492    }
493
494    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
495        if self.root.is_none() {
496            return Vec::new();
497        }
498
499        let mut heap = BinaryHeap::new();
500        self.search_node(
501            self.root
502                .as_ref()
503                .expect("tree should have root after build"),
504            query,
505            k,
506            &mut heap,
507        );
508
509        let mut results: Vec<(usize, f32)> =
510            heap.into_iter().map(|r| (r.index, r.distance)).collect();
511
512        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
513        results
514    }
515
516    fn search_node(
517        &self,
518        node: &KdNode,
519        query: &[f32],
520        k: usize,
521        heap: &mut BinaryHeap<SearchResult>,
522    ) {
523        if !node.indices.is_empty() {
524            // Leaf node
525            for &idx in &node.indices {
526                let point = &self.data[idx].1.as_f32();
527                let dist = self.config.distance_metric.distance(query, point);
528
529                if heap.len() < k {
530                    heap.push(SearchResult {
531                        index: idx,
532                        distance: dist,
533                    });
534                } else if dist < heap.peek().expect("heap should have k elements").distance {
535                    heap.pop();
536                    heap.push(SearchResult {
537                        index: idx,
538                        distance: dist,
539                    });
540                }
541            }
542            return;
543        }
544
545        // Determine which side to search first
546        let go_left = query[node.split_dim] <= node.split_value;
547
548        let (first, second) = if go_left {
549            (&node.left, &node.right)
550        } else {
551            (&node.right, &node.left)
552        };
553
554        // Search the nearer side first
555        if let Some(child) = first {
556            self.search_node(child, query, k, heap);
557        }
558
559        // Check if we need to search the other side
560        if heap.len() < k || {
561            let split_dist = (query[node.split_dim] - node.split_value).abs();
562            split_dist < heap.peek().expect("heap should have k elements").distance
563        } {
564            if let Some(child) = second {
565                self.search_node(child, query, k, heap);
566            }
567        }
568    }
569}
570
571/// VP-Tree (Vantage Point Tree) implementation
572pub struct VpTree {
573    root: Option<Box<VpNode>>,
574    data: Vec<(String, Vector)>,
575    config: TreeIndexConfig,
576}
577
578struct VpNode {
579    /// Vantage point index
580    vantage_point: usize,
581    /// Median distance from vantage point
582    median_distance: f32,
583    /// Points closer than median
584    inside: Option<Box<VpNode>>,
585    /// Points farther than median
586    outside: Option<Box<VpNode>>,
587    /// Indices for leaf nodes
588    indices: Vec<usize>,
589}
590
591impl VpTree {
592    pub fn new(config: TreeIndexConfig) -> Self {
593        Self {
594            root: None,
595            data: Vec::new(),
596            config,
597        }
598    }
599
600    pub fn build(&mut self) -> Result<()> {
601        if self.data.is_empty() {
602            return Ok(());
603        }
604
605        let indices: Vec<usize> = (0..self.data.len()).collect();
606        let mut rng = if let Some(seed) = self.config.random_seed {
607            Random::seed(seed)
608        } else {
609            Random::seed(42)
610        };
611
612        self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
613        Ok(())
614    }
615
616    fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
617        self.build_node_safe(indices, rng, 0)
618    }
619
620    #[allow(deprecated)]
621    fn build_node_safe<R: Rng>(
622        &self,
623        mut indices: Vec<usize>,
624        rng: &mut R,
625        depth: usize,
626    ) -> Result<VpNode> {
627        // Note: Using manual random selection instead of SliceRandom
628
629        // CRITICAL: Extremely strict depth and size limits to prevent stack overflow
630        // For very small datasets or deep recursion, immediately create leaf nodes
631        let max_depth = 30; // Conservative depth limit
632
633        // Aggressive leaf node creation for small datasets
634        if indices.len() <= self.config.max_leaf_size
635            || indices.len() <= 2  // Changed from <= 1 to <= 2 for extra safety
636            || depth >= max_depth
637        {
638            return Ok(VpNode {
639                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
640                median_distance: 0.0,
641                inside: None,
642                outside: None,
643                indices,
644            });
645        }
646
647        // Choose random vantage point - simplified to avoid potential issues
648        let vp_idx = if indices.len() > 1 {
649            rng.random_range(0..indices.len())
650        } else {
651            0
652        };
653        let vantage_point = indices[vp_idx];
654        indices.remove(vp_idx);
655
656        // Calculate distances from vantage point
657        let vp_data = &self.data[vantage_point].1.as_f32();
658        let mut distances: Vec<(f32, usize)> = indices
659            .iter()
660            .map(|&idx| {
661                let point = &self.data[idx].1.as_f32();
662                let dist = self.config.distance_metric.distance(vp_data, point);
663                (dist, idx)
664            })
665            .collect();
666
667        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
668
669        let median_idx = distances.len() / 2;
670        let median_distance = distances[median_idx].0;
671
672        let inside_indices: Vec<usize> = distances[..median_idx]
673            .iter()
674            .map(|(_, idx)| *idx)
675            .collect();
676
677        let outside_indices: Vec<usize> = distances[median_idx..]
678            .iter()
679            .map(|(_, idx)| *idx)
680            .collect();
681
682        // Prevent creating empty partitions - create leaf instead
683        if inside_indices.is_empty() || outside_indices.is_empty() {
684            return Ok(VpNode {
685                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
686                median_distance: 0.0,
687                inside: None,
688                outside: None,
689                indices,
690            });
691        }
692
693        let inside = Some(Box::new(self.build_node_safe(
694            inside_indices,
695            rng,
696            depth + 1,
697        )?));
698        let outside = Some(Box::new(self.build_node_safe(
699            outside_indices,
700            rng,
701            depth + 1,
702        )?));
703
704        Ok(VpNode {
705            vantage_point,
706            median_distance,
707            inside,
708            outside,
709            indices: Vec::new(),
710        })
711    }
712
713    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
714        if self.root.is_none() {
715            return Vec::new();
716        }
717
718        let mut heap = BinaryHeap::new();
719        self.search_node(
720            self.root
721                .as_ref()
722                .expect("tree should have root after build"),
723            query,
724            k,
725            &mut heap,
726            f32::INFINITY,
727        );
728
729        let mut results: Vec<(usize, f32)> =
730            heap.into_iter().map(|r| (r.index, r.distance)).collect();
731
732        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
733        results
734    }
735
736    fn search_node(
737        &self,
738        node: &VpNode,
739        query: &[f32],
740        k: usize,
741        heap: &mut BinaryHeap<SearchResult>,
742        tau: f32,
743    ) -> f32 {
744        let mut tau = tau;
745
746        if !node.indices.is_empty() {
747            // Leaf node
748            for &idx in &node.indices {
749                let point = &self.data[idx].1.as_f32();
750                let dist = self.config.distance_metric.distance(query, point);
751
752                if dist < tau {
753                    if heap.len() < k {
754                        heap.push(SearchResult {
755                            index: idx,
756                            distance: dist,
757                        });
758                    } else if dist < heap.peek().expect("heap should have k elements").distance {
759                        heap.pop();
760                        heap.push(SearchResult {
761                            index: idx,
762                            distance: dist,
763                        });
764                    }
765
766                    if heap.len() >= k {
767                        tau = heap.peek().expect("heap should have k elements").distance;
768                    }
769                }
770            }
771            return tau;
772        }
773
774        // Calculate distance to vantage point
775        let vp_data = &self.data[node.vantage_point].1.as_f32();
776        let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
777
778        // Consider vantage point itself
779        if dist_to_vp < tau {
780            if heap.len() < k {
781                heap.push(SearchResult {
782                    index: node.vantage_point,
783                    distance: dist_to_vp,
784                });
785            } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
786                heap.pop();
787                heap.push(SearchResult {
788                    index: node.vantage_point,
789                    distance: dist_to_vp,
790                });
791            }
792
793            if heap.len() >= k {
794                tau = heap.peek().expect("heap should have k elements").distance;
795            }
796        }
797
798        // Search children
799        if dist_to_vp < node.median_distance {
800            // Search inside first
801            if let Some(inside) = &node.inside {
802                tau = self.search_node(inside, query, k, heap, tau);
803            }
804
805            // Check if we need to search outside
806            if dist_to_vp + tau >= node.median_distance {
807                if let Some(outside) = &node.outside {
808                    tau = self.search_node(outside, query, k, heap, tau);
809                }
810            }
811        } else {
812            // Search outside first
813            if let Some(outside) = &node.outside {
814                tau = self.search_node(outside, query, k, heap, tau);
815            }
816
817            // Check if we need to search inside
818            if dist_to_vp - tau <= node.median_distance {
819                if let Some(inside) = &node.inside {
820                    tau = self.search_node(inside, query, k, heap, tau);
821                }
822            }
823        }
824
825        tau
826    }
827}
828
829/// Cover Tree implementation
830pub struct CoverTree {
831    root: Option<Box<CoverNode>>,
832    data: Vec<(String, Vector)>,
833    config: TreeIndexConfig,
834    base: f32,
835}
836
837struct CoverNode {
838    /// Point index
839    point: usize,
840    /// Level in the tree
841    level: i32,
842    /// Children at the same or lower level
843    #[allow(clippy::vec_box)] // Box is necessary for recursive structure
844    children: Vec<Box<CoverNode>>,
845}
846
847impl CoverTree {
848    pub fn new(config: TreeIndexConfig) -> Self {
849        Self {
850            root: None,
851            data: Vec::new(),
852            config,
853            base: 2.0, // Base for the covering constant
854        }
855    }
856
857    pub fn build(&mut self) -> Result<()> {
858        if self.data.is_empty() {
859            return Ok(());
860        }
861
862        // Initialize with first point
863        self.root = Some(Box::new(CoverNode {
864            point: 0,
865            level: self.get_level(0),
866            children: Vec::new(),
867        }));
868
869        // Insert remaining points
870        for idx in 1..self.data.len() {
871            self.insert(idx)?;
872        }
873
874        Ok(())
875    }
876
877    fn get_level(&self, _point_idx: usize) -> i32 {
878        // Simple heuristic for initial level
879        ((self.data.len() as f32).log2() as i32).max(0)
880    }
881
882    fn insert(&mut self, point_idx: usize) -> Result<()> {
883        // Simplified insert - in practice, this would be more complex
884        // to maintain the cover tree invariants
885        let level = self.get_level(point_idx);
886        if let Some(root) = &mut self.root {
887            root.children.push(Box::new(CoverNode {
888                point: point_idx,
889                level,
890                children: Vec::new(),
891            }));
892        }
893        Ok(())
894    }
895
896    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
897        if self.root.is_none() {
898            return Vec::new();
899        }
900
901        let mut results = Vec::new();
902        self.search_node(
903            self.root
904                .as_ref()
905                .expect("tree should have root after build"),
906            query,
907            k,
908            &mut results,
909        );
910
911        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
912        results.truncate(k);
913        results
914    }
915
916    #[allow(clippy::only_used_in_recursion)]
917    fn search_node(
918        &self,
919        node: &CoverNode,
920        query: &[f32],
921        k: usize,
922        results: &mut Vec<(usize, f32)>,
923    ) {
924        // Prevent excessive recursion depth
925        if results.len() >= k * 10 {
926            return;
927        }
928
929        let point_data = &self.data[node.point].1.as_f32();
930        let dist = self.config.distance_metric.distance(query, point_data);
931
932        results.push((node.point, dist));
933
934        // Search children
935        for child in &node.children {
936            self.search_node(child, query, k, results);
937        }
938    }
939}
940
941/// Random Projection Tree implementation
942pub struct RandomProjectionTree {
943    root: Option<Box<RpNode>>,
944    data: Vec<(String, Vector)>,
945    config: TreeIndexConfig,
946}
947
948struct RpNode {
949    /// Random projection vector
950    projection: Vec<f32>,
951    /// Projection threshold
952    threshold: f32,
953    /// Left child (projection <= threshold)
954    left: Option<Box<RpNode>>,
955    /// Right child (projection > threshold)
956    right: Option<Box<RpNode>>,
957    /// Indices for leaf nodes
958    indices: Vec<usize>,
959}
960
961impl RandomProjectionTree {
962    pub fn new(config: TreeIndexConfig) -> Self {
963        Self {
964            root: None,
965            data: Vec::new(),
966            config,
967        }
968    }
969
970    pub fn build(&mut self) -> Result<()> {
971        if self.data.is_empty() {
972            return Ok(());
973        }
974
975        let indices: Vec<usize> = (0..self.data.len()).collect();
976        let dimensions = self.data[0].1.dimensions;
977
978        let mut rng = if let Some(seed) = self.config.random_seed {
979            Random::seed(seed)
980        } else {
981            Random::seed(42)
982        };
983
984        self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
985        Ok(())
986    }
987
988    fn build_node<R: Rng>(
989        &self,
990        indices: Vec<usize>,
991        dimensions: usize,
992        rng: &mut R,
993    ) -> Result<RpNode> {
994        self.build_node_safe(indices, dimensions, rng, 0)
995    }
996
997    #[allow(deprecated)]
998    fn build_node_safe<R: Rng>(
999        &self,
1000        indices: Vec<usize>,
1001        dimensions: usize,
1002        rng: &mut R,
1003        depth: usize,
1004    ) -> Result<RpNode> {
1005        // Very strict stack overflow prevention - similar to BallTree approach
1006        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
1007            return Ok(RpNode {
1008                projection: Vec::new(),
1009                threshold: 0.0,
1010                left: None,
1011                right: None,
1012                indices,
1013            });
1014        }
1015
1016        // Generate random projection vector
1017        let projection: Vec<f32> = (0..dimensions)
1018            .map(|_| rng.random_range(-1.0..1.0))
1019            .collect();
1020
1021        // Normalize projection vector
1022        let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
1023        let projection: Vec<f32> = if norm > 0.0 {
1024            projection.iter().map(|&x| x / norm).collect()
1025        } else {
1026            projection
1027        };
1028
1029        // Project all points
1030        let mut projections: Vec<(f32, usize)> = indices
1031            .iter()
1032            .map(|&idx| {
1033                let point = &self.data[idx].1.as_f32();
1034                let proj_val = f32::dot(point, &projection);
1035                (proj_val, idx)
1036            })
1037            .collect();
1038
1039        projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
1040
1041        // Choose median as threshold
1042        let median_idx = projections.len() / 2;
1043        let threshold = projections[median_idx].0;
1044
1045        let left_indices: Vec<usize> = projections[..median_idx]
1046            .iter()
1047            .map(|(_, idx)| *idx)
1048            .collect();
1049
1050        let right_indices: Vec<usize> = projections[median_idx..]
1051            .iter()
1052            .map(|(_, idx)| *idx)
1053            .collect();
1054
1055        // Prevent creating empty partitions - create leaf instead
1056        if left_indices.is_empty() || right_indices.is_empty() {
1057            return Ok(RpNode {
1058                projection: Vec::new(),
1059                threshold: 0.0,
1060                left: None,
1061                right: None,
1062                indices,
1063            });
1064        }
1065
1066        let left = Some(Box::new(self.build_node_safe(
1067            left_indices,
1068            dimensions,
1069            rng,
1070            depth + 1,
1071        )?));
1072        let right = Some(Box::new(self.build_node_safe(
1073            right_indices,
1074            dimensions,
1075            rng,
1076            depth + 1,
1077        )?));
1078
1079        Ok(RpNode {
1080            projection,
1081            threshold,
1082            left,
1083            right,
1084            indices: Vec::new(),
1085        })
1086    }
1087
1088    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1089        if self.root.is_none() {
1090            return Vec::new();
1091        }
1092
1093        let mut heap = BinaryHeap::new();
1094        self.search_node(
1095            self.root
1096                .as_ref()
1097                .expect("tree should have root after build"),
1098            query,
1099            k,
1100            &mut heap,
1101        );
1102
1103        let mut results: Vec<(usize, f32)> =
1104            heap.into_iter().map(|r| (r.index, r.distance)).collect();
1105
1106        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1107        results
1108    }
1109
1110    fn search_node(
1111        &self,
1112        node: &RpNode,
1113        query: &[f32],
1114        k: usize,
1115        heap: &mut BinaryHeap<SearchResult>,
1116    ) {
1117        if !node.indices.is_empty() {
1118            // Leaf node
1119            for &idx in &node.indices {
1120                let point = &self.data[idx].1.as_f32();
1121                let dist = self.config.distance_metric.distance(query, point);
1122
1123                if heap.len() < k {
1124                    heap.push(SearchResult {
1125                        index: idx,
1126                        distance: dist,
1127                    });
1128                } else if dist < heap.peek().expect("heap should have k elements").distance {
1129                    heap.pop();
1130                    heap.push(SearchResult {
1131                        index: idx,
1132                        distance: dist,
1133                    });
1134                }
1135            }
1136            return;
1137        }
1138
1139        // Project query
1140        let query_projection = f32::dot(query, &node.projection);
1141
1142        // Determine which side to search first
1143        let go_left = query_projection <= node.threshold;
1144
1145        let (first, second) = if go_left {
1146            (&node.left, &node.right)
1147        } else {
1148            (&node.right, &node.left)
1149        };
1150
1151        // Search both sides (random projections don't provide distance bounds)
1152        if let Some(child) = first {
1153            self.search_node(child, query, k, heap);
1154        }
1155
1156        if let Some(child) = second {
1157            self.search_node(child, query, k, heap);
1158        }
1159    }
1160}
1161
1162/// Unified tree index interface
1163pub struct TreeIndex {
1164    tree_type: TreeType,
1165    ball_tree: Option<BallTree>,
1166    kd_tree: Option<KdTree>,
1167    vp_tree: Option<VpTree>,
1168    cover_tree: Option<CoverTree>,
1169    rp_tree: Option<RandomProjectionTree>,
1170}
1171
1172impl TreeIndex {
1173    pub fn new(config: TreeIndexConfig) -> Self {
1174        let tree_type = config.tree_type;
1175
1176        let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1177            TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1178            TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1179            TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1180            TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1181            TreeType::RandomProjectionTree => (
1182                None,
1183                None,
1184                None,
1185                None,
1186                Some(RandomProjectionTree::new(config)),
1187            ),
1188        };
1189
1190        Self {
1191            tree_type,
1192            ball_tree,
1193            kd_tree,
1194            vp_tree,
1195            cover_tree,
1196            rp_tree,
1197        }
1198    }
1199
1200    pub fn build(&mut self) -> Result<()> {
1201        match self.tree_type {
1202            TreeType::BallTree => self
1203                .ball_tree
1204                .as_mut()
1205                .expect("ball_tree should be initialized for BallTree type")
1206                .build(),
1207            TreeType::KdTree => self
1208                .kd_tree
1209                .as_mut()
1210                .expect("kd_tree should be initialized for KdTree type")
1211                .build(),
1212            TreeType::VpTree => self
1213                .vp_tree
1214                .as_mut()
1215                .expect("vp_tree should be initialized for VpTree type")
1216                .build(),
1217            TreeType::CoverTree => self
1218                .cover_tree
1219                .as_mut()
1220                .expect("cover_tree should be initialized for CoverTree type")
1221                .build(),
1222            TreeType::RandomProjectionTree => self
1223                .rp_tree
1224                .as_mut()
1225                .expect("rp_tree should be initialized for RandomProjectionTree type")
1226                .build(),
1227        }
1228    }
1229
1230    fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1231        match self.tree_type {
1232            TreeType::BallTree => self
1233                .ball_tree
1234                .as_ref()
1235                .expect("ball_tree should be initialized for BallTree type")
1236                .search(query, k),
1237            TreeType::KdTree => self
1238                .kd_tree
1239                .as_ref()
1240                .expect("kd_tree should be initialized for KdTree type")
1241                .search(query, k),
1242            TreeType::VpTree => self
1243                .vp_tree
1244                .as_ref()
1245                .expect("vp_tree should be initialized for VpTree type")
1246                .search(query, k),
1247            TreeType::CoverTree => self
1248                .cover_tree
1249                .as_ref()
1250                .expect("cover_tree should be initialized for CoverTree type")
1251                .search(query, k),
1252            TreeType::RandomProjectionTree => self
1253                .rp_tree
1254                .as_ref()
1255                .expect("rp_tree should be initialized for RandomProjectionTree type")
1256                .search(query, k),
1257        }
1258    }
1259}
1260
1261impl VectorIndex for TreeIndex {
1262    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1263        let data = match self.tree_type {
1264            TreeType::BallTree => {
1265                &mut self
1266                    .ball_tree
1267                    .as_mut()
1268                    .expect("ball_tree should be initialized for BallTree type")
1269                    .data
1270            }
1271            TreeType::KdTree => {
1272                &mut self
1273                    .kd_tree
1274                    .as_mut()
1275                    .expect("kd_tree should be initialized for KdTree type")
1276                    .data
1277            }
1278            TreeType::VpTree => {
1279                &mut self
1280                    .vp_tree
1281                    .as_mut()
1282                    .expect("vp_tree should be initialized for VpTree type")
1283                    .data
1284            }
1285            TreeType::CoverTree => {
1286                &mut self
1287                    .cover_tree
1288                    .as_mut()
1289                    .expect("cover_tree should be initialized for CoverTree type")
1290                    .data
1291            }
1292            TreeType::RandomProjectionTree => {
1293                &mut self
1294                    .rp_tree
1295                    .as_mut()
1296                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1297                    .data
1298            }
1299        };
1300
1301        data.push((uri, vector));
1302        Ok(())
1303    }
1304
1305    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1306        let query_f32 = query.as_f32();
1307        let results = self.search_internal(&query_f32, k);
1308
1309        let data = match self.tree_type {
1310            TreeType::BallTree => {
1311                &self
1312                    .ball_tree
1313                    .as_ref()
1314                    .expect("ball_tree should be initialized for BallTree type")
1315                    .data
1316            }
1317            TreeType::KdTree => {
1318                &self
1319                    .kd_tree
1320                    .as_ref()
1321                    .expect("kd_tree should be initialized for KdTree type")
1322                    .data
1323            }
1324            TreeType::VpTree => {
1325                &self
1326                    .vp_tree
1327                    .as_ref()
1328                    .expect("vp_tree should be initialized for VpTree type")
1329                    .data
1330            }
1331            TreeType::CoverTree => {
1332                &self
1333                    .cover_tree
1334                    .as_ref()
1335                    .expect("cover_tree should be initialized for CoverTree type")
1336                    .data
1337            }
1338            TreeType::RandomProjectionTree => {
1339                &self
1340                    .rp_tree
1341                    .as_ref()
1342                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1343                    .data
1344            }
1345        };
1346
1347        Ok(results
1348            .into_iter()
1349            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1350            .collect())
1351    }
1352
1353    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1354        let query_f32 = query.as_f32();
1355        let all_results = self.search_internal(&query_f32, 1000); // Search more broadly
1356
1357        let data = match self.tree_type {
1358            TreeType::BallTree => {
1359                &self
1360                    .ball_tree
1361                    .as_ref()
1362                    .expect("ball_tree should be initialized for BallTree type")
1363                    .data
1364            }
1365            TreeType::KdTree => {
1366                &self
1367                    .kd_tree
1368                    .as_ref()
1369                    .expect("kd_tree should be initialized for KdTree type")
1370                    .data
1371            }
1372            TreeType::VpTree => {
1373                &self
1374                    .vp_tree
1375                    .as_ref()
1376                    .expect("vp_tree should be initialized for VpTree type")
1377                    .data
1378            }
1379            TreeType::CoverTree => {
1380                &self
1381                    .cover_tree
1382                    .as_ref()
1383                    .expect("cover_tree should be initialized for CoverTree type")
1384                    .data
1385            }
1386            TreeType::RandomProjectionTree => {
1387                &self
1388                    .rp_tree
1389                    .as_ref()
1390                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1391                    .data
1392            }
1393        };
1394
1395        Ok(all_results
1396            .into_iter()
1397            .filter(|(_, dist)| *dist <= threshold)
1398            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1399            .collect())
1400    }
1401
1402    fn get_vector(&self, uri: &str) -> Option<&Vector> {
1403        let data = match self.tree_type {
1404            TreeType::BallTree => {
1405                &self
1406                    .ball_tree
1407                    .as_ref()
1408                    .expect("ball_tree should be initialized for BallTree type")
1409                    .data
1410            }
1411            TreeType::KdTree => {
1412                &self
1413                    .kd_tree
1414                    .as_ref()
1415                    .expect("kd_tree should be initialized for KdTree type")
1416                    .data
1417            }
1418            TreeType::VpTree => {
1419                &self
1420                    .vp_tree
1421                    .as_ref()
1422                    .expect("vp_tree should be initialized for VpTree type")
1423                    .data
1424            }
1425            TreeType::CoverTree => {
1426                &self
1427                    .cover_tree
1428                    .as_ref()
1429                    .expect("cover_tree should be initialized for CoverTree type")
1430                    .data
1431            }
1432            TreeType::RandomProjectionTree => {
1433                &self
1434                    .rp_tree
1435                    .as_ref()
1436                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1437                    .data
1438            }
1439        };
1440
1441        data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1442    }
1443}
1444
1445// Add rand to dependencies for VP-Tree and Random Projection Tree
1446// Note: Replaced with scirs2_core::random
1447
1448// Placeholder for async task spawning - integrate with oxirs-core::parallel
1449async fn spawn_task<F, T>(f: F) -> T
1450where
1451    F: FnOnce() -> T + Send + 'static,
1452    T: Send + 'static,
1453{
1454    // In practice, this would use oxirs-core::parallel's task spawning
1455    f()
1456}
1457
1458#[cfg(test)]
1459mod tests {
1460    use super::*;
1461
1462    #[test]
1463    #[ignore = "Tree indices are experimental - see module documentation for alternatives"]
1464    fn test_ball_tree() -> Result<()> {
1465        let config = TreeIndexConfig {
1466            tree_type: TreeType::BallTree,
1467            max_leaf_size: 10,
1468            ..Default::default()
1469        };
1470
1471        let mut ball_tree = BallTree::new(config);
1472
1473        // Add 100 vectors
1474        for i in 0..100 {
1475            let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1476            ball_tree.data.push((format!("vec_{i}"), vector));
1477        }
1478
1479        // Build and search
1480        ball_tree.build()?;
1481        assert!(ball_tree.root.is_some());
1482
1483        let query = vec![50.0, 100.0];
1484        let results = ball_tree.search(&query, 5);
1485
1486        assert!(results.len() <= 5);
1487        assert!(!results.is_empty());
1488        Ok(())
1489    }
1490
1491    #[test]
1492    #[ignore = "Investigating stack overflow with recursive tree construction"]
1493    fn test_kd_tree() -> Result<()> {
1494        let config = TreeIndexConfig {
1495            tree_type: TreeType::KdTree,
1496            max_leaf_size: 50, // Extremely large leaf size to force leaf nodes
1497            ..Default::default()
1498        };
1499
1500        let mut index = TreeIndex::new(config);
1501
1502        // Tiny dataset to prevent stack overflow
1503        for i in 0..3 {
1504            let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1505            index.insert(format!("vec_{i}"), vector)?;
1506        }
1507
1508        index.build()?;
1509
1510        // Search for nearest neighbors
1511        let query = Vector::new(vec![1.0, 2.0]);
1512        let results = index.search_knn(&query, 2)?;
1513
1514        assert_eq!(results.len(), 2);
1515        Ok(())
1516    }
1517
1518    #[test]
1519    #[ignore = "Investigating stack overflow with recursive tree construction"]
1520    fn test_vp_tree() -> Result<()> {
1521        let config = TreeIndexConfig {
1522            tree_type: TreeType::VpTree,
1523            random_seed: Some(42),
1524            max_leaf_size: 50, // Extremely large leaf size to force leaf nodes
1525            ..Default::default()
1526        };
1527
1528        let mut index = TreeIndex::new(config);
1529
1530        // Tiny dataset to prevent stack overflow
1531        for i in 0..3 {
1532            let angle = (i as f32) * std::f32::consts::PI / 4.0;
1533            let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1534            index.insert(format!("vec_{i}"), vector)?;
1535        }
1536
1537        index.build()?;
1538
1539        // Search for nearest neighbors
1540        let query = Vector::new(vec![1.0, 0.0]);
1541        let results = index.search_knn(&query, 2)?;
1542
1543        assert_eq!(results.len(), 2);
1544        Ok(())
1545    }
1546}