ball_tree/
lib.rs

1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3
4/// A `Point` is something that exists in some sort of metric space, and
5/// can thus calculate its distance to another `Point`, and can be moved
6/// a certain distance towards another `Point`.
7pub trait Point: Sized + PartialEq {
8    /// Distances should be positive, finite `f64`s. It is undefined behavior to
9    /// return a negative, infinite, or `NaN` result.
10    ///
11    /// Distance should satisfy the triangle inequality. That is, `a.distance(c)`
12    /// must be less or equal to than `a.distance(b) + b.distance(c)`.
13    fn distance(&self, other: &Self) -> f64;
14
15    /// If `d` is `0`, a point equal to the `self` should be returned. If `d` is equal
16    /// to `self.distance(other)`, a point equal to `other` should be returned.
17    /// Intermediate distances should be linearly interpolated between the two points,
18    /// so if `d` is equal to `self.distance(other) / 2.0`, the midpoint should be
19    /// returned.
20    /// It is undefined behavior to use a distance that is negative, `NaN`, or greater
21    /// than `self.distance(other)`.
22    fn move_towards(&self, other: &Self, d: f64) -> Self;
23
24    /// The midpoint between two points is the point that is equidistant from both.
25    /// It should be equivalent to `a.move_towards(b, a.distance(b) / 2.0)`, but
26    /// may be implemented more efficiently for some types.
27    fn midpoint(a: &Self, b: &Self) -> Self {
28        let d = a.distance(b);
29        a.move_towards(b, d / 2.0)
30    }
31}
32
33/// Implement `Point` in the normal `D` dimensional Euclidean way for all arrays of floats. For example, a 2D point
34/// would be a `[f64; 2]`.
35impl<const D: usize> Point for [f64; D] {
36    fn distance(&self, other: &Self) -> f64 {
37        self.iter()
38            .zip(other)
39            .map(|(a, b)| (*a - *b).powi(2))
40            .sum::<f64>()
41            .sqrt()
42    }
43
44    fn move_towards(&self, other: &Self, d: f64) -> Self {
45        let mut result = self.clone();
46
47        let distance = self.distance(other);
48
49        // Don't want to get a NaN in the division below
50        if distance == 0.0 {
51            return result;
52        }
53
54        let scale = d / self.distance(other);
55
56        for i in 0..D {
57            result[i] += scale * (other[i] - self[i]);
58        }
59
60        result
61    }
62
63    fn midpoint(a: &Self, b: &Self) -> Self {
64        let mut result = [0.0; D];
65        for i in 0..D {
66            result[i] = (a[i] + b[i]) / 2.0;
67        }
68        result
69    }
70}
71
72// A little helper to allow us to use comparative functions on `f64`s by asserting that
73// `NaN` isn't present.
74#[derive(Debug, Clone, PartialEq, PartialOrd)]
75struct OrdF64(f64);
76impl OrdF64 {
77    fn new(x: f64) -> Self {
78        assert!(!x.is_nan());
79        OrdF64(x)
80    }
81}
82impl Eq for OrdF64 {}
83impl Ord for OrdF64 {
84    fn cmp(&self, other: &Self) -> Ordering {
85        self.partial_cmp(other).unwrap()
86    }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq)]
90struct Sphere<C> {
91    center: C,
92    radius: f64,
93}
94
95impl<C: Point> Sphere<C> {
96    fn nearest_distance(&self, p: &C) -> f64 {
97        let d = self.center.distance(p) - self.radius;
98        d.max(0.0)
99    }
100
101    fn farthest_distance(&self, p: &C) -> f64 {
102        self.center.distance(p) + self.radius
103    }
104}
105
106// Implementation of the "bouncing bubble" algorithm which essentially works like this:
107// * Pick a point `a` that is farthest from `points[0]`
108// * Pick a point `b` that is farthest from `a`
109// * Use these two points to create an initial sphere centered at their midpoint and with
110//   enough radius to encompass them
111// * While there is still a point outside of this sphere, move the sphere towards that
112//   point just enough to encompass that point, and grow the sphere radius by 1%
113//
114// This process will produce a non-optimal, but relatively snug fitting bounding sphere.
115
116fn bounding_sphere<P: Point>(points: &[P]) -> Sphere<P> {
117    assert!(points.len() >= 2);
118
119    let a = &points
120        .iter()
121        .max_by_key(|a| OrdF64::new(points[0].distance(a)))
122        .unwrap();
123    let b = &points
124        .iter()
125        .max_by_key(|b| OrdF64::new(a.distance(b)))
126        .unwrap();
127
128    let mut center: P = P::midpoint(a, b);
129    let mut radius = center.distance(b).max(std::f64::EPSILON);
130
131    loop {
132        match points.iter().filter(|p| center.distance(p) > radius).next() {
133            None => break Sphere { center, radius },
134            Some(p) => {
135                let c_to_p = center.distance(&p);
136                let d = c_to_p - radius;
137                center = center.move_towards(p, d);
138                radius = radius * 1.01;
139            }
140        }
141    }
142}
143
144// Produce a partition of the given points with the following process:
145// * Pick a point `a` that is farthest from `points[0]`
146// * Pick a point `b` that is farthest from `a`
147// * Partition the points into two groups: those closest to `a` and those closest to `b`
148//
149// This doesn't necessarily form the best partition, since `a` and `b` are not guaranteed
150// to be the most distant pair of points, but it's usually sufficient.
151fn partition<P: Point, V>(
152    mut points: Vec<P>,
153    mut values: Vec<V>,
154) -> ((Vec<P>, Vec<V>), (Vec<P>, Vec<V>)) {
155    assert!(points.len() >= 2);
156    assert_eq!(points.len(), values.len());
157
158    let a_i = points
159        .iter()
160        .enumerate()
161        .max_by_key(|(_, a)| OrdF64::new(points[0].distance(a)))
162        .unwrap()
163        .0;
164
165    let b_i = points
166        .iter()
167        .enumerate()
168        .max_by_key(|(_, b)| OrdF64::new(points[a_i].distance(b)))
169        .unwrap()
170        .0;
171
172    let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i));
173
174    let (mut aps, mut avs) = (vec![points.swap_remove(a_i)], vec![values.swap_remove(a_i)]);
175    let (mut bps, mut bvs) = (vec![points.swap_remove(b_i)], vec![values.swap_remove(b_i)]);
176
177    for (p, v) in points.into_iter().zip(values) {
178        if aps[0].distance(&p) < bps[0].distance(&p) {
179            aps.push(p);
180            avs.push(v);
181        } else {
182            bps.push(p);
183            bvs.push(v);
184        }
185    }
186
187    ((aps, avs), (bps, bvs))
188}
189
190#[derive(Debug, Clone)]
191enum BallTreeInner<P, V> {
192    Empty,
193    Leaf(P, Vec<V>),
194    // The sphere is a bounding sphere that encompasses this node (both children)
195    Branch {
196        sphere: Sphere<P>,
197        a: Box<BallTreeInner<P, V>>,
198        b: Box<BallTreeInner<P, V>>,
199        count: usize,
200    },
201}
202
203impl<P: Point, V> Default for BallTreeInner<P, V> {
204    fn default() -> Self {
205        BallTreeInner::Empty
206    }
207}
208
209impl<P: Point, V> BallTreeInner<P, V> {
210    fn new(mut points: Vec<P>, values: Vec<V>) -> Self {
211        assert_eq!(
212            points.len(),
213            values.len(),
214            "Given two vectors of differing lengths. points: {}, values: {}",
215            points.len(),
216            values.len()
217        );
218
219        if points.is_empty() {
220            BallTreeInner::Empty
221        } else if points.iter().all(|p| p == &points[0]) {
222            BallTreeInner::Leaf(points.pop().unwrap(), values)
223        } else {
224            let count = points.len();
225            let sphere = bounding_sphere(&points);
226            let ((aps, avs), (bps, bvs)) = partition(points, values);
227            let (a_tree, b_tree) = (BallTreeInner::new(aps, avs), BallTreeInner::new(bps, bvs));
228            BallTreeInner::Branch { sphere, a: Box::new(a_tree), b: Box::new(b_tree), count }
229        }
230    }
231
232    fn nearest_distance(&self, p: &P) -> f64 {
233        match self {
234            BallTreeInner::Empty => std::f64::INFINITY,
235            // The distance to a leaf is the distance to the single point inside of it
236            BallTreeInner::Leaf(p0, _) => p.distance(p0),
237            // The distance to a branch is the distance to the edge of the bounding sphere
238            BallTreeInner::Branch { sphere, .. } => sphere.nearest_distance(p),
239        }
240    }
241}
242
243// We need a little wrapper to hold our priority queue elements for two reasons:
244// * Rust's BinaryHeap is a max-heap, and we need a min-heap, so we invert the
245//   ordering
246// * We only want to order based on the first element, so we need a custom
247//   implementation rather than deriving the order (which would require the value
248//   to be orderable which is not necessary).
249#[derive(Debug, Copy, Clone)]
250struct Item<T>(f64, T);
251impl<T> PartialEq for Item<T> {
252    fn eq(&self, other: &Self) -> bool {
253        self.0 == other.0
254    }
255}
256impl<T> Eq for Item<T> {}
257impl<T> PartialOrd for Item<T> {
258    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
259        self.0
260            .partial_cmp(&other.0)
261            .map(|ordering| ordering.reverse())
262    }
263}
264impl<T> Ord for Item<T> {
265    fn cmp(&self, other: &Self) -> Ordering {
266        self.partial_cmp(other).unwrap()
267    }
268}
269
270/// Iterator over the nearest neighbors.
271// Maintain a priority queue of the nodes that are closest to the provided `point`. If we
272// pop a leaf from the queue, that leaf is necessarily the next closest point. If we
273// pop a branch from the queue, add its children. The priority of a node is its
274// `distance` as defined above.
275#[derive(Debug)]
276pub struct Iter<'tree, 'query, P, V> {
277    point: &'query P,
278    balls: &'query mut BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
279    i: usize,
280    max_radius: f64,
281}
282
283impl<'tree, 'query, P: Point, V> Iterator for Iter<'tree, 'query, P, V> {
284    type Item = (&'tree P, f64, &'tree V);
285
286    fn next(&mut self) -> Option<Self::Item> {
287        while self.balls.len() > 0 {
288            // Peek in the leaf case, because we might need to visit this leaf multiple
289            // times (if it has multiple values).
290            if let Item(d, BallTreeInner::Leaf(p, vs)) = self.balls.peek().unwrap() {
291                if self.i < vs.len() && *d <= self.max_radius {
292                    self.i += 1;
293                    return Some((p, *d, &vs[self.i - 1]));
294                }
295            }
296            // Reset index for the next leaf we encounter
297            self.i = 0;
298            // Expand branch nodes
299            if let Item(_, BallTreeInner::Branch { a, b, .. }) = self.balls.pop().unwrap() {
300                let d_a = a.nearest_distance(self.point);
301                let d_b = b.nearest_distance(self.point);
302                if d_a <= self.max_radius {
303                    self.balls.push(Item(d_a, a));
304                }
305                if d_b <= self.max_radius {
306                    self.balls.push(Item(d_b, b));
307                }
308            }
309        }
310        None
311    }
312}
313
314/// A `BallTree` is a space-partitioning data-structure that allows for finding
315/// nearest neighbors in logarithmic time.
316///
317/// It does this by partitioning data into a series of nested bounding spheres
318/// ("balls" in the literature). Spheres are used because it is trivial to
319/// compute the distance between a point and a sphere (distance to the sphere's
320/// center minus thte radius). The key observation is that a potential neighbor
321/// is necessarily closer than all neighbors that are located inside of a
322/// bounding sphere that is farther than the aforementioned neighbor.
323///
324/// Graphically:
325/// ```text
326///
327///    A -
328///    |  ----         distance(A, B) = 4
329///    |      - B      distance(A, S) = 6
330///     |
331///      |
332///      |    S
333///        --------
334///      /        G \
335///     /   C        \
336///    |           D |
337///    |       F     |
338///     \ E         /
339///      \_________/
340///```
341///
342/// In the diagram, `A` is closer to `B` than to `S`, and because `S` bounds
343/// `C`, `D`, `E`, `F`, and `G`, it can be determined that `A` it is necessarily
344/// closer to `B` than the other points without even computing exact distances
345/// to them.
346///
347/// Ball trees are most commonly used as a form of predictive model where the
348/// points are features and each point is associated with a value or label. Thus,
349/// This implementation allows the user to associate a value with each point. If
350/// this functionality is unneeded, `()` can be used as a value.
351///
352/// This implementation returns the nearest neighbors, their distances, and their
353/// associated values. Returning the distances allows the user to perform some
354/// sort of weighted interpolation of the neighbors for predictive purposes.
355#[derive(Debug, Clone)]
356pub struct BallTree<P, V>(BallTreeInner<P, V>);
357
358impl<P: Point, V> Default for BallTree<P, V> {
359    fn default() -> Self {
360        BallTree(BallTreeInner::default())
361    }
362}
363
364impl<P: Point, V> BallTree<P, V> {
365    /// Construct this `BallTree`. Construction is somewhat expensive, so `BallTree`s
366    /// are best constructed once and then used repeatedly.
367    ///
368    /// `panic` if `points.len() != values.len()`
369    pub fn new(points: Vec<P>, values: Vec<V>) -> Self {
370        BallTree(BallTreeInner::new(points, values))
371    }
372
373    /// Query this `BallTree`. The `Query` object provides a nearest-neighbor API and internally re-uses memory to avoid
374    /// allocations on repeated queries.
375    pub fn query(&self) -> Query<P, V> {
376        Query {
377            ball_tree: self,
378            balls: Default::default(),
379        }
380    }
381}
382
383/// A context for repeated nearest-neighbor queries that internally re-uses memory across queries.
384#[derive(Debug, Clone)]
385pub struct Query<'tree, P, V> {
386    ball_tree: &'tree BallTree<P, V>,
387    balls: BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
388}
389
390impl<'tree, P: Point, V> Query<'tree, P, V> {
391    /// Given a `point`, return an `Iterator` that yields neighbors from closest to
392    /// farthest. To get the K nearest neighbors, simply `take` K from the iterator.
393    ///
394    /// The neighbor, its distance, and associated value are returned.
395    pub fn nn<'query>(
396        &'query mut self,
397        point: &'query P,
398    ) -> Iter<'tree, 'query, P, V> {
399        self.nn_within(point, f64::INFINITY)
400    }
401
402    /// The same as `nn` but only consider neighbors whose distance is `<= max_radius`.
403    pub fn nn_within<'query>(
404        &'query mut self,
405        point: &'query P,
406        max_radius: f64,
407    ) -> Iter<'tree, 'query, P, V> {
408        let balls = &mut self.balls;
409        balls.clear();
410        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
411        Iter {
412            point,
413            balls,
414            i: 0,
415            max_radius,
416        }
417    }
418
419    /// What is the minimum radius that encompasses `k` neighbors of `point`?
420    pub fn min_radius<'query>(&'query mut self, point: &'query P, k: usize) -> f64 {
421        let mut total_count = 0;
422        let balls = &mut self.balls;
423        balls.clear(); 
424        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
425    
426        while let Some(Item(distance, node)) = balls.pop() {
427            match node {
428                BallTreeInner::Empty => {}
429                BallTreeInner::Leaf(_, vs) => {
430                    total_count += vs.len();
431                    if total_count >= k {
432                        return distance;
433                    }
434                }
435                BallTreeInner::Branch { sphere, a, b, count } => {
436                    let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY);
437                    if total_count + count < k && sphere.farthest_distance(point) < next_distance {
438                        total_count += count;
439                    } else {
440                        balls.push(Item(a.nearest_distance(point), &a));
441                        balls.push(Item(b.nearest_distance(point), &b));
442                    }
443                }
444            }
445        }
446    
447        f64::INFINITY
448    }
449
450    /// How many neighbors are `<= max_radius` of `point`?
451    pub fn count<'query>(&'query mut self, point: &'query P, max_radius: f64) -> usize {
452        let mut total = 0;
453        let balls = &mut self.balls;
454        balls.clear();
455        balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
456
457        while let Some(Item(nearest_distance, node)) = balls.pop() {
458            if nearest_distance > max_radius {
459                break;
460            }
461            match node {
462                BallTreeInner::Empty => {}
463                BallTreeInner::Leaf(_, vs) => {
464                    total += vs.len();
465                }
466                BallTreeInner::Branch { a, b, count, sphere} => {
467                    let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY).min(max_radius);
468                    if sphere.farthest_distance(point) < next_distance {
469                        total += count;
470                    } else {
471                        balls.push(Item(a.nearest_distance(point), &a));
472                        balls.push(Item(b.nearest_distance(point), &b));
473                    }
474                }
475            }
476        }
477    
478        total
479    }
480
481    /// Return the size in bytes of the memory this `Query` is keeping internally to avoid allocation.
482    pub fn allocated_size(&self) -> usize {
483        self.balls.capacity() * std::mem::size_of::<Item<&'tree BallTreeInner<P, V>>>()
484    }
485
486    /// The `Query` object re-uses memory internally to avoid allocation. This method deallocates that memory.
487    pub fn deallocate_memory(&mut self) {
488        self.balls.clear();
489        self.balls.shrink_to_fit();
490    }
491}
492
493#[cfg(test)]
494mod tests {
495    use super::*;
496    use rand::{Rng, SeedableRng};
497    use rand_chacha::ChaChaRng;
498    use std::collections::HashSet;
499
500    #[test]
501    fn test_3d_points() {
502        let mut rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96);
503
504        macro_rules! random_small_f64 {
505            () => {
506                rng.gen_range(-100.0 ..= 100.0)
507            };
508        }
509
510        macro_rules! random_3d_point {
511            () => {
512                [
513                    random_small_f64!(),
514                    random_small_f64!(),
515                    random_small_f64!(),
516                ]
517            };
518        }
519
520        for i in 0..1000 {
521            let point_count: usize = if i < 100 {
522                rng.gen_range(1..=3)
523            } else if i < 500 {
524                rng.gen_range(1..=10)
525            } else {
526                rng.gen_range(1..=100)
527            };
528
529            let mut points = vec![];
530            let mut values = vec![];
531
532            for _ in 0..point_count {
533                let point = random_3d_point!();
534                let value = rng.gen::<u64>();
535                points.push(point);
536                values.push(value);
537            }
538
539            let tree = BallTree::new(points.clone(), values.clone());
540
541            let mut query = tree.query();
542
543            for _ in 0..100 {
544                let point = random_3d_point!();
545                let max_radius = rng.gen_range(0.0 ..= 110.0);
546
547                let expected_values = points
548                    .iter()
549                    .zip(&values)
550                    .filter(|(p, _)| p.distance(&point) <= max_radius)
551                    .map(|(_, v)| v)
552                    .cloned()
553                    .collect::<HashSet<_>>();
554
555                let mut found_values = HashSet::new();
556
557                let mut previous_d = 0.0;
558                for (p, d, v) in query.nn_within(&point, max_radius) {
559                    assert_eq!(point.distance(p), d);
560                    assert!(d >= previous_d);
561                    assert!(d <= max_radius);
562                    previous_d = d;
563                    found_values.insert(*v);
564                }
565
566                assert_eq!(expected_values, found_values);
567
568                assert_eq!(found_values.len(), query.count(&point, max_radius));
569
570                let radius = query.min_radius(&point, expected_values.len());
571
572                let should_be_fewer = query.count(&point, radius * 0.99);
573
574                assert!(expected_values.is_empty() || should_be_fewer < expected_values.len(), "{} < {}", should_be_fewer, expected_values.len());
575            }
576
577            assert!(query.allocated_size() > 0);
578            // 2 (branching factor) * 8 (pointer size) * point count rounded up (max of 4 due to minimum vec sizing)
579            assert!(query.allocated_size() <= 2 * 8 * point_count.next_power_of_two().max(4));
580
581            query.deallocate_memory();
582            assert_eq!(query.allocated_size(), 0);
583        }
584    }
585
586    #[test]
587    fn test_point_array_impls() {
588        assert_eq!([5.0].distance(&[7.0]), 2.0);
589        assert_eq!([5.0].move_towards(&[3.0], 1.0), [4.0]);
590
591        assert_eq!([5.0, 3.0].distance(&[7.0, 5.0]), 2.0 * 2f64.sqrt());
592        assert_eq!(
593            [5.0, 3.0].move_towards(&[3.0, 1.0], 2f64.sqrt()),
594            [4.0, 2.0]
595        );
596
597        assert_eq!([0.0, 0.0, 0.0, 0.0].distance(&[2.0, 2.0, 2.0, 2.0]), 4.0);
598        assert_eq!(
599            [0.0, 0.0, 0.0, 0.0].move_towards(&[2.0, 2.0, 2.0, 2.0], 8.0),
600            [4.0, 4.0, 4.0, 4.0]
601        );
602    }
603}