1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3
4pub trait Point: Sized + PartialEq {
8 fn distance(&self, other: &Self) -> f64;
14
15 fn move_towards(&self, other: &Self, d: f64) -> Self;
23
24 fn midpoint(a: &Self, b: &Self) -> Self {
28 let d = a.distance(b);
29 a.move_towards(b, d / 2.0)
30 }
31}
32
33impl<const D: usize> Point for [f64; D] {
36 fn distance(&self, other: &Self) -> f64 {
37 self.iter()
38 .zip(other)
39 .map(|(a, b)| (*a - *b).powi(2))
40 .sum::<f64>()
41 .sqrt()
42 }
43
44 fn move_towards(&self, other: &Self, d: f64) -> Self {
45 let mut result = self.clone();
46
47 let distance = self.distance(other);
48
49 if distance == 0.0 {
51 return result;
52 }
53
54 let scale = d / self.distance(other);
55
56 for i in 0..D {
57 result[i] += scale * (other[i] - self[i]);
58 }
59
60 result
61 }
62
63 fn midpoint(a: &Self, b: &Self) -> Self {
64 let mut result = [0.0; D];
65 for i in 0..D {
66 result[i] = (a[i] + b[i]) / 2.0;
67 }
68 result
69 }
70}
71
72impl<const D: usize> Point for [f32; D] {
75 fn distance(&self, other: &Self) -> f64 {
76 self.iter()
77 .zip(other)
78 .map(|(a, b)| (*a - *b).powi(2))
79 .sum::<f32>()
80 .sqrt() as f64
81 }
82
83 fn move_towards(&self, other: &Self, d: f64) -> Self {
84 let mut result = self.clone();
85
86 let distance = self.distance(other);
87
88 if distance == 0.0 {
90 return result;
91 }
92
93 let scale = d / self.distance(other);
94 let scale = scale as f32;
95
96 for i in 0..D {
97 result[i] += scale * (other[i] - self[i]);
98 }
99
100 result
101 }
102
103 fn midpoint(a: &Self, b: &Self) -> Self {
104 let mut result = [0.0; D];
105 for i in 0..D {
106 result[i] = (a[i] + b[i]) / 2.0;
107 }
108 result
109 }
110}
111
112#[derive(Debug, Clone, PartialEq, PartialOrd)]
115struct OrdF64(f64);
116impl OrdF64 {
117 fn new(x: f64) -> Self {
118 assert!(!x.is_nan());
119 OrdF64(x)
120 }
121}
122impl Eq for OrdF64 {}
123impl Ord for OrdF64 {
124 fn cmp(&self, other: &Self) -> Ordering {
125 self.partial_cmp(other).unwrap()
126 }
127}
128
129#[derive(Debug, Copy, Clone, PartialEq)]
130struct Sphere<C> {
131 center: C,
132 radius: f64,
133}
134
135impl<C: Point> Sphere<C> {
136 fn nearest_distance(&self, p: &C) -> f64 {
137 let d = self.center.distance(p) - self.radius;
138 d.max(0.0)
139 }
140
141 fn farthest_distance(&self, p: &C) -> f64 {
142 self.center.distance(p) + self.radius
143 }
144}
145
146fn bounding_sphere<P: Point>(points: &[P]) -> Sphere<P> {
157 assert!(points.len() >= 2);
158
159 let a = &points
160 .iter()
161 .max_by_key(|a| OrdF64::new(points[0].distance(a)))
162 .unwrap();
163 let b = &points
164 .iter()
165 .max_by_key(|b| OrdF64::new(a.distance(b)))
166 .unwrap();
167
168 let mut center: P = P::midpoint(a, b);
169 let mut radius = center.distance(b).max(std::f64::EPSILON);
170
171 loop {
172 match points.iter().filter(|p| center.distance(p) > radius).next() {
173 None => break Sphere { center, radius },
174 Some(p) => {
175 let c_to_p = center.distance(&p);
176 let d = c_to_p - radius;
177 center = center.move_towards(p, d);
178 radius = radius * 1.01;
179 }
180 }
181 }
182}
183
184fn partition<P: Point, V>(
192 mut points: Vec<P>,
193 mut values: Vec<V>,
194) -> ((Vec<P>, Vec<V>), (Vec<P>, Vec<V>)) {
195 assert!(points.len() >= 2);
196 assert_eq!(points.len(), values.len());
197
198 let a_i = points
199 .iter()
200 .enumerate()
201 .max_by_key(|(_, a)| OrdF64::new(points[0].distance(a)))
202 .unwrap()
203 .0;
204
205 let b_i = points
206 .iter()
207 .enumerate()
208 .max_by_key(|(_, b)| OrdF64::new(points[a_i].distance(b)))
209 .unwrap()
210 .0;
211
212 let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i));
213
214 let (mut aps, mut avs) = (vec![points.swap_remove(a_i)], vec![values.swap_remove(a_i)]);
215 let (mut bps, mut bvs) = (vec![points.swap_remove(b_i)], vec![values.swap_remove(b_i)]);
216
217 for (p, v) in points.into_iter().zip(values) {
218 if aps[0].distance(&p) < bps[0].distance(&p) {
219 aps.push(p);
220 avs.push(v);
221 } else {
222 bps.push(p);
223 bvs.push(v);
224 }
225 }
226
227 ((aps, avs), (bps, bvs))
228}
229
230#[derive(Debug, Clone)]
231enum BallTreeInner<P, V> {
232 Empty,
233 Leaf(P, Vec<V>),
234 Branch {
236 sphere: Sphere<P>,
237 a: Box<BallTreeInner<P, V>>,
238 b: Box<BallTreeInner<P, V>>,
239 count: usize,
240 },
241}
242
243impl<P: Point, V> Default for BallTreeInner<P, V> {
244 fn default() -> Self {
245 BallTreeInner::Empty
246 }
247}
248
249impl<P: Point, V> BallTreeInner<P, V> {
250 fn new(mut points: Vec<P>, values: Vec<V>) -> Self {
251 assert_eq!(
252 points.len(),
253 values.len(),
254 "Given two vectors of differing lengths. points: {}, values: {}",
255 points.len(),
256 values.len()
257 );
258
259 if points.is_empty() {
260 BallTreeInner::Empty
261 } else if points.iter().all(|p| p == &points[0]) {
262 BallTreeInner::Leaf(points.pop().unwrap(), values)
263 } else {
264 let count = points.len();
265 let sphere = bounding_sphere(&points);
266 let ((aps, avs), (bps, bvs)) = partition(points, values);
267 let (a_tree, b_tree) = (BallTreeInner::new(aps, avs), BallTreeInner::new(bps, bvs));
268 BallTreeInner::Branch { sphere, a: Box::new(a_tree), b: Box::new(b_tree), count }
269 }
270 }
271
272 fn nearest_distance(&self, p: &P) -> f64 {
273 match self {
274 BallTreeInner::Empty => std::f64::INFINITY,
275 BallTreeInner::Leaf(p0, _) => p.distance(p0),
277 BallTreeInner::Branch { sphere, .. } => sphere.nearest_distance(p),
279 }
280 }
281}
282
283#[derive(Debug, Copy, Clone)]
290struct Item<T>(f64, T);
291impl<T> PartialEq for Item<T> {
292 fn eq(&self, other: &Self) -> bool {
293 self.0 == other.0
294 }
295}
296impl<T> Eq for Item<T> {}
297impl<T> PartialOrd for Item<T> {
298 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
299 self.0
300 .partial_cmp(&other.0)
301 .map(|ordering| ordering.reverse())
302 }
303}
304impl<T> Ord for Item<T> {
305 fn cmp(&self, other: &Self) -> Ordering {
306 self.partial_cmp(other).unwrap()
307 }
308}
309
310#[derive(Debug)]
316pub struct Iter<'tree, 'query, P, V> {
317 point: &'query P,
318 balls: &'query mut BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
319 i: usize,
320 max_radius: f64,
321}
322
323impl<'tree, 'query, P: Point, V> Iterator for Iter<'tree, 'query, P, V> {
324 type Item = (&'tree P, f64, &'tree V);
325
326 fn next(&mut self) -> Option<Self::Item> {
327 while self.balls.len() > 0 {
328 if let Item(d, BallTreeInner::Leaf(p, vs)) = self.balls.peek().unwrap() {
331 if self.i < vs.len() && *d <= self.max_radius {
332 self.i += 1;
333 return Some((p, *d, &vs[self.i - 1]));
334 }
335 }
336 self.i = 0;
338 if let Item(_, BallTreeInner::Branch { a, b, .. }) = self.balls.pop().unwrap() {
340 let d_a = a.nearest_distance(self.point);
341 let d_b = b.nearest_distance(self.point);
342 if d_a <= self.max_radius {
343 self.balls.push(Item(d_a, a));
344 }
345 if d_b <= self.max_radius {
346 self.balls.push(Item(d_b, b));
347 }
348 }
349 }
350 None
351 }
352}
353
354#[derive(Debug, Clone)]
396pub struct BallTree<P, V>(BallTreeInner<P, V>);
397
398impl<P: Point, V> Default for BallTree<P, V> {
399 fn default() -> Self {
400 BallTree(BallTreeInner::default())
401 }
402}
403
404impl<P: Point, V> BallTree<P, V> {
405 pub fn new(points: Vec<P>, values: Vec<V>) -> Self {
410 BallTree(BallTreeInner::new(points, values))
411 }
412
413 pub fn query(&self) -> Query<'_, P, V> {
416 Query {
417 ball_tree: self,
418 balls: Default::default(),
419 }
420 }
421}
422
423#[derive(Debug, Clone)]
425pub struct Query<'tree, P, V> {
426 ball_tree: &'tree BallTree<P, V>,
427 balls: BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
428}
429
430impl<'tree, P: Point, V> Query<'tree, P, V> {
431 pub fn nn<'query>(
436 &'query mut self,
437 point: &'query P,
438 ) -> Iter<'tree, 'query, P, V> {
439 self.nn_within(point, f64::INFINITY)
440 }
441
442 pub fn nn_within<'query>(
444 &'query mut self,
445 point: &'query P,
446 max_radius: f64,
447 ) -> Iter<'tree, 'query, P, V> {
448 let balls = &mut self.balls;
449 balls.clear();
450 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
451 Iter {
452 point,
453 balls,
454 i: 0,
455 max_radius,
456 }
457 }
458
459 pub fn min_radius<'query>(&'query mut self, point: &'query P, k: usize) -> f64 {
461 let mut total_count = 0;
462 let balls = &mut self.balls;
463 balls.clear();
464 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
465
466 while let Some(Item(distance, node)) = balls.pop() {
467 match node {
468 BallTreeInner::Empty => {}
469 BallTreeInner::Leaf(_, vs) => {
470 total_count += vs.len();
471 if total_count >= k {
472 return distance;
473 }
474 }
475 BallTreeInner::Branch { sphere, a, b, count } => {
476 let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY);
477 if total_count + count < k && sphere.farthest_distance(point) < next_distance {
478 total_count += count;
479 } else {
480 balls.push(Item(a.nearest_distance(point), &a));
481 balls.push(Item(b.nearest_distance(point), &b));
482 }
483 }
484 }
485 }
486
487 f64::INFINITY
488 }
489
490 pub fn count<'query>(&'query mut self, point: &'query P, max_radius: f64) -> usize {
492 let mut total = 0;
493 let balls = &mut self.balls;
494 balls.clear();
495 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
496
497 while let Some(Item(nearest_distance, node)) = balls.pop() {
498 if nearest_distance > max_radius {
499 break;
500 }
501 match node {
502 BallTreeInner::Empty => {}
503 BallTreeInner::Leaf(_, vs) => {
504 total += vs.len();
505 }
506 BallTreeInner::Branch { a, b, count, sphere} => {
507 let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY).min(max_radius);
508 if sphere.farthest_distance(point) < next_distance {
509 total += count;
510 } else {
511 balls.push(Item(a.nearest_distance(point), &a));
512 balls.push(Item(b.nearest_distance(point), &b));
513 }
514 }
515 }
516 }
517
518 total
519 }
520
521 pub fn allocated_size(&self) -> usize {
523 self.balls.capacity() * std::mem::size_of::<Item<&'tree BallTreeInner<P, V>>>()
524 }
525
526 pub fn deallocate_memory(&mut self) {
528 self.balls.clear();
529 self.balls.shrink_to_fit();
530 }
531}
532
533#[cfg(test)]
534mod tests {
535 use super::*;
536 use rand::{Rng, SeedableRng};
537 use rand_chacha::ChaChaRng;
538 use std::collections::HashSet;
539
540 #[test]
541 fn test_3d_points() {
542 let mut rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96);
543
544 macro_rules! random_small_f64 {
545 () => {
546 rng.gen_range(-100.0 ..= 100.0)
547 };
548 }
549
550 macro_rules! random_3d_point {
551 () => {
552 [
553 random_small_f64!(),
554 random_small_f64!(),
555 random_small_f64!(),
556 ]
557 };
558 }
559
560 for i in 0..1000 {
561 let point_count: usize = if i < 100 {
562 rng.gen_range(1..=3)
563 } else if i < 500 {
564 rng.gen_range(1..=10)
565 } else {
566 rng.gen_range(1..=100)
567 };
568
569 let mut points = vec![];
570 let mut values = vec![];
571
572 for _ in 0..point_count {
573 let point = random_3d_point!();
574 let value = rng.gen::<u64>();
575 points.push(point);
576 values.push(value);
577 }
578
579 let tree = BallTree::new(points.clone(), values.clone());
580
581 let mut query = tree.query();
582
583 for _ in 0..100 {
584 let point = random_3d_point!();
585 let max_radius = rng.gen_range(0.0 ..= 110.0);
586
587 let expected_values = points
588 .iter()
589 .zip(&values)
590 .filter(|(p, _)| p.distance(&point) <= max_radius)
591 .map(|(_, v)| v)
592 .cloned()
593 .collect::<HashSet<_>>();
594
595 let mut found_values = HashSet::new();
596
597 let mut previous_d = 0.0;
598 for (p, d, v) in query.nn_within(&point, max_radius) {
599 assert_eq!(point.distance(p), d);
600 assert!(d >= previous_d);
601 assert!(d <= max_radius);
602 previous_d = d;
603 found_values.insert(*v);
604 }
605
606 assert_eq!(expected_values, found_values);
607
608 assert_eq!(found_values.len(), query.count(&point, max_radius));
609
610 let radius = query.min_radius(&point, expected_values.len());
611
612 let should_be_fewer = query.count(&point, radius * 0.99);
613
614 assert!(expected_values.is_empty() || should_be_fewer < expected_values.len(), "{} < {}", should_be_fewer, expected_values.len());
615 }
616
617 assert!(query.allocated_size() > 0);
618 assert!(query.allocated_size() <= 2 * 8 * point_count.next_power_of_two().max(4));
620
621 query.deallocate_memory();
622 assert_eq!(query.allocated_size(), 0);
623 }
624 }
625
626 #[test]
627 fn test_point_array_impls() {
628 assert_eq!([5.0].distance(&[7.0]), 2.0);
629 assert_eq!([5.0].move_towards(&[3.0], 1.0), [4.0]);
630
631 assert_eq!([5.0, 3.0].distance(&[7.0, 5.0]), 2.0 * 2f64.sqrt());
632 assert_eq!(
633 [5.0, 3.0].move_towards(&[3.0, 1.0], 2f64.sqrt()),
634 [4.0, 2.0]
635 );
636
637 assert_eq!([0.0, 0.0, 0.0, 0.0].distance(&[2.0, 2.0, 2.0, 2.0]), 4.0);
638 assert_eq!(
639 [0.0, 0.0, 0.0, 0.0].move_towards(&[2.0, 2.0, 2.0, 2.0], 8.0),
640 [4.0, 4.0, 4.0, 4.0]
641 );
642 }
643}