mapping_algorithms/kd_tree/
mod.rs

1// SPDX-License-Identifier: MIT
2/*
3 * Copyright (c) [2023 - Present] Emily Matheys <emilymatt96@gmail.com>
4 *
5 * Permission is hereby granted, free of charge, to any person obtaining a copy
6 * of this software and associated documentation files (the "Software"), to deal
7 * in the Software without restriction, including without limitation the rights
8 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9 * copies of the Software, and to permit persons to whom the Software is
10 * furnished to do so, subject to the following conditions:
11 *
12 * The above copyright notice and this permission notice shall be included in all
13 * copies or substantial portions of the Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21 * SOFTWARE.
22 */
23
24use nalgebra::{Point, Scalar};
25use num_traits::{NumOps, Zero};
26
27use crate::{utils::distance_squared, Box, Ordering};
28
29#[derive(Clone, Debug, Default)]
30struct KDNode<T, const N: usize>
31where
32    T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
33{
34    internal_data: Point<T, N>,
35    right: Option<Box<KDNode<T, N>>>,
36    left: Option<Box<KDNode<T, N>>>,
37}
38
39impl<T, const N: usize> KDNode<T, N>
40where
41    T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
42{
43    fn new(root: Point<T, N>) -> Self {
44        Self {
45            internal_data: root,
46            left: None,
47            right: None,
48        }
49    }
50
51    #[cfg_attr(
52        feature = "tracing",
53        tracing::instrument("Insert New Point", skip_all, level = "trace")
54    )]
55    fn insert(&mut self, data: Point<T, N>, depth: usize) -> bool {
56        let dimension_to_check = depth % N;
57
58        let (branch_to_use, verify_equals) =
59            // Note that this is a &mut Option, not an Option<&mut>!
60            match data.coords[dimension_to_check].partial_cmp(&self.internal_data.coords[dimension_to_check]).unwrap() {
61                Ordering::Less => (&mut self.left, false),
62                Ordering::Equal => (&mut self.right, true),
63                Ordering::Greater => (&mut self.right, false)
64            };
65
66        if let Some(branch_exists) = branch_to_use.as_mut() {
67            return branch_exists.insert(data, depth + 1);
68        } else if verify_equals && self.internal_data == data {
69            return false;
70        }
71
72        *branch_to_use = Some(Box::new(KDNode::new(data)));
73        true
74    }
75
76    #[cfg_attr(
77        feature = "tracing",
78        tracing::instrument("Branch Nearest Neighbour", skip_all, level = "trace")
79    )]
80    fn nearest(&self, target: &Point<T, N>, depth: usize) -> Option<Point<T, N>> {
81        let dimension_to_check = depth % N;
82        let (next_branch, opposite_branch) =
83            if target.coords[dimension_to_check] < self.internal_data.coords[dimension_to_check] {
84                (self.left.as_ref(), self.right.as_ref())
85            } else {
86                (self.right.as_ref(), self.left.as_ref())
87            };
88
89        // Start with the nearer branch, default to this branch's point
90        let mut best = next_branch
91            .and_then(|branch| branch.nearest(target, depth + 1))
92            .unwrap_or(self.internal_data);
93
94        let axis_distance =
95            target.coords[dimension_to_check] - self.internal_data.coords[dimension_to_check];
96
97        if distance_squared(&self.internal_data, target) < distance_squared(&best, target) {
98            best = self.internal_data;
99        }
100
101        if (axis_distance * axis_distance) < distance_squared(&best, target) {
102            if let Some(opposite_best) =
103                opposite_branch.and_then(|branch| branch.nearest(target, depth + 1))
104            {
105                if distance_squared(&opposite_best, target) < distance_squared(&best, target) {
106                    return Some(opposite_best);
107                }
108            }
109        }
110
111        Some(best)
112    }
113
114    #[cfg_attr(
115        feature = "tracing",
116        tracing::instrument("Traverse Branch With Function", skip_all, level = "debug")
117    )]
118    fn traverse_branch<F: FnMut(&Point<T, N>)>(&self, func: &mut F) {
119        if let Some(left) = self.left.as_ref() {
120            left.traverse_branch(func);
121        }
122        func(&self.internal_data);
123        if let Some(right) = self.right.as_ref() {
124            right.traverse_branch(func);
125        }
126    }
127
128    #[cfg_attr(
129        feature = "tracing",
130        tracing::instrument("Traverse Branch With Mutable Function)", skip_all, level = "debug")
131    )]
132    fn traverse_branch_mut<F: FnMut(&mut Point<T, N>)>(&mut self, func: &mut F) {
133        if let Some(left) = self.left.as_mut() {
134            left.traverse_branch_mut(func);
135        }
136        func(&mut self.internal_data);
137        if let Some(right) = self.right.as_mut() {
138            right.traverse_branch_mut(func);
139        }
140    }
141}
142
143/// The Actual K-Dimensional Tree struct, contains it's first node.
144///
145/// # Generics
146/// `T`: Either an [`f32`] or [`f64`]
147/// `N`: a const usize specifying how many dimensions should each point have.
148#[derive(Clone, Debug, Default)]
149pub struct KDTree<T, const N: usize>
150where
151    T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
152{
153    root: Option<KDNode<T, N>>,
154    element_count: usize,
155}
156
157impl<T, const N: usize> KDTree<T, N>
158where
159    T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
160{
161    /// Inserts a new data points into the tree, taking into consideration it's position.
162    ///
163    /// # Arguments
164    /// * `data`: a [`Point`], to be inserted into the tree.
165    #[cfg_attr(
166        feature = "tracing",
167        tracing::instrument("Insert To Tree", skip_all, level = "debug")
168    )]
169    pub fn insert(&mut self, data: Point<T, N>) {
170        if let Some(root) = self.root.as_mut() {
171            if root.insert(data, 0) {
172                self.element_count += 1;
173            }
174        } else {
175            self.root = Some(KDNode::new(data));
176            self.element_count = 1;
177        }
178    }
179
180    /// Returns the number of elements in the tree.
181    ///
182    /// # Returns
183    /// A [`usize`] representing the number of elements in the tree.
184    pub fn len(&self) -> usize {
185        self.element_count
186    }
187
188    /// Returns whether the tree is empty or not.
189    ///
190    /// # Returns
191    /// A [`bool`] representing whether the tree is empty or not.
192    pub fn is_empty(&self) -> bool {
193        self.element_count == 0
194    }
195
196    /// Attempts to find the nearest point in the tree for the specified target point.
197    /// # Arguments
198    /// * `target`: a [`Point`], to search the closest point for.
199    ///
200    /// # Returns
201    /// [`None`] if the tree is empty, otherwise returns the closest [`Point`].
202    #[cfg_attr(
203        feature = "tracing",
204        tracing::instrument("Find Nearest Neighbour", skip_all, level = "debug")
205    )]
206    pub fn nearest(&self, target: &Point<T, N>) -> Option<Point<T, N>> {
207        self.root.as_ref().and_then(|root| root.nearest(target, 0))
208    }
209
210    /// Allows traversal of the entire tree structure, calling the `func` closure on each branch's data.
211    ///
212    /// # Arguments
213    /// * `func`: a closure of type [`Fn`], it's only parameter is a reference of the branch's [`Point`].
214    #[cfg_attr(
215        feature = "tracing",
216        tracing::instrument("Traverse Tree With Function", skip_all, level = "info")
217    )]
218    pub fn traverse_tree<F: FnMut(&Point<T, N>)>(&self, mut func: F) {
219        if let Some(root) = self.root.as_ref() {
220            root.traverse_branch(&mut func);
221        }
222    }
223
224    /// Allows traversal of the entire tree structure, calling the `func` closure on each branch's data, possible mutating the data.
225    ///
226    /// # Arguments
227    /// * func: a closure of type [`FnMut`], it's only parameter is a reference of the branch's [`Point`].
228    #[cfg_attr(
229        feature = "tracing",
230        tracing::instrument("Traverse Tree With Mutable Function", skip_all, level = "info")
231    )]
232    pub fn traverse_tree_mut<F: FnMut(&mut Point<T, N>)>(&mut self, mut func: F) {
233        if let Some(root) = self.root.as_mut() {
234            root.traverse_branch_mut(&mut func);
235        }
236    }
237}
238
239impl<T, const N: usize> From<&[Point<T, N>]> for KDTree<T, N>
240where
241    T: Copy + Default + NumOps + PartialOrd + Scalar + Zero,
242{
243    #[cfg_attr(
244        feature = "tracing",
245        tracing::instrument("Generate Tree From Point Cloud", skip_all, level = "info")
246    )]
247    fn from(point_cloud: &[Point<T, N>]) -> Self {
248        point_cloud
249            .iter()
250            .copied()
251            .fold(Self::default(), |mut tree, current_point| {
252                tree.insert(current_point);
253                tree
254            })
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use nalgebra::{Point2, Point3};
261
262    use crate::{point_clouds::find_nearest_neighbour_naive, Vec};
263
264    use super::*;
265
266    fn generate_tree() -> KDTree<f32, 3> {
267        let points = Vec::from([
268            Point3::new(0.0, 2.0, 1.0),
269            Point3::new(-1.0, 4.0, 2.5),
270            Point3::new(1.3, 2.5, 0.5),
271            Point3::new(-2.1, 0.2, -0.2),
272        ]);
273        KDTree::from(points.as_slice())
274    }
275
276    #[test]
277    fn test_insert() {
278        // Test an empty tree
279        let mut tree = KDTree::default();
280        tree.insert(Point2::new(0.0f32, 0.0f32));
281
282        match tree.root.as_ref() {
283            None => {
284                panic!("Error, tree root should be Some()")
285            }
286            Some(root) => {
287                assert_eq!(root.internal_data, Point2::new(0.0f32, 0.0f32));
288            }
289        }
290
291        // Inserting new element
292        // Since x is less than root's x, first divergence should be to the left branch.
293        tree.insert(Point2::new(-1.0f32, 0.4f32));
294        match tree.root.as_ref().unwrap().left.as_ref() {
295            None => {
296                panic!("Error, first left branch should be Some()");
297            }
298            Some(left_branch) => {
299                assert_eq!(left_branch.internal_data, Point2::new(-1.0f32, 0.4f32));
300            }
301        }
302
303        // Since second element's x is still less than root's x, right branch should be unchanged.
304        tree.insert(Point2::new(-2.0f32, -3.0f32));
305        assert!(tree.root.as_ref().unwrap().right.is_none());
306
307        // Third element's x is larger than root's x, so it should be in right branch.
308        tree.insert(Point2::new(1.4f32, 5.0f32));
309        match tree.root.as_ref().unwrap().right.as_ref() {
310            None => {
311                panic!("Error, first right branch should be Some()");
312            }
313            Some(right_branch) => {
314                assert_eq!(right_branch.internal_data, Point2::new(1.4f32, 5.0f32));
315            }
316        }
317    }
318
319    #[test]
320    fn test_insert_duplicate() {
321        let mut tree = KDTree::default();
322        assert!(tree.is_empty());
323
324        tree.insert(Point2::new(0.0f32, 0.0f32));
325        assert_eq!(tree.len(), 1);
326        assert!(!tree.is_empty());
327
328        // Insert duplicate
329        tree.insert(Point2::new(0.0f32, 0.0f32));
330        assert_eq!(tree.len(), 1);
331    }
332
333    #[test]
334    fn test_nearest() {
335        // Test an empty tree
336        {
337            let tree = KDTree::<f32, 2>::default();
338            assert!(tree.nearest(&Point2::new(0.0, 0.0)).is_none())
339        }
340
341        let tree = generate_tree();
342        let nearest = tree.nearest(&Point3::new(1.32, 2.7, 0.2));
343        assert!(nearest.is_some());
344        assert_eq!(nearest.unwrap(), Point3::new(1.3, 2.5, 0.5));
345    }
346
347    #[test]
348    fn compare_nearest_with_naive_version() {
349        let points_a = [
350            [8.037338, -10.512266, 5.3038273],
351            [-13.573973, 5.2957783, -5.7758245],
352            [5.399618, 14.216839, 13.042112],
353            [10.134924, -3.9498444, 12.201418],
354            [-3.7965546, -4.1447372, 3.7468758],
355            [2.494978, -5.231186, 10.918207],
356            [10.469978, 2.231762, 12.076345],
357            [-11.764912, 14.629526, -14.80231],
358            [-8.693936, 5.038475, -0.32558632],
359            [7.616955, -3.7277327, 2.344328],
360            [-11.924471, -11.668331, -1.2298765],
361            [-14.369208, -7.1591473, -9.843174],
362        ]
363        .into_iter()
364        .map(Point3::from)
365        .collect::<Vec<_>>();
366
367        let points_b = [
368            [6.196747, -11.11811, 0.470586],
369            [-13.9269495, 9.677899, 1.9754279],
370            [13.07056, 12.289567, 9.591913],
371            [12.668911, -6.104495, 5.763672],
372            [-3.2386777, -2.61825, 5.1327395],
373            [5.2409143, -5.826359, 8.294433],
374            [14.281796, -0.12630486, 5.762767],
375            [-2.7135608, 15.505872, 16.110285],
376            [5.980031, -4.006213, -1.6124942],
377            [-14.19904, -7.7923203, 4.401306],
378            [-19.287233, -1.7146804, -1.7363598],
379        ]
380        .into_iter()
381        .map(Point3::from)
382        .collect::<Vec<_>>();
383
384        let kd_tree = KDTree::from(points_b.as_slice());
385
386        let closest_points_naive = points_a
387            .iter()
388            .map(|point_a| find_nearest_neighbour_naive(point_a, points_b.as_slice()))
389            .collect::<Vec<_>>();
390        let closest_point_kd = points_a
391            .iter()
392            .map(|point_a| kd_tree.nearest(point_a))
393            .collect::<Vec<_>>();
394        assert_eq!(closest_points_naive, closest_point_kd);
395    }
396
397    #[test]
398    fn test_traverse_tree() {
399        let tree = generate_tree();
400        let mut sum = 0.0;
401        tree.traverse_tree(|point| {
402            sum += point.x + point.y;
403        });
404
405        assert_eq!(sum, 6.9); // Nice
406    }
407
408    #[test]
409    fn test_traverse_tree_mut() {
410        let mut tree = generate_tree();
411        tree.traverse_tree_mut(|point| {
412            *point = Point3::new(1.0, 1.0, 1.0);
413        });
414
415        tree.traverse_tree(|point| {
416            assert_eq!(point.x, 1.0);
417            assert_eq!(point.y, 1.0);
418            assert_eq!(point.z, 1.0);
419        });
420    }
421
422    #[test]
423    fn test_multiple_elements_structure() {
424        let mut tree = KDTree::default();
425        let points = Vec::from([
426            Point2::new(3.0, 6.0),
427            Point2::new(17.0, 15.0),
428            Point2::new(13.0, 15.0),
429            Point2::new(6.0, 12.0),
430            Point2::new(9.0, 1.0),
431            Point2::new(2.0, 7.0),
432            Point2::new(10.0, 19.0),
433        ]);
434
435        for point in points.iter() {
436            tree.insert(*point);
437        }
438
439        assert_eq!(tree.len(), 7);
440    }
441}