use std::cmp::Ordering;
use std::collections::BinaryHeap;
pub trait Point: Sized + PartialEq {
fn distance(&self, other: &Self) -> f64;
fn move_towards(&self, other: &Self, d: f64) -> Self;
fn midpoint(a: &Self, b: &Self) -> Self {
let d = a.distance(b);
a.move_towards(b, d / 2.0)
}
}
impl<const D: usize> Point for [f64; D] {
fn distance(&self, other: &Self) -> f64 {
self.iter()
.zip(other)
.map(|(a, b)| (*a - *b).powi(2))
.sum::<f64>()
.sqrt()
}
fn move_towards(&self, other: &Self, d: f64) -> Self {
let mut result = self.clone();
let distance = self.distance(other);
if distance == 0.0 {
return result;
}
let scale = d / self.distance(other);
for i in 0..D {
result[i] += scale * (other[i] - self[i]);
}
result
}
fn midpoint(a: &Self, b: &Self) -> Self {
let mut result = [0.0; D];
for i in 0..D {
result[i] = (a[i] + b[i]) / 2.0;
}
result
}
}
impl<const D: usize> Point for [f32; D] {
fn distance(&self, other: &Self) -> f64 {
self.iter()
.zip(other)
.map(|(a, b)| (*a - *b).powi(2))
.sum::<f32>()
.sqrt() as f64
}
fn move_towards(&self, other: &Self, d: f64) -> Self {
let mut result = self.clone();
let distance = self.distance(other);
if distance == 0.0 {
return result;
}
let scale = d / self.distance(other);
let scale = scale as f32;
for i in 0..D {
result[i] += scale * (other[i] - self[i]);
}
result
}
fn midpoint(a: &Self, b: &Self) -> Self {
let mut result = [0.0; D];
for i in 0..D {
result[i] = (a[i] + b[i]) / 2.0;
}
result
}
}
#[derive(Debug, Clone, PartialEq, PartialOrd)]
struct OrdF64(f64);
impl OrdF64 {
fn new(x: f64) -> Self {
assert!(!x.is_nan());
OrdF64(x)
}
}
impl Eq for OrdF64 {}
impl Ord for OrdF64 {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
struct Sphere<C> {
center: C,
radius: f64,
}
impl<C: Point> Sphere<C> {
fn nearest_distance(&self, p: &C) -> f64 {
let d = self.center.distance(p) - self.radius;
d.max(0.0)
}
fn farthest_distance(&self, p: &C) -> f64 {
self.center.distance(p) + self.radius
}
}
fn bounding_sphere<P: Point>(points: &[P]) -> Sphere<P> {
assert!(points.len() >= 2);
let a = &points
.iter()
.max_by_key(|a| OrdF64::new(points[0].distance(a)))
.unwrap();
let b = &points
.iter()
.max_by_key(|b| OrdF64::new(a.distance(b)))
.unwrap();
let mut center: P = P::midpoint(a, b);
let mut radius = center.distance(b).max(std::f64::EPSILON);
loop {
match points.iter().filter(|p| center.distance(p) > radius).next() {
None => break Sphere { center, radius },
Some(p) => {
let c_to_p = center.distance(&p);
let d = c_to_p - radius;
center = center.move_towards(p, d);
radius = radius * 1.01;
}
}
}
}
fn partition<P: Point, V>(
mut points: Vec<P>,
mut values: Vec<V>,
) -> ((Vec<P>, Vec<V>), (Vec<P>, Vec<V>)) {
assert!(points.len() >= 2);
assert_eq!(points.len(), values.len());
let a_i = points
.iter()
.enumerate()
.max_by_key(|(_, a)| OrdF64::new(points[0].distance(a)))
.unwrap()
.0;
let b_i = points
.iter()
.enumerate()
.max_by_key(|(_, b)| OrdF64::new(points[a_i].distance(b)))
.unwrap()
.0;
let (a_i, b_i) = (a_i.max(b_i), a_i.min(b_i));
let (mut aps, mut avs) = (vec![points.swap_remove(a_i)], vec![values.swap_remove(a_i)]);
let (mut bps, mut bvs) = (vec![points.swap_remove(b_i)], vec![values.swap_remove(b_i)]);
for (p, v) in points.into_iter().zip(values) {
if aps[0].distance(&p) < bps[0].distance(&p) {
aps.push(p);
avs.push(v);
} else {
bps.push(p);
bvs.push(v);
}
}
((aps, avs), (bps, bvs))
}
#[derive(Debug, Clone)]
enum BallTreeInner<P, V> {
Empty,
Leaf(P, Vec<V>),
Branch {
sphere: Sphere<P>,
a: Box<BallTreeInner<P, V>>,
b: Box<BallTreeInner<P, V>>,
count: usize,
},
}
impl<P: Point, V> Default for BallTreeInner<P, V> {
fn default() -> Self {
BallTreeInner::Empty
}
}
impl<P: Point, V> BallTreeInner<P, V> {
fn new(mut points: Vec<P>, values: Vec<V>) -> Self {
assert_eq!(
points.len(),
values.len(),
"Given two vectors of differing lengths. points: {}, values: {}",
points.len(),
values.len()
);
if points.is_empty() {
BallTreeInner::Empty
} else if points.iter().all(|p| p == &points[0]) {
BallTreeInner::Leaf(points.pop().unwrap(), values)
} else {
let count = points.len();
let sphere = bounding_sphere(&points);
let ((aps, avs), (bps, bvs)) = partition(points, values);
let (a_tree, b_tree) = (BallTreeInner::new(aps, avs), BallTreeInner::new(bps, bvs));
BallTreeInner::Branch { sphere, a: Box::new(a_tree), b: Box::new(b_tree), count }
}
}
fn nearest_distance(&self, p: &P) -> f64 {
match self {
BallTreeInner::Empty => std::f64::INFINITY,
BallTreeInner::Leaf(p0, _) => p.distance(p0),
BallTreeInner::Branch { sphere, .. } => sphere.nearest_distance(p),
}
}
}
#[derive(Debug, Copy, Clone)]
struct Item<T>(f64, T);
impl<T> PartialEq for Item<T> {
fn eq(&self, other: &Self) -> bool {
self.0 == other.0
}
}
impl<T> Eq for Item<T> {}
impl<T> PartialOrd for Item<T> {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
self.0
.partial_cmp(&other.0)
.map(|ordering| ordering.reverse())
}
}
impl<T> Ord for Item<T> {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap()
}
}
#[derive(Debug)]
pub struct Iter<'tree, 'query, P, V> {
point: &'query P,
balls: &'query mut BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
i: usize,
max_radius: f64,
}
impl<'tree, 'query, P: Point, V> Iterator for Iter<'tree, 'query, P, V> {
type Item = (&'tree P, f64, &'tree V);
fn next(&mut self) -> Option<Self::Item> {
while self.balls.len() > 0 {
if let Item(d, BallTreeInner::Leaf(p, vs)) = self.balls.peek().unwrap() {
if self.i < vs.len() && *d <= self.max_radius {
self.i += 1;
return Some((p, *d, &vs[self.i - 1]));
}
}
self.i = 0;
if let Item(_, BallTreeInner::Branch { a, b, .. }) = self.balls.pop().unwrap() {
let d_a = a.nearest_distance(self.point);
let d_b = b.nearest_distance(self.point);
if d_a <= self.max_radius {
self.balls.push(Item(d_a, a));
}
if d_b <= self.max_radius {
self.balls.push(Item(d_b, b));
}
}
}
None
}
}
#[derive(Debug, Clone)]
pub struct BallTree<P, V>(BallTreeInner<P, V>);
impl<P: Point, V> Default for BallTree<P, V> {
fn default() -> Self {
BallTree(BallTreeInner::default())
}
}
impl<P: Point, V> BallTree<P, V> {
pub fn new(points: Vec<P>, values: Vec<V>) -> Self {
BallTree(BallTreeInner::new(points, values))
}
pub fn query(&self) -> Query<'_, P, V> {
Query {
ball_tree: self,
balls: Default::default(),
}
}
}
#[derive(Debug, Clone)]
pub struct Query<'tree, P, V> {
ball_tree: &'tree BallTree<P, V>,
balls: BinaryHeap<Item<&'tree BallTreeInner<P, V>>>,
}
impl<'tree, P: Point, V> Query<'tree, P, V> {
pub fn nn<'query>(
&'query mut self,
point: &'query P,
) -> Iter<'tree, 'query, P, V> {
self.nn_within(point, f64::INFINITY)
}
pub fn nn_within<'query>(
&'query mut self,
point: &'query P,
max_radius: f64,
) -> Iter<'tree, 'query, P, V> {
let balls = &mut self.balls;
balls.clear();
balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
Iter {
point,
balls,
i: 0,
max_radius,
}
}
pub fn min_radius<'query>(&'query mut self, point: &'query P, k: usize) -> f64 {
let mut total_count = 0;
let balls = &mut self.balls;
balls.clear();
balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
while let Some(Item(distance, node)) = balls.pop() {
match node {
BallTreeInner::Empty => {}
BallTreeInner::Leaf(_, vs) => {
total_count += vs.len();
if total_count >= k {
return distance;
}
}
BallTreeInner::Branch { sphere, a, b, count } => {
let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY);
if total_count + count < k && sphere.farthest_distance(point) < next_distance {
total_count += count;
} else {
balls.push(Item(a.nearest_distance(point), &a));
balls.push(Item(b.nearest_distance(point), &b));
}
}
}
}
f64::INFINITY
}
pub fn count<'query>(&'query mut self, point: &'query P, max_radius: f64) -> usize {
let mut total = 0;
let balls = &mut self.balls;
balls.clear();
balls.push(Item(self.ball_tree.0.nearest_distance(point), &self.ball_tree.0));
while let Some(Item(nearest_distance, node)) = balls.pop() {
if nearest_distance > max_radius {
break;
}
match node {
BallTreeInner::Empty => {}
BallTreeInner::Leaf(_, vs) => {
total += vs.len();
}
BallTreeInner::Branch { a, b, count, sphere} => {
let next_distance = balls.peek().map(|Item(d, _)| *d).unwrap_or(f64::INFINITY).min(max_radius);
if sphere.farthest_distance(point) < next_distance {
total += count;
} else {
balls.push(Item(a.nearest_distance(point), &a));
balls.push(Item(b.nearest_distance(point), &b));
}
}
}
}
total
}
pub fn allocated_size(&self) -> usize {
self.balls.capacity() * std::mem::size_of::<Item<&'tree BallTreeInner<P, V>>>()
}
pub fn deallocate_memory(&mut self) {
self.balls.clear();
self.balls.shrink_to_fit();
}
}
#[cfg(test)]
mod tests {
use super::*;
use rand::{Rng, SeedableRng};
use rand_chacha::ChaChaRng;
use std::collections::HashSet;
#[test]
fn test_3d_points() {
let mut rng: ChaChaRng = SeedableRng::seed_from_u64(0xcb42c94d23346e96);
macro_rules! random_small_f64 {
() => {
rng.gen_range(-100.0 ..= 100.0)
};
}
macro_rules! random_3d_point {
() => {
[
random_small_f64!(),
random_small_f64!(),
random_small_f64!(),
]
};
}
for i in 0..1000 {
let point_count: usize = if i < 100 {
rng.gen_range(1..=3)
} else if i < 500 {
rng.gen_range(1..=10)
} else {
rng.gen_range(1..=100)
};
let mut points = vec![];
let mut values = vec![];
for _ in 0..point_count {
let point = random_3d_point!();
let value = rng.gen::<u64>();
points.push(point);
values.push(value);
}
let tree = BallTree::new(points.clone(), values.clone());
let mut query = tree.query();
for _ in 0..100 {
let point = random_3d_point!();
let max_radius = rng.gen_range(0.0 ..= 110.0);
let expected_values = points
.iter()
.zip(&values)
.filter(|(p, _)| p.distance(&point) <= max_radius)
.map(|(_, v)| v)
.cloned()
.collect::<HashSet<_>>();
let mut found_values = HashSet::new();
let mut previous_d = 0.0;
for (p, d, v) in query.nn_within(&point, max_radius) {
assert_eq!(point.distance(p), d);
assert!(d >= previous_d);
assert!(d <= max_radius);
previous_d = d;
found_values.insert(*v);
}
assert_eq!(expected_values, found_values);
assert_eq!(found_values.len(), query.count(&point, max_radius));
let radius = query.min_radius(&point, expected_values.len());
let should_be_fewer = query.count(&point, radius * 0.99);
assert!(expected_values.is_empty() || should_be_fewer < expected_values.len(), "{} < {}", should_be_fewer, expected_values.len());
}
assert!(query.allocated_size() > 0);
assert!(query.allocated_size() <= 2 * 8 * point_count.next_power_of_two().max(4));
query.deallocate_memory();
assert_eq!(query.allocated_size(), 0);
}
}
#[test]
fn test_point_array_impls() {
assert_eq!([5.0].distance(&[7.0]), 2.0);
assert_eq!([5.0].move_towards(&[3.0], 1.0), [4.0]);
assert_eq!([5.0, 3.0].distance(&[7.0, 5.0]), 2.0 * 2f64.sqrt());
assert_eq!(
[5.0, 3.0].move_towards(&[3.0, 1.0], 2f64.sqrt()),
[4.0, 2.0]
);
assert_eq!([0.0, 0.0, 0.0, 0.0].distance(&[2.0, 2.0, 2.0, 2.0]), 4.0);
assert_eq!(
[0.0, 0.0, 0.0, 0.0].move_towards(&[2.0, 2.0, 2.0, 2.0], 8.0),
[4.0, 4.0, 4.0, 4.0]
);
}
}