1use std::cmp::Ordering;
2use std::collections::BinaryHeap;
3
4pub trait Point: Sized + PartialEq {
8 fn distance(&self, other: &Self) -> f64;
14
15 fn move_towards(&self, other: &Self, d: f64) -> Self;
23
24 fn midpoint(a: &Self, b: &Self) -> Self {
28 let d = a.distance(b);
29 a.move_towards(b, d / 2.0)
30 }
31}
32
33impl<const D: usize> Point for [f64; D] {
36 fn distance(&self, other: &Self) -> f64 {
37 self.iter()
38 .zip(other)
39 .map(|(a, b)| (*a - *b).powi(2))
40 .sum::<f64>()
41 .sqrt()
42 }
43
44 fn move_towards(&self, other: &Self, d: f64) -> Self {
45 let mut result = self.clone();
46
47 let distance = self.distance(other);
48
49 if distance == 0.0 {
51 return result;
52 }
53
54 let scale = d / self.distance(other);
55
56 for i in 0..D {
57 result[i] += scale * (other[i] - self[i]);
58 }
59
60 result
61 }
62
63 fn midpoint(a: &Self, b: &Self) -> Self {
64 let mut result = [0.0; D];
65 for i in 0..D {
66 result[i] = (a[i] + b[i]) / 2.0;
67 }
68 result
69 }
70}
71
72#[derive(Debug, Clone, PartialEq, PartialOrd)]
75struct OrdF64(f64);
76impl OrdF64 {
77 fn new(x: f64) -> Self {
78 assert!(!x.is_nan());
79 OrdF64(x)
80 }
81}
82impl Eq for OrdF64 {}
83impl Ord for OrdF64 {
84 fn cmp(&self, other: &Self) -> Ordering {
85 self.partial_cmp(other).unwrap()
86 }
87}
88
89#[derive(Debug, Copy, Clone, PartialEq)]
90struct Sphere<C> {
91 center: C,
92 radius: f64,
93}
94
95impl<C: Point> Sphere<C> {
96 fn nearest_distance(&self, p: &C) -> f64 {
97 let d = self.center.distance(p) - self.radius;
98 d.max(0.0)
99 }
100
101 fn farthest_distance(&self, p: &C) -> f64 {
102 self.center.distance(p) + self.radius
103 }
104}
105
106fn bounding_sphere<P: Point>(points: &[P]) -> Sphere<P> {
117 assert!(points.len() >= 2);
118
119 let a = &points
120 .iter()
121 .max_by_key(|a| OrdF64::new(points[0].distance(a)))
122 .unwrap();
123 let b = &points
124 .iter()
125 .max_by_key(|b| OrdF64::new(a.distance(b)))
126 .unwrap();
127
128 let mut center: P = P::midpoint(a, b);
129 let mut radius = center.distance(b).max(std::f64::EPSILON);
130
131 loop {
132 match points.iter().filter(|p| center.distance(p) > radius).next() {
133 None => break Sphere { center, radius },
134 Some(p) => {
135 let c_to_p = center.distance(&p);
136 let d = c_to_p - radius;
137 center = center.move_towards(p, d);
138 radius = radius * 1.01;
139 }
140 }
141 }
142}
143
144fn partition<P: Point, V>(
152 mut points: Vec<P>,
153 mut values: Vec<V>,
154) -> ((Vec<P>, Vec<V>), (Vec<P>, Vec<V>)) {
155 assert!(points.len() >= 2);
156 assert_eq!(points.len(), values.len());
157
158 let a_i = points
159 .iter()
160 .enumerate()
161 .max_by_key(|(_, a)| OrdF64::new(points[0].distance(a)))
162 .unwrap()
163 .0;
164
165 let b_i = points
166 .iter()
167 .enumerate()
168 .max_by_key(|(_, b)| OrdF64::new(points[a_i].distance(b)))
169 .unwrap()
170 .0;
171
172 let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i));
173
174 let (mut aps, mut avs) = (vec![points.swap_remove(a_i)], vec![values.swap_remove(a_i)]);
175 let (mut bps, mut bvs) = (vec![points.swap_remove(b_i)], vec![values.swap_remove(b_i)]);
176
177 for (p, v) in points.into_iter().zip(values) {
178 if aps[0].distance(&p) < bps[0].distance(&p) {
179 aps.push(p);
180 avs.push(v);
181 } else {
182 bps.push(p);
183 bvs.push(v);
184 }
185 }
186
187 ((aps, avs), (bps, bvs))
188}
189
190#[derive(Debug, Clone)]
191enum BallTreeInner<P, V> {
192 Empty,
193 Leaf(P, Vec<V>),
194 Branch {
196 sphere: Sphere<P>,
197 a: Box<BallTreeInner<P, V>>,
198 b: Box<BallTreeInner<P, V>>,
199 count: usize,
200 },
201}
202
203impl<P: Point, V> Default for BallTreeInner<P, V> {
204 fn default() -> Self {
205 BallTreeInner::Empty
206 }
207}
208
209impl<P: Point, V> BallTreeInner<P, V> {
210 fn new(mut points: Vec<P>, values: Vec<V>) -> Self {
211 assert_eq!(
212 points.len(),
213 values.len(),
214 "Given two vectors of differing lengths. points: {}, values: {}",
215 points.len(),
216 values.len()
217 );
218
219 if points.is_empty() {
220 BallTreeInner::Empty
221 } else if points.iter().all(|p| p == &points[0]) {
222 BallTreeInner::Leaf(points.pop().unwrap(), values)
223 } else {
224 let count = points.len();
225 let sphere = bounding_sphere(&points);
226 let ((aps, avs), (bps, bvs)) = partition(points, values);
227 let (a_tree, b_tree) = (BallTreeInner::new(aps, avs), BallTreeInner::new(bps, bvs));
228 BallTreeInner::Branch { sphere, a: Box::new(a_tree), b: Box::new(b_tree), count }
229 }
230 }
231
232 fn nearest_distance(&self, p: &P) -> f64 {
233 match self {
234 BallTreeInner::Empty => std::f64::INFINITY,
235 BallTreeInner::Leaf(p0, _) => p.distance(p0),
237 BallTreeInner::Branch { sphere, .. } => sphere.nearest_distance(p),
239 }
240 }
241}
242
243#[derive(Debug, Copy, Clone)]
250struct Item<T>(f64, T);
251impl<T> PartialEq for Item<T> {
252 fn eq(&self, other: &Self) -> bool {
253 self.0 == other.0
254 }
255}
256impl<T> Eq for Item<T> {}
257impl<T> PartialOrd for Item<T> {
258 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
259 self.0
260 .partial_cmp(&other.0)
261 .map(|ordering| ordering.reverse())
262 }
263}
264impl<T> Ord for Item<T> {
265 fn cmp(&self, other: &Self) -> Ordering {
266 self.partial_cmp(other).unwrap()
267 }
268}
269
270#[derive(Debug)]
276pub struct Iter<'tree, 'query, P, V> {
277 point: &'query P,
278 balls: &'query mut BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
279 i: usize,
280 max_radius: f64,
281}
282
283impl<'tree, 'query, P: Point, V> Iterator for Iter<'tree, 'query, P, V> {
284 type Item = (&'tree P, f64, &'tree V);
285
286 fn next(&mut self) -> Option<Self::Item> {
287 while self.balls.len() > 0 {
288 if let Item(d, BallTreeInner::Leaf(p, vs)) = self.balls.peek().unwrap() {
291 if self.i < vs.len() && *d <= self.max_radius {
292 self.i += 1;
293 return Some((p, *d, &vs[self.i - 1]));
294 }
295 }
296 self.i = 0;
298 if let Item(_, BallTreeInner::Branch { a, b, .. }) = self.balls.pop().unwrap() {
300 let d_a = a.nearest_distance(self.point);
301 let d_b = b.nearest_distance(self.point);
302 if d_a <= self.max_radius {
303 self.balls.push(Item(d_a, a));
304 }
305 if d_b <= self.max_radius {
306 self.balls.push(Item(d_b, b));
307 }
308 }
309 }
310 None
311 }
312}
313
314#[derive(Debug, Clone)]
356pub struct BallTree<P, V>(BallTreeInner<P, V>);
357
358impl<P: Point, V> Default for BallTree<P, V> {
359 fn default() -> Self {
360 BallTree(BallTreeInner::default())
361 }
362}
363
364impl<P: Point, V> BallTree<P, V> {
365 pub fn new(points: Vec<P>, values: Vec<V>) -> Self {
370 BallTree(BallTreeInner::new(points, values))
371 }
372
373 pub fn query(&self) -> Query<P, V> {
376 Query {
377 ball_tree: self,
378 balls: Default::default(),
379 }
380 }
381}
382
383#[derive(Debug, Clone)]
385pub struct Query<'tree, P, V> {
386 ball_tree: &'tree BallTree<P, V>,
387 balls: BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
388}
389
390impl<'tree, P: Point, V> Query<'tree, P, V> {
391 pub fn nn<'query>(
396 &'query mut self,
397 point: &'query P,
398 ) -> Iter<'tree, 'query, P, V> {
399 self.nn_within(point, f64::INFINITY)
400 }
401
402 pub fn nn_within<'query>(
404 &'query mut self,
405 point: &'query P,
406 max_radius: f64,
407 ) -> Iter<'tree, 'query, P, V> {
408 let balls = &mut self.balls;
409 balls.clear();
410 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
411 Iter {
412 point,
413 balls,
414 i: 0,
415 max_radius,
416 }
417 }
418
419 pub fn min_radius<'query>(&'query mut self, point: &'query P, k: usize) -> f64 {
421 let mut total_count = 0;
422 let balls = &mut self.balls;
423 balls.clear();
424 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
425
426 while let Some(Item(distance, node)) = balls.pop() {
427 match node {
428 BallTreeInner::Empty => {}
429 BallTreeInner::Leaf(_, vs) => {
430 total_count += vs.len();
431 if total_count >= k {
432 return distance;
433 }
434 }
435 BallTreeInner::Branch { sphere, a, b, count } => {
436 let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY);
437 if total_count + count < k && sphere.farthest_distance(point) < next_distance {
438 total_count += count;
439 } else {
440 balls.push(Item(a.nearest_distance(point), &a));
441 balls.push(Item(b.nearest_distance(point), &b));
442 }
443 }
444 }
445 }
446
447 f64::INFINITY
448 }
449
450 pub fn count<'query>(&'query mut self, point: &'query P, max_radius: f64) -> usize {
452 let mut total = 0;
453 let balls = &mut self.balls;
454 balls.clear();
455 balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
456
457 while let Some(Item(nearest_distance, node)) = balls.pop() {
458 if nearest_distance > max_radius {
459 break;
460 }
461 match node {
462 BallTreeInner::Empty => {}
463 BallTreeInner::Leaf(_, vs) => {
464 total += vs.len();
465 }
466 BallTreeInner::Branch { a, b, count, sphere} => {
467 let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY).min(max_radius);
468 if sphere.farthest_distance(point) < next_distance {
469 total += count;
470 } else {
471 balls.push(Item(a.nearest_distance(point), &a));
472 balls.push(Item(b.nearest_distance(point), &b));
473 }
474 }
475 }
476 }
477
478 total
479 }
480
481 pub fn allocated_size(&self) -> usize {
483 self.balls.capacity() * std::mem::size_of::<Item<&'tree BallTreeInner<P, V>>>()
484 }
485
486 pub fn deallocate_memory(&mut self) {
488 self.balls.clear();
489 self.balls.shrink_to_fit();
490 }
491}
492
493#[cfg(test)]
494mod tests {
495 use super::*;
496 use rand::{Rng, SeedableRng};
497 use rand_chacha::ChaChaRng;
498 use std::collections::HashSet;
499
500 #[test]
501 fn test_3d_points() {
502 let mut rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96);
503
504 macro_rules! random_small_f64 {
505 () => {
506 rng.gen_range(-100.0 ..= 100.0)
507 };
508 }
509
510 macro_rules! random_3d_point {
511 () => {
512 [
513 random_small_f64!(),
514 random_small_f64!(),
515 random_small_f64!(),
516 ]
517 };
518 }
519
520 for i in 0..1000 {
521 let point_count: usize = if i < 100 {
522 rng.gen_range(1..=3)
523 } else if i < 500 {
524 rng.gen_range(1..=10)
525 } else {
526 rng.gen_range(1..=100)
527 };
528
529 let mut points = vec![];
530 let mut values = vec![];
531
532 for _ in 0..point_count {
533 let point = random_3d_point!();
534 let value = rng.gen::<u64>();
535 points.push(point);
536 values.push(value);
537 }
538
539 let tree = BallTree::new(points.clone(), values.clone());
540
541 let mut query = tree.query();
542
543 for _ in 0..100 {
544 let point = random_3d_point!();
545 let max_radius = rng.gen_range(0.0 ..= 110.0);
546
547 let expected_values = points
548 .iter()
549 .zip(&values)
550 .filter(|(p, _)| p.distance(&point) <= max_radius)
551 .map(|(_, v)| v)
552 .cloned()
553 .collect::<HashSet<_>>();
554
555 let mut found_values = HashSet::new();
556
557 let mut previous_d = 0.0;
558 for (p, d, v) in query.nn_within(&point, max_radius) {
559 assert_eq!(point.distance(p), d);
560 assert!(d >= previous_d);
561 assert!(d <= max_radius);
562 previous_d = d;
563 found_values.insert(*v);
564 }
565
566 assert_eq!(expected_values, found_values);
567
568 assert_eq!(found_values.len(), query.count(&point, max_radius));
569
570 let radius = query.min_radius(&point, expected_values.len());
571
572 let should_be_fewer = query.count(&point, radius * 0.99);
573
574 assert!(expected_values.is_empty() || should_be_fewer < expected_values.len(), "{} < {}", should_be_fewer, expected_values.len());
575 }
576
577 assert!(query.allocated_size() > 0);
578 assert!(query.allocated_size() <= 2 * 8 * point_count.next_power_of_two().max(4));
580
581 query.deallocate_memory();
582 assert_eq!(query.allocated_size(), 0);
583 }
584 }
585
586 #[test]
587 fn test_point_array_impls() {
588 assert_eq!([5.0].distance(&[7.0]), 2.0);
589 assert_eq!([5.0].move_towards(&[3.0], 1.0), [4.0]);
590
591 assert_eq!([5.0, 3.0].distance(&[7.0, 5.0]), 2.0 * 2f64.sqrt());
592 assert_eq!(
593 [5.0, 3.0].move_towards(&[3.0, 1.0], 2f64.sqrt()),
594 [4.0, 2.0]
595 );
596
597 assert_eq!([0.0, 0.0, 0.0, 0.0].distance(&[2.0, 2.0, 2.0, 2.0]), 4.0);
598 assert_eq!(
599 [0.0, 0.0, 0.0, 0.0].move_towards(&[2.0, 2.0, 2.0, 2.0], 8.0),
600 [4.0, 4.0, 4.0, 4.0]
601 );
602 }
603}