use std::cmp::Ordering;
use std::collections::BinaryHeap;
use crate::distance::euclidean_sq;
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
#[non_exhaustive]
pub struct KdTree {
nodes: Vec<KdNode>,
n_dims: usize,
}
#[derive(Clone, Debug)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
enum KdNode {
Split {
dim: usize,
value: f64,
left: usize,
right: usize,
},
Leaf {
point_idx: usize,
},
}
#[derive(Clone, Copy)]
struct HeapEntry {
dist_sq: f64,
idx: usize,
}
impl PartialEq for HeapEntry {
fn eq(&self, other: &Self) -> bool {
self.dist_sq == other.dist_sq
}
}
impl Eq for HeapEntry {}
impl PartialOrd for HeapEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for HeapEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.dist_sq
.partial_cmp(&other.dist_sq)
.unwrap_or(Ordering::Equal)
}
}
impl KdTree {
pub fn build(points: &[Vec<f64>]) -> Self {
if points.is_empty() {
return Self {
nodes: Vec::new(),
n_dims: 0,
};
}
let n_dims = points[0].len();
let mut nodes = Vec::with_capacity(2 * points.len());
let indices: Vec<usize> = (0..points.len()).collect();
Self::build_recursive(points, &indices, 0, n_dims, &mut nodes);
Self { nodes, n_dims }
}
fn build_recursive(
points: &[Vec<f64>],
indices: &[usize],
depth: usize,
n_dims: usize,
nodes: &mut Vec<KdNode>,
) -> usize {
debug_assert!(!indices.is_empty());
if indices.len() == 1 {
let node_idx = nodes.len();
nodes.push(KdNode::Leaf {
point_idx: indices[0],
});
return node_idx;
}
let dim = Self::best_split_dim(points, indices, n_dims, depth);
let mut sorted = indices.to_vec();
sorted.sort_by(|&a, &b| {
points[a][dim]
.partial_cmp(&points[b][dim])
.unwrap_or(Ordering::Equal)
});
let median = sorted.len() / 2;
let split_value = points[sorted[median]][dim];
let this_idx = nodes.len();
nodes.push(KdNode::Leaf { point_idx: 0 });
let left_indices = &sorted[..median];
let right_indices = &sorted[median..];
let left_idx = if left_indices.is_empty() {
let leaf_idx = nodes.len();
nodes.push(KdNode::Leaf {
point_idx: right_indices[0],
});
leaf_idx
} else {
Self::build_recursive(points, left_indices, depth + 1, n_dims, nodes)
};
let right_idx = if right_indices.is_empty() {
let leaf_idx = nodes.len();
nodes.push(KdNode::Leaf {
point_idx: left_indices[left_indices.len() - 1],
});
leaf_idx
} else if left_indices.is_empty() && right_indices.len() > 1 {
Self::build_recursive(points, &right_indices[1..], depth + 1, n_dims, nodes)
} else {
Self::build_recursive(points, right_indices, depth + 1, n_dims, nodes)
};
nodes[this_idx] = KdNode::Split {
dim,
value: split_value,
left: left_idx,
right: right_idx,
};
this_idx
}
#[allow(clippy::needless_range_loop)] fn best_split_dim(
points: &[Vec<f64>],
indices: &[usize],
n_dims: usize,
depth: usize,
) -> usize {
let mut best_dim = depth % n_dims;
let mut best_spread = -1.0_f64;
for d in 0..n_dims {
let (min_v, max_v) =
indices
.iter()
.fold((f64::INFINITY, f64::NEG_INFINITY), |(lo, hi), &idx| {
let v = points[idx][d];
(lo.min(v), hi.max(v))
});
let spread = max_v - min_v;
if spread > best_spread {
best_spread = spread;
best_dim = d;
}
}
best_dim
}
pub fn query_k_nearest(
&self,
query: &[f64],
k: usize,
points: &[Vec<f64>],
) -> Vec<(f64, usize)> {
assert!(k > 0, "k must be at least 1");
if self.nodes.is_empty() {
return Vec::new();
}
let mut heap: BinaryHeap<HeapEntry> = BinaryHeap::with_capacity(k + 1);
self.search(0, query, k, points, &mut heap);
let mut result: Vec<(f64, usize)> = heap.into_iter().map(|e| (e.dist_sq, e.idx)).collect();
result.sort_by(|a, b| {
a.0.partial_cmp(&b.0)
.unwrap_or(Ordering::Equal)
.then(a.1.cmp(&b.1))
});
result
}
fn search(
&self,
node_idx: usize,
query: &[f64],
k: usize,
points: &[Vec<f64>],
heap: &mut BinaryHeap<HeapEntry>,
) {
match &self.nodes[node_idx] {
KdNode::Leaf { point_idx } => {
let dist_sq = euclidean_sq(query, &points[*point_idx]);
if heap.len() < k {
heap.push(HeapEntry {
dist_sq,
idx: *point_idx,
});
} else if let Some(worst) = heap.peek() {
if dist_sq < worst.dist_sq {
heap.pop();
heap.push(HeapEntry {
dist_sq,
idx: *point_idx,
});
}
}
}
KdNode::Split {
dim,
value,
left,
right,
} => {
let diff = query[*dim] - value;
let (near, far) = if diff <= 0.0 {
(*left, *right)
} else {
(*right, *left)
};
self.search(near, query, k, points, heap);
let plane_dist_sq = diff * diff;
let should_search_far = heap.len() < k
|| heap
.peek()
.is_none_or(|worst| plane_dist_sq < worst.dist_sq);
if should_search_far {
self.search(far, query, k, points, heap);
}
}
}
}
pub fn query_radius(&self, query: &[f64], radius_sq: f64, points: &[Vec<f64>]) -> Vec<usize> {
let mut result = Vec::new();
if !self.nodes.is_empty() {
self.search_radius(0, query, radius_sq, points, &mut result);
}
result
}
fn search_radius(
&self,
node_idx: usize,
query: &[f64],
radius_sq: f64,
points: &[Vec<f64>],
result: &mut Vec<usize>,
) {
match &self.nodes[node_idx] {
KdNode::Leaf { point_idx } => {
if euclidean_sq(query, &points[*point_idx]) <= radius_sq {
result.push(*point_idx);
}
}
KdNode::Split {
dim,
value,
left,
right,
} => {
let diff = query[*dim] - value;
let (near, far) = if diff <= 0.0 {
(*left, *right)
} else {
(*right, *left)
};
self.search_radius(near, query, radius_sq, points, result);
let plane_dist_sq = diff * diff;
if plane_dist_sq <= radius_sq {
self.search_radius(far, query, radius_sq, points, result);
}
}
}
}
pub fn is_empty(&self) -> bool {
self.nodes.is_empty()
}
pub fn n_dims(&self) -> usize {
self.n_dims
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kdtree_empty() {
let tree = KdTree::build(&[]);
assert!(tree.is_empty());
let result = tree.query_k_nearest(&[0.0, 0.0], 1, &[]);
assert!(result.is_empty());
}
#[test]
fn test_kdtree_single_point() {
let points = vec![vec![1.0, 2.0]];
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&[0.0, 0.0], 1, &points);
assert_eq!(result.len(), 1);
assert_eq!(result[0].1, 0); assert!((result[0].0 - 5.0).abs() < 1e-9); }
#[test]
fn test_kdtree_two_clusters() {
let points = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![10.0, 10.0],
vec![11.0, 10.0],
vec![10.0, 11.0],
];
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&[0.5, 0.5], 3, &points);
assert_eq!(result.len(), 3);
for (_, idx) in &result {
assert!(*idx < 3, "Expected cluster A indices, got {idx}");
}
let result = tree.query_k_nearest(&[10.5, 10.5], 3, &points);
assert_eq!(result.len(), 3);
for (_, idx) in &result {
assert!(*idx >= 3, "Expected cluster B indices, got {idx}");
}
}
#[test]
fn test_kdtree_k_larger_than_n() {
let points = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&[0.0, 0.0], 5, &points);
assert_eq!(result.len(), 2, "Should return all points when k > n");
}
#[test]
fn test_kdtree_sorted_nearest_first() {
let points = vec![vec![0.0, 0.0], vec![5.0, 0.0], vec![2.0, 0.0]];
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&[1.0, 0.0], 3, &points);
assert!(result[0].0 <= result[1].0);
assert!(result[1].0 <= result[2].0);
}
#[test]
fn test_kdtree_matches_brute_force() {
let mut points = Vec::new();
let mut seed = 12345u64;
for _ in 0..100 {
let mut p = Vec::new();
for _ in 0..5 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
p.push((seed >> 33) as f64 / 1e9);
}
points.push(p);
}
let tree = KdTree::build(&points);
for q_idx in [0, 25, 50, 75, 99] {
let query = &points[q_idx];
let k = 7;
let mut dists: Vec<(f64, usize)> = points
.iter()
.enumerate()
.map(|(i, p)| (euclidean_sq(query, p), i))
.collect();
dists.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
let brute: Vec<usize> = dists.iter().take(k).map(|(_, i)| *i).collect();
let kd_result = tree.query_k_nearest(query, k, &points);
let kd: Vec<usize> = kd_result.iter().map(|(_, i)| *i).collect();
assert_eq!(
brute, kd,
"KD-tree and brute-force disagree for query point {q_idx}"
);
}
}
#[test]
fn test_kdtree_duplicate_points() {
let points = vec![
vec![1.0, 1.0],
vec![1.0, 1.0],
vec![1.0, 1.0],
vec![5.0, 5.0],
];
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&[1.0, 1.0], 3, &points);
assert_eq!(result.len(), 3);
for (dist, _) in &result {
assert!(*dist < 1e-9);
}
}
#[test]
fn test_kdtree_high_dim() {
let mut points = Vec::new();
let mut seed = 42u64;
for _ in 0..50 {
let mut p = Vec::new();
for _ in 0..20 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
p.push((seed >> 33) as f64 / 1e9);
}
points.push(p);
}
let tree = KdTree::build(&points);
let result = tree.query_k_nearest(&points[0], 5, &points);
assert_eq!(result.len(), 5);
assert!((result[0].0).abs() < 1e-9);
assert_eq!(result[0].1, 0);
}
#[test]
fn test_kdtree_query_radius() {
let points = vec![
vec![0.0, 0.0],
vec![1.0, 0.0],
vec![0.0, 1.0],
vec![1.0, 1.0],
vec![10.0, 10.0],
vec![11.0, 10.0],
];
let tree = KdTree::build(&points);
let mut result = tree.query_radius(&[0.5, 0.5], 2.0, &points);
result.sort_unstable();
assert_eq!(result, vec![0, 1, 2, 3], "Should find the 4 nearby points");
let mut rng_points = Vec::new();
let mut seed = 99u64;
for _ in 0..200 {
let mut p = Vec::new();
for _ in 0..3 {
seed = seed.wrapping_mul(6_364_136_223_846_793_005).wrapping_add(1);
p.push((seed >> 33) as f64 / 1e9);
}
rng_points.push(p);
}
let tree2 = KdTree::build(&rng_points);
let query = &[2.0, 2.0, 2.0];
let radius_sq = 1.0;
let mut kd_result = tree2.query_radius(query, radius_sq, &rng_points);
kd_result.sort_unstable();
let brute: Vec<usize> = rng_points
.iter()
.enumerate()
.filter(|(_, p)| euclidean_sq(query, p) <= radius_sq)
.map(|(i, _)| i)
.collect();
assert_eq!(
kd_result, brute,
"KD-tree radius and brute-force should agree"
);
}
}