align3d/
kdtree.rs

1use nalgebra::Vector3;
2use ndarray::prelude::*;
3
4enum Node {
5    Leaf {
6        points: Array1<Vector3<f32>>,
7        indices: Vec<usize>,
8    },
9    NonLeaf {
10        middle_value: f32,
11        left: Box<Node>,
12        right: Box<Node>,
13    },
14}
15
16/// KdTree for fast nearest neighbor search.
17pub struct R3dTree {
18    root: Box<Node>,
19}
20
21impl R3dTree {
22    /// Create a new KdTree from a set of points.
23    /// The points are stored in a 2D array, where each row is a point.
24    ///
25    /// # Arguments
26    ///
27    /// * points - 2D array of points.
28    pub fn new(points: &ArrayView1<Vector3<f32>>) -> Self {
29        // Recursive creation.
30        fn rec(points: &ArrayView1<Vector3<f32>>, mut indices: Vec<usize>, depth: usize) -> Node {
31            // Stop recursion if this should be a leaf node.
32            if indices.len() <= 16 {
33                return Node::Leaf {
34                    points: points.select(ndarray::Axis(0), &indices),
35                    indices,
36                };
37            }
38
39            let k = depth % 3;
40            indices.sort_by(|idx1, idx2| {
41                let a = points[*idx1][k];
42                let b = points[*idx2][k];
43                a.partial_cmp(&b).unwrap()
44            });
45
46            let mid = indices.len() / 2;
47            Node::NonLeaf {
48                middle_value: points[indices[mid]][k],
49                left: Box::new(rec(points, indices[0..mid].to_vec(), depth + 1)),
50                right: Box::new(rec(points, indices[mid..].to_vec(), depth + 1)),
51            }
52        }
53
54        let indices = Vec::from_iter(0..points.shape()[0]);
55        Self {
56            root: Box::new(rec(points, indices, 0)),
57        }
58    }
59
60    /// Find the nearest neighbor to a query point. This version is for 3D points only.
61    ///
62    /// # Arguments
63    ///
64    /// * point - The query point.
65    ///
66    /// # Returns
67    ///
68    /// A tuple containing the index of the nearest neighbor and the distance to it.
69    pub fn nearest(&self, point: &Vector3<f32>) -> (usize, f32) {
70        let mut curr_node = &self.root;
71        let mut current_dim = 0;
72
73        loop {
74            match curr_node.as_ref() {
75                Node::NonLeaf {
76                    middle_value: mid,
77                    left,
78                    right,
79                } => {
80                    curr_node = if point[current_dim] < *mid {
81                        left
82                    } else {
83                        right
84                    };
85                    current_dim = (current_dim + 1) % 3;
86                }
87                Node::Leaf {
88                    points: leaf_points,
89                    indices,
90                } => {
91                    let mut min_dist = f32::MAX;
92                    let mut min_idx = 0;
93                    for (idx, leaf_point) in leaf_points.iter().enumerate() {
94                        let leaf_point = Vector3::new(leaf_point[0], leaf_point[1], leaf_point[2]);
95                        let dist = (point - leaf_point).norm_squared();
96                        if dist < min_dist {
97                            min_dist = dist;
98                            min_idx = idx;
99                        }
100                    }
101                    return (indices[min_idx], min_dist);
102                }
103            }
104        }
105    }
106}
107
108#[cfg(test)]
109mod tests {
110    use std::time::Instant;
111
112    use crate::kdtree::R3dTree;
113    use crate::unit_test::access::UnflattenVector3;
114    use nalgebra::Vector3;
115    use ndarray::prelude::*;
116    use rand::rngs::SmallRng;
117    use rand::seq::SliceRandom;
118    use rand::SeedableRng;
119
120    #[test]
121    fn should_find_nearest_points() {
122        let points = array![[1., 2., 3.], [2., 3., 4.], [5., 6., 7.], [8., 9., 1.]]
123            .unflatten_vector3()
124            .unwrap();
125        let tree = R3dTree::new(&points.view());
126
127        let queries = array![
128            [8., 9.1, 1.3],
129            [5.1, 6.4, 7.],
130            [1.5, 2.1, 3.3],
131            [2.2, 3.1, 4.2]
132        ];
133
134        for (query, expected) in queries.outer_iter().zip(&[3, 2, 0, 1]) {
135            let query = Vector3::new(query[0], query[1], query[2]);
136            let (idx, _) = tree.nearest(&query);
137            assert_eq!(idx, *expected);
138        }
139    }
140
141    #[test]
142    fn should_find_nearest_points_big() {
143        let ordered_points =
144            Array::from_shape_vec((500, 3), (0..500 * 3).map(|x| x as f32).collect()).unwrap();
145
146        let (random_indices, randomized_points) = {
147            let mut random_indices = (0..500).collect::<Vec<usize>>();
148            let seed: [u8; 32] = [5; 32];
149            random_indices.shuffle(&mut SmallRng::from_seed(seed));
150
151            let mut randomized_points = ordered_points.clone();
152            for (i, rand_index) in random_indices.iter().enumerate().take(500_usize) {
153                randomized_points
154                    .slice_mut(s![*rand_index, ..])
155                    .assign(&ordered_points.slice(s![i, ..]).view());
156            }
157            (
158                random_indices,
159                randomized_points.unflatten_vector3().unwrap(),
160            )
161        };
162
163        let tree = R3dTree::new(&randomized_points.view());
164
165        for (query, expected) in ordered_points.outer_iter().zip(random_indices.iter()) {
166            let query = Vector3::new(query[0], query[1], query[2]);
167            let (idx, _) = tree.nearest(&query);
168            assert_eq!(idx, *expected);
169        }
170    }
171
172    #[test]
173    fn bench_nearest() {
174        const N: usize = 500_000;
175        let ordered_points =
176            Array::from_shape_vec((N, 3), (0..N * 3).map(|x| x as f32).collect()).unwrap();
177
178        let randomized_points = {
179            let mut random_indices = (0..N).collect::<Vec<usize>>();
180            let seed: [u8; 32] = [5; 32];
181            random_indices.shuffle(&mut SmallRng::from_seed(seed));
182
183            let mut randomized_points = ordered_points.clone();
184            for (i, rand_index) in random_indices.iter().enumerate().take(N) {
185                randomized_points
186                    .slice_mut(s![*rand_index, ..])
187                    .assign(&ordered_points.slice(s![i, ..]).view());
188            }
189            randomized_points
190                .slice_move(s![0..5000, ..])
191                .unflatten_vector3()
192                .unwrap()
193        };
194
195        let tree = R3dTree::new(&randomized_points.view());
196
197        let mut sum_millis = 0;
198        const M: usize = 10;
199        for _ in 0..M {
200            let start = Instant::now();
201            for point in ordered_points.outer_iter() {
202                let point = Vector3::new(point[0], point[1], point[2]);
203                tree.nearest(&point);
204            }
205            sum_millis += start.elapsed().as_millis();
206        }
207
208        println!("Mean time: {}", sum_millis as f64 / M as f64);
209    }
210}