ball_tree/
lib.rs

1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3
4/// A `Point` is something that exists in some sort of metric space, and
5/// can thus calculate its distance to another `Point`, and can be moved
6/// a certain distance towards another `Point`.
7pub trait Point: Sized + PartialEq {
8    /// Distances should be positive, finite `f64`s. It is undefined behavior to
9    /// return a negative, infinite, or `NaN` result.
10    ///
11    /// Distance should satisfy the triangle inequality. That is, `a.distance(c)`
12    /// must be less or equal to than `a.distance(b) + b.distance(c)`.
13    fn distance(&self, other: &Self) -> f64;
14
15    /// If `d` is `0`, a point equal to the `self` should be returned. If `d` is equal
16    /// to `self.distance(other)`, a point equal to `other` should be returned.
17    /// Intermediate distances should be linearly interpolated between the two points,
18    /// so if `d` is equal to `self.distance(other) / 2.0`, the midpoint should be
19    /// returned.
20    /// It is undefined behavior to use a distance that is negative, `NaN`, or greater
21    /// than `self.distance(other)`.
22    fn move_towards(&self, other: &Self, d: f64) -> Self;
23
24    /// The midpoint between two points is the point that is equidistant from both.
25    /// It should be equivalent to `a.move_towards(b, a.distance(b) / 2.0)`, but
26    /// may be implemented more efficiently for some types.
27    fn midpoint(a: &Self, b: &Self) -> Self {
28        let d = a.distance(b);
29        a.move_towards(b, d / 2.0)
30    }
31}
32
33/// Implement `Point` in the normal `D` dimensional Euclidean way for all arrays of floats. For example, a 2D point
34/// would be a `[f64; 2]`.
35impl<const D: usize> Point for [f64; D] {
36    fn distance(&self, other: &Self) -> f64 {
37        self.iter()
38            .zip(other)
39            .map(|(a, b)| (*a - *b).powi(2))
40            .sum::<f64>()
41            .sqrt()
42    }
43
44    fn move_towards(&self, other: &Self, d: f64) -> Self {
45        let mut result = self.clone();
46
47        let distance = self.distance(other);
48
49        // Don't want to get a NaN in the division below
50        if distance == 0.0 {
51            return result;
52        }
53
54        let scale = d / self.distance(other);
55
56        for i in 0..D {
57            result[i] += scale * (other[i] - self[i]);
58        }
59
60        result
61    }
62
63    fn midpoint(a: &Self, b: &Self) -> Self {
64        let mut result = [0.0; D];
65        for i in 0..D {
66            result[i] = (a[i] + b[i]) / 2.0;
67        }
68        result
69    }
70}
71
72/// Implement `Point` in the normal `D` dimensional Euclidean way for all arrays of floats. For example, a 2D point
73/// would be a `[f32; 2]`.
74impl<const D: usize> Point for [f32; D] {
75    fn distance(&self, other: &Self) -> f64 {
76        self.iter()
77            .zip(other)
78            .map(|(a, b)| (*a - *b).powi(2))
79            .sum::<f32>()
80            .sqrt() as f64
81    }
82
83    fn move_towards(&self, other: &Self, d: f64) -> Self {
84        let mut result = self.clone();
85
86        let distance = self.distance(other);
87
88        // Don't want to get a NaN in the division below
89        if distance == 0.0 {
90            return result;
91        }
92
93        let scale = d / self.distance(other);
94        let scale = scale as f32;
95
96        for i in 0..D {
97            result[i] += scale * (other[i] - self[i]);
98        }
99
100        result
101    }
102
103    fn midpoint(a: &Self, b: &Self) -> Self {
104        let mut result = [0.0; D];
105        for i in 0..D {
106            result[i] = (a[i] + b[i]) / 2.0;
107        }
108        result
109    }
110}
111
112// A little helper to allow us to use comparative functions on `f64`s by asserting that
113// `NaN` isn't present.
114#[derive(Debug, Clone, PartialEq, PartialOrd)]
115struct OrdF64(f64);
116impl OrdF64 {
117    fn new(x: f64) -> Self {
118        assert!(!x.is_nan());
119        OrdF64(x)
120    }
121}
122impl Eq for OrdF64 {}
123impl Ord for OrdF64 {
124    fn cmp(&self, other: &Self) -> Ordering {
125        self.partial_cmp(other).unwrap()
126    }
127}
128
129#[derive(Debug, Copy, Clone, PartialEq)]
130struct Sphere<C> {
131    center: C,
132    radius: f64,
133}
134
135impl<C: Point> Sphere<C> {
136    fn nearest_distance(&self, p: &C) -> f64 {
137        let d = self.center.distance(p) - self.radius;
138        d.max(0.0)
139    }
140
141    fn farthest_distance(&self, p: &C) -> f64 {
142        self.center.distance(p) + self.radius
143    }
144}
145
146// Implementation of the "bouncing bubble" algorithm which essentially works like this:
147// * Pick a point `a` that is farthest from `points[0]`
148// * Pick a point `b` that is farthest from `a`
149// * Use these two points to create an initial sphere centered at their midpoint and with
150//   enough radius to encompass them
151// * While there is still a point outside of this sphere, move the sphere towards that
152//   point just enough to encompass that point, and grow the sphere radius by 1%
153//
154// This process will produce a non-optimal, but relatively snug fitting bounding sphere.
155
156fn bounding_sphere<P: Point>(points: &[P]) -> Sphere<P> {
157    assert!(points.len() >= 2);
158
159    let a = &points
160        .iter()
161        .max_by_key(|a| OrdF64::new(points[0].distance(a)))
162        .unwrap();
163    let b = &points
164        .iter()
165        .max_by_key(|b| OrdF64::new(a.distance(b)))
166        .unwrap();
167
168    let mut center: P = P::midpoint(a, b);
169    let mut radius = center.distance(b).max(std::f64::EPSILON);
170
171    loop {
172        match points.iter().filter(|p| center.distance(p) > radius).next() {
173            None => break Sphere { center, radius },
174            Some(p) => {
175                let c_to_p = center.distance(&p);
176                let d = c_to_p - radius;
177                center = center.move_towards(p, d);
178                radius = radius * 1.01;
179            }
180        }
181    }
182}
183
184// Produce a partition of the given points with the following process:
185// * Pick a point `a` that is farthest from `points[0]`
186// * Pick a point `b` that is farthest from `a`
187// * Partition the points into two groups: those closest to `a` and those closest to `b`
188//
189// This doesn't necessarily form the best partition, since `a` and `b` are not guaranteed
190// to be the most distant pair of points, but it's usually sufficient.
191fn partition<P: Point, V>(
192    mut points: Vec<P>,
193    mut values: Vec<V>,
194) -> ((Vec<P>, Vec<V>), (Vec<P>, Vec<V>)) {
195    assert!(points.len() >= 2);
196    assert_eq!(points.len(), values.len());
197
198    let a_i = points
199        .iter()
200        .enumerate()
201        .max_by_key(|(_, a)| OrdF64::new(points[0].distance(a)))
202        .unwrap()
203        .0;
204
205    let b_i = points
206        .iter()
207        .enumerate()
208        .max_by_key(|(_, b)| OrdF64::new(points[a_i].distance(b)))
209        .unwrap()
210        .0;
211
212    let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i));
213
214    let (mut aps, mut avs) = (vec![points.swap_remove(a_i)], vec![values.swap_remove(a_i)]);
215    let (mut bps, mut bvs) = (vec![points.swap_remove(b_i)], vec![values.swap_remove(b_i)]);
216
217    for (p, v) in points.into_iter().zip(values) {
218        if aps[0].distance(&p) < bps[0].distance(&p) {
219            aps.push(p);
220            avs.push(v);
221        } else {
222            bps.push(p);
223            bvs.push(v);
224        }
225    }
226
227    ((aps, avs), (bps, bvs))
228}
229
230#[derive(Debug, Clone)]
231enum BallTreeInner<P, V> {
232    Empty,
233    Leaf(P, Vec<V>),
234    // The sphere is a bounding sphere that encompasses this node (both children)
235    Branch {
236        sphere: Sphere<P>,
237        a: Box<BallTreeInner<P, V>>,
238        b: Box<BallTreeInner<P, V>>,
239        count: usize,
240    },
241}
242
243impl<P: Point, V> Default for BallTreeInner<P, V> {
244    fn default() -> Self {
245        BallTreeInner::Empty
246    }
247}
248
249impl<P: Point, V> BallTreeInner<P, V> {
250    fn new(mut points: Vec<P>, values: Vec<V>) -> Self {
251        assert_eq!(
252            points.len(),
253            values.len(),
254            "Given two vectors of differing lengths. points: {}, values: {}",
255            points.len(),
256            values.len()
257        );
258
259        if points.is_empty() {
260            BallTreeInner::Empty
261        } else if points.iter().all(|p| p == &points[0]) {
262            BallTreeInner::Leaf(points.pop().unwrap(), values)
263        } else {
264            let count = points.len();
265            let sphere = bounding_sphere(&points);
266            let ((aps, avs), (bps, bvs)) = partition(points, values);
267            let (a_tree, b_tree) = (BallTreeInner::new(aps, avs), BallTreeInner::new(bps, bvs));
268            BallTreeInner::Branch { sphere, a: Box::new(a_tree), b: Box::new(b_tree), count }
269        }
270    }
271
272    fn nearest_distance(&self, p: &P) -> f64 {
273        match self {
274            BallTreeInner::Empty => std::f64::INFINITY,
275            // The distance to a leaf is the distance to the single point inside of it
276            BallTreeInner::Leaf(p0, _) => p.distance(p0),
277            // The distance to a branch is the distance to the edge of the bounding sphere
278            BallTreeInner::Branch { sphere, .. } => sphere.nearest_distance(p),
279        }
280    }
281}
282
283// We need a little wrapper to hold our priority queue elements for two reasons:
284// * Rust's BinaryHeap is a max-heap, and we need a min-heap, so we invert the
285//   ordering
286// * We only want to order based on the first element, so we need a custom
287//   implementation rather than deriving the order (which would require the value
288//   to be orderable which is not necessary).
289#[derive(Debug, Copy, Clone)]
290struct Item<T>(f64, T);
291impl<T> PartialEq for Item<T> {
292    fn eq(&self, other: &Self) -> bool {
293        self.0 == other.0
294    }
295}
296impl<T> Eq for Item<T> {}
297impl<T> PartialOrd for Item<T> {
298    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
299        self.0
300            .partial_cmp(&other.0)
301            .map(|ordering| ordering.reverse())
302    }
303}
304impl<T> Ord for Item<T> {
305    fn cmp(&self, other: &Self) -> Ordering {
306        self.partial_cmp(other).unwrap()
307    }
308}
309
310/// Iterator over the nearest neighbors.
311// Maintain a priority queue of the nodes that are closest to the provided `point`. If we
312// pop a leaf from the queue, that leaf is necessarily the next closest point. If we
313// pop a branch from the queue, add its children. The priority of a node is its
314// `distance` as defined above.
315#[derive(Debug)]
316pub struct Iter<'tree, 'query, P, V> {
317    point: &'query P,
318    balls: &'query mut BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
319    i: usize,
320    max_radius: f64,
321}
322
323impl<'tree, 'query, P: Point, V> Iterator for Iter<'tree, 'query, P, V> {
324    type Item = (&'tree P, f64, &'tree V);
325
326    fn next(&mut self) -> Option<Self::Item> {
327        while self.balls.len() > 0 {
328            // Peek in the leaf case, because we might need to visit this leaf multiple
329            // times (if it has multiple values).
330            if let Item(d, BallTreeInner::Leaf(p, vs)) = self.balls.peek().unwrap() {
331                if self.i < vs.len() && *d <= self.max_radius {
332                    self.i += 1;
333                    return Some((p, *d, &vs[self.i - 1]));
334                }
335            }
336            // Reset index for the next leaf we encounter
337            self.i = 0;
338            // Expand branch nodes
339            if let Item(_, BallTreeInner::Branch { a, b, .. }) = self.balls.pop().unwrap() {
340                let d_a = a.nearest_distance(self.point);
341                let d_b = b.nearest_distance(self.point);
342                if d_a <= self.max_radius {
343                    self.balls.push(Item(d_a, a));
344                }
345                if d_b <= self.max_radius {
346                    self.balls.push(Item(d_b, b));
347                }
348            }
349        }
350        None
351    }
352}
353
354/// A `BallTree` is a space-partitioning data-structure that allows for finding
355/// nearest neighbors in logarithmic time.
356///
357/// It does this by partitioning data into a series of nested bounding spheres
358/// ("balls" in the literature). Spheres are used because it is trivial to
359/// compute the distance between a point and a sphere (distance to the sphere's
360/// center minus thte radius). The key observation is that a potential neighbor
361/// is necessarily closer than all neighbors that are located inside of a
362/// bounding sphere that is farther than the aforementioned neighbor.
363///
364/// Graphically:
365/// ```text
366///
367///    A -
368///    |  ----         distance(A, B) = 4
369///    |      - B      distance(A, S) = 6
370///     |
371///      |
372///      |    S
373///        --------
374///      /        G \
375///     /   C        \
376///    |           D |
377///    |       F     |
378///     \ E         /
379///      \_________/
380///```
381///
382/// In the diagram, `A` is closer to `B` than to `S`, and because `S` bounds
383/// `C`, `D`, `E`, `F`, and `G`, it can be determined that `A` it is necessarily
384/// closer to `B` than the other points without even computing exact distances
385/// to them.
386///
387/// Ball trees are most commonly used as a form of predictive model where the
388/// points are features and each point is associated with a value or label. Thus,
389/// This implementation allows the user to associate a value with each point. If
390/// this functionality is unneeded, `()` can be used as a value.
391///
392/// This implementation returns the nearest neighbors, their distances, and their
393/// associated values. Returning the distances allows the user to perform some
394/// sort of weighted interpolation of the neighbors for predictive purposes.
395#[derive(Debug, Clone)]
396pub struct BallTree<P, V>(BallTreeInner<P, V>);
397
398impl<P: Point, V> Default for BallTree<P, V> {
399    fn default() -> Self {
400        BallTree(BallTreeInner::default())
401    }
402}
403
404impl<P: Point, V> BallTree<P, V> {
405    /// Construct this `BallTree`. Construction is somewhat expensive, so `BallTree`s
406    /// are best constructed once and then used repeatedly.
407    ///
408    /// `panic` if `points.len() != values.len()`
409    pub fn new(points: Vec<P>, values: Vec<V>) -> Self {
410        BallTree(BallTreeInner::new(points, values))
411    }
412
413    /// Query this `BallTree`. The `Query` object provides a nearest-neighbor API and internally re-uses memory to avoid
414    /// allocations on repeated queries.
415    pub fn query(&self) -> Query<'_, P, V> {
416        Query {
417            ball_tree: self,
418            balls: Default::default(),
419        }
420    }
421}
422
423/// A context for repeated nearest-neighbor queries that internally re-uses memory across queries.
424#[derive(Debug, Clone)]
425pub struct Query<'tree, P, V> {
426    ball_tree: &'tree BallTree<P, V>,
427    balls: BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
428}
429
430impl<'tree, P: Point, V> Query<'tree, P, V> {
431    /// Given a `point`, return an `Iterator` that yields neighbors from closest to
432    /// farthest. To get the K nearest neighbors, simply `take` K from the iterator.
433    ///
434    /// The neighbor, its distance, and associated value are returned.
435    pub fn nn<'query>(
436        &'query mut self,
437        point: &'query P,
438    ) -> Iter<'tree, 'query, P, V> {
439        self.nn_within(point, f64::INFINITY)
440    }
441
442    /// The same as `nn` but only consider neighbors whose distance is `<= max_radius`.
443    pub fn nn_within<'query>(
444        &'query mut self,
445        point: &'query P,
446        max_radius: f64,
447    ) -> Iter<'tree, 'query, P, V> {
448        let balls = &mut self.balls;
449        balls.clear();
450        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
451        Iter {
452            point,
453            balls,
454            i: 0,
455            max_radius,
456        }
457    }
458
459    /// What is the minimum radius that encompasses `k` neighbors of `point`?
460    pub fn min_radius<'query>(&'query mut self, point: &'query P, k: usize) -> f64 {
461        let mut total_count = 0;
462        let balls = &mut self.balls;
463        balls.clear(); 
464        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
465    
466        while let Some(Item(distance, node)) = balls.pop() {
467            match node {
468                BallTreeInner::Empty => {}
469                BallTreeInner::Leaf(_, vs) => {
470                    total_count += vs.len();
471                    if total_count >= k {
472                        return distance;
473                    }
474                }
475                BallTreeInner::Branch { sphere, a, b, count } => {
476                    let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY);
477                    if total_count + count < k && sphere.farthest_distance(point) < next_distance {
478                        total_count += count;
479                    } else {
480                        balls.push(Item(a.nearest_distance(point), &a));
481                        balls.push(Item(b.nearest_distance(point), &b));
482                    }
483                }
484            }
485        }
486    
487        f64::INFINITY
488    }
489
490    /// How many neighbors are `<= max_radius` of `point`?
491    pub fn count<'query>(&'query mut self, point: &'query P, max_radius: f64) -> usize {
492        let mut total = 0;
493        let balls = &mut self.balls;
494        balls.clear();
495        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
496
497        while let Some(Item(nearest_distance, node)) = balls.pop() {
498            if nearest_distance > max_radius {
499                break;
500            }
501            match node {
502                BallTreeInner::Empty => {}
503                BallTreeInner::Leaf(_, vs) => {
504                    total += vs.len();
505                }
506                BallTreeInner::Branch { a, b, count, sphere} => {
507                    let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY).min(max_radius);
508                    if sphere.farthest_distance(point) < next_distance {
509                        total += count;
510                    } else {
511                        balls.push(Item(a.nearest_distance(point), &a));
512                        balls.push(Item(b.nearest_distance(point), &b));
513                    }
514                }
515            }
516        }
517    
518        total
519    }
520
521    /// Return the size in bytes of the memory this `Query` is keeping internally to avoid allocation.
522    pub fn allocated_size(&self) -> usize {
523        self.balls.capacity() * std::mem::size_of::<Item<&'tree BallTreeInner<P, V>>>()
524    }
525
526    /// The `Query` object re-uses memory internally to avoid allocation. This method deallocates that memory.
527    pub fn deallocate_memory(&mut self) {
528        self.balls.clear();
529        self.balls.shrink_to_fit();
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use rand::{Rng, SeedableRng};
537    use rand_chacha::ChaChaRng;
538    use std::collections::HashSet;
539
540    #[test]
541    fn test_3d_points() {
542        let mut rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96);
543
544        macro_rules! random_small_f64 {
545            () => {
546                rng.gen_range(-100.0 ..= 100.0)
547            };
548        }
549
550        macro_rules! random_3d_point {
551            () => {
552                [
553                    random_small_f64!(),
554                    random_small_f64!(),
555                    random_small_f64!(),
556                ]
557            };
558        }
559
560        for i in 0..1000 {
561            let point_count: usize = if i < 100 {
562                rng.gen_range(1..=3)
563            } else if i < 500 {
564                rng.gen_range(1..=10)
565            } else {
566                rng.gen_range(1..=100)
567            };
568
569            let mut points = vec![];
570            let mut values = vec![];
571
572            for _ in 0..point_count {
573                let point = random_3d_point!();
574                let value = rng.gen::<u64>();
575                points.push(point);
576                values.push(value);
577            }
578
579            let tree = BallTree::new(points.clone(), values.clone());
580
581            let mut query = tree.query();
582
583            for _ in 0..100 {
584                let point = random_3d_point!();
585                let max_radius = rng.gen_range(0.0 ..= 110.0);
586
587                let expected_values = points
588                    .iter()
589                    .zip(&values)
590                    .filter(|(p, _)| p.distance(&point) <= max_radius)
591                    .map(|(_, v)| v)
592                    .cloned()
593                    .collect::<HashSet<_>>();
594
595                let mut found_values = HashSet::new();
596
597                let mut previous_d = 0.0;
598                for (p, d, v) in query.nn_within(&point, max_radius) {
599                    assert_eq!(point.distance(p), d);
600                    assert!(d >= previous_d);
601                    assert!(d <= max_radius);
602                    previous_d = d;
603                    found_values.insert(*v);
604                }
605
606                assert_eq!(expected_values, found_values);
607
608                assert_eq!(found_values.len(), query.count(&point, max_radius));
609
610                let radius = query.min_radius(&point, expected_values.len());
611
612                let should_be_fewer = query.count(&point, radius * 0.99);
613
614                assert!(expected_values.is_empty() || should_be_fewer < expected_values.len(), "{} < {}", should_be_fewer, expected_values.len());
615            }
616
617            assert!(query.allocated_size() > 0);
618            // 2 (branching factor) * 8 (pointer size) * point count rounded up (max of 4 due to minimum vec sizing)
619            assert!(query.allocated_size() <= 2 * 8 * point_count.next_power_of_two().max(4));
620
621            query.deallocate_memory();
622            assert_eq!(query.allocated_size(), 0);
623        }
624    }
625
626    #[test]
627    fn test_point_array_impls() {
628        assert_eq!([5.0].distance(&[7.0]), 2.0);
629        assert_eq!([5.0].move_towards(&[3.0], 1.0), [4.0]);
630
631        assert_eq!([5.0, 3.0].distance(&[7.0, 5.0]), 2.0 * 2f64.sqrt());
632        assert_eq!(
633            [5.0, 3.0].move_towards(&[3.0, 1.0], 2f64.sqrt()),
634            [4.0, 2.0]
635        );
636
637        assert_eq!([0.0, 0.0, 0.0, 0.0].distance(&[2.0, 2.0, 2.0, 2.0]), 4.0);
638        assert_eq!(
639            [0.0, 0.0, 0.0, 0.0].move_towards(&[2.0, 2.0, 2.0, 2.0], 8.0),
640            [4.0, 4.0, 4.0, 4.0]
641        );
642    }
643}