use std::f32;
const MAX_LEAF_SIZE: usize = 16;
#[derive(Debug)]
enum KDNode {
Internal {
dimension: usize,
split_value: f32,
left: Box<KDTree>,
right: Box<KDTree>,
},
Leaf {
points: Vec<[f32; 2]>,
indices: Vec<usize>,
},
}
#[derive(Debug)]
pub struct KDTree {
root: KDNode,
size: usize,
}
impl KDTree {
pub fn build(points: &[[f32; 2]]) -> Self {
let size = points.len();
let indices: Vec<usize> = (0..size).collect();
let points_vec = points.to_vec();
KDTree {
root: Self::build_recursive(&points_vec, &indices, 0),
size,
}
}
fn build_recursive(points: &[[f32; 2]], indices: &[usize], depth: usize) -> KDNode {
if indices.len() <= MAX_LEAF_SIZE {
let leaf_points = indices.iter().map(|&i| points[i]).collect();
KDNode::Leaf {
points: leaf_points,
indices: indices.to_vec(),
}
} else {
let dimension = depth % 2;
let mut sorted_indices = indices.to_vec();
sorted_indices.sort_by(|&a, &b| {
points[a][dimension]
.partial_cmp(&points[b][dimension])
.unwrap_or(std::cmp::Ordering::Equal)
});
let median_idx = sorted_indices.len() / 2;
let median_point_idx = sorted_indices[median_idx];
let split_value = points[median_point_idx][dimension];
let left_indices = &sorted_indices[..median_idx];
let right_indices = &sorted_indices[median_idx + 1..];
let left = Self::build_recursive(points, left_indices, depth + 1);
let right = Self::build_recursive(points, right_indices, depth + 1);
KDNode::Internal {
dimension,
split_value,
left: Box::new(KDTree {
root: left,
size: left_indices.len(),
}),
right: Box::new(KDTree {
root: right,
size: right_indices.len(),
}),
}
}
}
pub fn nearest(&self, query: &[f32; 2]) -> Option<([f32; 2], usize, f32)> {
if self.size == 0 {
return None;
}
let mut best_point = [0.0, 0.0];
let mut best_idx = 0;
let mut best_dist_sq = f32::MAX;
self.nearest_recursive(
&self.root,
query,
&mut best_point,
&mut best_idx,
&mut best_dist_sq,
0,
);
Some((best_point, best_idx, best_dist_sq))
}
fn nearest_recursive(
&self,
node: &KDNode,
query: &[f32; 2],
best_point: &mut [f32; 2],
best_idx: &mut usize,
best_dist_sq: &mut f32,
_depth: usize,
) {
match node {
KDNode::Internal {
dimension,
split_value,
left,
right,
} => {
let query_val = query[*dimension];
let (first, second) = if query_val <= *split_value {
(left, right)
} else {
(right, left)
};
self.nearest_recursive(
&first.root,
query,
best_point,
best_idx,
best_dist_sq,
_depth + 1,
);
let dist_to_split_plane = query_val - *split_value;
if dist_to_split_plane * dist_to_split_plane < *best_dist_sq {
self.nearest_recursive(
&second.root,
query,
best_point,
best_idx,
best_dist_sq,
_depth + 1,
);
}
}
KDNode::Leaf { points, indices } => {
for (i, &point) in points.iter().enumerate() {
let dx = point[0] - query[0];
let dy = point[1] - query[1];
let dist_sq = dx * dx + dy * dy;
if dist_sq < *best_dist_sq || (dist_sq == *best_dist_sq && indices[i] < *best_idx) {
*best_dist_sq = dist_sq;
*best_point = point;
*best_idx = indices[i];
}
}
}
}
}
pub fn nearest_k(&self, query: &[f32; 2], k: usize) -> Vec<([f32; 2], usize, f32)> {
if self.size == 0 || k == 0 {
return Vec::new();
}
let mut results = NearestK::new(k);
self.nearest_k_recursive(&self.root, query, &mut results, 0);
results.into_sorted()
}
fn nearest_k_recursive(
&self,
node: &KDNode,
query: &[f32; 2],
results: &mut NearestK,
_depth: usize,
) {
match node {
KDNode::Internal {
dimension,
split_value,
left,
right,
} => {
let query_val = query[*dimension];
let (first, second) = if query_val <= *split_value {
(left, right)
} else {
(right, left)
};
self.nearest_k_recursive(&first.root, query, results, _depth + 1);
if let Some(worst_dist) = results.worst_distance() {
let dist_to_split = query_val - *split_value;
if dist_to_split * dist_to_split < worst_dist {
self.nearest_k_recursive(&second.root, query, results, _depth + 1);
}
}
}
KDNode::Leaf { points, indices } => {
for (i, &point) in points.iter().enumerate() {
let dx = point[0] - query[0];
let dy = point[1] - query[1];
let dist_sq = dx * dx + dy * dy;
results.insert(point, indices[i], dist_sq);
}
}
}
}
pub fn size(&self) -> usize {
self.size
}
pub fn is_empty(&self) -> bool {
self.size == 0
}
}
struct NearestK {
neighbors: Vec<([f32; 2], usize, f32)>,
k: usize,
}
impl NearestK {
fn new(k: usize) -> Self {
NearestK {
neighbors: Vec::with_capacity(k),
k,
}
}
fn insert(&mut self, point: [f32; 2], index: usize, dist_sq: f32) {
if self.neighbors.len() < self.k {
self.neighbors.push((point, index, dist_sq));
self.neighbors
.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
} else if let Some(&mut (.., ref mut worst_dist)) = self.neighbors.last_mut() {
if dist_sq < *worst_dist {
*self.neighbors.last_mut().unwrap() = (point, index, dist_sq);
self.neighbors
.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap());
}
}
}
fn worst_distance(&self) -> Option<f32> {
self.neighbors.last().map(|&(_, _, dist)| dist)
}
fn into_sorted(self) -> Vec<([f32; 2], usize, f32)> {
self.neighbors
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_kdtree_build() {
let points = vec![[1.0, 0.0], [0.0, 1.0], [0.6, 0.8], [0.8, 0.6]];
let tree = KDTree::build(&points);
assert_eq!(tree.size(), 4);
assert!(!tree.is_empty());
}
#[test]
fn test_kdtree_nearest() {
let points = vec![[1.0, 0.0], [0.0, 1.0], [0.6, 0.8], [0.8, 0.6]];
let tree = KDTree::build(&points);
let query = [0.59, 0.81];
let (nearest, idx, dist_sq) = tree.nearest(&query).unwrap();
assert_eq!(nearest, [0.6, 0.8]);
assert_eq!(idx, 2);
assert!(dist_sq < 0.0005); }
#[test]
fn test_kdtree_nearest_k() {
let points = vec![
[1.0, 0.0],
[0.0, 1.0],
[0.6, 0.8],
[0.8, 0.6],
[0.5, 0.5],
[0.9, 0.9],
];
let tree = KDTree::build(&points);
let query = [0.55, 0.55];
let results = tree.nearest_k(&query, 3);
assert_eq!(results.len(), 3);
for i in 1..results.len() {
assert!(results[i - 1].2 <= results[i].2);
}
}
#[test]
fn test_kdtree_empty() {
let points: Vec<[f32; 2]> = vec![];
let tree = KDTree::build(&points);
assert!(tree.is_empty());
assert_eq!(tree.size(), 0);
assert!(tree.nearest(&[0.5, 0.5]).is_none());
}
#[test]
fn test_kdtree_single_point() {
let points = vec![[0.5, 0.5]];
let tree = KDTree::build(&points);
let query = [0.51, 0.49];
let (nearest, idx, dist_sq) = tree.nearest(&query).unwrap();
assert_eq!(nearest, [0.5, 0.5]);
assert_eq!(idx, 0);
assert!(dist_sq < 0.0002);
}
#[test]
fn test_kdtree_large_random() {
use rand::Rng;
let mut rng = rand::thread_rng();
let points: Vec<[f32; 2]> = (0..1000)
.map(|_| [rng.gen::<f32>(), rng.gen::<f32>()])
.collect();
let tree = KDTree::build(&points);
assert_eq!(tree.size(), 1000);
let query = [rng.gen::<f32>(), rng.gen::<f32>()];
let result = tree.nearest(&query);
assert!(result.is_some());
}
}