Skip to main content

oxirs_vec/
tree_indices_balltree.rs

1//! Ball Tree implementation for nearest-neighbor search.
2//!
3//! A ball tree recursively partitions points into nested hyperspheres ("balls"),
4//! making it efficient for arbitrary distance metrics.
5
6use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12/// Ball Tree implementation
13pub struct BallTree {
14    pub(crate) root: Option<Box<BallNode>>,
15    pub(crate) data: Vec<(String, Vector)>,
16    pub(crate) config: TreeIndexConfig,
17}
18
19#[derive(Clone)]
20pub(crate) struct BallNode {
21    /// Center of the ball
22    center: Vec<f32>,
23    /// Radius of the ball
24    radius: f32,
25    /// Left child
26    left: Option<Box<BallNode>>,
27    /// Right child
28    right: Option<Box<BallNode>>,
29    /// Indices of points in this node (for leaf nodes)
30    indices: Vec<usize>,
31}
32
33impl BallTree {
34    pub fn new(config: TreeIndexConfig) -> Self {
35        Self {
36            root: None,
37            data: Vec::new(),
38            config,
39        }
40    }
41
42    /// Build the tree from data with conservative depth limits to prevent stack overflow
43    ///
44    /// Note: Tree indices work best with moderate dataset sizes (< 100K points).
45    /// For larger datasets, consider using HNSW, IVF, or LSH indices instead.
46    pub fn build(&mut self) -> Result<()> {
47        if self.data.is_empty() {
48            return Ok(());
49        }
50
51        let indices: Vec<usize> = (0..self.data.len()).collect();
52        let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
53
54        self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
55        Ok(())
56    }
57
58    /// Conservative recursive construction with strict depth limits
59    fn build_node_safe(
60        &self,
61        points: &[Vec<f32>],
62        indices: Vec<usize>,
63        depth: usize,
64    ) -> Result<BallNode> {
65        // VERY conservative depth limit to prevent stack overflow
66        // Limit depth to 20 for safety (can handle ~1M points with leaf_size=10)
67        const MAX_DEPTH: usize = 20;
68
69        // Force leaf creation if:
70        // 1. At or below leaf size
71        // 2. Only 1 or 2 points left
72        // 3. Reached maximum safe depth
73        if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
74            let center = self.compute_centroid(points, &indices);
75            let radius = self.compute_radius(points, &indices, &center);
76            return Ok(BallNode {
77                center,
78                radius,
79                left: None,
80                right: None,
81                indices,
82            });
83        }
84
85        // Find split dimension
86        let split_dim = self.find_split_dimension(points, &indices);
87        let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
88
89        // Prevent empty partitions - create leaf instead
90        if left_indices.is_empty() || right_indices.is_empty() {
91            let center = self.compute_centroid(points, &indices);
92            let radius = self.compute_radius(points, &indices, &center);
93            return Ok(BallNode {
94                center,
95                radius,
96                left: None,
97                right: None,
98                indices,
99            });
100        }
101
102        // Recursively build children (limited by MAX_DEPTH)
103        let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
104        let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
105
106        // Compute bounding ball
107        let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
108        let center = self.compute_centroid_of_centers(&all_centers);
109        let radius = left_node.radius.max(right_node.radius)
110            + self
111                .config
112                .distance_metric
113                .distance(&center, &left_node.center);
114
115        Ok(BallNode {
116            center,
117            radius,
118            left: Some(Box::new(left_node)),
119            right: Some(Box::new(right_node)),
120            indices: Vec::new(),
121        })
122    }
123
124    fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
125        let dim = points[0].len();
126        let mut centroid = vec![0.0; dim];
127
128        for &idx in indices {
129            for (i, &val) in points[idx].iter().enumerate() {
130                centroid[i] += val;
131            }
132        }
133
134        let n = indices.len() as f32;
135        for val in &mut centroid {
136            *val /= n;
137        }
138
139        centroid
140    }
141
142    fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
143        indices
144            .iter()
145            .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
146            .fold(0.0f32, f32::max)
147    }
148
149    fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
150        let dim = points[0].len();
151        let mut max_spread = 0.0;
152        let mut split_dim = 0;
153
154        // We need the dimension index `d` to access the d-th component of each point
155        #[allow(clippy::needless_range_loop)]
156        for d in 0..dim {
157            let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
158
159            let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
160            let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
161            let spread = max_val - min_val;
162
163            if spread > max_spread {
164                max_spread = spread;
165                split_dim = d;
166            }
167        }
168
169        split_dim
170    }
171
172    fn partition_indices(
173        &self,
174        points: &[Vec<f32>],
175        indices: &[usize],
176        dim: usize,
177    ) -> (Vec<usize>, Vec<usize>) {
178        let mut values: Vec<(f32, usize)> =
179            indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
180
181        values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
182
183        let mid = values.len() / 2;
184        let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
185        let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
186
187        (left_indices, right_indices)
188    }
189
190    fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
191        let dim = centers[0].len();
192        let mut centroid = vec![0.0; dim];
193
194        for center in centers {
195            for (i, &val) in center.iter().enumerate() {
196                centroid[i] += val;
197            }
198        }
199
200        let n = centers.len() as f32;
201        for val in &mut centroid {
202            *val /= n;
203        }
204
205        centroid
206    }
207
208    /// Search for k nearest neighbors using iterative algorithm
209    pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
210        if self.root.is_none() {
211            return Vec::new();
212        }
213
214        let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
215        let mut stack: Vec<&BallNode> = vec![self
216            .root
217            .as_ref()
218            .expect("tree should have root after build")];
219
220        while let Some(node) = stack.pop() {
221            // Check if we need to explore this node
222            let dist_to_center = self.config.distance_metric.distance(query, &node.center);
223
224            if heap.len() >= k {
225                let worst_dist = heap.peek().expect("heap should have k elements").distance;
226                if dist_to_center - node.radius > worst_dist {
227                    continue; // Prune this branch
228                }
229            }
230
231            if node.indices.is_empty() {
232                // Internal node - add children to stack
233                if let (Some(left), Some(right)) = (&node.left, &node.right) {
234                    let left_dist = self.config.distance_metric.distance(query, &left.center);
235                    let right_dist = self.config.distance_metric.distance(query, &right.center);
236
237                    // Add in order so closer one is processed first
238                    if left_dist < right_dist {
239                        stack.push(right);
240                        stack.push(left);
241                    } else {
242                        stack.push(left);
243                        stack.push(right);
244                    }
245                }
246            } else {
247                // Leaf node - check all points
248                for &idx in &node.indices {
249                    let point = &self.data[idx].1.as_f32();
250                    let dist = self.config.distance_metric.distance(query, point);
251
252                    if heap.len() < k {
253                        heap.push(SearchResult {
254                            index: idx,
255                            distance: dist,
256                        });
257                    } else if dist < heap.peek().expect("heap should have k elements").distance {
258                        heap.pop();
259                        heap.push(SearchResult {
260                            index: idx,
261                            distance: dist,
262                        });
263                    }
264                }
265            }
266        }
267
268        let mut results: Vec<(usize, f32)> =
269            heap.into_iter().map(|r| (r.index, r.distance)).collect();
270
271        results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
272        results
273    }
274}