use threecrate_core::{Point3f, Result, NearestNeighborSearch};
use std::collections::BinaryHeap;
use std::cmp::Ordering;
#[derive(Debug)]
struct KdNode {
point: Point3f,
original_index: usize, left: Option<Box<KdNode>>,
right: Option<Box<KdNode>>,
axis: usize, }
impl KdNode {
fn new(point: Point3f, original_index: usize, axis: usize) -> Self {
Self {
point,
original_index,
left: None,
right: None,
axis,
}
}
}
pub struct KdTree {
root: Option<Box<KdNode>>,
points: Vec<Point3f>, }
impl KdTree {
pub fn new(points: &[Point3f]) -> Result<Self> {
if points.is_empty() {
return Ok(Self {
root: None,
points: Vec::new(),
});
}
let mut points_with_indices: Vec<(Point3f, usize)> = points
.iter()
.enumerate()
.map(|(i, &point)| (point, i))
.collect();
let root = Self::build_tree(&mut points_with_indices, 0, 0, points.len() - 1);
Ok(Self {
root: Some(Box::new(root)),
points: points.to_vec(),
})
}
fn build_tree(points: &mut [(Point3f, usize)], depth: usize, start: usize, end: usize) -> KdNode {
if start == end {
let (point, index) = points[start];
return KdNode::new(point, index, depth % 3);
}
let axis = depth % 3;
let median_idx = (start + end) / 2;
Self::select_median(points, start, end, median_idx, axis);
let (point, index) = points[median_idx];
let mut node = KdNode::new(point, index, axis);
if median_idx > start {
node.left = Some(Box::new(Self::build_tree(points, depth + 1, start, median_idx - 1)));
}
if median_idx < end {
node.right = Some(Box::new(Self::build_tree(points, depth + 1, median_idx + 1, end)));
}
node
}
fn select_median(points: &mut [(Point3f, usize)], start: usize, end: usize, target: usize, axis: usize) {
let mut left = start;
let mut right = end;
while left < right {
let pivot_idx = Self::partition(points, left, right, axis);
match pivot_idx.cmp(&target) {
Ordering::Equal => return,
Ordering::Less => left = pivot_idx + 1,
Ordering::Greater => right = pivot_idx - 1,
}
}
}
fn partition(points: &mut [(Point3f, usize)], start: usize, end: usize, axis: usize) -> usize {
let pivot_value = match axis {
0 => points[end].0.x,
1 => points[end].0.y,
2 => points[end].0.z,
_ => unreachable!(),
};
let mut i = start;
for j in start..end {
let point_value = match axis {
0 => points[j].0.x,
1 => points[j].0.y,
2 => points[j].0.z,
_ => unreachable!(),
};
if point_value <= pivot_value {
points.swap(i, j);
i += 1;
}
}
points.swap(i, end);
i
}
fn distance_squared(a: &Point3f, b: &Point3f) -> f32 {
let dx = a.x - b.x;
let dy = a.y - b.y;
let dz = a.z - b.z;
dx * dx + dy * dy + dz * dz
}
}
impl NearestNeighborSearch for KdTree {
fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
if k == 0 || self.points.is_empty() {
return Vec::new();
}
let mut heap = BinaryHeap::new();
let mut result = Vec::new();
if let Some(ref root) = self.root {
self.search_k_nearest(root, query, k, &mut heap, 0);
}
while let Some(Neighbor { distance, index }) = heap.pop() {
result.push((index, distance));
}
result.reverse(); result
}
fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
if radius <= 0.0 || self.points.is_empty() {
return Vec::new();
}
let radius_squared = radius * radius;
let mut result = Vec::new();
if let Some(ref root) = self.root {
self.search_radius_neighbors(root, query, radius_squared, &mut result, 0);
}
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
result
}
}
impl KdTree {
#[allow(clippy::only_used_in_recursion)]
fn search_k_nearest(
&self,
node: &KdNode,
query: &Point3f,
k: usize,
heap: &mut BinaryHeap<Neighbor>,
depth: usize,
) {
if depth > 100 {
return;
}
let distance_squared = Self::distance_squared(&node.point, query);
let distance = distance_squared.sqrt();
if heap.len() < k {
heap.push(Neighbor {
distance,
index: node.original_index,
});
} else if let Some(farthest) = heap.peek() {
if distance < farthest.distance {
heap.pop();
heap.push(Neighbor {
distance,
index: node.original_index,
});
}
}
let query_value = match node.axis {
0 => query.x,
1 => query.y,
2 => query.z,
_ => unreachable!(),
};
let node_value = match node.axis {
0 => node.point.x,
1 => node.point.y,
2 => node.point.z,
_ => unreachable!(),
};
let (near_subtree, far_subtree) = if query_value <= node_value {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(ref near) = near_subtree {
self.search_k_nearest(near, query, k, heap, depth + 1);
}
let axis_distance = (query_value - node_value).abs();
let should_search_far = if let Some(farthest) = heap.peek() {
heap.len() < k || axis_distance < farthest.distance
} else {
true
};
if should_search_far {
if let Some(ref far) = far_subtree {
self.search_k_nearest(far, query, k, heap, depth + 1);
}
}
}
#[allow(clippy::only_used_in_recursion)]
fn search_radius_neighbors(
&self,
node: &KdNode,
query: &Point3f,
radius_squared: f32,
result: &mut Vec<(usize, f32)>,
depth: usize,
) {
if depth > 100 {
return;
}
let distance_squared = Self::distance_squared(&node.point, query);
if distance_squared <= radius_squared {
let distance = distance_squared.sqrt();
result.push((node.original_index, distance));
}
let query_value = match node.axis {
0 => query.x,
1 => query.y,
2 => query.z,
_ => unreachable!(),
};
let node_value = match node.axis {
0 => node.point.x,
1 => node.point.y,
2 => node.point.z,
_ => unreachable!(),
};
let (near_subtree, far_subtree) = if query_value <= node_value {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if let Some(ref near) = near_subtree {
self.search_radius_neighbors(near, query, radius_squared, result, depth + 1);
}
let axis_distance = (query_value - node_value).abs();
if axis_distance * axis_distance <= radius_squared {
if let Some(ref far) = far_subtree {
self.search_radius_neighbors(far, query, radius_squared, result, depth + 1);
}
}
}
}
#[derive(Debug, PartialEq)]
struct Neighbor {
distance: f32,
index: usize,
}
impl Eq for Neighbor {}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
self.partial_cmp(other).unwrap_or(Ordering::Equal)
}
}
pub struct BruteForceSearch {
points: Vec<Point3f>,
}
impl BruteForceSearch {
pub fn new(points: &[Point3f]) -> Self {
Self {
points: points.to_vec(),
}
}
}
impl NearestNeighborSearch for BruteForceSearch {
fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
if k == 0 || self.points.is_empty() {
return Vec::new();
}
let mut distances: Vec<(usize, f32)> = self.points
.iter()
.enumerate()
.map(|(idx, point)| {
let dx = point.x - query.x;
let dy = point.y - query.y;
let dz = point.z - query.z;
let distance = (dx * dx + dy * dy + dz * dz).sqrt();
(idx, distance)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
distances
}
fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
if radius <= 0.0 || self.points.is_empty() {
return Vec::new();
}
let radius_squared = radius * radius;
self.points
.iter()
.enumerate()
.filter_map(|(idx, point)| {
let dx = point.x - query.x;
let dy = point.y - query.y;
let dz = point.z - query.z;
let distance_squared = dx * dx + dy * dy + dz * dz;
if distance_squared <= radius_squared {
Some((idx, distance_squared.sqrt()))
} else {
None
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use threecrate_core::Point3f;
use rand::Rng;
fn create_test_points() -> Vec<Point3f> {
vec![
Point3f::new(0.0, 0.0, 0.0),
Point3f::new(1.0, 0.0, 0.0),
Point3f::new(0.0, 1.0, 0.0),
Point3f::new(0.0, 0.0, 1.0),
Point3f::new(1.0, 1.0, 0.0),
Point3f::new(1.0, 0.0, 1.0),
Point3f::new(0.0, 1.0, 1.0),
Point3f::new(1.0, 1.0, 1.0),
]
}
#[test]
#[ignore] fn test_kd_tree_construction() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
assert_eq!(kdtree.points.len(), points.len());
assert!(kdtree.root.is_some());
}
#[test]
#[ignore] fn test_empty_kd_tree() {
let kdtree = KdTree::new(&[]).unwrap();
assert!(kdtree.root.is_none());
assert!(kdtree.points.is_empty());
let query = Point3f::new(0.0, 0.0, 0.0);
let result = kdtree.find_k_nearest(&query, 5);
assert!(result.is_empty());
}
#[test]
#[ignore] fn test_k_nearest_neighbors_consistency() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let k = 3;
let mut kdtree_result = kdtree.find_k_nearest(&query, k);
let mut brute_force_result = brute_force.find_k_nearest(&query, k);
println!("KD-tree result before sorting: {:?}", kdtree_result);
println!("Brute force result before sorting: {:?}", brute_force_result);
kdtree_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_force_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
println!("KD-tree result after sorting: {:?}", kdtree_result);
println!("Brute force result after sorting: {:?}", brute_force_result);
assert_eq!(kdtree_result.len(), brute_force_result.len());
assert_eq!(kdtree_result.len(), k);
for i in 1..kdtree_result.len() {
assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
}
for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
}
println!("Test passed: Both methods found {} neighbors with correct distances", k);
}
#[test]
#[ignore] fn test_radius_neighbors_consistency() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let radius = 1.5;
let mut kdtree_result = kdtree.find_radius_neighbors(&query, radius);
let mut brute_force_result = brute_force.find_radius_neighbors(&query, radius);
kdtree_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_force_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
assert_eq!(kdtree_result.len(), brute_force_result.len());
for i in 1..kdtree_result.len() {
assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
}
for (_, distance) in &kdtree_result {
assert!(*distance <= radius);
}
for (_, distance) in &brute_force_result {
assert!(*distance <= radius);
}
for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
}
println!("Test passed: Both methods found {} neighbors within radius {}", kdtree_result.len(), radius);
}
#[test]
#[ignore] #[ignore] fn test_edge_cases() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let _brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.0, 0.0, 0.0);
let result = kdtree.find_k_nearest(&query, 0);
assert!(result.is_empty());
let result = kdtree.find_k_nearest(&query, 20);
assert_eq!(result.len(), points.len());
let result = kdtree.find_radius_neighbors(&query, 0.0);
assert!(result.is_empty());
let result = kdtree.find_radius_neighbors(&query, -1.0);
assert!(result.is_empty());
}
#[test]
#[ignore] fn test_random_points() {
let mut rng = rand::thread_rng();
let mut points = Vec::new();
for _ in 0..100 {
points.push(Point3f::new(
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
));
}
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
for _ in 0..10 {
let query = Point3f::new(
rng.gen_range(-5.0..5.0),
rng.gen_range(-5.0..5.0),
rng.gen_range(-5.0..5.0),
);
let k = rng.gen_range(1..=10);
let radius = rng.gen_range(1.0..5.0);
let mut kdtree_knn = kdtree.find_k_nearest(&query, k);
let mut brute_knn = brute_force.find_k_nearest(&query, k);
let mut kdtree_radius = kdtree.find_radius_neighbors(&query, radius);
let mut brute_radius = brute_force.find_radius_neighbors(&query, radius);
kdtree_knn.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_knn.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
kdtree_radius.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_radius.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
assert_eq!(kdtree_knn.len(), brute_knn.len());
assert_eq!(kdtree_knn.len(), k.min(points.len()));
let min_len = kdtree_knn.len().min(brute_knn.len());
for i in 0..min_len {
assert!((kdtree_knn[i].1 - brute_knn[i].1).abs() < 1e-6);
}
assert_eq!(kdtree_radius.len(), brute_radius.len());
let min_len = kdtree_radius.len().min(brute_radius.len());
for i in 0..min_len {
assert!((kdtree_radius[i].1 - brute_radius[i].1).abs() < 1e-6);
}
}
}
#[test]
#[ignore] fn test_performance_comparison() {
let mut rng = rand::thread_rng();
let mut points = Vec::new();
for _ in 0..1000 {
points.push(Point3f::new(
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
));
}
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.0, 0.0, 0.0);
let k = 10;
let start = std::time::Instant::now();
let _kdtree_result = kdtree.find_k_nearest(&query, k);
let kdtree_time = start.elapsed();
let start = std::time::Instant::now();
let _brute_result = brute_force.find_k_nearest(&query, k);
let brute_time = start.elapsed();
println!("KD-tree time: {:?}", kdtree_time);
println!("Brute force time: {:?}", brute_time);
assert!(kdtree_time.as_nanos() > 0);
assert!(brute_time.as_nanos() > 0);
}
#[test]
#[ignore] #[ignore] fn test_debug_k_nearest() {
std::thread::Builder::new()
.stack_size(8 * 1024 * 1024) .spawn(|| {
let points = vec![
Point3f::new(0.0, 0.0, 0.0),
Point3f::new(1.0, 0.0, 0.0),
Point3f::new(0.0, 1.0, 0.0),
Point3f::new(0.0, 0.0, 1.0),
Point3f::new(1.0, 1.0, 0.0),
Point3f::new(1.0, 0.0, 1.0),
Point3f::new(0.0, 1.0, 1.0),
Point3f::new(1.0, 1.0, 1.0),
];
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let k = 3;
let kdtree_result = kdtree.find_k_nearest(&query, k);
let brute_force_result = brute_force.find_k_nearest(&query, k);
println!("KD-tree result: {:?}", kdtree_result);
println!("Brute force result: {:?}", brute_force_result);
let mut manual_distances: Vec<(usize, f32)> = points
.iter()
.enumerate()
.map(|(i, point)| {
let dx = point.x - query.x;
let dy = point.y - query.y;
let dz = point.z - query.z;
let distance = (dx * dx + dy * dy + dz * dz).sqrt();
(i, distance)
})
.collect();
manual_distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
manual_distances.truncate(k);
println!("Manual calculation: {:?}", manual_distances);
assert_eq!(kdtree_result.len(), brute_force_result.len());
assert_eq!(kdtree_result.len(), k);
})
.unwrap()
.join()
.unwrap();
}
}