1#![allow(dead_code)]
4
5#[derive(Debug, Clone, Copy, PartialEq)]
9pub struct KdPoint3 {
10 pub pos: [f32; 3],
11 pub id: usize,
12}
13
14impl KdPoint3 {
15 pub fn new(x: f32, y: f32, z: f32, id: usize) -> Self {
16 KdPoint3 { pos: [x, y, z], id }
17 }
18
19 fn dist_sq(&self, other: &[f32; 3]) -> f32 {
20 (0..3).map(|i| (self.pos[i] - other[i]).powi(2)).sum()
21 }
22}
23
24#[derive(Debug)]
26struct KdNode {
27 point: KdPoint3,
28 left: Option<Box<KdNode>>,
29 right: Option<Box<KdNode>>,
30 axis: usize,
31}
32
33#[derive(Default)]
35pub struct KdTree3 {
36 root: Option<Box<KdNode>>,
37 count: usize,
38}
39
40fn build(points: &mut [KdPoint3], depth: usize) -> Option<Box<KdNode>> {
41 if points.is_empty() {
42 return None;
43 }
44 let axis = depth % 3;
45 points.sort_by(|a, b| {
46 a.pos[axis]
47 .partial_cmp(&b.pos[axis])
48 .unwrap_or(std::cmp::Ordering::Equal)
49 });
50 let mid = points.len() / 2;
51 Some(Box::new(KdNode {
52 point: points[mid],
53 axis,
54 left: build(&mut points[..mid], depth + 1),
55 right: build(&mut points[mid + 1..], depth + 1),
56 }))
57}
58
59fn nn_search<'a>(node: &'a KdNode, query: &[f32; 3], best: &mut Option<(f32, &'a KdPoint3)>) {
60 let d = node.point.dist_sq(query);
61 if best.is_none_or(|(bd, _)| d < bd) {
62 *best = Some((d, &node.point));
63 }
64 let axis = node.axis;
65 let diff = query[axis] - node.point.pos[axis];
66 let (near, far) = if diff <= 0.0 {
67 (node.left.as_deref(), node.right.as_deref())
68 } else {
69 (node.right.as_deref(), node.left.as_deref())
70 };
71 if let Some(n) = near {
72 nn_search(n, query, best);
73 }
74 if let Some(f) = far {
75 let best_d = best.map(|(bd, _)| bd).unwrap_or(f32::INFINITY);
76 if diff * diff <= best_d {
77 nn_search(f, query, best);
78 }
79 }
80}
81
82fn range_search(node: &KdNode, query: &[f32; 3], r_sq: f32, result: &mut Vec<KdPoint3>) {
83 if node.point.dist_sq(query) <= r_sq {
84 result.push(node.point);
85 }
86 let axis = node.axis;
87 let diff = query[axis] - node.point.pos[axis];
88 if diff - r_sq.sqrt() <= 0.0 {
89 if let Some(n) = &node.left {
90 range_search(n, query, r_sq, result);
91 }
92 }
93 if diff + r_sq.sqrt() >= 0.0 {
94 if let Some(n) = &node.right {
95 range_search(n, query, r_sq, result);
96 }
97 }
98}
99
100impl KdTree3 {
101 pub fn build(mut points: Vec<KdPoint3>) -> Self {
103 let count = points.len();
104 let root = build(&mut points, 0);
105 KdTree3 { root, count }
106 }
107
108 pub fn nearest(&self, query: &[f32; 3]) -> Option<(KdPoint3, f32)> {
110 let root = self.root.as_deref()?;
111 let mut best = None;
112 nn_search(root, query, &mut best);
113 best.map(|(d, p)| (*p, d.sqrt()))
114 }
115
116 pub fn range_query(&self, query: &[f32; 3], r: f32) -> Vec<KdPoint3> {
118 let mut result = Vec::new();
119 if let Some(root) = &self.root {
120 range_search(root, query, r * r, &mut result);
121 }
122 result
123 }
124
125 pub fn len(&self) -> usize {
127 self.count
128 }
129
130 pub fn is_empty(&self) -> bool {
132 self.count == 0
133 }
134}
135
136pub fn new_kd_tree(positions: &[[f32; 3]]) -> KdTree3 {
138 let points: Vec<KdPoint3> = positions
139 .iter()
140 .enumerate()
141 .map(|(i, p)| KdPoint3 { pos: *p, id: i })
142 .collect();
143 KdTree3::build(points)
144}
145
146pub fn new_kd_tree_2d(positions: &[[f32; 2]]) -> KdTree3 {
148 let points: Vec<KdPoint3> = positions
149 .iter()
150 .enumerate()
151 .map(|(i, p)| KdPoint3 {
152 pos: [p[0], p[1], 0.0],
153 id: i,
154 })
155 .collect();
156 KdTree3::build(points)
157}
158
159#[cfg(test)]
160mod tests {
161 use super::*;
162
163 #[test]
164 fn test_nearest_basic() {
165 let pts = vec![
166 KdPoint3::new(0.0, 0.0, 0.0, 0),
167 KdPoint3::new(1.0, 0.0, 0.0, 1),
168 KdPoint3::new(5.0, 5.0, 5.0, 2),
169 ];
170 let tree = KdTree3::build(pts);
171 let (p, d) = tree.nearest(&[0.1, 0.0, 0.0]).expect("should succeed");
172 assert_eq!(p.id, 0);
173 assert!(d < 0.2);
174 }
175
176 #[test]
177 fn test_nearest_single_point() {
178 let pts = vec![KdPoint3::new(3.0, 4.0, 0.0, 0)];
179 let tree = KdTree3::build(pts);
180 let (p, d) = tree.nearest(&[3.0, 4.0, 0.0]).expect("should succeed");
181 assert_eq!(p.id, 0);
182 assert!(d < 1e-5);
183 }
184
185 #[test]
186 fn test_empty_tree() {
187 let tree = KdTree3::build(vec![]);
188 assert!(tree.nearest(&[0.0, 0.0, 0.0]).is_none());
189 assert!(tree.is_empty());
190 }
191
192 #[test]
193 fn test_range_query() {
194 let pts: Vec<KdPoint3> = (0..10)
195 .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
196 .collect();
197 let tree = KdTree3::build(pts);
198 let found = tree.range_query(&[5.0, 0.0, 0.0], 2.5);
199 assert!(found.len() >= 4);
201 }
202
203 #[test]
204 fn test_new_kd_tree() {
205 let pos = vec![[0.0f32, 1.0, 2.0], [3.0, 4.0, 5.0]];
206 let tree = new_kd_tree(&pos);
207 assert_eq!(tree.len(), 2);
208 }
209
210 #[test]
211 fn test_new_kd_tree_2d() {
212 let pos = vec![[0.0f32, 0.0], [1.0, 1.0], [2.0, 2.0]];
213 let tree = new_kd_tree_2d(&pos);
214 let (p, _) = tree.nearest(&[0.9, 0.9, 0.0]).expect("should succeed");
215 assert_eq!(p.id, 1);
216 }
217
218 #[test]
219 fn test_len() {
220 let pts: Vec<KdPoint3> = (0..5)
221 .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
222 .collect();
223 let tree = KdTree3::build(pts);
224 assert_eq!(tree.len(), 5);
225 }
226
227 #[test]
228 fn test_many_points_nearest() {
229 let pts: Vec<KdPoint3> = (0..100)
230 .map(|i| KdPoint3::new(i as f32, 0.0, 0.0, i))
231 .collect();
232 let tree = KdTree3::build(pts);
233 let (p, d) = tree.nearest(&[49.5, 0.0, 0.0]).expect("should succeed");
234 assert!(p.id == 49 || p.id == 50);
235 assert!(d < 1.0);
236 }
237}