1use crate::tree_indices_types::{SearchResult, TreeIndexConfig};
7use crate::Vector;
8use anyhow::Result;
9use std::cmp::Ordering;
10use std::collections::BinaryHeap;
11
12pub struct BallTree {
14 pub(crate) root: Option<Box<BallNode>>,
15 pub(crate) data: Vec<(String, Vector)>,
16 pub(crate) config: TreeIndexConfig,
17}
18
19#[derive(Clone)]
20pub(crate) struct BallNode {
21 center: Vec<f32>,
23 radius: f32,
25 left: Option<Box<BallNode>>,
27 right: Option<Box<BallNode>>,
29 indices: Vec<usize>,
31}
32
33impl BallTree {
34 pub fn new(config: TreeIndexConfig) -> Self {
35 Self {
36 root: None,
37 data: Vec::new(),
38 config,
39 }
40 }
41
42 pub fn build(&mut self) -> Result<()> {
47 if self.data.is_empty() {
48 return Ok(());
49 }
50
51 let indices: Vec<usize> = (0..self.data.len()).collect();
52 let points: Vec<Vec<f32>> = self.data.iter().map(|(_, v)| v.as_f32()).collect();
53
54 self.root = Some(Box::new(self.build_node_safe(&points, indices, 0)?));
55 Ok(())
56 }
57
58 fn build_node_safe(
60 &self,
61 points: &[Vec<f32>],
62 indices: Vec<usize>,
63 depth: usize,
64 ) -> Result<BallNode> {
65 const MAX_DEPTH: usize = 20;
68
69 if indices.len() <= self.config.max_leaf_size || indices.len() <= 2 || depth >= MAX_DEPTH {
74 let center = self.compute_centroid(points, &indices);
75 let radius = self.compute_radius(points, &indices, ¢er);
76 return Ok(BallNode {
77 center,
78 radius,
79 left: None,
80 right: None,
81 indices,
82 });
83 }
84
85 let split_dim = self.find_split_dimension(points, &indices);
87 let (left_indices, right_indices) = self.partition_indices(points, &indices, split_dim);
88
89 if left_indices.is_empty() || right_indices.is_empty() {
91 let center = self.compute_centroid(points, &indices);
92 let radius = self.compute_radius(points, &indices, ¢er);
93 return Ok(BallNode {
94 center,
95 radius,
96 left: None,
97 right: None,
98 indices,
99 });
100 }
101
102 let left_node = self.build_node_safe(points, left_indices, depth + 1)?;
104 let right_node = self.build_node_safe(points, right_indices, depth + 1)?;
105
106 let all_centers = vec![left_node.center.clone(), right_node.center.clone()];
108 let center = self.compute_centroid_of_centers(&all_centers);
109 let radius = left_node.radius.max(right_node.radius)
110 + self
111 .config
112 .distance_metric
113 .distance(¢er, &left_node.center);
114
115 Ok(BallNode {
116 center,
117 radius,
118 left: Some(Box::new(left_node)),
119 right: Some(Box::new(right_node)),
120 indices: Vec::new(),
121 })
122 }
123
124 fn compute_centroid(&self, points: &[Vec<f32>], indices: &[usize]) -> Vec<f32> {
125 let dim = points[0].len();
126 let mut centroid = vec![0.0; dim];
127
128 for &idx in indices {
129 for (i, &val) in points[idx].iter().enumerate() {
130 centroid[i] += val;
131 }
132 }
133
134 let n = indices.len() as f32;
135 for val in &mut centroid {
136 *val /= n;
137 }
138
139 centroid
140 }
141
142 fn compute_radius(&self, points: &[Vec<f32>], indices: &[usize], center: &[f32]) -> f32 {
143 indices
144 .iter()
145 .map(|&idx| self.config.distance_metric.distance(&points[idx], center))
146 .fold(0.0f32, f32::max)
147 }
148
149 fn find_split_dimension(&self, points: &[Vec<f32>], indices: &[usize]) -> usize {
150 let dim = points[0].len();
151 let mut max_spread = 0.0;
152 let mut split_dim = 0;
153
154 #[allow(clippy::needless_range_loop)]
156 for d in 0..dim {
157 let values: Vec<f32> = indices.iter().map(|&idx| points[idx][d]).collect();
158
159 let min_val = values.iter().fold(f32::INFINITY, |a, &b| a.min(b));
160 let max_val = values.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
161 let spread = max_val - min_val;
162
163 if spread > max_spread {
164 max_spread = spread;
165 split_dim = d;
166 }
167 }
168
169 split_dim
170 }
171
172 fn partition_indices(
173 &self,
174 points: &[Vec<f32>],
175 indices: &[usize],
176 dim: usize,
177 ) -> (Vec<usize>, Vec<usize>) {
178 let mut values: Vec<(f32, usize)> =
179 indices.iter().map(|&idx| (points[idx][dim], idx)).collect();
180
181 values.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(Ordering::Equal));
182
183 let mid = values.len() / 2;
184 let left_indices: Vec<usize> = values[..mid].iter().map(|(_, idx)| *idx).collect();
185 let right_indices: Vec<usize> = values[mid..].iter().map(|(_, idx)| *idx).collect();
186
187 (left_indices, right_indices)
188 }
189
190 fn compute_centroid_of_centers(&self, centers: &[Vec<f32>]) -> Vec<f32> {
191 let dim = centers[0].len();
192 let mut centroid = vec![0.0; dim];
193
194 for center in centers {
195 for (i, &val) in center.iter().enumerate() {
196 centroid[i] += val;
197 }
198 }
199
200 let n = centers.len() as f32;
201 for val in &mut centroid {
202 *val /= n;
203 }
204
205 centroid
206 }
207
208 pub fn search(&self, query: &[f32], k: usize) -> Vec<(usize, f32)> {
210 if self.root.is_none() {
211 return Vec::new();
212 }
213
214 let mut heap: BinaryHeap<SearchResult> = BinaryHeap::new();
215 let mut stack: Vec<&BallNode> = vec![self
216 .root
217 .as_ref()
218 .expect("tree should have root after build")];
219
220 while let Some(node) = stack.pop() {
221 let dist_to_center = self.config.distance_metric.distance(query, &node.center);
223
224 if heap.len() >= k {
225 let worst_dist = heap.peek().expect("heap should have k elements").distance;
226 if dist_to_center - node.radius > worst_dist {
227 continue; }
229 }
230
231 if node.indices.is_empty() {
232 if let (Some(left), Some(right)) = (&node.left, &node.right) {
234 let left_dist = self.config.distance_metric.distance(query, &left.center);
235 let right_dist = self.config.distance_metric.distance(query, &right.center);
236
237 if left_dist < right_dist {
239 stack.push(right);
240 stack.push(left);
241 } else {
242 stack.push(left);
243 stack.push(right);
244 }
245 }
246 } else {
247 for &idx in &node.indices {
249 let point = &self.data[idx].1.as_f32();
250 let dist = self.config.distance_metric.distance(query, point);
251
252 if heap.len() < k {
253 heap.push(SearchResult {
254 index: idx,
255 distance: dist,
256 });
257 } else if dist < heap.peek().expect("heap should have k elements").distance {
258 heap.pop();
259 heap.push(SearchResult {
260 index: idx,
261 distance: dist,
262 });
263 }
264 }
265 }
266 }
267
268 let mut results: Vec<(usize, f32)> =
269 heap.into_iter().map(|r| (r.index, r.distance)).collect();
270
271 results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
272 results
273 }
274}