use crate::distance::{self, Euclidean, Metric};
use crate::ArrayError;
use ndarray::{Array1, ArrayBase, ArrayView1, CowArray, Data, Ix1, Ix2};
use num_traits::{Float, FromPrimitive, Zero};
use ordered_float::OrderedFloat;
use std::cmp;
use std::collections::BinaryHeap;
use std::num::NonZeroUsize;
use std::ops::{AddAssign, DivAssign, Range};
pub struct BallTree<'a, A, M>
where
A: Float,
M: Metric<A>,
{
pub points: CowArray<'a, A, Ix2>,
pub idx: Vec<usize>,
pub nodes: Vec<Node<A>>,
pub metric: M,
}
impl<'a, A, M> BallTree<'a, A, M>
where
A: Float + Zero + AddAssign + DivAssign + FromPrimitive,
M: Metric<A>,
{
pub fn new<T>(points: T, metric: M) -> Result<Self, ArrayError>
where
T: Into<CowArray<'a, A, Ix2>>,
{
let points = points.into();
let n_points: usize = points.nrows();
if n_points == 0 {
return Err(ArrayError::Empty);
}
if !points.row(0).is_standard_layout() {
return Err(ArrayError::NotContiguous);
}
let height = usize::BITS - n_points.leading_zeros();
let size = 1_usize.wrapping_shl(height) - 1;
let mut idx: Vec<usize> = (0..n_points).collect();
let mut nodes = vec![Node::default(); size];
build_subtree(&mut nodes, &mut idx, &points, 0, 0..n_points, &metric);
Ok(BallTree {
points,
idx,
nodes,
metric,
})
}
pub fn query_nearest<S>(&self, point: &ArrayBase<S, Ix1>) -> (usize, A)
where
S: Data<Elem = A>,
{
self.nearest_neighbor_in_subtree(&point.view(), 0, A::infinity())
.expect("0 is a valid index")
}
pub fn query<S>(&self, point: &ArrayBase<S, Ix1>, k: usize) -> (Vec<usize>, Vec<A>)
where
S: Data<Elem = A>,
{
let k = if let Some(k) = NonZeroUsize::new(k) {
k
} else {
return (Vec::new(), Vec::new());
};
let mut neighbors = BinaryHeap::with_capacity(k.get());
self.nearest_k_neighbors_in_subtree(&point.view(), 0, A::infinity(), k, &mut neighbors);
let sorted = neighbors.into_sorted_vec();
let indices = sorted.iter().map(|v| v.idx).collect();
let distances = sorted.iter().map(|v| v.distance.into_inner()).collect();
(indices, distances)
}
pub fn query_radius<S>(&self, point: &ArrayBase<S, Ix1>, distance: A) -> Vec<usize>
where
S: Data<Elem = A>,
{
self.neighbors_within_radius_in_subtree(&point.view(), distance, 0)
}
fn nearest_neighbor_in_subtree(
&self,
point: &ArrayView1<A>,
root: usize,
radius: A,
) -> Option<(usize, A)> {
let root_node = &self.nodes[root];
let lower_bound = self.nodes[root].distance_lower_bound(point, &self.metric);
if lower_bound > radius {
return None;
}
if root_node.is_leaf {
let (min_i, min_dist) = self.idx[root_node.range.clone()].iter().fold(
(0, A::infinity()),
|(min_i, min_dist), &i| {
let dist = self.metric.distance(point, &self.points.row(i));
if dist < min_dist {
(i, dist)
} else {
(min_i, min_dist)
}
},
);
if min_dist <= radius {
Some((min_i, min_dist))
} else {
None
}
} else {
let child1 = root * 2 + 1;
let child2 = child1 + 1;
let lb1 = self.nodes[child1].distance_lower_bound(point, &self.metric);
let lb2 = self.nodes[child2].distance_lower_bound(point, &self.metric);
let (child1, child2) = if lb1 < lb2 {
(child1, child2)
} else {
(child2, child1)
};
match self.nearest_neighbor_in_subtree(point, child1, radius) {
Some(neighbor) => self
.nearest_neighbor_in_subtree(point, child2, neighbor.1)
.map_or(Some(neighbor), Some),
None => self.nearest_neighbor_in_subtree(point, child2, radius),
}
}
}
fn nearest_k_neighbors_in_subtree(
&self,
point: &ArrayView1<A>,
root: usize,
radius: A,
k: NonZeroUsize,
neighbors: &mut BinaryHeap<Neighbor<A>>,
) {
let root_node = &self.nodes[root];
if root_node.distance_lower_bound(point, &self.metric) > radius {
return;
}
if root_node.is_leaf {
self.idx[root_node.range.clone()]
.iter()
.filter_map(|&i| {
let dist = self.metric.distance(point, &self.points.row(i));
if dist < radius {
Some(Neighbor::new(i, dist))
} else {
None
}
})
.fold(neighbors, |neighbors, n| {
if neighbors.len() < k.get() {
neighbors.push(n);
} else if n < *neighbors.peek().expect("not empty") {
neighbors.pop();
neighbors.push(n);
}
neighbors
});
} else {
let child1 = root * 2 + 1;
let child2 = child1 + 1;
let lb1 = self.nodes[child1].distance_lower_bound(point, &self.metric);
let lb2 = self.nodes[child2].distance_lower_bound(point, &self.metric);
let (child1, child2) = if lb1 < lb2 {
(child1, child2)
} else {
(child2, child1)
};
self.nearest_k_neighbors_in_subtree(point, child1, radius, k, neighbors);
self.nearest_k_neighbors_in_subtree(point, child2, radius, k, neighbors);
}
}
fn neighbors_within_radius_in_subtree(
&self,
point: &ArrayView1<A>,
radius: A,
root: usize,
) -> Vec<usize> {
let mut neighbors = Vec::new();
let mut subtrees_to_visit = vec![root];
loop {
let subroot = subtrees_to_visit.pop().expect("should not be empty");
let root_node = &self.nodes[subroot];
let (lb, ub) = root_node.distance_bounds(point, &self.metric);
if lb > radius {
if subtrees_to_visit.is_empty() {
break;
}
continue;
}
if ub <= radius {
neighbors.reserve(root_node.range.end - root_node.range.start);
neighbors.extend(self.idx[root_node.range.clone()].iter().copied());
} else if root_node.is_leaf {
neighbors.extend(self.idx[root_node.range.clone()].iter().filter_map(|&i| {
let dist = self.metric.distance(point, &self.points.row(i));
if dist < radius {
Some(i)
} else {
None
}
}));
} else {
subtrees_to_visit.push(subroot * 2 + 1);
subtrees_to_visit.push(subroot * 2 + 2);
}
if subtrees_to_visit.is_empty() {
break;
}
}
neighbors
}
#[inline]
pub fn node_distance_lower_bound(&self, n1: usize, n2: usize) -> A {
assert!(n1 < self.nodes.len() && n2 < self.nodes.len());
let n1 = &self.nodes[n1];
let n2 = &self.nodes[n2];
let lb = self
.metric
.distance(&n1.centroid.view(), &n2.centroid.view())
- n1.radius
- n2.radius;
if lb < A::zero() {
A::zero()
} else {
lb
}
}
#[inline]
pub fn children_of(&self, n: usize) -> Option<(usize, usize)> {
if self.nodes[n].is_leaf {
None
} else {
let left = 2 * n + 1;
let right = left + 1;
Some((left, right))
}
}
#[inline]
pub fn points_of(&self, n: usize) -> &[usize] {
&self.idx[self.nodes[n].range.clone()]
}
#[inline]
pub fn radius_of(&self, n: usize) -> A {
self.nodes[n].radius
}
#[inline]
pub fn compare_nodes(&self, x: usize, y: usize) -> Option<std::cmp::Ordering> {
self.nodes[x].radius.partial_cmp(&self.nodes[y].radius)
}
#[inline]
pub fn num_nodes(&self) -> usize {
self.nodes.len()
}
#[inline]
pub fn num_points(&self) -> usize {
self.points.nrows()
}
}
impl<'a, A> BallTree<'a, A, Euclidean>
where
A: Float + Zero + AddAssign + DivAssign + FromPrimitive,
{
pub fn euclidean<T>(points: T) -> Result<BallTree<'a, A, Euclidean>, ArrayError>
where
A: Float + Zero + AddAssign + DivAssign + FromPrimitive,
T: Into<CowArray<'a, A, Ix2>>,
{
BallTree::<'a, A, Euclidean>::new(points, distance::Euclidean::default())
}
}
#[derive(Clone, Debug)]
struct Neighbor<A>
where
A: Float,
{
pub idx: usize,
pub distance: OrderedFloat<A>,
}
impl<A> Neighbor<A>
where
A: Float,
{
#[must_use]
pub fn new(idx: usize, distance: A) -> Self {
Self {
idx,
distance: distance.into(),
}
}
}
impl<A> Ord for Neighbor<A>
where
A: Float,
{
#[must_use]
fn cmp(&self, other: &Self) -> cmp::Ordering {
self.distance.cmp(&other.distance)
}
}
impl<A> PartialOrd for Neighbor<A>
where
A: Float,
{
#[must_use]
fn partial_cmp(&self, other: &Self) -> Option<cmp::Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
impl<A> PartialEq for Neighbor<A>
where
A: Float,
{
#[must_use]
fn eq(&self, other: &Self) -> bool {
self.distance == other.distance
}
}
impl<A> Eq for Neighbor<A> where A: Float {}
#[derive(Clone, Debug)]
pub struct Node<A> {
range: Range<usize>,
centroid: Array1<A>,
radius: A,
is_leaf: bool,
}
impl<A> Node<A>
where
A: Float + Zero + AddAssign + DivAssign + FromPrimitive,
{
#[allow(clippy::cast_precision_loss)] fn init(&mut self, points: &CowArray<A, Ix2>, idx: &[usize], metric: &dyn Metric<A>) {
let mut sum = idx
.iter()
.fold(Array1::<A>::zeros(points.ncols()), |mut sum, &i| {
for (s, v) in sum.iter_mut().zip(points.row(i)) {
*s += *v;
}
sum
});
let len = A::from_usize(idx.len()).expect("approximation");
sum.iter_mut().for_each(|v| *v /= len);
self.centroid = sum;
self.radius = idx.iter().fold(A::zero(), |max, &i| {
A::max(metric.distance(&self.centroid.view(), &points.row(i)), max)
});
}
fn distance_bounds(&self, point: &ArrayView1<A>, metric: &dyn Metric<A>) -> (A, A) {
let centroid_dist = metric.distance(point, &self.centroid.view());
let mut lb = centroid_dist - self.radius;
if lb < A::zero() {
lb = A::zero();
}
let ub = centroid_dist + self.radius;
(lb, ub)
}
fn distance_lower_bound(&self, point: &ArrayView1<A>, metric: &dyn Metric<A>) -> A {
let centroid_dist = metric.distance(point, &self.centroid.view());
let lb = centroid_dist - self.radius;
if lb < A::zero() {
A::zero()
} else {
lb
}
}
}
impl<A> Default for Node<A>
where
A: Float + Zero,
{
#[allow(clippy::reversed_empty_ranges)] fn default() -> Self {
Self {
range: (0..0),
centroid: Array1::<A>::zeros(0),
radius: A::zero(),
is_leaf: false,
}
}
}
fn build_subtree<A, M>(
nodes: &mut [Node<A>],
idx: &mut [usize],
points: &CowArray<A, Ix2>,
root: usize,
range: Range<usize>,
metric: &M,
) where
A: Float + AddAssign + DivAssign + FromPrimitive,
M: Metric<A>,
{
let n_nodes = nodes.len();
let mut root_node = nodes.get_mut(root).expect("root node index out of range");
root_node.init(
points,
idx.get(range.clone()).expect("invalid subtree range"),
metric,
);
root_node.range = range.clone();
let left = root * 2 + 1;
if left >= n_nodes {
root_node.is_leaf = true;
return;
}
#[allow(clippy::deref_addrof)]
let col_idx = max_spread_column(points, &idx[range.clone()]);
debug_assert!(col_idx < points.ncols());
let col = points.column(col_idx);
halve_node_indices(&mut idx[range.clone()], &col);
let mid = (range.start + range.end) / 2;
build_subtree(nodes, idx, points, left, range.start..mid, metric);
build_subtree(nodes, idx, points, left + 1, mid..range.end, metric);
}
fn halve_node_indices<A>(idx: &mut [usize], col: &ArrayView1<A>)
where
A: Float,
{
let (mut first, mut last) = (0, idx.len() - 1);
let mid = idx.len() / 2;
loop {
let mut cur = first;
for i in first..last {
if col[idx[i]] < col[idx[last]] {
idx.swap(i, cur);
cur += 1;
}
}
idx.swap(cur, last);
if cur == mid {
break;
}
if cur < mid {
first = cur + 1;
} else {
last = cur - 1;
}
}
}
fn max_spread_column<A, S>(matrix: &ArrayBase<S, Ix2>, idx: &[usize]) -> usize
where
A: Float,
S: Data<Elem = A>,
{
let mut spread_iter = matrix
.columns()
.into_iter()
.map(|col| {
let (min, max) = idx
.iter()
.skip(1)
.fold((col[idx[0]], col[idx[0]]), |(min, max), &i| {
(A::min(min, col[i]), A::max(max, col[i]))
});
max - min
})
.enumerate();
let (_, max_spread) = spread_iter.next().expect("empty matrix");
let (max_spread_col, _) = spread_iter.fold(
(0, max_spread),
|(max_spread_col, max_spread), (i, spread)| {
if spread
.partial_cmp(&max_spread)
.map_or(false, |o| o == cmp::Ordering::Greater)
{
(i, spread)
} else {
(max_spread_col, max_spread)
}
},
);
max_spread_col
}
#[cfg(test)]
mod test {
use super::*;
use crate::distance;
use approx;
use ndarray::{arr1, array, aview1, aview2, Array, Axis};
use ndarray_rand::rand_distr::Uniform;
use ndarray_rand::RandomExt;
#[test]
#[should_panic]
fn ball_tree_empty() {
let data: [[f64; 0]; 0] = [];
let tree = BallTree::euclidean(aview2(&data)).expect("`data` should not be empty");
let point = aview1(&[0., 0.]);
tree.query_nearest(&point);
}
#[test]
#[should_panic]
fn ball_tree_column_base() {
let array = array![[1., 1.], [1., 1.1], [9., 9.]];
let fortran = array.reversed_axes();
let _ = BallTree::euclidean(fortran).expect("`array` should not be empty");
}
#[test]
fn ball_tree_metric() {
let array = array![[1., 1.], [1., 1.1], [9., 9.]];
let tree = BallTree::new(array.clone(), Euclidean::default())
.expect("`array` should not be empty");
let tree1 = BallTree::euclidean(array).expect("`array` should not be empty");
assert_eq!(tree.metric, tree1.metric);
}
#[test]
fn ball_tree_3() {
let array = array![[1., 1.], [1., 1.1], [9., 9.]];
let tree = BallTree::euclidean(array).expect("`array` should not be empty");
let point = aview1(&[0., 0.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 0);
assert!(approx::abs_diff_eq!(neighbor.1, 2_f64.sqrt()));
let (indices, distances) = tree.query(&point, 0);
assert!(indices.is_empty());
assert!(distances.is_empty());
let (indices, distances) = tree.query(&point, 1);
assert_eq!(indices.len(), 1);
assert_eq!(distances.len(), 1);
assert_eq!(indices[0], neighbor.0);
assert!(approx::abs_diff_eq!(distances[0], neighbor.1));
let mut neighbors = tree.query_radius(&point, 2.);
neighbors.sort_unstable();
assert_eq!(neighbors, &[0, 1]);
let neighbors = tree.nearest_neighbor_in_subtree(&aview1(&[20., 20.]), 0, 1.);
assert_eq!(neighbors, None);
let neighbors = tree.query_radius(&aview1(&[20., 20.]), 1.);
assert_eq!(neighbors, &[]);
let point = aview1(&[1.1, 1.2]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 1);
assert!(approx::abs_diff_eq!(
neighbor.1,
(2f64 * 0.1_f64 * 0.1_f64).sqrt()
));
let (indices, distances) = tree.query(&point, 1);
assert_eq!(indices.len(), 1);
assert_eq!(distances.len(), 1);
assert_eq!(indices[0], neighbor.0);
assert!(approx::abs_diff_eq!(distances[0], neighbor.1));
let point = aview1(&[7., 7.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 2);
assert!(approx::abs_diff_eq!(neighbor.1, 8_f64.sqrt()));
let (indices, distances) = tree.query(&point, 1);
assert_eq!(indices.len(), 1);
assert_eq!(distances.len(), 1);
assert_eq!(indices[0], neighbor.0);
assert!(approx::abs_diff_eq!(distances[0], neighbor.1));
}
#[test]
fn ball_tree_6() {
let array = array![
[1.0, 2.0],
[1.1, 2.2],
[0.9, 1.9],
[1.0, 2.1],
[-2.0, 3.0],
[-2.2, 3.1],
];
let tree = BallTree::euclidean(array).expect("`array` should not be empty");
let point = aview1(&[1., 2.]);
let neighbor = tree.query_nearest(&point);
assert_eq!(neighbor.0, 0);
assert!(approx::abs_diff_eq!(neighbor.1, 0_f64.sqrt()));
}
#[test]
fn ball_tree_identical_points() {
let array = array![
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
[1.0, 1.0],
];
let tree = BallTree::new(array, distance::Euclidean::default())
.expect("`array` should not be empty");
let point = aview1(&[1., 2.]);
let neighbor = tree.query_nearest(&point);
assert!(approx::abs_diff_eq!(neighbor.1, 1_f64.sqrt()));
let point = aview1(&[1., 1.]);
let neighbor = tree.query_nearest(&point);
assert!(approx::abs_diff_eq!(neighbor.1, 0_f64.sqrt()));
}
#[test]
fn ball_tree_query() {
const DIMENSION: usize = 3;
let array = Array::random((40, DIMENSION), Uniform::new(0., 1.));
let bt = BallTree::euclidean(array.view()).expect("`array` should not be empty");
let euclidean = distance::Euclidean::default();
for _ in 0..10 {
let query = Array::random(DIMENSION, Uniform::new(0., 1.));
let (_, bt_distances) = bt.query(&query, 5);
let naive_neighbors = naive_k_nearest_neighbors(&array, &query.view(), 5, &euclidean);
for (bt_dist, naive_neighbor) in bt_distances.iter().zip(naive_neighbors.iter()) {
assert!(approx::abs_diff_eq!(
*bt_dist,
naive_neighbor.distance.into_inner()
));
}
}
}
#[test]
fn ball_tree_query_radius() {
let array = array![[0.], [2.], [3.], [4.], [6.], [8.], [10.]];
let bt = BallTree::new(array, distance::Euclidean::default())
.expect("`array` should not be empty");
let neighbors = bt.query_radius(&aview1(&[0.1]), 1.);
assert_eq!(neighbors, &[0]);
let mut neighbors = bt.query_radius(&aview1(&[3.2]), 1.);
neighbors.sort_unstable();
assert_eq!(neighbors, &[2, 3]);
let neighbors = bt.query_radius(&aview1(&[9.]), 0.9);
assert!(neighbors.is_empty());
}
#[test]
fn node_init() {
let array = array![[0., 1.], [0., 9.], [0., 2.]];
let idx: [usize; 3] = [0, 1, 2];
let mut node = Node::default();
let metric = distance::Euclidean::default();
node.init(&array.view().into(), &idx, &metric);
assert_eq!(node.centroid, arr1(&[0., 4.]));
assert_eq!(node.radius, 5.);
let idx: [usize; 2] = [0, 2];
node.init(&array.into(), &idx, &metric);
assert_eq!(node.centroid, arr1(&[0., 1.5]));
}
#[test]
#[should_panic]
fn halve_node_indices_empty() {
let col: [f64; 0] = [];
let mut idx: [usize; 0] = [];
halve_node_indices(&mut idx, &aview1(&col));
}
#[test]
fn halve_node_indices_one() {
let col = [1.];
let mut idx = [0];
halve_node_indices(&mut idx, &aview1(&col));
assert_eq!(idx, [0]);
}
#[test]
fn halve_node_indices_odd() {
let col = [1., 2., 3., 4., 5.];
let mut idx = [0, 1, 4, 3, 2];
halve_node_indices(&mut idx, &aview1(&col));
assert!(idx[0] < idx[2]);
assert!(idx[1] < idx[2]);
assert!(idx[2] <= idx[3]);
assert!(idx[2] <= idx[4]);
}
#[test]
fn halve_node_indices_even() {
let col = [1., 2., 3., 4.];
let mut idx = [3, 2, 1, 0];
halve_node_indices(&mut idx, &aview1(&col));
assert!(idx[0] < idx[2]);
assert!(idx[1] < idx[2]);
assert!(idx[2] <= idx[3]);
}
#[test]
#[should_panic]
fn max_spread_column_empty_idx() {
let data = [[0., 1.], [0., 9.], [0., 2.]];
let idx: [usize; 0] = [];
super::max_spread_column(&aview2(&data), &idx);
}
#[test]
#[should_panic]
fn max_spread_column_idx_out_of_bound() {
let data = [[0., 1.], [0., 9.], [0., 2.]];
let idx: [usize; 3] = [0, 4, 2];
super::max_spread_column(&aview2(&data), &idx);
}
#[test]
#[should_panic]
fn max_spread_column_empty_matrix() {
let data: [[f64; 0]; 0] = [];
let idx: [usize; 3] = [0, 1, 2];
super::max_spread_column(&aview2(&data), &idx);
}
#[test]
fn max_spread_column() {
let data = [[0., 1.], [0., 9.], [0., 2.]];
let idx: [usize; 3] = [0, 1, 2];
assert_eq!(super::max_spread_column(&aview2(&data), &idx), 1);
}
fn naive_k_nearest_neighbors<'a, A, S, M>(
neighbors: &'a ArrayBase<S, Ix2>,
point: &ArrayView1<A>,
k: usize,
metric: &M,
) -> Vec<Neighbor<A>>
where
A: Float,
S: Data<Elem = A>,
M: Metric<A>,
{
let mut knn = neighbors
.axis_iter(Axis(0))
.enumerate()
.map(|(i, n)| Neighbor {
idx: i,
distance: metric.distance(&n, point).into(),
})
.collect::<Vec<Neighbor<A>>>();
knn.sort();
knn[0..k].to_vec()
}
}