use threecrate_core::{Point3f, Result, NearestNeighborSearch};
use std::collections::BinaryHeap;
use std::cmp::Ordering;
#[derive(Debug)]
struct KdNode {
point: Point3f,
original_index: usize, left: Option<Box<KdNode>>,
right: Option<Box<KdNode>>,
axis: usize, }
impl KdNode {
fn new(point: Point3f, original_index: usize, axis: usize) -> Self {
Self {
point,
original_index,
left: None,
right: None,
axis,
}
}
}
pub struct KdTree {
root: Option<Box<KdNode>>,
points: Vec<Point3f>, }
impl KdTree {
pub fn new(points: &[Point3f]) -> Result<Self> {
if points.is_empty() {
return Ok(Self {
root: None,
points: Vec::new(),
});
}
let mut points_with_indices: Vec<(Point3f, usize)> = points
.iter()
.enumerate()
.map(|(i, &point)| (point, i))
.collect();
let root = Self::build_tree(&mut points_with_indices, 0, 0, points.len() - 1);
Ok(Self {
root: Some(Box::new(root)),
points: points.to_vec(),
})
}
fn build_tree(points: &mut [(Point3f, usize)], depth: usize, start: usize, end: usize) -> KdNode {
if start == end {
let (point, index) = points[start];
return KdNode::new(point, index, depth % 3);
}
let axis = depth % 3;
let median_idx = (start + end) / 2;
Self::select_median(points, start, end, median_idx, axis);
let (point, index) = points[median_idx];
let mut node = KdNode::new(point, index, axis);
if median_idx > start {
node.left = Some(Box::new(Self::build_tree(points, depth + 1, start, median_idx - 1)));
}
if median_idx < end {
node.right = Some(Box::new(Self::build_tree(points, depth + 1, median_idx + 1, end)));
}
node
}
fn select_median(points: &mut [(Point3f, usize)], start: usize, end: usize, target: usize, axis: usize) {
let mut left = start;
let mut right = end;
while left < right {
let pivot_idx = Self::partition(points, left, right, axis);
match pivot_idx.cmp(&target) {
Ordering::Equal => return,
Ordering::Less => left = pivot_idx + 1,
Ordering::Greater => right = pivot_idx - 1,
}
}
}
fn partition(points: &mut [(Point3f, usize)], start: usize, end: usize, axis: usize) -> usize {
let pivot_value = match axis {
0 => points[end].0.x,
1 => points[end].0.y,
2 => points[end].0.z,
_ => unreachable!(),
};
let mut i = start;
for j in start..end {
let point_value = match axis {
0 => points[j].0.x,
1 => points[j].0.y,
2 => points[j].0.z,
_ => unreachable!(),
};
if point_value <= pivot_value {
points.swap(i, j);
i += 1;
}
}
points.swap(i, end);
i
}
fn distance_squared(a: &Point3f, b: &Point3f) -> f32 {
let dx = a.x - b.x;
let dy = a.y - b.y;
let dz = a.z - b.z;
dx * dx + dy * dy + dz * dz
}
}
impl NearestNeighborSearch for KdTree {
fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
if k == 0 || self.points.is_empty() {
return Vec::new();
}
let mut heap: BinaryHeap<Neighbor> = BinaryHeap::with_capacity(k + 1);
let mut stack: Vec<&KdNode> = Vec::new();
if let Some(ref root) = self.root {
stack.push(root);
}
while let Some(node) = stack.pop() {
let dist = Self::distance_squared(&node.point, query).sqrt();
if heap.len() < k {
heap.push(Neighbor { distance: dist, index: node.original_index });
} else if let Some(farthest) = heap.peek() {
if dist < farthest.distance {
heap.pop();
heap.push(Neighbor { distance: dist, index: node.original_index });
}
}
let query_val = query.coords[node.axis];
let node_val = node.point.coords[node.axis];
let axis_dist = (query_val - node_val).abs();
let (near, far) = if query_val <= node_val {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
let search_far = if let Some(farthest) = heap.peek() {
heap.len() < k || axis_dist < farthest.distance
} else {
true
};
if search_far {
if let Some(ref far_node) = far {
stack.push(far_node);
}
}
if let Some(ref near_node) = near {
stack.push(near_node);
}
}
heap.into_sorted_vec()
.into_iter()
.map(|n| (n.index, n.distance))
.collect()
}
fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
if radius <= 0.0 || self.points.is_empty() {
return Vec::new();
}
let radius_sq = radius * radius;
let mut result: Vec<(usize, f32)> = Vec::new();
let mut stack: Vec<&KdNode> = Vec::new();
if let Some(ref root) = self.root {
stack.push(root);
}
while let Some(node) = stack.pop() {
let dist_sq = Self::distance_squared(&node.point, query);
if dist_sq <= radius_sq {
result.push((node.original_index, dist_sq.sqrt()));
}
let query_val = query.coords[node.axis];
let node_val = node.point.coords[node.axis];
let axis_dist = query_val - node_val;
let (near, far) = if query_val <= node_val {
(&node.left, &node.right)
} else {
(&node.right, &node.left)
};
if axis_dist * axis_dist <= radius_sq {
if let Some(ref far_node) = far {
stack.push(far_node);
}
}
if let Some(ref near_node) = near {
stack.push(near_node);
}
}
result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
result
}
}
#[derive(Debug, PartialEq)]
struct Neighbor {
distance: f32,
index: usize,
}
impl Eq for Neighbor {}
impl PartialOrd for Neighbor {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Neighbor {
fn cmp(&self, other: &Self) -> Ordering {
self.distance.partial_cmp(&other.distance).unwrap_or(Ordering::Equal)
}
}
pub struct BruteForceSearch {
points: Vec<Point3f>,
}
impl BruteForceSearch {
pub fn new(points: &[Point3f]) -> Self {
Self {
points: points.to_vec(),
}
}
}
impl NearestNeighborSearch for BruteForceSearch {
fn find_k_nearest(&self, query: &Point3f, k: usize) -> Vec<(usize, f32)> {
if k == 0 || self.points.is_empty() {
return Vec::new();
}
let mut distances: Vec<(usize, f32)> = self.points
.iter()
.enumerate()
.map(|(idx, point)| {
let dx = point.x - query.x;
let dy = point.y - query.y;
let dz = point.z - query.z;
let distance = (dx * dx + dy * dy + dz * dz).sqrt();
(idx, distance)
})
.collect();
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal));
distances.truncate(k);
distances
}
fn find_radius_neighbors(&self, query: &Point3f, radius: f32) -> Vec<(usize, f32)> {
if radius <= 0.0 || self.points.is_empty() {
return Vec::new();
}
let radius_squared = radius * radius;
self.points
.iter()
.enumerate()
.filter_map(|(idx, point)| {
let dx = point.x - query.x;
let dy = point.y - query.y;
let dz = point.z - query.z;
let distance_squared = dx * dx + dy * dy + dz * dz;
if distance_squared <= radius_squared {
Some((idx, distance_squared.sqrt()))
} else {
None
}
})
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use threecrate_core::Point3f;
use rand::Rng;
fn create_test_points() -> Vec<Point3f> {
vec![
Point3f::new(0.0, 0.0, 0.0),
Point3f::new(1.0, 0.0, 0.0),
Point3f::new(0.0, 1.0, 0.0),
Point3f::new(0.0, 0.0, 1.0),
Point3f::new(1.0, 1.0, 0.0),
Point3f::new(1.0, 0.0, 1.0),
Point3f::new(0.0, 1.0, 1.0),
Point3f::new(1.0, 1.0, 1.0),
]
}
#[test]
fn test_kd_tree_construction() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
assert_eq!(kdtree.points.len(), points.len());
assert!(kdtree.root.is_some());
}
#[test]
fn test_empty_kd_tree() {
let kdtree = KdTree::new(&[]).unwrap();
assert!(kdtree.root.is_none());
assert!(kdtree.points.is_empty());
let query = Point3f::new(0.0, 0.0, 0.0);
let result = kdtree.find_k_nearest(&query, 5);
assert!(result.is_empty());
}
#[test]
fn test_k_nearest_neighbors_consistency() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let k = 3;
let mut kdtree_result = kdtree.find_k_nearest(&query, k);
let mut brute_force_result = brute_force.find_k_nearest(&query, k);
println!("KD-tree result before sorting: {:?}", kdtree_result);
println!("Brute force result before sorting: {:?}", brute_force_result);
kdtree_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_force_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
println!("KD-tree result after sorting: {:?}", kdtree_result);
println!("Brute force result after sorting: {:?}", brute_force_result);
assert_eq!(kdtree_result.len(), brute_force_result.len());
assert_eq!(kdtree_result.len(), k);
for i in 1..kdtree_result.len() {
assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
}
for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
}
println!("Test passed: Both methods found {} neighbors with correct distances", k);
}
#[test]
fn test_radius_neighbors_consistency() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let radius = 1.5;
let mut kdtree_result = kdtree.find_radius_neighbors(&query, radius);
let mut brute_force_result = brute_force.find_radius_neighbors(&query, radius);
kdtree_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_force_result.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
assert_eq!(kdtree_result.len(), brute_force_result.len());
for i in 1..kdtree_result.len() {
assert!(kdtree_result[i-1].1 <= kdtree_result[i].1);
assert!(brute_force_result[i-1].1 <= brute_force_result[i].1);
}
for (_, distance) in &kdtree_result {
assert!(*distance <= radius);
}
for (_, distance) in &brute_force_result {
assert!(*distance <= radius);
}
for (kdtree_neighbor, brute_neighbor) in kdtree_result.iter().zip(brute_force_result.iter()) {
assert!((kdtree_neighbor.1 - brute_neighbor.1).abs() < 1e-6);
}
println!("Test passed: Both methods found {} neighbors within radius {}", kdtree_result.len(), radius);
}
#[test]
fn test_edge_cases() {
let points = create_test_points();
let kdtree = KdTree::new(&points).unwrap();
let _brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.0, 0.0, 0.0);
let result = kdtree.find_k_nearest(&query, 0);
assert!(result.is_empty());
let result = kdtree.find_k_nearest(&query, 20);
assert_eq!(result.len(), points.len());
let result = kdtree.find_radius_neighbors(&query, 0.0);
assert!(result.is_empty());
let result = kdtree.find_radius_neighbors(&query, -1.0);
assert!(result.is_empty());
}
#[test]
fn test_random_points() {
let mut rng = rand::thread_rng();
let mut points = Vec::new();
for _ in 0..100 {
points.push(Point3f::new(
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
));
}
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
for _ in 0..10 {
let query = Point3f::new(
rng.gen_range(-5.0..5.0),
rng.gen_range(-5.0..5.0),
rng.gen_range(-5.0..5.0),
);
let k = rng.gen_range(1..=10);
let radius = rng.gen_range(1.0..5.0);
let mut kdtree_knn = kdtree.find_k_nearest(&query, k);
let mut brute_knn = brute_force.find_k_nearest(&query, k);
let mut kdtree_radius = kdtree.find_radius_neighbors(&query, radius);
let mut brute_radius = brute_force.find_radius_neighbors(&query, radius);
kdtree_knn.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_knn.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
kdtree_radius.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
brute_radius.sort_by(|a, b| {
a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal)
.then(a.0.cmp(&b.0))
});
assert_eq!(kdtree_knn.len(), brute_knn.len());
assert_eq!(kdtree_knn.len(), k.min(points.len()));
let min_len = kdtree_knn.len().min(brute_knn.len());
for i in 0..min_len {
assert!((kdtree_knn[i].1 - brute_knn[i].1).abs() < 1e-6);
}
assert_eq!(kdtree_radius.len(), brute_radius.len());
let min_len = kdtree_radius.len().min(brute_radius.len());
for i in 0..min_len {
assert!((kdtree_radius[i].1 - brute_radius[i].1).abs() < 1e-6);
}
}
}
#[test]
fn test_performance_comparison() {
let mut rng = rand::thread_rng();
let mut points = Vec::new();
for _ in 0..1000 {
points.push(Point3f::new(
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
rng.gen_range(-10.0..10.0),
));
}
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.0, 0.0, 0.0);
let k = 10;
let start = std::time::Instant::now();
let _kdtree_result = kdtree.find_k_nearest(&query, k);
let kdtree_time = start.elapsed();
let start = std::time::Instant::now();
let _brute_result = brute_force.find_k_nearest(&query, k);
let brute_time = start.elapsed();
println!("KD-tree time: {:?}", kdtree_time);
println!("Brute force time: {:?}", brute_time);
assert!(kdtree_time.as_nanos() > 0);
assert!(brute_time.as_nanos() > 0);
}
#[test]
fn test_debug_k_nearest() {
let points = vec![
Point3f::new(0.0, 0.0, 0.0),
Point3f::new(1.0, 0.0, 0.0),
Point3f::new(0.0, 1.0, 0.0),
Point3f::new(0.0, 0.0, 1.0),
Point3f::new(1.0, 1.0, 0.0),
Point3f::new(1.0, 0.0, 1.0),
Point3f::new(0.0, 1.0, 1.0),
Point3f::new(1.0, 1.0, 1.0),
];
let kdtree = KdTree::new(&points).unwrap();
let brute_force = BruteForceSearch::new(&points);
let query = Point3f::new(0.5, 0.5, 0.5);
let k = 3;
let mut kdtree_result = kdtree.find_k_nearest(&query, k);
let mut brute_force_result = brute_force.find_k_nearest(&query, k);
kdtree_result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0)));
brute_force_result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(Ordering::Equal).then(a.0.cmp(&b.0)));
assert_eq!(kdtree_result.len(), brute_force_result.len());
assert_eq!(kdtree_result.len(), k);
for (kd, bf) in kdtree_result.iter().zip(brute_force_result.iter()) {
assert!((kd.1 - bf.1).abs() < 1e-6, "distance mismatch: kd={}, bf={}", kd.1, bf.1);
}
}
}