use ordered_float::OrderedFloat;
use scirs2_core::ndarray::Array2;
use scirs2_core::numeric::{Float, FromPrimitive};
use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{InterpolateError, InterpolateResult};
use scirs2_core::ndarray::ArrayView1;
#[derive(Debug, Clone)]
struct BallNode<F: Float + ordered_float::FloatCore> {
indices: Vec<usize>,
center: Vec<F>,
radius: F,
left: Option<usize>,
right: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct BallTree<F>
where
F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
{
points: Array2<F>,
nodes: Vec<BallNode<F>>,
root: Option<usize>,
dim: usize,
leafsize: usize,
_phantom: PhantomData<F>,
}
impl<F> BallTree<F>
where
F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
{
pub fn new(points: Array2<F>) -> InterpolateResult<Self> {
Self::with_leafsize(points, 10)
}
pub fn with_leafsize(points: Array2<F>, leafsize: usize) -> InterpolateResult<Self> {
if points.is_empty() {
return Err(InterpolateError::InvalidValue(
"Points array cannot be empty".to_string(),
));
}
let n_points = points.shape()[0];
let dim = points.shape()[1];
if n_points <= leafsize {
let indices: Vec<usize> = (0..n_points).collect();
let center = compute_centroid(&points, &indices);
let radius = compute_radius(&points, &indices, ¢er);
let mut tree = Self {
points,
nodes: Vec::new(),
root: None,
dim,
leafsize,
_phantom: PhantomData,
};
if n_points > 0 {
tree.nodes.push(BallNode {
indices,
center,
radius,
left: None,
right: None,
});
tree.root = Some(0);
}
return Ok(tree);
}
let est_nodes = (2 * n_points / leafsize).max(16);
let mut tree = Self {
points,
nodes: Vec::with_capacity(est_nodes),
root: None,
dim,
leafsize,
_phantom: PhantomData,
};
let indices: Vec<usize> = (0..n_points).collect();
tree.root = Some(tree.build_subtree(&indices));
Ok(tree)
}
fn build_subtree(&mut self, indices: &[usize]) -> usize {
let n_points = indices.len();
let center = compute_centroid(&self.points, indices);
let radius = compute_radius(&self.points, indices, ¢er);
if n_points <= self.leafsize {
let node_idx = self.nodes.len();
self.nodes.push(BallNode {
indices: indices.to_vec(),
center,
radius,
left: None,
right: None,
});
return node_idx;
}
let (split_dim, _) = find_max_spread_dimension(&self.points, indices);
let (seed1, seed2) = find_distant_points(&self.points, indices, split_dim);
let (left_indices, right_indices) = partition_by_seeds(&self.points, indices, seed1, seed2);
let node_idx = self.nodes.len();
self.nodes.push(BallNode {
indices: indices.to_vec(),
center,
radius,
left: None,
right: None,
});
let left_idx = self.build_subtree(&left_indices);
let right_idx = self.build_subtree(&right_indices);
self.nodes[node_idx].left = Some(left_idx);
self.nodes[node_idx].right = Some(right_idx);
node_idx
}
pub fn nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match Ball Tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"Ball Tree is empty".to_string(),
));
}
if self.points.shape()[0] <= self.leafsize {
return self.linear_nearest_neighbor(query);
}
let mut best_dist = <F as scirs2_core::numeric::Float>::infinity();
let mut best_idx = 0;
self.search_nearest(
self.root.expect("Operation failed"),
query,
&mut best_dist,
&mut best_idx,
);
Ok((best_idx, best_dist))
}
pub fn k_nearest_neighbors(&self, query: &[F], k: usize) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match Ball Tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"Ball Tree is empty".to_string(),
));
}
let k = k.min(self.points.shape()[0]);
if k == 0 {
return Ok(Vec::new());
}
if self.points.shape()[0] <= self.leafsize {
return self.linear_k_nearest_neighbors(query, k);
}
use std::collections::BinaryHeap;
let mut heap = BinaryHeap::with_capacity(k + 1);
self.search_k_nearest(self.root.expect("Operation failed"), query, k, &mut heap);
let mut results: Vec<(usize, F)> = heap
.into_iter()
.map(|(dist, idx)| (idx, dist.into_inner()))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
pub fn points_within_radius(
&self,
query: &[F],
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match Ball Tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"Ball Tree is empty".to_string(),
));
}
if radius <= F::zero() {
return Err(InterpolateError::InvalidValue(
"Radius must be positive".to_string(),
));
}
if self.points.shape()[0] <= self.leafsize {
return self.linear_points_within_radius(query, radius);
}
let mut results = Vec::new();
self.search_radius(
self.root.expect("Operation failed"),
query,
radius,
&mut results,
);
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn search_nearest(
&self,
node_idx: usize,
query: &[F],
best_dist: &mut F,
best_idx: &mut usize,
) {
let node = &self.nodes[node_idx];
let center_dist = euclidean_distance(query, &node.center);
if center_dist > node.radius + *best_dist {
return;
}
if node.left.is_none() && node.right.is_none() {
for &idx in &node.indices {
let point = self.points.row(idx);
let dist = euclidean_distance(query, &point.to_vec());
if dist < *best_dist {
*best_dist = dist;
*best_idx = idx;
}
}
return;
}
let left_idx = node.left.expect("Operation failed");
let right_idx = node.right.expect("Operation failed");
let left_node = &self.nodes[left_idx];
let right_node = &self.nodes[right_idx];
let left_dist = euclidean_distance(query, &left_node.center);
let right_dist = euclidean_distance(query, &right_node.center);
if left_dist < right_dist {
self.search_nearest(left_idx, query, best_dist, best_idx);
self.search_nearest(right_idx, query, best_dist, best_idx);
} else {
self.search_nearest(right_idx, query, best_dist, best_idx);
self.search_nearest(left_idx, query, best_dist, best_idx);
}
}
#[allow(clippy::type_complexity)]
fn search_k_nearest(
&self,
node_idx: usize,
query: &[F],
k: usize,
heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
) {
let node = &self.nodes[node_idx];
let center_dist = euclidean_distance(query, &node.center);
let kth_dist = if heap.len() < k {
<F as scirs2_core::numeric::Float>::infinity()
} else {
match heap.peek() {
Some(&(dist_, _)) => dist_.into_inner(),
None => <F as scirs2_core::numeric::Float>::infinity(),
}
};
if center_dist > node.radius + kth_dist {
return;
}
if node.left.is_none() && node.right.is_none() {
for &idx in &node.indices {
let point = self.points.row(idx);
let dist = euclidean_distance(query, &point.to_vec());
heap.push((OrderedFloat(dist), idx));
if heap.len() > k {
heap.pop();
}
}
return;
}
let left_idx = node.left.expect("Operation failed");
let right_idx = node.right.expect("Operation failed");
let left_node = &self.nodes[left_idx];
let right_node = &self.nodes[right_idx];
let left_dist = euclidean_distance(query, &left_node.center);
let right_dist = euclidean_distance(query, &right_node.center);
if left_dist < right_dist {
self.search_k_nearest(left_idx, query, k, heap);
self.search_k_nearest(right_idx, query, k, heap);
} else {
self.search_k_nearest(right_idx, query, k, heap);
self.search_k_nearest(left_idx, query, k, heap);
}
}
fn search_radius(
&self,
node_idx: usize,
query: &[F],
radius: F,
results: &mut Vec<(usize, F)>,
) {
let node = &self.nodes[node_idx];
let center_dist = euclidean_distance(query, &node.center);
if center_dist > node.radius + radius {
return;
}
if node.left.is_none() && node.right.is_none() {
for &_idx in &node.indices {
let point = self.points.row(_idx);
let dist = euclidean_distance(query, &point.to_vec());
if dist <= radius {
results.push((_idx, dist));
}
}
return;
}
if let Some(left_idx) = node.left {
self.search_radius(left_idx, query, radius, results);
}
if let Some(right_idx) = node.right {
self.search_radius(right_idx, query, radius, results);
}
}
fn linear_nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
let n_points = self.points.shape()[0];
let mut min_dist = <F as scirs2_core::numeric::Float>::infinity();
let mut min_idx = 0;
for i in 0..n_points {
let point = self.points.row(i);
let dist = euclidean_distance(query, &point.to_vec());
if dist < min_dist {
min_dist = dist;
min_idx = i;
}
}
Ok((min_idx, min_dist))
}
fn linear_k_nearest_neighbors(
&self,
query: &[F],
k: usize,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let k = k.min(n_points);
let mut distances: Vec<(usize, F)> = (0..n_points)
.map(|i| {
let point = self.points.row(i);
let dist = euclidean_distance(query, &point.to_vec());
(i, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
fn linear_points_within_radius(
&self,
query: &[F],
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let mut results: Vec<(usize, F)> = (0..n_points)
.filter_map(|i| {
let point = self.points.row(i);
let dist = euclidean_distance(query, &point.to_vec());
if dist <= radius {
Some((i, dist))
} else {
None
}
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
pub fn len(&self) -> usize {
self.points.shape()[0]
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn points(&self) -> &Array2<F> {
&self.points
}
pub fn radius_neighbors(&self, query: &[F], radius: F) -> InterpolateResult<Vec<(usize, F)>> {
self.points_within_radius(query, radius)
}
pub fn radius_neighbors_view(
&self,
query: &scirs2_core::ndarray::ArrayView1<F>,
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
let query_slice = query.as_slice().ok_or_else(|| {
InterpolateError::InvalidValue("Query must be contiguous".to_string())
})?;
self.points_within_radius(query_slice, radius)
}
pub fn k_nearest_neighbors_optimized(
&self,
query: &[F],
k: usize,
max_distance: Option<F>,
) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match Ball Tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"Ball Tree is empty".to_string(),
));
}
let k = k.min(self.points.shape()[0]);
if k == 0 {
return Ok(Vec::new());
}
if self.points.shape()[0] <= self.leafsize {
return self.linear_k_nearest_neighbors_optimized(query, k, max_distance);
}
use std::collections::BinaryHeap;
let mut heap = BinaryHeap::with_capacity(k + 1);
let mut search_radius =
max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
self.search_k_nearest_optimized(
self.root.expect("Operation failed"),
query,
k,
&mut heap,
&mut search_radius,
);
let mut results: Vec<(usize, F)> = heap
.into_iter()
.map(|(dist, idx)| (idx, dist.into_inner()))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn linear_k_nearest_neighbors_optimized(
&self,
query: &[F],
k: usize,
max_distance: Option<F>,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let k = k.min(n_points);
let max_dist = max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
let mut distances: Vec<(usize, F)> = Vec::with_capacity(n_points);
for i in 0..n_points {
let point = self.points.row(i);
let dist = euclidean_distance(query, &point.to_vec());
if dist <= max_dist {
distances.push((i, dist));
}
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
#[allow(clippy::type_complexity)]
fn search_k_nearest_optimized(
&self,
node_idx: usize,
query: &[F],
k: usize,
heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
search_radius: &mut F,
) {
let node = &self.nodes[node_idx];
let center_dist = euclidean_distance(query, &node.center);
let min_possible_dist = if center_dist > node.radius {
center_dist - node.radius
} else {
F::zero()
};
let kth_dist = if heap.len() < k {
*search_radius
} else {
match heap.peek() {
Some(&(dist_, _)) => dist_.into_inner(),
None => *search_radius,
}
};
if min_possible_dist > kth_dist {
return;
}
if node.left.is_none() && node.right.is_none() {
for &idx in &node.indices {
let point = self.points.row(idx);
let dist = euclidean_distance(query, &point.to_vec());
if dist <= *search_radius {
heap.push((OrderedFloat(dist), idx));
if heap.len() > k {
heap.pop();
}
if heap.len() == k {
if let Some(&(max_dist_, _)) = heap.peek() {
*search_radius = max_dist_.into_inner();
}
}
}
}
return;
}
let left_idx = node.left.expect("Operation failed");
let right_idx = node.right.expect("Operation failed");
let left_node = &self.nodes[left_idx];
let right_node = &self.nodes[right_idx];
let left_center_dist = euclidean_distance(query, &left_node.center);
let right_center_dist = euclidean_distance(query, &right_node.center);
let left_min_dist = if left_center_dist > left_node.radius {
left_center_dist - left_node.radius
} else {
F::zero()
};
let right_min_dist = if right_center_dist > right_node.radius {
right_center_dist - right_node.radius
} else {
F::zero()
};
let (first_idx, second_idx, second_min_dist) = if left_min_dist < right_min_dist {
(left_idx, right_idx, right_min_dist)
} else {
(right_idx, left_idx, left_min_dist)
};
self.search_k_nearest_optimized(first_idx, query, k, heap, search_radius);
let updated_kth_dist = if heap.len() < k {
*search_radius
} else {
match heap.peek() {
Some(&(dist_, _)) => dist_.into_inner(),
None => *search_radius,
}
};
if second_min_dist <= updated_kth_dist {
self.search_k_nearest_optimized(second_idx, query, k, heap, search_radius);
}
}
pub fn approximate_k_nearest_neighbors(
&self,
query: &[F],
k: usize,
max_checks: usize,
) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match Ball Tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"Ball Tree is empty".to_string(),
));
}
let k = k.min(self.points.shape()[0]);
if k == 0 {
return Ok(Vec::new());
}
if self.points.shape()[0] <= self.leafsize || max_checks >= self.points.shape()[0] {
return self.k_nearest_neighbors(query, k);
}
use std::collections::{BinaryHeap, VecDeque};
let mut heap = BinaryHeap::with_capacity(k + 1);
let mut checks_performed = 0;
let mut nodes_to_visit = VecDeque::new();
nodes_to_visit.push_back((self.root.expect("Operation failed"), F::zero()));
while let Some((node_idx, _min_dist)) = nodes_to_visit.pop_front() {
if checks_performed >= max_checks {
break;
}
let node = &self.nodes[node_idx];
let _center_dist = euclidean_distance(query, &node.center);
if node.left.is_none() && node.right.is_none() {
for &idx in &node.indices {
if checks_performed >= max_checks {
break;
}
let point = self.points.row(idx);
let dist = euclidean_distance(query, &point.to_vec());
checks_performed += 1;
heap.push((OrderedFloat(dist), idx));
if heap.len() > k {
heap.pop();
}
}
} else {
if let Some(left_idx) = node.left {
let left_node = &self.nodes[left_idx];
let left_center_dist = euclidean_distance(query, &left_node.center);
let left_min_dist = if left_center_dist > left_node.radius {
left_center_dist - left_node.radius
} else {
F::zero()
};
nodes_to_visit.push_back((left_idx, left_min_dist));
}
if let Some(right_idx) = node.right {
let right_node = &self.nodes[right_idx];
let right_center_dist = euclidean_distance(query, &right_node.center);
let right_min_dist = if right_center_dist > right_node.radius {
right_center_dist - right_node.radius
} else {
F::zero()
};
nodes_to_visit.push_back((right_idx, right_min_dist));
}
nodes_to_visit
.make_contiguous()
.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
}
}
let mut results: Vec<(usize, F)> = heap
.into_iter()
.map(|(dist, idx)| (idx, dist.into_inner()))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
}
#[allow(dead_code)]
fn compute_centroid<F: Float + FromPrimitive>(points: &Array2<F>, indices: &[usize]) -> Vec<F> {
let n_points = indices.len();
let n_dims = points.shape()[1];
if n_points == 0 {
return vec![F::zero(); n_dims];
}
let mut center = vec![F::zero(); n_dims];
for &idx in indices {
let point = points.row(idx);
for d in 0..n_dims {
center[d] = center[d] + point[d];
}
}
let n = F::from_usize(n_points).expect("Operation failed");
for val in center.iter_mut() {
*val = *val / n;
}
center
}
#[allow(dead_code)]
fn compute_radius<F: Float>(points: &Array2<F>, indices: &[usize], center: &[F]) -> F {
let n_points = indices.len();
if n_points == 0 {
return F::zero();
}
let mut max_dist = F::zero();
for &idx in indices {
let point = points.row(idx);
let dist = euclidean_distance(&point.to_vec(), center);
if dist > max_dist {
max_dist = dist;
}
}
max_dist
}
#[allow(dead_code)]
fn find_max_spread_dimension<F: Float>(points: &Array2<F>, indices: &[usize]) -> (usize, F) {
let n_points = indices.len();
let n_dims = points.shape()[1];
if n_points <= 1 {
return (0, F::zero());
}
let mut max_dim = 0;
let mut max_spread = F::neg_infinity();
for d in 0..n_dims {
let mut min_val = F::infinity();
let mut max_val = F::neg_infinity();
for &idx in indices {
let val = points[[idx, d]];
if val < min_val {
min_val = val;
}
if val > max_val {
max_val = val;
}
}
let spread = max_val - min_val;
if spread > max_spread {
max_spread = spread;
max_dim = d;
}
}
(max_dim, max_spread)
}
#[allow(dead_code)]
fn find_distant_points<F: Float>(
points: &Array2<F>,
indices: &[usize],
dim: usize,
) -> (usize, usize) {
let n_points = indices.len();
if n_points <= 1 {
return (indices[0], indices[0]);
}
let mut min_idx = indices[0];
let mut max_idx = indices[0];
let mut min_val = points[[min_idx, dim]];
let mut max_val = min_val;
for &idx in indices.iter().skip(1) {
let val = points[[idx, dim]];
if val < min_val {
min_val = val;
min_idx = idx;
}
if val > max_val {
max_val = val;
max_idx = idx;
}
}
(min_idx, max_idx)
}
#[allow(dead_code)]
fn partition_by_seeds<F: Float>(
points: &Array2<F>,
indices: &[usize],
seed1: usize,
seed2: usize,
) -> (Vec<usize>, Vec<usize>) {
let seed1_point = points.row(seed1).to_vec();
let seed2_point = points.row(seed2).to_vec();
let mut left_indices = Vec::new();
let mut right_indices = Vec::new();
left_indices.push(seed1);
right_indices.push(seed2);
for &idx in indices {
if idx == seed1 || idx == seed2 {
continue; }
let point = points.row(idx).to_vec();
let dist1 = euclidean_distance(&point, &seed1_point);
let dist2 = euclidean_distance(&point, &seed2_point);
if dist1 <= dist2 {
left_indices.push(idx);
} else {
right_indices.push(idx);
}
}
if left_indices.is_empty() && right_indices.len() >= 2 {
left_indices.push(right_indices.pop().expect("Operation failed"));
} else if right_indices.is_empty() && left_indices.len() >= 2 {
right_indices.push(left_indices.pop().expect("Operation failed"));
}
(left_indices, right_indices)
}
#[allow(dead_code)]
fn euclidean_distance<F: Float>(a: &[F], b: &[F]) -> F {
debug_assert_eq!(a.len(), b.len());
let mut sum_sq = F::zero();
for i in 0..a.len() {
let diff = a[i] - b[i];
sum_sq = sum_sq + diff * diff;
}
sum_sq.sqrt()
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr2;
#[test]
fn test_balltree_creation() {
let points = arr2(&[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.5, 0.5],
]);
let balltree = BallTree::new(points).expect("Operation failed");
assert_eq!(balltree.len(), 5);
assert_eq!(balltree.dim(), 3);
assert!(!balltree.is_empty());
}
#[test]
fn test_nearest_neighbor() {
let points = arr2(&[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.5, 0.5],
]);
let balltree = BallTree::new(points).expect("Operation failed");
for i in 0..5 {
let point = balltree.points().row(i).to_vec();
let (idx, dist) = balltree.nearest_neighbor(&point).expect("Operation failed");
assert_eq!(idx, i);
assert!(dist < 1e-10);
}
let query = vec![0.6, 0.6, 0.6];
let (idx, _) = balltree.nearest_neighbor(&query).expect("Operation failed");
assert_eq!(idx, 4);
let query = vec![0.9, 0.1, 0.1];
let (idx, _) = balltree.nearest_neighbor(&query).expect("Operation failed");
assert_eq!(idx, 1); }
#[test]
fn test_k_nearest_neighbors() {
let points = arr2(&[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.5, 0.5],
]);
let balltree = BallTree::new(points).expect("Operation failed");
let query = vec![0.6, 0.6, 0.6];
let neighbors = balltree
.k_nearest_neighbors(&query, 3)
.expect("Operation failed");
assert_eq!(neighbors.len(), 3);
assert_eq!(neighbors[0].0, 4); }
#[test]
fn test_points_within_radius() {
let points = arr2(&[
[0.0, 0.0, 0.0],
[1.0, 0.0, 0.0],
[0.0, 1.0, 0.0],
[0.0, 0.0, 1.0],
[0.5, 0.5, 0.5],
]);
let balltree = BallTree::new(points).expect("Operation failed");
let query = vec![0.0, 0.0, 0.0];
let radius = 0.7;
let results = balltree
.points_within_radius(&query, radius)
.expect("Operation failed");
assert!(!results.is_empty());
assert_eq!(results[0].0, 0); }
}