use ordered_float::OrderedFloat;
use scirs2_core::ndarray::{Array2, ArrayBase, ArrayView1, Data, Ix2};
use scirs2_core::numeric::{Float, FromPrimitive};
use std::cmp::Ordering;
use std::fmt::Debug;
use std::marker::PhantomData;
use crate::error::{InterpolateError, InterpolateResult};
#[derive(Debug, Clone)]
struct KdNode<F: Float + ordered_float::FloatCore> {
idx: usize,
dim: usize,
value: F,
left: Option<usize>,
right: Option<usize>,
}
#[derive(Debug, Clone)]
pub struct KdTree<F>
where
F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
{
points: Array2<F>,
nodes: Vec<KdNode<F>>,
root: Option<usize>,
dim: usize,
leaf_size: usize,
_phantom: PhantomData<F>,
}
impl<F> KdTree<F>
where
F: Float + FromPrimitive + Debug + std::cmp::PartialOrd + ordered_float::FloatCore,
{
pub fn new<S>(points: ArrayBase<S, Ix2>) -> InterpolateResult<Self>
where
S: Data<Elem = F>,
{
Self::with_leaf_size(points, 10)
}
pub fn with_leaf_size<S>(
_points: ArrayBase<S, Ix2>,
leaf_size: usize,
) -> InterpolateResult<Self>
where
S: Data<Elem = F>,
{
let points = _points.to_owned();
if points.is_empty() {
return Err(InterpolateError::InvalidValue(
"Points array cannot be empty".to_string(),
));
}
let n_points = points.shape()[0];
let dim = points.shape()[1];
if n_points <= leaf_size {
let mut tree = Self {
points,
nodes: Vec::new(),
root: None,
dim,
leaf_size,
_phantom: PhantomData,
};
if n_points > 0 {
tree.nodes.push(KdNode {
idx: 0,
dim: 0,
value: F::zero(), left: None,
right: None,
});
tree.root = Some(0);
}
return Ok(tree);
}
let est_nodes = (2 * n_points / leaf_size).max(16);
let mut tree = Self {
points,
nodes: Vec::with_capacity(est_nodes),
root: None,
dim,
leaf_size,
_phantom: PhantomData,
};
let mut indices: Vec<usize> = (0..n_points).collect();
tree.root = tree.build_subtree(&mut indices, 0);
Ok(tree)
}
fn build_subtree(&mut self, indices: &mut [usize], depth: usize) -> Option<usize> {
let n_points = indices.len();
if n_points == 0 {
return None;
}
if n_points <= self.leaf_size {
let node_idx = self.nodes.len();
self.nodes.push(KdNode {
idx: indices[0], dim: 0, value: F::zero(), left: None,
right: None,
});
return Some(node_idx);
}
let dim = depth % self.dim;
self.find_median(indices, dim);
let median_idx = n_points / 2;
let split_point_idx = indices[median_idx];
let split_value = self.points[[split_point_idx, dim]];
let node_idx = self.nodes.len();
self.nodes.push(KdNode {
idx: split_point_idx,
dim,
value: split_value,
left: None,
right: None,
});
let (left_indices, right_indices) = indices.split_at_mut(median_idx);
let right_indices = &mut right_indices[1..];
let left_child = self.build_subtree(left_indices, depth + 1);
let right_child = self.build_subtree(right_indices, depth + 1);
self.nodes[node_idx].left = left_child;
self.nodes[node_idx].right = right_child;
Some(node_idx)
}
fn find_median(&self, indices: &mut [usize], dim: usize) {
let n = indices.len();
if n <= 1 {
return;
}
let median_idx = n / 2;
quickselect_by_key(indices, median_idx, |&idx| self.points[[idx, dim]]);
}
pub fn nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match KD-tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"KD-tree is empty".to_string(),
));
}
if self.points.shape()[0] <= self.leaf_size {
return self.linear_nearest_neighbor(query);
}
let mut best_dist = <F as scirs2_core::numeric::Float>::infinity();
let mut best_idx = 0;
self.search_nearest(
self.root.expect("Operation failed"),
query,
&mut best_dist,
&mut best_idx,
);
Ok((best_idx, best_dist))
}
pub fn k_nearest_neighbors(&self, query: &[F], k: usize) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match KD-tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"KD-tree is empty".to_string(),
));
}
let k = k.min(self.points.shape()[0]);
if k == 0 {
return Ok(Vec::new());
}
if self.points.shape()[0] <= self.leaf_size {
return self.linear_k_nearest_neighbors(query, k);
}
use ordered_float::OrderedFloat;
use std::collections::BinaryHeap;
let mut heap: BinaryHeap<(OrderedFloat<F>, usize)> = BinaryHeap::with_capacity(k + 1);
self.search_k_nearest(self.root.expect("Operation failed"), query, k, &mut heap);
let mut results: Vec<(usize, F)> = heap
.into_iter()
.map(|(dist, idx)| (idx, dist.into_inner()))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
pub fn points_within_radius(
&self,
query: &[F],
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match KD-tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"KD-tree is empty".to_string(),
));
}
if radius <= F::zero() {
return Err(InterpolateError::InvalidValue(
"Radius must be positive".to_string(),
));
}
if self.points.shape()[0] <= self.leaf_size {
return self.linear_points_within_radius(query, radius);
}
let mut results = Vec::new();
self.search_radius(
self.root.expect("Operation failed"),
query,
radius,
&mut results,
);
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn search_nearest(
&self,
node_idx: usize,
query: &[F],
best_dist: &mut F,
best_idx: &mut usize,
) {
let node = &self.nodes[node_idx];
let point_idx = node.idx;
let point = self.points.row(point_idx);
let _dist = self.distance(&point.to_vec(), query);
if _dist < *best_dist {
*best_dist = _dist;
*best_idx = point_idx;
}
if node.left.is_none() && node.right.is_none() {
return;
}
let dim = node.dim;
let query_val = query[dim];
let node_val = node.value;
let (first, second) = if query_val < node_val {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(first_idx) = first {
self.search_nearest(first_idx, query, best_dist, best_idx);
}
let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
if plane_dist < *best_dist {
if let Some(second_idx) = second {
self.search_nearest(second_idx, query, best_dist, best_idx);
}
}
}
#[allow(clippy::type_complexity)]
fn search_k_nearest(
&self,
node_idx: usize,
query: &[F],
k: usize,
heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
) {
let node = &self.nodes[node_idx];
let point_idx = node.idx;
let point = self.points.row(point_idx);
let dist = self.distance(&point.to_vec(), query);
heap.push((OrderedFloat(dist), point_idx));
if heap.len() > k {
heap.pop();
}
if node.left.is_none() && node.right.is_none() {
return;
}
let farthest_dist = match heap.peek() {
Some(&(dist_, _)) => dist_.into_inner(),
None => <F as scirs2_core::numeric::Float>::infinity(),
};
let dim = node.dim;
let query_val = query[dim];
let node_val = node.value;
let (first, second) = if query_val < node_val {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(first_idx) = first {
self.search_k_nearest(first_idx, query, k, heap);
}
let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
if plane_dist < farthest_dist || heap.len() < k {
if let Some(second_idx) = second {
self.search_k_nearest(second_idx, query, k, heap);
}
}
}
fn search_radius(
&self,
node_idx: usize,
query: &[F],
radius: F,
results: &mut Vec<(usize, F)>,
) {
let node = &self.nodes[node_idx];
let point_idx = node.idx;
let point = self.points.row(point_idx);
let dist = self.distance(&point.to_vec(), query);
if dist <= radius {
results.push((point_idx, dist));
}
if node.left.is_none() && node.right.is_none() {
return;
}
let dim = node.dim;
let query_val = query[dim];
let node_val = node.value;
let (first, second) = if query_val < node_val {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(first_idx) = first {
self.search_radius(first_idx, query, radius, results);
}
let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
if plane_dist <= radius {
if let Some(second_idx) = second {
self.search_radius(second_idx, query, radius, results);
}
}
}
fn linear_nearest_neighbor(&self, query: &[F]) -> InterpolateResult<(usize, F)> {
let n_points = self.points.shape()[0];
let mut min_dist = <F as scirs2_core::numeric::Float>::infinity();
let mut min_idx = 0;
for i in 0..n_points {
let point = self.points.row(i);
let dist = self.distance(&point.to_vec(), query);
if dist < min_dist {
min_dist = dist;
min_idx = i;
}
}
Ok((min_idx, min_dist))
}
fn linear_k_nearest_neighbors(
&self,
query: &[F],
k: usize,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let k = k.min(n_points);
let mut distances: Vec<(usize, F)> = (0..n_points)
.map(|i| {
let point = self.points.row(i);
let dist = self.distance(&point.to_vec(), query);
(i, dist)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
fn linear_points_within_radius(
&self,
query: &[F],
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let mut results: Vec<(usize, F)> = (0..n_points)
.filter_map(|i| {
let point = self.points.row(i);
let dist = self.distance(&point.to_vec(), query);
if dist <= radius {
Some((i, dist))
} else {
None
}
})
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn distance(&self, a: &[F], b: &[F]) -> F {
let mut sum_sq = F::zero();
for i in 0..self.dim {
let diff = a[i] - b[i];
sum_sq = sum_sq + diff * diff;
}
sum_sq.sqrt()
}
pub fn len(&self) -> usize {
self.points.shape()[0]
}
pub fn is_empty(&self) -> bool {
self.len() == 0
}
pub fn dim(&self) -> usize {
self.dim
}
pub fn points(&self) -> &Array2<F> {
&self.points
}
pub fn radius_neighbors(&self, query: &[F], radius: F) -> InterpolateResult<Vec<(usize, F)>> {
self.points_within_radius(query, radius)
}
pub fn radius_neighbors_view(
&self,
query: &ArrayView1<F>,
radius: F,
) -> InterpolateResult<Vec<(usize, F)>> {
let query_slice = query.as_slice().ok_or_else(|| {
InterpolateError::InvalidValue("Query must be contiguous".to_string())
})?;
self.points_within_radius(query_slice, radius)
}
pub fn k_nearest_neighbors_optimized(
&self,
query: &[F],
k: usize,
max_distance: Option<F>,
) -> InterpolateResult<Vec<(usize, F)>> {
if query.len() != self.dim {
return Err(InterpolateError::DimensionMismatch(format!(
"Query dimension {} doesn't match KD-tree dimension {}",
query.len(),
self.dim
)));
}
if self.root.is_none() {
return Err(InterpolateError::InvalidState(
"KD-tree is empty".to_string(),
));
}
let k = k.min(self.points.shape()[0]);
if k == 0 {
return Ok(Vec::new());
}
if self.points.shape()[0] <= self.leaf_size {
return self.linear_k_nearest_neighbors_optimized(query, k, max_distance);
}
use ordered_float::OrderedFloat;
use std::collections::BinaryHeap;
let mut heap: BinaryHeap<(OrderedFloat<F>, usize)> = BinaryHeap::with_capacity(k + 1);
let mut search_radius =
max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
self.search_k_nearest_optimized(
self.root.expect("Operation failed"),
query,
k,
&mut heap,
&mut search_radius,
);
let mut results: Vec<(usize, F)> = heap
.into_iter()
.map(|(dist, idx)| (idx, dist.into_inner()))
.collect();
results.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn linear_k_nearest_neighbors_optimized(
&self,
query: &[F],
k: usize,
max_distance: Option<F>,
) -> InterpolateResult<Vec<(usize, F)>> {
let n_points = self.points.shape()[0];
let k = k.min(n_points);
let max_dist = max_distance.unwrap_or(<F as scirs2_core::numeric::Float>::infinity());
let mut distances: Vec<(usize, F)> = Vec::with_capacity(n_points);
for i in 0..n_points {
let point = self.points.row(i);
let dist = self.distance(&point.to_vec(), query);
if dist <= max_dist {
distances.push((i, dist));
}
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
Ok(distances)
}
#[allow(clippy::type_complexity)]
fn search_k_nearest_optimized(
&self,
node_idx: usize,
query: &[F],
k: usize,
heap: &mut std::collections::BinaryHeap<(OrderedFloat<F>, usize)>,
search_radius: &mut F,
) {
let node = &self.nodes[node_idx];
let point_idx = node.idx;
let point = self.points.row(point_idx);
let dist = self.distance(&point.to_vec(), query);
if dist <= *search_radius {
heap.push((OrderedFloat(dist), point_idx));
if heap.len() > k {
heap.pop();
}
if heap.len() == k {
if let Some(&(max_dist_, _)) = heap.peek() {
*search_radius = max_dist_.into_inner();
}
}
}
if node.left.is_none() && node.right.is_none() {
return;
}
let kth_dist = if heap.len() < k {
*search_radius
} else {
match heap.peek() {
Some(&(dist_, _)) => dist_.into_inner(),
None => *search_radius,
}
};
let dim = node.dim;
let query_val = query[dim];
let node_val = node.value;
let (first, second) = if query_val < node_val {
(node.left, node.right)
} else {
(node.right, node.left)
};
if let Some(first_idx) = first {
self.search_k_nearest_optimized(first_idx, query, k, heap, search_radius);
}
let plane_dist = scirs2_core::numeric::Float::abs(query_val - node_val);
if plane_dist <= kth_dist {
if let Some(second_idx) = second {
self.search_k_nearest_optimized(second_idx, query, k, heap, search_radius);
}
}
}
pub fn query_nearest(
&self,
query: &scirs2_core::ndarray::ArrayView1<F>,
k: usize,
) -> InterpolateResult<scirs2_core::ndarray::Array1<usize>> {
use scirs2_core::ndarray::Array1;
let query_slice = query.as_slice().ok_or_else(|| {
InterpolateError::InvalidValue("Query must be contiguous".to_string())
})?;
let neighbors = self.k_nearest_neighbors(query_slice, k)?;
let indices = neighbors.iter().map(|(idx_, _)| *idx_).collect::<Vec<_>>();
Ok(Array1::from(indices))
}
}
#[allow(dead_code)]
fn quickselect_by_key<T, F, K>(items: &mut [T], k: usize, keyfn: F)
where
F: Fn(&T) -> K,
K: PartialOrd,
{
if items.len() <= 1 {
return;
}
let len = items.len();
let pivot_idx = len / 2;
items.swap(pivot_idx, len - 1);
let mut store_idx = 0;
for i in 0..len - 1 {
if keyfn(&items[i]) <= keyfn(&items[len - 1]) {
items.swap(i, store_idx);
store_idx += 1;
}
}
items.swap(store_idx, len - 1);
match k.cmp(&store_idx) {
Ordering::Less => quickselect_by_key(&mut items[0..store_idx], k, keyfn),
Ordering::Greater => {
quickselect_by_key(&mut items[store_idx + 1..], k - store_idx - 1, keyfn)
}
Ordering::Equal => (), }
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::arr2;
#[test]
fn test_kdtree_creation() {
let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
let kdtree = KdTree::new(points).expect("Operation failed");
assert_eq!(kdtree.len(), 5);
assert_eq!(kdtree.dim(), 2);
assert!(!kdtree.is_empty());
}
#[test]
fn test_nearest_neighbor() {
let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
let kdtree = KdTree::new(points).expect("Operation failed");
for i in 0..5 {
let point = kdtree.points().row(i).to_vec();
let (idx, dist) = kdtree.nearest_neighbor(&point).expect("Operation failed");
assert_eq!(idx, i);
assert!(dist < 1e-10);
}
let query = vec![0.6, 0.6];
let (idx, _) = kdtree.nearest_neighbor(&query).expect("Operation failed");
assert_eq!(idx, 4);
let query = vec![0.9, 0.1];
let (idx, _) = kdtree.nearest_neighbor(&query).expect("Operation failed");
assert_eq!(idx, 1); }
#[test]
fn test_k_nearest_neighbors() {
let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
let kdtree = KdTree::new(points).expect("Operation failed");
let query = vec![0.6, 0.6];
let neighbors = kdtree
.k_nearest_neighbors(&query, 3)
.expect("Operation failed");
assert_eq!(neighbors.len(), 3);
assert_eq!(neighbors[0].0, 4); }
#[test]
fn test_points_within_radius() {
let points = arr2(&[[0.0, 0.0], [1.0, 0.0], [0.0, 1.0], [1.0, 1.0], [0.5, 0.5]]);
let kdtree = KdTree::new(points).expect("Operation failed");
let query = vec![0.0, 0.0];
let radius = 0.7;
let results = kdtree
.points_within_radius(&query, radius)
.expect("Operation failed");
assert!(!results.is_empty());
assert_eq!(results[0].0, 0);
assert!(results[0].1 < 1e-10);
println!("Points within radius:");
for (idx, dist) in &results {
println!("Point index: {idx}, distance: {dist}");
}
}
}