1use nalgebra::{Point, Scalar};
25use num_traits::{NumOps, Zero};
26
27use crate::{utils::distance_squared, Box, Ordering};
28
29#[derive(Clone, Debug, Default)]
30struct KDNode<T, const N: usize>
31where
32 T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
33{
34 internal_data: Point<T, N>,
35 right: Option<Box<KDNode<T, N>>>,
36 left: Option<Box<KDNode<T, N>>>,
37}
38
39impl<T, const N: usize> KDNode<T, N>
40where
41 T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
42{
43 fn new(root: Point<T, N>) -> Self {
44 Self {
45 internal_data: root,
46 left: None,
47 right: None,
48 }
49 }
50
51 #[cfg_attr(
52 feature = "tracing",
53 tracing::instrument("Insert New Point", skip_all, level = "trace")
54 )]
55 fn insert(&mut self, data: Point<T, N>, depth: usize) -> bool {
56 let dimension_to_check = depth % N;
57
58 let (branch_to_use, verify_equals) =
59 match data.coords[dimension_to_check].partial_cmp(&self.internal_data.coords[dimension_to_check]).unwrap() {
61 Ordering::Less => (&mut self.left, false),
62 Ordering::Equal => (&mut self.right, true),
63 Ordering::Greater => (&mut self.right, false)
64 };
65
66 if let Some(branch_exists) = branch_to_use.as_mut() {
67 return branch_exists.insert(data, depth + 1);
68 } else if verify_equals && self.internal_data == data {
69 return false;
70 }
71
72 *branch_to_use = Some(Box::new(KDNode::new(data)));
73 true
74 }
75
76 #[cfg_attr(
77 feature = "tracing",
78 tracing::instrument("Branch Nearest Neighbour", skip_all, level = "trace")
79 )]
80 fn nearest(&self, target: &Point<T, N>, depth: usize) -> Option<Point<T, N>> {
81 let dimension_to_check = depth % N;
82 let (next_branch, opposite_branch) =
83 if target.coords[dimension_to_check] < self.internal_data.coords[dimension_to_check] {
84 (self.left.as_ref(), self.right.as_ref())
85 } else {
86 (self.right.as_ref(), self.left.as_ref())
87 };
88
89 let mut best = next_branch
91 .and_then(|branch| branch.nearest(target, depth + 1))
92 .unwrap_or(self.internal_data);
93
94 let axis_distance =
95 target.coords[dimension_to_check] - self.internal_data.coords[dimension_to_check];
96
97 if distance_squared(&self.internal_data, target) < distance_squared(&best, target) {
98 best = self.internal_data;
99 }
100
101 if (axis_distance * axis_distance) < distance_squared(&best, target) {
102 if let Some(opposite_best) =
103 opposite_branch.and_then(|branch| branch.nearest(target, depth + 1))
104 {
105 if distance_squared(&opposite_best, target) < distance_squared(&best, target) {
106 return Some(opposite_best);
107 }
108 }
109 }
110
111 Some(best)
112 }
113
114 #[cfg_attr(
115 feature = "tracing",
116 tracing::instrument("Traverse Branch With Function", skip_all, level = "debug")
117 )]
118 fn traverse_branch<F: FnMut(&Point<T, N>)>(&self, func: &mut F) {
119 if let Some(left) = self.left.as_ref() {
120 left.traverse_branch(func);
121 }
122 func(&self.internal_data);
123 if let Some(right) = self.right.as_ref() {
124 right.traverse_branch(func);
125 }
126 }
127
128 #[cfg_attr(
129 feature = "tracing",
130 tracing::instrument("Traverse Branch With Mutable Function)", skip_all, level = "debug")
131 )]
132 fn traverse_branch_mut<F: FnMut(&mut Point<T, N>)>(&mut self, func: &mut F) {
133 if let Some(left) = self.left.as_mut() {
134 left.traverse_branch_mut(func);
135 }
136 func(&mut self.internal_data);
137 if let Some(right) = self.right.as_mut() {
138 right.traverse_branch_mut(func);
139 }
140 }
141}
142
143#[derive(Clone, Debug, Default)]
149pub struct KDTree<T, const N: usize>
150where
151 T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
152{
153 root: Option<KDNode<T, N>>,
154 element_count: usize,
155}
156
157impl<T, const N: usize> KDTree<T, N>
158where
159 T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
160{
161 #[cfg_attr(
166 feature = "tracing",
167 tracing::instrument("Insert To Tree", skip_all, level = "debug")
168 )]
169 pub fn insert(&mut self, data: Point<T, N>) {
170 if let Some(root) = self.root.as_mut() {
171 if root.insert(data, 0) {
172 self.element_count += 1;
173 }
174 } else {
175 self.root = Some(KDNode::new(data));
176 self.element_count = 1;
177 }
178 }
179
180 pub fn len(&self) -> usize {
185 self.element_count
186 }
187
188 pub fn is_empty(&self) -> bool {
193 self.element_count == 0
194 }
195
196 #[cfg_attr(
203 feature = "tracing",
204 tracing::instrument("Find Nearest Neighbour", skip_all, level = "debug")
205 )]
206 pub fn nearest(&self, target: &Point<T, N>) -> Option<Point<T, N>> {
207 self.root.as_ref().and_then(|root| root.nearest(target, 0))
208 }
209
210 #[cfg_attr(
215 feature = "tracing",
216 tracing::instrument("Traverse Tree With Function", skip_all, level = "info")
217 )]
218 pub fn traverse_tree<F: FnMut(&Point<T, N>)>(&self, mut func: F) {
219 if let Some(root) = self.root.as_ref() {
220 root.traverse_branch(&mut func);
221 }
222 }
223
224 #[cfg_attr(
229 feature = "tracing",
230 tracing::instrument("Traverse Tree With Mutable Function", skip_all, level = "info")
231 )]
232 pub fn traverse_tree_mut<F: FnMut(&mut Point<T, N>)>(&mut self, mut func: F) {
233 if let Some(root) = self.root.as_mut() {
234 root.traverse_branch_mut(&mut func);
235 }
236 }
237}
238
239impl<T, const N: usize> From<&[Point<T, N>]> for KDTree<T, N>
240where
241 T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
242{
243 #[cfg_attr(
244 feature = "tracing",
245 tracing::instrument("Generate Tree From Point Cloud", skip_all, level = "info")
246 )]
247 fn from(point_cloud: &[Point<T, N>]) -> Self {
248 point_cloud
249 .iter()
250 .copied()
251 .fold(Self::default(), |mut tree, current_point| {
252 tree.insert(current_point);
253 tree
254 })
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use nalgebra::{Point2, Point3};
261
262 use crate::{point_clouds::find_nearest_neighbour_naive, Vec};
263
264 use super::*;
265
266 fn generate_tree() -> KDTree<f32, 3> {
267 let points = Vec::from([
268 Point3::new(0.0, 2.0, 1.0),
269 Point3::new(-1.0, 4.0, 2.5),
270 Point3::new(1.3, 2.5, 0.5),
271 Point3::new(-2.1, 0.2, -0.2),
272 ]);
273 KDTree::from(points.as_slice())
274 }
275
276 #[test]
277 fn test_insert() {
278 let mut tree = KDTree::default();
280 tree.insert(Point2::new(0.0f32, 0.0f32));
281
282 match tree.root.as_ref() {
283 None => {
284 panic!("Error, tree root should be Some()")
285 }
286 Some(root) => {
287 assert_eq!(root.internal_data, Point2::new(0.0f32, 0.0f32));
288 }
289 }
290
291 tree.insert(Point2::new(-1.0f32, 0.4f32));
294 match tree.root.as_ref().unwrap().left.as_ref() {
295 None => {
296 panic!("Error, first left branch should be Some()");
297 }
298 Some(left_branch) => {
299 assert_eq!(left_branch.internal_data, Point2::new(-1.0f32, 0.4f32));
300 }
301 }
302
303 tree.insert(Point2::new(-2.0f32, -3.0f32));
305 assert!(tree.root.as_ref().unwrap().right.is_none());
306
307 tree.insert(Point2::new(1.4f32, 5.0f32));
309 match tree.root.as_ref().unwrap().right.as_ref() {
310 None => {
311 panic!("Error, first right branch should be Some()");
312 }
313 Some(right_branch) => {
314 assert_eq!(right_branch.internal_data, Point2::new(1.4f32, 5.0f32));
315 }
316 }
317 }
318
319 #[test]
320 fn test_insert_duplicate() {
321 let mut tree = KDTree::default();
322 assert!(tree.is_empty());
323
324 tree.insert(Point2::new(0.0f32, 0.0f32));
325 assert_eq!(tree.len(), 1);
326 assert!(!tree.is_empty());
327
328 tree.insert(Point2::new(0.0f32, 0.0f32));
330 assert_eq!(tree.len(), 1);
331 }
332
333 #[test]
334 fn test_nearest() {
335 {
337 let tree = KDTree::<f32, 2>::default();
338 assert!(tree.nearest(&Point2::new(0.0, 0.0)).is_none())
339 }
340
341 let tree = generate_tree();
342 let nearest = tree.nearest(&Point3::new(1.32, 2.7, 0.2));
343 assert!(nearest.is_some());
344 assert_eq!(nearest.unwrap(), Point3::new(1.3, 2.5, 0.5));
345 }
346
347 #[test]
348 fn compare_nearest_with_naive_version() {
349 let points_a = [
350 [8.037338, -10.512266, 5.3038273],
351 [-13.573973, 5.2957783, -5.7758245],
352 [5.399618, 14.216839, 13.042112],
353 [10.134924, -3.9498444, 12.201418],
354 [-3.7965546, -4.1447372, 3.7468758],
355 [2.494978, -5.231186, 10.918207],
356 [10.469978, 2.231762, 12.076345],
357 [-11.764912, 14.629526, -14.80231],
358 [-8.693936, 5.038475, -0.32558632],
359 [7.616955, -3.7277327, 2.344328],
360 [-11.924471, -11.668331, -1.2298765],
361 [-14.369208, -7.1591473, -9.843174],
362 ]
363 .into_iter()
364 .map(Point3::from)
365 .collect::<Vec<_>>();
366
367 let points_b = [
368 [6.196747, -11.11811, 0.470586],
369 [-13.9269495, 9.677899, 1.9754279],
370 [13.07056, 12.289567, 9.591913],
371 [12.668911, -6.104495, 5.763672],
372 [-3.2386777, -2.61825, 5.1327395],
373 [5.2409143, -5.826359, 8.294433],
374 [14.281796, -0.12630486, 5.762767],
375 [-2.7135608, 15.505872, 16.110285],
376 [5.980031, -4.006213, -1.6124942],
377 [-14.19904, -7.7923203, 4.401306],
378 [-19.287233, -1.7146804, -1.7363598],
379 ]
380 .into_iter()
381 .map(Point3::from)
382 .collect::<Vec<_>>();
383
384 let kd_tree = KDTree::from(points_b.as_slice());
385
386 let closest_points_naive = points_a
387 .iter()
388 .map(|point_a| find_nearest_neighbour_naive(point_a, points_b.as_slice()))
389 .collect::<Vec<_>>();
390 let closest_point_kd = points_a
391 .iter()
392 .map(|point_a| kd_tree.nearest(point_a))
393 .collect::<Vec<_>>();
394 assert_eq!(closest_points_naive, closest_point_kd);
395 }
396
397 #[test]
398 fn test_traverse_tree() {
399 let tree = generate_tree();
400 let mut sum = 0.0;
401 tree.traverse_tree(|point| {
402 sum += point.x + point.y;
403 });
404
405 assert_eq!(sum, 6.9); }
407
408 #[test]
409 fn test_traverse_tree_mut() {
410 let mut tree = generate_tree();
411 tree.traverse_tree_mut(|point| {
412 *point = Point3::new(1.0, 1.0, 1.0);
413 });
414
415 tree.traverse_tree(|point| {
416 assert_eq!(point.x, 1.0);
417 assert_eq!(point.y, 1.0);
418 assert_eq!(point.z, 1.0);
419 });
420 }
421
422 #[test]
423 fn test_multiple_elements_structure() {
424 let mut tree = KDTree::default();
425 let points = Vec::from([
426 Point2::new(3.0, 6.0),
427 Point2::new(17.0, 15.0),
428 Point2::new(13.0, 15.0),
429 Point2::new(6.0, 12.0),
430 Point2::new(9.0, 1.0),
431 Point2::new(2.0, 7.0),
432 Point2::new(10.0, 19.0),
433 ]);
434
435 for point in points.iter() {
436 tree.insert(*point);
437 }
438
439 assert_eq!(tree.len(), 7);
440 }
441}