1#![allow(clippy::type_complexity)]
2use std::{cmp::Reverse, collections::BinaryHeap};
3
4use linfa::Float;
5use ndarray::{Array1, ArrayBase, Data, Ix2};
6use noisy_float::{checkers::FiniteChecker, NoisyFloat};
7#[cfg(feature = "serde")]
8use serde_crate::{Deserialize, Serialize};
9
10use crate::{
11 distance::Distance,
12 heap_elem::{MaxHeapElem, MinHeapElem},
13 BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex, NnError, Point,
14};
15
16fn partition<F: Float>(
18 mut points: Vec<(Point<F>, usize)>,
19) -> (Vec<(Point<F>, usize)>, Point<F>, Vec<(Point<F>, usize)>) {
20 debug_assert!(points.len() >= 2);
21
22 let max_spread_dim = (0..points[0].0.len())
25 .map(|dim| {
26 let (max, min) = points
28 .iter()
29 .map(|p| p.0[dim])
30 .fold((F::neg_infinity(), F::infinity()), |(a, b), c| {
31 (F::max(a, c), F::min(b, c))
32 });
33
34 (dim, NoisyFloat::<_, FiniteChecker>::new(max - min))
35 })
36 .max_by_key(|&(_, range)| range)
37 .expect("vec has no dimensions")
38 .0;
39
40 let mid = points.len() / 2;
41 let median = order_stat::kth_by(&mut points, mid, |p1, p2| {
43 p1.0[max_spread_dim]
44 .partial_cmp(&p2.0[max_spread_dim])
45 .expect("NaN in data")
46 })
47 .0
48 .reborrow();
49
50 let (mut left, mut right): (Vec<_>, Vec<_>) = points
51 .into_iter()
52 .partition(|pt| pt.0[max_spread_dim] < median[max_spread_dim]);
53 if left.is_empty() {
57 left.push(right.pop().unwrap());
58 }
59 (left, median, right)
60}
61
62fn calc_radius<'a, F: Float, D: Distance<F>>(
64 points: impl Iterator<Item = Point<'a, F>>,
65 center: Point<F>,
66 dist_fn: &D,
67) -> F {
68 let r_rad = points
69 .map(|pt| NoisyFloat::<_, FiniteChecker>::new(dist_fn.rdistance(pt, center)))
70 .max()
71 .unwrap()
72 .raw();
73 dist_fn.rdist_to_dist(r_rad)
74}
75
76#[derive(Debug, PartialEq, Clone)]
77enum BallTreeInner<'a, F: Float> {
78 Leaf {
80 center: Array1<F>,
81 radius: F,
82 points: Vec<(Point<'a, F>, usize)>,
83 },
84 Branch {
86 center: Point<'a, F>,
87 radius: F,
88 left: Box<BallTreeInner<'a, F>>,
89 right: Box<BallTreeInner<'a, F>>,
90 },
91}
92
93impl<'a, F: Float> BallTreeInner<'a, F> {
94 fn new<D: Distance<F>>(
95 points: Vec<(Point<'a, F>, usize)>,
96 leaf_size: usize,
97 dist_fn: &D,
98 ) -> Self {
99 if points.len() <= leaf_size {
100 if let Some(dim) = points.first().map(|p| p.0.len()) {
102 let center = {
105 let mut c = Array1::zeros(dim);
106 points.iter().for_each(|p| c += &p.0);
107 c / F::from(points.len()).unwrap()
108 };
109 let radius = calc_radius(
110 points.iter().map(|p| p.0.reborrow()),
111 center.view(),
112 dist_fn,
113 );
114 BallTreeInner::Leaf {
115 center,
116 radius,
117 points,
118 }
119 } else {
120 BallTreeInner::Leaf {
122 center: Array1::zeros(0),
123 points,
124 radius: F::zero(),
125 }
126 }
127 } else {
128 let (aps, center, bps) = partition(points);
130 debug_assert!(!aps.is_empty() && !bps.is_empty());
131 let radius = calc_radius(
132 aps.iter().chain(bps.iter()).map(|p| p.0.reborrow()),
133 center,
134 dist_fn,
135 );
136 let a_tree = BallTreeInner::new(aps, leaf_size, dist_fn);
137 let b_tree = BallTreeInner::new(bps, leaf_size, dist_fn);
138 BallTreeInner::Branch {
139 center,
140 radius,
141 left: Box::new(a_tree),
142 right: Box::new(b_tree),
143 }
144 }
145 }
146
147 fn rdistance<D: Distance<F>>(&self, p: Point<F>, dist_fn: &D) -> F {
148 let (center, radius) = match self {
149 BallTreeInner::Leaf { center, radius, .. } => (center.view(), radius),
150 BallTreeInner::Branch { center, radius, .. } => (center.reborrow(), radius),
151 };
152
153 let border_dist = dist_fn.distance(p, center.reborrow()) - *radius;
157 dist_fn.dist_to_rdist(border_dist.max(F::zero()))
158 }
159}
160
161#[derive(Debug, Clone, PartialEq)]
163pub struct BallTreeIndex<'a, F: Float, D: Distance<F>> {
164 tree: BallTreeInner<'a, F>,
165 dist_fn: D,
166 dim: usize,
167 len: usize,
168}
169
170impl<'a, F: Float, D: Distance<F>> BallTreeIndex<'a, F, D> {
171 pub fn new<DT: Data<Elem = F>>(
173 batch: &'a ArrayBase<DT, Ix2>,
174 leaf_size: usize,
175 dist_fn: D,
176 ) -> Result<Self, BuildError> {
177 let dim = batch.ncols();
178 let len = batch.nrows();
179 if leaf_size == 0 {
180 Err(BuildError::EmptyLeaf)
181 } else if dim == 0 {
182 Err(BuildError::ZeroDimension)
183 } else {
184 let points: Vec<_> = batch
185 .rows()
186 .into_iter()
187 .enumerate()
188 .map(|(i, pt)| (pt, i))
189 .collect();
190 Ok(BallTreeIndex {
191 tree: BallTreeInner::new(points, leaf_size, &dist_fn),
192 dist_fn,
193 dim,
194 len,
195 })
196 }
197 }
198
199 fn nn_helper(
200 &self,
201 point: Point<'_, F>,
202 k: usize,
203 max_radius: F,
204 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
205 if self.dim != point.len() {
206 Err(NnError::WrongDimension)
207 } else if self.len == 0 {
208 Ok(Vec::new())
209 } else {
210 let mut out: BinaryHeap<MaxHeapElem<_, _>> = BinaryHeap::new();
211 let mut queue = BinaryHeap::new();
212 queue.push(MinHeapElem::new(
213 self.tree.rdistance(point, &self.dist_fn),
214 &self.tree,
215 ));
216
217 while let Some(MinHeapElem {
218 dist: Reverse(dist),
219 elem,
220 }) = queue.pop()
221 {
222 if dist >= max_radius || (out.len() == k && dist >= out.peek().unwrap().dist) {
223 break;
224 }
225
226 match elem {
227 BallTreeInner::Leaf { points, .. } => {
228 for p in points {
229 let dist = self.dist_fn.rdistance(point, p.0.reborrow());
230 if dist < max_radius
231 && (out.len() < k || out.peek().unwrap().dist > dist)
232 {
233 out.push(MaxHeapElem::new(dist, p));
234 if out.len() > k {
235 out.pop();
236 }
237 }
238 }
239 }
240 BallTreeInner::Branch { left, right, .. } => {
241 let dl = left.rdistance(point, &self.dist_fn);
242 let dr = right.rdistance(point, &self.dist_fn);
243
244 if dl <= max_radius {
245 queue.push(MinHeapElem::new(dl, left));
246 }
247 if dr <= max_radius {
248 queue.push(MinHeapElem::new(dr, right));
249 }
250 }
251 }
252 }
253 Ok(out
254 .into_sorted_vec()
255 .into_iter()
256 .map(|e| e.elem)
257 .map(|(pt, i)| (pt.reborrow(), *i))
258 .collect())
259 }
260 }
261}
262
263impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for BallTreeIndex<'_, F, D> {
264 fn k_nearest(
265 &self,
266 point: Point<'_, F>,
267 k: usize,
268 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
269 self.nn_helper(point, k, F::infinity())
270 }
271
272 fn within_range(
273 &self,
274 point: Point<'_, F>,
275 range: F,
276 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
277 let range = self.dist_fn.dist_to_rdist(range);
278 self.nn_helper(point, self.len, range)
279 }
280}
281
282#[derive(Default, Clone, Debug, PartialEq, Eq)]
290#[cfg_attr(
291 feature = "serde",
292 derive(Serialize, Deserialize),
293 serde(crate = "serde_crate")
294)]
295pub struct BallTree;
296
297impl BallTree {
298 pub fn new() -> Self {
300 Self
301 }
302}
303
304impl NearestNeighbour for BallTree {
305 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
306 &self,
307 batch: &'a ArrayBase<DT, Ix2>,
308 leaf_size: usize,
309 dist_fn: D,
310 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
311 BallTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
312 }
313}
314
315#[cfg(test)]
316mod test {
317 use approx::assert_abs_diff_eq;
318 use ndarray::{arr1, arr2, stack, Array1, Array2, Axis};
319
320 use crate::distance::L2Dist;
321
322 use super::*;
323
324 #[test]
325 fn autotraits() {
326 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
327 has_autotraits::<BallTree>();
328 has_autotraits::<BallTreeIndex<f64, L2Dist>>();
329 has_autotraits::<BallTreeInner<f64>>();
330 }
331
332 fn assert_partition(
333 input: Array2<f64>,
334 exp_left: Array2<f64>,
335 exp_med: Array1<f64>,
336 exp_right: Array2<f64>,
337 exp_rad: f64,
338 ) {
339 let vec: Vec<_> = input
340 .rows()
341 .into_iter()
342 .enumerate()
343 .map(|(i, p)| (p, i))
344 .collect();
345 let (l, mid, r) = partition(vec.clone());
346 let l: Vec<_> = l.into_iter().map(|(p, _)| p).collect();
347 let r: Vec<_> = r.into_iter().map(|(p, _)| p).collect();
348 assert_abs_diff_eq!(stack(Axis(0), &l).unwrap(), exp_left);
349 assert_abs_diff_eq!(mid.to_owned(), exp_med);
350 assert_abs_diff_eq!(stack(Axis(0), &r).unwrap(), exp_right);
351 assert_abs_diff_eq!(
352 calc_radius(vec.iter().map(|(p, _)| p.reborrow()), mid, &L2Dist),
353 exp_rad
354 );
355 }
356
357 #[test]
358 fn partition_test() {
359 assert_partition(
361 arr2(&[[0.0, 1.0], [2.0, 3.0]]),
362 arr2(&[[0.0, 1.0]]),
363 arr1(&[2.0, 3.0]),
364 arr2(&[[2.0, 3.0]]),
365 8.0f64.sqrt(),
366 );
367 assert_partition(
368 arr2(&[[2.0, 3.0], [0.0, 1.0]]),
369 arr2(&[[0.0, 1.0]]),
370 arr1(&[2.0, 3.0]),
371 arr2(&[[2.0, 3.0]]),
372 8.0f64.sqrt(),
373 );
374
375 assert_partition(
377 arr2(&[[0.3, 5.0], [4.5, 7.0], [8.1, 1.5]]),
378 arr2(&[[0.3, 5.0]]),
379 arr1(&[4.5, 7.0]),
380 arr2(&[[4.5, 7.0], [8.1, 1.5]]),
381 43.21f64.sqrt(),
382 );
383
384 assert_partition(
386 arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
387 arr2(&[[1.4, 4.3]]),
388 arr1(&[1.4, 4.3]),
389 arr2(&[[1.4, 4.3], [1.4, 4.3], [1.4, 4.3]]),
390 0.0,
391 );
392 }
393}