oxirs_vec/
tree_indices.rs

1//! Tree-based indices for efficient nearest neighbor search
2//!
3//! **EXPERIMENTAL**: These tree implementations are currently experimental
4//! and have known limitations with large datasets or specific configurations.
5//! For production use, prefer HNSW, IVF, or LSH indices instead.
6//!
7//! ## Known Limitations
8//! - Tree construction uses recursion with conservative depth limits (20 levels)
9//! - Best suited for moderate-sized datasets (< 100K vectors)
10//! - May encounter stack overflow on systems with very small stack sizes
11//! - Performance degrades in high-dimensional spaces (> 128 dimensions)
12//!
13//! ## Recommended Alternatives
14//! - For most use cases: Use `HnswIndex` (Hierarchical Navigable Small World)
15//! - For very large datasets: Use `IVFIndex` (Inverted File Index)
16//! - For approximate search: Use `LSHIndex` (Locality Sensitive Hashing)
17//!
18//! This module implements various tree data structures:
19//! - Ball Tree: Efficient for arbitrary metrics
20//! - KD-Tree: Classic space partitioning tree
21//! - VP-Tree: Vantage point tree for metric spaces
22//! - Cover Tree: Navigating nets with provable bounds
23//! - Random Projection Trees: Randomized space partitioning
24
25use crate::{Vector, VectorIndex};
26use anyhow::Result;
27use oxirs_core::simd::SimdOps;
28use scirs2_core::random::{Random, Rng};
29use std::cmp::Ordering;
30use std::collections::BinaryHeap;
31
32/// Configuration for tree-based indices
33#[derive(Debug, Clone)]
34pub struct TreeIndexConfig {
35    /// Type of tree to use
36    pub tree_type: TreeType,
37    /// Maximum leaf size before splitting
38    pub max_leaf_size: usize,
39    /// Random seed for reproducibility
40    pub random_seed: Option<u64>,
41    /// Enable parallel construction
42    pub parallel_construction: bool,
43    /// Distance metric
44    pub distance_metric: DistanceMetric,
45}
46
47impl Default for TreeIndexConfig {
48    fn default() -> Self {
49        Self {
50            tree_type: TreeType::BallTree,
51            max_leaf_size: 16, // Larger leaf size to prevent deep recursion and stack overflow
52            random_seed: None,
53            parallel_construction: true,
54            distance_metric: DistanceMetric::Euclidean,
55        }
56    }
57}
58
59/// Available tree types
60#[derive(Debug, Clone, Copy)]
61pub enum TreeType {
62    BallTree,
63    KdTree,
64    VpTree,
65    CoverTree,
66    RandomProjectionTree,
67}
68
69/// Distance metrics
70#[derive(Debug, Clone, Copy)]
71pub enum DistanceMetric {
72    Euclidean,
73    Manhattan,
74    Cosine,
75    Minkowski(f32),
76}
77
78impl DistanceMetric {
79    fn distance(&self, a: &[f32], b: &[f32]) -> f32 {
80        match self {
81            DistanceMetric::Euclidean => f32::euclidean_distance(a, b),
82            DistanceMetric::Manhattan => f32::manhattan_distance(a, b),
83            DistanceMetric::Cosine => f32::cosine_distance(a, b),
84            DistanceMetric::Minkowski(p) => a
85                .iter()
86                .zip(b.iter())
87                .map(|(x, y)| (x - y).abs().powf(*p))
88                .sum::<f32>()
89                .powf(1.0 / p),
90        }
91    }
92}
93
94/// Search result with distance
95#[derive(Debug, Clone)]
96struct SearchResult {
97    index: usize,
98    distance: f32,
99}
100
101impl PartialEq for SearchResult {
102    fn eq(&self, other: &Self) -> bool {
103        self.distance == other.distance
104    }
105}
106
107impl Eq for SearchResult {}
108
109impl PartialOrd for SearchResult {
110    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
111        Some(self.cmp(other))
112    }
113}
114
115impl Ord for SearchResult {
116    fn cmp(&self, other: &Self) -> Ordering {
117        self.partial_cmp(other).unwrap_or(Ordering::Equal)
118    }
119}
120
121/// Ball Tree implementation
122pub struct BallTree {
123    root: Option<Box<BallNode>>,
124    data: Vec<(String, Vector)>,
125    config: TreeIndexConfig,
126}
127
128#[derive(Clone)]
129struct BallNode {
130    /// Center of the ball
131    center: Vec<f32>,
132    /// Radius of the ball
133    radius: f32,
134    /// Left child
135    left: Option<Box<BallNode>>,
136    /// Right child
137    right: Option<Box<BallNode>>,
138    /// Indices of points in this node (for leaf nodes)
139    indices: Vec<usize>,
140}
141
142impl BallTree {
143    pub fn new(config: TreeIndexConfig) -> Self {
144        Self {
145            root: None,
146            data: Vec::new(),
147            config,
148        }
149    }
150
151    /// Build the tree from data with conservative depth limits to prevent stack overflow
152    ///
153    /// Note: Tree indices work best with moderate dataset sizes (< 100K points).
154    /// For larger datasets, consider using HNSW, IVF, or LSH indices instead.
155    pub fn build(&mut self) -> Result<()> {
156        if self.data.is_empty() {
157            return Ok(());
158        }
159
160        let indices: Vec<usize> = (0..self.data.len()).collect();
161        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
162
163        self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
164        Ok(())
165    }
166
167    /// Conservative recursive construction with strict depth limits
168    fn build_node_safe(
169        &self,
170        points: &[Vec<f32>],
171        indices: Vec<usize>,
172        depth: usize,
173    ) -> Result<BallNode> {
174        // VERY conservative depth limit to prevent stack overflow
175        // Limit depth to 20 for safety (can handle ~1M points with leaf_size=10)
176        const MAX_DEPTH: usize = 20;
177
178        // Force leaf creation if:
179        // 1. At or below leaf size
180        // 2. Only 1 or 2 points left
181        // 3. Reached maximum safe depth
182        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
183            let center = self.compute_centroid(points, &indices);
184            let radius = self.compute_radius(points, &indices, &center);
185            return Ok(BallNode {
186                center,
187                radius,
188                left: None,
189                right: None,
190                indices,
191            });
192        }
193
194        // Find split dimension
195        let split_dim = self.find_split_dimension(points, &indices);
196        let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
197
198        // Prevent empty partitions - create leaf instead
199        if left_indices.is_empty() || right_indices.is_empty() {
200            let center = self.compute_centroid(points, &indices);
201            let radius = self.compute_radius(points, &indices, &center);
202            return Ok(BallNode {
203                center,
204                radius,
205                left: None,
206                right: None,
207                indices,
208            });
209        }
210
211        // Recursively build children (limited by MAX_DEPTH)
212        let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
213        let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
214
215        // Compute bounding ball
216        let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
217        let center = self.compute_centroid_of_centers(&all_centers);
218        let radius = left_node.radius.max(right_node.radius)
219            + self
220                .config
221                .distance_metric
222                .distance(&center, &left_node.center);
223
224        Ok(BallNode {
225            center,
226            radius,
227            left: Some(Box::new(left_node)),
228            right: Some(Box::new(right_node)),
229            indices: Vec::new(),
230        })
231    }
232
233    fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
234        let dim = points[0].len();
235        let mut centroid = vec![0.0; dim];
236
237        for &idx in indices {
238            for (i, &val) in points[idx].iter().enumerate() {
239                centroid[i] += val;
240            }
241        }
242
243        let n = indices.len() as f32;
244        for val in &mut centroid {
245            *val /= n;
246        }
247
248        centroid
249    }
250
251    fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
252        indices
253            .iter()
254            .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
255            .fold(0.0f32, f32::max)
256    }
257
258    fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
259        let dim = points[0].len();
260        let mut max_spread = 0.0;
261        let mut split_dim = 0;
262
263        // We need the dimension index `d` to access the d-th component of each point
264        #[allow(clippy::needless_range_loop)]
265        for d in 0..dim {
266            let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
267
268            let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
269            let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
270            let spread = max_val - min_val;
271
272            if spread > max_spread {
273                max_spread = spread;
274                split_dim = d;
275            }
276        }
277
278        split_dim
279    }
280
281    fn partition_indices(
282        &self,
283        points: &[Vec<f32>],
284        indices: &[usize],
285        dim: usize,
286    ) -> (Vec<usize>, Vec<usize>) {
287        let mut values: Vec<(f32, usize)> =
288            indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
289
290        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
291
292        let mid = values.len() / 2;
293        let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
294        let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
295
296        (left_indices, right_indices)
297    }
298
299    fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
300        let dim = centers[0].len();
301        let mut centroid = vec![0.0; dim];
302
303        for center in centers {
304            for (i, &val) in center.iter().enumerate() {
305                centroid[i] += val;
306            }
307        }
308
309        let n = centers.len() as f32;
310        for val in &mut centroid {
311            *val /= n;
312        }
313
314        centroid
315    }
316
317    /// Search for k nearest neighbors using iterative algorithm
318    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
319        if self.root.is_none() {
320            return Vec::new();
321        }
322
323        let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
324        let mut stack: Vec<&BallNode> = vec![self
325            .root
326            .as_ref()
327            .expect("tree should have root after build")];
328
329        while let Some(node) = stack.pop() {
330            // Check if we need to explore this node
331            let dist_to_center = self.config.distance_metric.distance(query, &node.center);
332
333            if heap.len() >= k {
334                let worst_dist = heap.peek().expect("heap should have k elements").distance;
335                if dist_to_center - node.radius > worst_dist {
336                    continue; // Prune this branch
337                }
338            }
339
340            if node.indices.is_empty() {
341                // Internal node - add children to stack
342                if let (Some(left), Some(right)) = (&node.left, &node.right) {
343                    let left_dist = self.config.distance_metric.distance(query, &left.center);
344                    let right_dist = self.config.distance_metric.distance(query, &right.center);
345
346                    // Add in order so closer one is processed first
347                    if left_dist < right_dist {
348                        stack.push(right);
349                        stack.push(left);
350                    } else {
351                        stack.push(left);
352                        stack.push(right);
353                    }
354                }
355            } else {
356                // Leaf node - check all points
357                for &idx in &node.indices {
358                    let point = &self.data[idx].1.as_f32();
359                    let dist = self.config.distance_metric.distance(query, point);
360
361                    if heap.len() < k {
362                        heap.push(SearchResult {
363                            index: idx,
364                            distance: dist,
365                        });
366                    } else if dist < heap.peek().expect("heap should have k elements").distance {
367                        heap.pop();
368                        heap.push(SearchResult {
369                            index: idx,
370                            distance: dist,
371                        });
372                    }
373                }
374            }
375        }
376
377        let mut results: Vec<(usize, f32)> =
378            heap.into_iter().map(|r| (r.index, r.distance)).collect();
379
380        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
381        results
382    }
383}
384
385/// KD-Tree implementation
386pub struct KdTree {
387    root: Option<Box<KdNode>>,
388    data: Vec<(String, Vector)>,
389    config: TreeIndexConfig,
390}
391
392struct KdNode {
393    /// Split dimension
394    split_dim: usize,
395    /// Split value
396    split_value: f32,
397    /// Left child (values <= split_value)
398    left: Option<Box<KdNode>>,
399    /// Right child (values > split_value)
400    right: Option<Box<KdNode>>,
401    /// Indices for leaf nodes
402    indices: Vec<usize>,
403}
404
405impl KdTree {
406    pub fn new(config: TreeIndexConfig) -> Self {
407        Self {
408            root: None,
409            data: Vec::new(),
410            config,
411        }
412    }
413
414    pub fn build(&mut self) -> Result<()> {
415        if self.data.is_empty() {
416            return Ok(());
417        }
418
419        let indices: Vec<usize> = (0..self.data.len()).collect();
420        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
421
422        self.root = Some(Box::new(self.build_node(&points, indices, 0)?));
423        Ok(())
424    }
425
426    fn build_node(&self, points: &[Vec<f32>], indices: Vec<usize>, depth: usize) -> Result<KdNode> {
427        // Reasonable stack overflow prevention with proper depth limit
428        let max_depth = if !self.data.is_empty() {
429            ((self.data.len() as f32).log2() * 2.0) as usize + 10
430        } else {
431            50
432        };
433
434        if indices.len() <= self.config.max_leaf_size || indices.len() <= 1 || depth >= max_depth {
435            return Ok(KdNode {
436                split_dim: 0,
437                split_value: 0.0,
438                left: None,
439                right: None,
440                indices,
441            });
442        }
443
444        let dimensions = points[0].len();
445        let split_dim = depth % dimensions;
446
447        // Find median along split dimension
448        let mut values: Vec<(f32, usize)> = indices
449            .iter()
450            .map(|&idx| (points[idx][split_dim], idx))
451            .collect();
452
453        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
454
455        let median_idx = values.len() / 2;
456        let split_value = values[median_idx].0;
457
458        let left_indices: Vec<usize> = values[..median_idx].iter().map(|(_, idx)| *idx).collect();
459
460        let right_indices: Vec<usize> = values[median_idx..].iter().map(|(_, idx)| *idx).collect();
461
462        // Prevent creating empty partitions - create leaf instead
463        if left_indices.is_empty() || right_indices.is_empty() {
464            return Ok(KdNode {
465                split_dim: 0,
466                split_value: 0.0,
467                left: None,
468                right: None,
469                indices,
470            });
471        }
472
473        let left = Some(Box::new(self.build_node(
474            points,
475            left_indices,
476            depth + 1,
477        )?));
478
479        let right = Some(Box::new(self.build_node(
480            points,
481            right_indices,
482            depth + 1,
483        )?));
484
485        Ok(KdNode {
486            split_dim,
487            split_value,
488            left,
489            right,
490            indices: Vec::new(),
491        })
492    }
493
494    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
495        if self.root.is_none() {
496            return Vec::new();
497        }
498
499        let mut heap = BinaryHeap::new();
500        self.search_node(
501            self.root
502                .as_ref()
503                .expect("tree should have root after build"),
504            query,
505            k,
506            &mut heap,
507        );
508
509        let mut results: Vec<(usize, f32)> =
510            heap.into_iter().map(|r| (r.index, r.distance)).collect();
511
512        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
513        results
514    }
515
516    fn search_node(
517        &self,
518        node: &KdNode,
519        query: &[f32],
520        k: usize,
521        heap: &mut BinaryHeap<SearchResult>,
522    ) {
523        if !node.indices.is_empty() {
524            // Leaf node
525            for &idx in &node.indices {
526                let point = &self.data[idx].1.as_f32();
527                let dist = self.config.distance_metric.distance(query, point);
528
529                if heap.len() < k {
530                    heap.push(SearchResult {
531                        index: idx,
532                        distance: dist,
533                    });
534                } else if dist < heap.peek().expect("heap should have k elements").distance {
535                    heap.pop();
536                    heap.push(SearchResult {
537                        index: idx,
538                        distance: dist,
539                    });
540                }
541            }
542            return;
543        }
544
545        // Determine which side to search first
546        let go_left = query[node.split_dim] <= node.split_value;
547
548        let (first, second) = if go_left {
549            (&node.left, &node.right)
550        } else {
551            (&node.right, &node.left)
552        };
553
554        // Search the nearer side first
555        if let Some(child) = first {
556            self.search_node(child, query, k, heap);
557        }
558
559        // Check if we need to search the other side
560        if heap.len() < k || {
561            let split_dist = (query[node.split_dim] - node.split_value).abs();
562            split_dist < heap.peek().expect("heap should have k elements").distance
563        } {
564            if let Some(child) = second {
565                self.search_node(child, query, k, heap);
566            }
567        }
568    }
569}
570
571/// VP-Tree (Vantage Point Tree) implementation
572pub struct VpTree {
573    root: Option<Box<VpNode>>,
574    data: Vec<(String, Vector)>,
575    config: TreeIndexConfig,
576}
577
578struct VpNode {
579    /// Vantage point index
580    vantage_point: usize,
581    /// Median distance from vantage point
582    median_distance: f32,
583    /// Points closer than median
584    inside: Option<Box<VpNode>>,
585    /// Points farther than median
586    outside: Option<Box<VpNode>>,
587    /// Indices for leaf nodes
588    indices: Vec<usize>,
589}
590
591impl VpTree {
592    pub fn new(config: TreeIndexConfig) -> Self {
593        Self {
594            root: None,
595            data: Vec::new(),
596            config,
597        }
598    }
599
600    pub fn build(&mut self) -> Result<()> {
601        if self.data.is_empty() {
602            return Ok(());
603        }
604
605        let indices: Vec<usize> = (0..self.data.len()).collect();
606        let mut rng = if let Some(seed) = self.config.random_seed {
607            Random::seed(seed)
608        } else {
609            Random::seed(42)
610        };
611
612        self.root = Some(Box::new(self.build_node(indices, &mut rng)?));
613        Ok(())
614    }
615
616    fn build_node<R: Rng>(&self, indices: Vec<usize>, rng: &mut R) -> Result<VpNode> {
617        self.build_node_safe(indices, rng, 0)
618    }
619
620    #[allow(deprecated)]
621    fn build_node_safe<R: Rng>(
622        &self,
623        mut indices: Vec<usize>,
624        rng: &mut R,
625        depth: usize,
626    ) -> Result<VpNode> {
627        // Note: Using manual random selection instead of SliceRandom
628
629        // CRITICAL: Extremely strict depth and size limits to prevent stack overflow
630        // For very small datasets or deep recursion, immediately create leaf nodes
631        let max_depth = 30; // Conservative depth limit
632
633        // Aggressive leaf node creation for small datasets
634        if indices.len() <= self.config.max_leaf_size
635            || indices.len() <= 2  // Changed from <= 1 to <= 2 for extra safety
636            || depth >= max_depth
637        {
638            return Ok(VpNode {
639                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
640                median_distance: 0.0,
641                inside: None,
642                outside: None,
643                indices,
644            });
645        }
646
647        // Choose random vantage point - simplified to avoid potential issues
648        let vp_idx = if indices.len() > 1 {
649            rng.gen_range(0..indices.len())
650        } else {
651            0
652        };
653        let vantage_point = indices[vp_idx];
654        indices.remove(vp_idx);
655
656        // Calculate distances from vantage point
657        let vp_data = &self.data[vantage_point].1.as_f32();
658        let mut distances: Vec<(f32, usize)> = indices
659            .iter()
660            .map(|&idx| {
661                let point = &self.data[idx].1.as_f32();
662                let dist = self.config.distance_metric.distance(vp_data, point);
663                (dist, idx)
664            })
665            .collect();
666
667        distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
668
669        let median_idx = distances.len() / 2;
670        let median_distance = distances[median_idx].0;
671
672        let inside_indices: Vec<usize> = distances[..median_idx]
673            .iter()
674            .map(|(_, idx)| *idx)
675            .collect();
676
677        let outside_indices: Vec<usize> = distances[median_idx..]
678            .iter()
679            .map(|(_, idx)| *idx)
680            .collect();
681
682        // Prevent creating empty partitions - create leaf instead
683        if inside_indices.is_empty() || outside_indices.is_empty() {
684            return Ok(VpNode {
685                vantage_point: if indices.is_empty() { 0 } else { indices[0] },
686                median_distance: 0.0,
687                inside: None,
688                outside: None,
689                indices,
690            });
691        }
692
693        let inside = Some(Box::new(self.build_node_safe(
694            inside_indices,
695            rng,
696            depth + 1,
697        )?));
698        let outside = Some(Box::new(self.build_node_safe(
699            outside_indices,
700            rng,
701            depth + 1,
702        )?));
703
704        Ok(VpNode {
705            vantage_point,
706            median_distance,
707            inside,
708            outside,
709            indices: Vec::new(),
710        })
711    }
712
713    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
714        if self.root.is_none() {
715            return Vec::new();
716        }
717
718        let mut heap = BinaryHeap::new();
719        self.search_node(
720            self.root
721                .as_ref()
722                .expect("tree should have root after build"),
723            query,
724            k,
725            &mut heap,
726            f32::INFINITY,
727        );
728
729        let mut results: Vec<(usize, f32)> =
730            heap.into_iter().map(|r| (r.index, r.distance)).collect();
731
732        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
733        results
734    }
735
736    fn search_node(
737        &self,
738        node: &VpNode,
739        query: &[f32],
740        k: usize,
741        heap: &mut BinaryHeap<SearchResult>,
742        tau: f32,
743    ) -> f32 {
744        let mut tau = tau;
745
746        if !node.indices.is_empty() {
747            // Leaf node
748            for &idx in &node.indices {
749                let point = &self.data[idx].1.as_f32();
750                let dist = self.config.distance_metric.distance(query, point);
751
752                if dist < tau {
753                    if heap.len() < k {
754                        heap.push(SearchResult {
755                            index: idx,
756                            distance: dist,
757                        });
758                    } else if dist < heap.peek().expect("heap should have k elements").distance {
759                        heap.pop();
760                        heap.push(SearchResult {
761                            index: idx,
762                            distance: dist,
763                        });
764                    }
765
766                    if heap.len() >= k {
767                        tau = heap.peek().expect("heap should have k elements").distance;
768                    }
769                }
770            }
771            return tau;
772        }
773
774        // Calculate distance to vantage point
775        let vp_data = &self.data[node.vantage_point].1.as_f32();
776        let dist_to_vp = self.config.distance_metric.distance(query, vp_data);
777
778        // Consider vantage point itself
779        if dist_to_vp < tau {
780            if heap.len() < k {
781                heap.push(SearchResult {
782                    index: node.vantage_point,
783                    distance: dist_to_vp,
784                });
785            } else if dist_to_vp < heap.peek().expect("heap should have k elements").distance {
786                heap.pop();
787                heap.push(SearchResult {
788                    index: node.vantage_point,
789                    distance: dist_to_vp,
790                });
791            }
792
793            if heap.len() >= k {
794                tau = heap.peek().expect("heap should have k elements").distance;
795            }
796        }
797
798        // Search children
799        if dist_to_vp < node.median_distance {
800            // Search inside first
801            if let Some(inside) = &node.inside {
802                tau = self.search_node(inside, query, k, heap, tau);
803            }
804
805            // Check if we need to search outside
806            if dist_to_vp + tau >= node.median_distance {
807                if let Some(outside) = &node.outside {
808                    tau = self.search_node(outside, query, k, heap, tau);
809                }
810            }
811        } else {
812            // Search outside first
813            if let Some(outside) = &node.outside {
814                tau = self.search_node(outside, query, k, heap, tau);
815            }
816
817            // Check if we need to search inside
818            if dist_to_vp - tau <= node.median_distance {
819                if let Some(inside) = &node.inside {
820                    tau = self.search_node(inside, query, k, heap, tau);
821                }
822            }
823        }
824
825        tau
826    }
827}
828
829/// Cover Tree implementation
830pub struct CoverTree {
831    root: Option<Box<CoverNode>>,
832    data: Vec<(String, Vector)>,
833    config: TreeIndexConfig,
834    base: f32,
835}
836
837struct CoverNode {
838    /// Point index
839    point: usize,
840    /// Level in the tree
841    level: i32,
842    /// Children at the same or lower level
843    #[allow(clippy::vec_box)] // Box is necessary for recursive structure
844    children: Vec<Box<CoverNode>>,
845}
846
847impl CoverTree {
848    pub fn new(config: TreeIndexConfig) -> Self {
849        Self {
850            root: None,
851            data: Vec::new(),
852            config,
853            base: 2.0, // Base for the covering constant
854        }
855    }
856
857    pub fn build(&mut self) -> Result<()> {
858        if self.data.is_empty() {
859            return Ok(());
860        }
861
862        // Initialize with first point
863        self.root = Some(Box::new(CoverNode {
864            point: 0,
865            level: self.get_level(0),
866            children: Vec::new(),
867        }));
868
869        // Insert remaining points
870        for idx in 1..self.data.len() {
871            self.insert(idx)?;
872        }
873
874        Ok(())
875    }
876
877    fn get_level(&self, _point_idx: usize) -> i32 {
878        // Simple heuristic for initial level
879        ((self.data.len() as f32).log2() as i32).max(0)
880    }
881
882    fn insert(&mut self, point_idx: usize) -> Result<()> {
883        // Simplified insert - in practice, this would be more complex
884        // to maintain the cover tree invariants
885        let level = self.get_level(point_idx);
886        if let Some(root) = &mut self.root {
887            root.children.push(Box::new(CoverNode {
888                point: point_idx,
889                level,
890                children: Vec::new(),
891            }));
892        }
893        Ok(())
894    }
895
896    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
897        if self.root.is_none() {
898            return Vec::new();
899        }
900
901        let mut results = Vec::new();
902        self.search_node(
903            self.root
904                .as_ref()
905                .expect("tree should have root after build"),
906            query,
907            k,
908            &mut results,
909        );
910
911        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
912        results.truncate(k);
913        results
914    }
915
916    #[allow(clippy::only_used_in_recursion)]
917    fn search_node(
918        &self,
919        node: &CoverNode,
920        query: &[f32],
921        k: usize,
922        results: &mut Vec<(usize, f32)>,
923    ) {
924        // Prevent excessive recursion depth
925        if results.len() >= k * 10 {
926            return;
927        }
928
929        let point_data = &self.data[node.point].1.as_f32();
930        let dist = self.config.distance_metric.distance(query, point_data);
931
932        results.push((node.point, dist));
933
934        // Search children
935        for child in &node.children {
936            self.search_node(child, query, k, results);
937        }
938    }
939}
940
941/// Random Projection Tree implementation
942pub struct RandomProjectionTree {
943    root: Option<Box<RpNode>>,
944    data: Vec<(String, Vector)>,
945    config: TreeIndexConfig,
946}
947
948struct RpNode {
949    /// Random projection vector
950    projection: Vec<f32>,
951    /// Projection threshold
952    threshold: f32,
953    /// Left child (projection <= threshold)
954    left: Option<Box<RpNode>>,
955    /// Right child (projection > threshold)
956    right: Option<Box<RpNode>>,
957    /// Indices for leaf nodes
958    indices: Vec<usize>,
959}
960
961impl RandomProjectionTree {
962    pub fn new(config: TreeIndexConfig) -> Self {
963        Self {
964            root: None,
965            data: Vec::new(),
966            config,
967        }
968    }
969
970    pub fn build(&mut self) -> Result<()> {
971        if self.data.is_empty() {
972            return Ok(());
973        }
974
975        let indices: Vec<usize> = (0..self.data.len()).collect();
976        let dimensions = self.data[0].1.dimensions;
977
978        let mut rng = if let Some(seed) = self.config.random_seed {
979            Random::seed(seed)
980        } else {
981            Random::seed(42)
982        };
983
984        self.root = Some(Box::new(self.build_node(indices, dimensions, &mut rng)?));
985        Ok(())
986    }
987
988    fn build_node<R: Rng>(
989        &self,
990        indices: Vec<usize>,
991        dimensions: usize,
992        rng: &mut R,
993    ) -> Result<RpNode> {
994        self.build_node_safe(indices, dimensions, rng, 0)
995    }
996
997    #[allow(deprecated)]
998    fn build_node_safe<R: Rng>(
999        &self,
1000        indices: Vec<usize>,
1001        dimensions: usize,
1002        rng: &mut R,
1003        depth: usize,
1004    ) -> Result<RpNode> {
1005        // Very strict stack overflow prevention - similar to BallTree approach
1006        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= 5 {
1007            return Ok(RpNode {
1008                projection: Vec::new(),
1009                threshold: 0.0,
1010                left: None,
1011                right: None,
1012                indices,
1013            });
1014        }
1015
1016        // Generate random projection vector
1017        let projection: Vec<f32> = (0..dimensions).map(|_| rng.gen_range(-1.0..1.0)).collect();
1018
1019        // Normalize projection vector
1020        let norm = (projection.iter().map(|&x| x * x).sum::<f32>()).sqrt();
1021        let projection: Vec<f32> = if norm > 0.0 {
1022            projection.iter().map(|&x| x / norm).collect()
1023        } else {
1024            projection
1025        };
1026
1027        // Project all points
1028        let mut projections: Vec<(f32, usize)> = indices
1029            .iter()
1030            .map(|&idx| {
1031                let point = &self.data[idx].1.as_f32();
1032                let proj_val = f32::dot(point, &projection);
1033                (proj_val, idx)
1034            })
1035            .collect();
1036
1037        projections.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
1038
1039        // Choose median as threshold
1040        let median_idx = projections.len() / 2;
1041        let threshold = projections[median_idx].0;
1042
1043        let left_indices: Vec<usize> = projections[..median_idx]
1044            .iter()
1045            .map(|(_, idx)| *idx)
1046            .collect();
1047
1048        let right_indices: Vec<usize> = projections[median_idx..]
1049            .iter()
1050            .map(|(_, idx)| *idx)
1051            .collect();
1052
1053        // Prevent creating empty partitions - create leaf instead
1054        if left_indices.is_empty() || right_indices.is_empty() {
1055            return Ok(RpNode {
1056                projection: Vec::new(),
1057                threshold: 0.0,
1058                left: None,
1059                right: None,
1060                indices,
1061            });
1062        }
1063
1064        let left = Some(Box::new(self.build_node_safe(
1065            left_indices,
1066            dimensions,
1067            rng,
1068            depth + 1,
1069        )?));
1070        let right = Some(Box::new(self.build_node_safe(
1071            right_indices,
1072            dimensions,
1073            rng,
1074            depth + 1,
1075        )?));
1076
1077        Ok(RpNode {
1078            projection,
1079            threshold,
1080            left,
1081            right,
1082            indices: Vec::new(),
1083        })
1084    }
1085
1086    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1087        if self.root.is_none() {
1088            return Vec::new();
1089        }
1090
1091        let mut heap = BinaryHeap::new();
1092        self.search_node(
1093            self.root
1094                .as_ref()
1095                .expect("tree should have root after build"),
1096            query,
1097            k,
1098            &mut heap,
1099        );
1100
1101        let mut results: Vec<(usize, f32)> =
1102            heap.into_iter().map(|r| (r.index, r.distance)).collect();
1103
1104        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
1105        results
1106    }
1107
1108    fn search_node(
1109        &self,
1110        node: &RpNode,
1111        query: &[f32],
1112        k: usize,
1113        heap: &mut BinaryHeap<SearchResult>,
1114    ) {
1115        if !node.indices.is_empty() {
1116            // Leaf node
1117            for &idx in &node.indices {
1118                let point = &self.data[idx].1.as_f32();
1119                let dist = self.config.distance_metric.distance(query, point);
1120
1121                if heap.len() < k {
1122                    heap.push(SearchResult {
1123                        index: idx,
1124                        distance: dist,
1125                    });
1126                } else if dist < heap.peek().expect("heap should have k elements").distance {
1127                    heap.pop();
1128                    heap.push(SearchResult {
1129                        index: idx,
1130                        distance: dist,
1131                    });
1132                }
1133            }
1134            return;
1135        }
1136
1137        // Project query
1138        let query_projection = f32::dot(query, &node.projection);
1139
1140        // Determine which side to search first
1141        let go_left = query_projection <= node.threshold;
1142
1143        let (first, second) = if go_left {
1144            (&node.left, &node.right)
1145        } else {
1146            (&node.right, &node.left)
1147        };
1148
1149        // Search both sides (random projections don't provide distance bounds)
1150        if let Some(child) = first {
1151            self.search_node(child, query, k, heap);
1152        }
1153
1154        if let Some(child) = second {
1155            self.search_node(child, query, k, heap);
1156        }
1157    }
1158}
1159
1160/// Unified tree index interface
1161pub struct TreeIndex {
1162    tree_type: TreeType,
1163    ball_tree: Option<BallTree>,
1164    kd_tree: Option<KdTree>,
1165    vp_tree: Option<VpTree>,
1166    cover_tree: Option<CoverTree>,
1167    rp_tree: Option<RandomProjectionTree>,
1168}
1169
1170impl TreeIndex {
1171    pub fn new(config: TreeIndexConfig) -> Self {
1172        let tree_type = config.tree_type;
1173
1174        let (ball_tree, kd_tree, vp_tree, cover_tree, rp_tree) = match tree_type {
1175            TreeType::BallTree => (Some(BallTree::new(config)), None, None, None, None),
1176            TreeType::KdTree => (None, Some(KdTree::new(config)), None, None, None),
1177            TreeType::VpTree => (None, None, Some(VpTree::new(config)), None, None),
1178            TreeType::CoverTree => (None, None, None, Some(CoverTree::new(config)), None),
1179            TreeType::RandomProjectionTree => (
1180                None,
1181                None,
1182                None,
1183                None,
1184                Some(RandomProjectionTree::new(config)),
1185            ),
1186        };
1187
1188        Self {
1189            tree_type,
1190            ball_tree,
1191            kd_tree,
1192            vp_tree,
1193            cover_tree,
1194            rp_tree,
1195        }
1196    }
1197
1198    pub fn build(&mut self) -> Result<()> {
1199        match self.tree_type {
1200            TreeType::BallTree => self
1201                .ball_tree
1202                .as_mut()
1203                .expect("ball_tree should be initialized for BallTree type")
1204                .build(),
1205            TreeType::KdTree => self
1206                .kd_tree
1207                .as_mut()
1208                .expect("kd_tree should be initialized for KdTree type")
1209                .build(),
1210            TreeType::VpTree => self
1211                .vp_tree
1212                .as_mut()
1213                .expect("vp_tree should be initialized for VpTree type")
1214                .build(),
1215            TreeType::CoverTree => self
1216                .cover_tree
1217                .as_mut()
1218                .expect("cover_tree should be initialized for CoverTree type")
1219                .build(),
1220            TreeType::RandomProjectionTree => self
1221                .rp_tree
1222                .as_mut()
1223                .expect("rp_tree should be initialized for RandomProjectionTree type")
1224                .build(),
1225        }
1226    }
1227
1228    fn search_internal(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
1229        match self.tree_type {
1230            TreeType::BallTree => self
1231                .ball_tree
1232                .as_ref()
1233                .expect("ball_tree should be initialized for BallTree type")
1234                .search(query, k),
1235            TreeType::KdTree => self
1236                .kd_tree
1237                .as_ref()
1238                .expect("kd_tree should be initialized for KdTree type")
1239                .search(query, k),
1240            TreeType::VpTree => self
1241                .vp_tree
1242                .as_ref()
1243                .expect("vp_tree should be initialized for VpTree type")
1244                .search(query, k),
1245            TreeType::CoverTree => self
1246                .cover_tree
1247                .as_ref()
1248                .expect("cover_tree should be initialized for CoverTree type")
1249                .search(query, k),
1250            TreeType::RandomProjectionTree => self
1251                .rp_tree
1252                .as_ref()
1253                .expect("rp_tree should be initialized for RandomProjectionTree type")
1254                .search(query, k),
1255        }
1256    }
1257}
1258
1259impl VectorIndex for TreeIndex {
1260    fn insert(&mut self, uri: String, vector: Vector) -> Result<()> {
1261        let data = match self.tree_type {
1262            TreeType::BallTree => {
1263                &mut self
1264                    .ball_tree
1265                    .as_mut()
1266                    .expect("ball_tree should be initialized for BallTree type")
1267                    .data
1268            }
1269            TreeType::KdTree => {
1270                &mut self
1271                    .kd_tree
1272                    .as_mut()
1273                    .expect("kd_tree should be initialized for KdTree type")
1274                    .data
1275            }
1276            TreeType::VpTree => {
1277                &mut self
1278                    .vp_tree
1279                    .as_mut()
1280                    .expect("vp_tree should be initialized for VpTree type")
1281                    .data
1282            }
1283            TreeType::CoverTree => {
1284                &mut self
1285                    .cover_tree
1286                    .as_mut()
1287                    .expect("cover_tree should be initialized for CoverTree type")
1288                    .data
1289            }
1290            TreeType::RandomProjectionTree => {
1291                &mut self
1292                    .rp_tree
1293                    .as_mut()
1294                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1295                    .data
1296            }
1297        };
1298
1299        data.push((uri, vector));
1300        Ok(())
1301    }
1302
1303    fn search_knn(&self, query: &Vector, k: usize) -> Result<Vec<(String, f32)>> {
1304        let query_f32 = query.as_f32();
1305        let results = self.search_internal(&query_f32, k);
1306
1307        let data = match self.tree_type {
1308            TreeType::BallTree => {
1309                &self
1310                    .ball_tree
1311                    .as_ref()
1312                    .expect("ball_tree should be initialized for BallTree type")
1313                    .data
1314            }
1315            TreeType::KdTree => {
1316                &self
1317                    .kd_tree
1318                    .as_ref()
1319                    .expect("kd_tree should be initialized for KdTree type")
1320                    .data
1321            }
1322            TreeType::VpTree => {
1323                &self
1324                    .vp_tree
1325                    .as_ref()
1326                    .expect("vp_tree should be initialized for VpTree type")
1327                    .data
1328            }
1329            TreeType::CoverTree => {
1330                &self
1331                    .cover_tree
1332                    .as_ref()
1333                    .expect("cover_tree should be initialized for CoverTree type")
1334                    .data
1335            }
1336            TreeType::RandomProjectionTree => {
1337                &self
1338                    .rp_tree
1339                    .as_ref()
1340                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1341                    .data
1342            }
1343        };
1344
1345        Ok(results
1346            .into_iter()
1347            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1348            .collect())
1349    }
1350
1351    fn search_threshold(&self, query: &Vector, threshold: f32) -> Result<Vec<(String, f32)>> {
1352        let query_f32 = query.as_f32();
1353        let all_results = self.search_internal(&query_f32, 1000); // Search more broadly
1354
1355        let data = match self.tree_type {
1356            TreeType::BallTree => {
1357                &self
1358                    .ball_tree
1359                    .as_ref()
1360                    .expect("ball_tree should be initialized for BallTree type")
1361                    .data
1362            }
1363            TreeType::KdTree => {
1364                &self
1365                    .kd_tree
1366                    .as_ref()
1367                    .expect("kd_tree should be initialized for KdTree type")
1368                    .data
1369            }
1370            TreeType::VpTree => {
1371                &self
1372                    .vp_tree
1373                    .as_ref()
1374                    .expect("vp_tree should be initialized for VpTree type")
1375                    .data
1376            }
1377            TreeType::CoverTree => {
1378                &self
1379                    .cover_tree
1380                    .as_ref()
1381                    .expect("cover_tree should be initialized for CoverTree type")
1382                    .data
1383            }
1384            TreeType::RandomProjectionTree => {
1385                &self
1386                    .rp_tree
1387                    .as_ref()
1388                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1389                    .data
1390            }
1391        };
1392
1393        Ok(all_results
1394            .into_iter()
1395            .filter(|(_, dist)| *dist <= threshold)
1396            .map(|(idx, dist)| (data[idx].0.clone(), dist))
1397            .collect())
1398    }
1399
1400    fn get_vector(&self, uri: &str) -> Option<&Vector> {
1401        let data = match self.tree_type {
1402            TreeType::BallTree => {
1403                &self
1404                    .ball_tree
1405                    .as_ref()
1406                    .expect("ball_tree should be initialized for BallTree type")
1407                    .data
1408            }
1409            TreeType::KdTree => {
1410                &self
1411                    .kd_tree
1412                    .as_ref()
1413                    .expect("kd_tree should be initialized for KdTree type")
1414                    .data
1415            }
1416            TreeType::VpTree => {
1417                &self
1418                    .vp_tree
1419                    .as_ref()
1420                    .expect("vp_tree should be initialized for VpTree type")
1421                    .data
1422            }
1423            TreeType::CoverTree => {
1424                &self
1425                    .cover_tree
1426                    .as_ref()
1427                    .expect("cover_tree should be initialized for CoverTree type")
1428                    .data
1429            }
1430            TreeType::RandomProjectionTree => {
1431                &self
1432                    .rp_tree
1433                    .as_ref()
1434                    .expect("rp_tree should be initialized for RandomProjectionTree type")
1435                    .data
1436            }
1437        };
1438
1439        data.iter().find(|(u, _)| u == uri).map(|(_, v)| v)
1440    }
1441}
1442
1443// Add rand to dependencies for VP-Tree and Random Projection Tree
1444// Note: Replaced with scirs2_core::random
1445
1446// Placeholder for async task spawning - integrate with oxirs-core::parallel
1447async fn spawn_task<F, T>(f: F) -> T
1448where
1449    F: FnOnce() -> T + Send + 'static,
1450    T: Send + 'static,
1451{
1452    // In practice, this would use oxirs-core::parallel's task spawning
1453    f()
1454}
1455
1456#[cfg(test)]
1457mod tests {
1458    use super::*;
1459
1460    #[test]
1461    #[ignore = "Tree indices are experimental - see module documentation for alternatives"]
1462    fn test_ball_tree() {
1463        let config = TreeIndexConfig {
1464            tree_type: TreeType::BallTree,
1465            max_leaf_size: 10,
1466            ..Default::default()
1467        };
1468
1469        let mut ball_tree = BallTree::new(config);
1470
1471        // Add 100 vectors
1472        for i in 0..100 {
1473            let vector = Vector::new(vec![i as f32, (i * 2) as f32]);
1474            ball_tree.data.push((format!("vec_{i}"), vector));
1475        }
1476
1477        // Build and search
1478        ball_tree.build().unwrap();
1479        assert!(ball_tree.root.is_some());
1480
1481        let query = vec![50.0, 100.0];
1482        let results = ball_tree.search(&query, 5);
1483
1484        assert!(results.len() <= 5);
1485        assert!(!results.is_empty());
1486    }
1487
1488    #[test]
1489    #[ignore = "Investigating stack overflow with recursive tree construction"]
1490    fn test_kd_tree() {
1491        let config = TreeIndexConfig {
1492            tree_type: TreeType::KdTree,
1493            max_leaf_size: 50, // Extremely large leaf size to force leaf nodes
1494            ..Default::default()
1495        };
1496
1497        let mut index = TreeIndex::new(config);
1498
1499        // Tiny dataset to prevent stack overflow
1500        for i in 0..3 {
1501            let vector = Vector::new(vec![i as f32, (3 - i) as f32]);
1502            index.insert(format!("vec_{i}"), vector).unwrap();
1503        }
1504
1505        index.build().unwrap();
1506
1507        // Search for nearest neighbors
1508        let query = Vector::new(vec![1.0, 2.0]);
1509        let results = index.search_knn(&query, 2).unwrap();
1510
1511        assert_eq!(results.len(), 2);
1512    }
1513
1514    #[test]
1515    #[ignore = "Investigating stack overflow with recursive tree construction"]
1516    fn test_vp_tree() {
1517        let config = TreeIndexConfig {
1518            tree_type: TreeType::VpTree,
1519            random_seed: Some(42),
1520            max_leaf_size: 50, // Extremely large leaf size to force leaf nodes
1521            ..Default::default()
1522        };
1523
1524        let mut index = TreeIndex::new(config);
1525
1526        // Tiny dataset to prevent stack overflow
1527        for i in 0..3 {
1528            let angle = (i as f32) * std::f32::consts::PI / 4.0;
1529            let vector = Vector::new(vec![angle.cos(), angle.sin()]);
1530            index.insert(format!("vec_{i}"), vector).unwrap();
1531        }
1532
1533        index.build().unwrap();
1534
1535        // Search for nearest neighbors
1536        let query = Vector::new(vec![1.0, 0.0]);
1537        let results = index.search_knn(&query, 2).unwrap();
1538
1539        assert_eq!(results.len(), 2);
1540    }
1541}