1use nalgebra::Vector3;
2use ndarray::prelude::*;
3
4enum Node {
5 Leaf {
6 points: Array1<Vector3<f32>>,
7 indices: Vec<usize>,
8 },
9 NonLeaf {
10 middle_value: f32,
11 left: Box<Node>,
12 right: Box<Node>,
13 },
14}
15
16pub struct R3dTree {
18 root: Box<Node>,
19}
20
21impl R3dTree {
22 pub fn new(points: &ArrayView1<Vector3<f32>>) -> Self {
29 fn rec(points: &ArrayView1<Vector3<f32>>, mut indices: Vec<usize>, depth: usize) -> Node {
31 if indices.len() <= 16 {
33 return Node::Leaf {
34 points: points.select(ndarray::Axis(0), &indices),
35 indices,
36 };
37 }
38
39 let k = depth % 3;
40 indices.sort_by(|idx1, idx2| {
41 let a = points[*idx1][k];
42 let b = points[*idx2][k];
43 a.partial_cmp(&b).unwrap()
44 });
45
46 let mid = indices.len() / 2;
47 Node::NonLeaf {
48 middle_value: points[indices[mid]][k],
49 left: Box::new(rec(points, indices[0..mid].to_vec(), depth + 1)),
50 right: Box::new(rec(points, indices[mid..].to_vec(), depth + 1)),
51 }
52 }
53
54 let indices = Vec::from_iter(0..points.shape()[0]);
55 Self {
56 root: Box::new(rec(points, indices, 0)),
57 }
58 }
59
60 pub fn nearest(&self, point: &Vector3<f32>) -> (usize, f32) {
70 let mut curr_node = &self.root;
71 let mut current_dim = 0;
72
73 loop {
74 match curr_node.as_ref() {
75 Node::NonLeaf {
76 middle_value: mid,
77 left,
78 right,
79 } => {
80 curr_node = if point[current_dim] < *mid {
81 left
82 } else {
83 right
84 };
85 current_dim = (current_dim + 1) % 3;
86 }
87 Node::Leaf {
88 points: leaf_points,
89 indices,
90 } => {
91 let mut min_dist = f32::MAX;
92 let mut min_idx = 0;
93 for (idx, leaf_point) in leaf_points.iter().enumerate() {
94 let leaf_point = Vector3::new(leaf_point[0], leaf_point[1], leaf_point[2]);
95 let dist = (point - leaf_point).norm_squared();
96 if dist < min_dist {
97 min_dist = dist;
98 min_idx = idx;
99 }
100 }
101 return (indices[min_idx], min_dist);
102 }
103 }
104 }
105 }
106}
107
108#[cfg(test)]
109mod tests {
110 use std::time::Instant;
111
112 use crate::kdtree::R3dTree;
113 use crate::unit_test::access::UnflattenVector3;
114 use nalgebra::Vector3;
115 use ndarray::prelude::*;
116 use rand::rngs::SmallRng;
117 use rand::seq::SliceRandom;
118 use rand::SeedableRng;
119
120 #[test]
121 fn should_find_nearest_points() {
122 let points = array![[1., 2., 3.], [2., 3., 4.], [5., 6., 7.], [8., 9., 1.]]
123 .unflatten_vector3()
124 .unwrap();
125 let tree = R3dTree::new(&points.view());
126
127 let queries = array![
128 [8., 9.1, 1.3],
129 [5.1, 6.4, 7.],
130 [1.5, 2.1, 3.3],
131 [2.2, 3.1, 4.2]
132 ];
133
134 for (query, expected) in queries.outer_iter().zip(&[3, 2, 0, 1]) {
135 let query = Vector3::new(query[0], query[1], query[2]);
136 let (idx, _) = tree.nearest(&query);
137 assert_eq!(idx, *expected);
138 }
139 }
140
141 #[test]
142 fn should_find_nearest_points_big() {
143 let ordered_points =
144 Array::from_shape_vec((500, 3), (0..500 * 3).map(|x| x as f32).collect()).unwrap();
145
146 let (random_indices, randomized_points) = {
147 let mut random_indices = (0..500).collect::<Vec<usize>>();
148 let seed: [u8; 32] = [5; 32];
149 random_indices.shuffle(&mut SmallRng::from_seed(seed));
150
151 let mut randomized_points = ordered_points.clone();
152 for (i, rand_index) in random_indices.iter().enumerate().take(500_usize) {
153 randomized_points
154 .slice_mut(s![*rand_index, ..])
155 .assign(&ordered_points.slice(s![i, ..]).view());
156 }
157 (
158 random_indices,
159 randomized_points.unflatten_vector3().unwrap(),
160 )
161 };
162
163 let tree = R3dTree::new(&randomized_points.view());
164
165 for (query, expected) in ordered_points.outer_iter().zip(random_indices.iter()) {
166 let query = Vector3::new(query[0], query[1], query[2]);
167 let (idx, _) = tree.nearest(&query);
168 assert_eq!(idx, *expected);
169 }
170 }
171
172 #[test]
173 fn bench_nearest() {
174 const N: usize = 500_000;
175 let ordered_points =
176 Array::from_shape_vec((N, 3), (0..N * 3).map(|x| x as f32).collect()).unwrap();
177
178 let randomized_points = {
179 let mut random_indices = (0..N).collect::<Vec<usize>>();
180 let seed: [u8; 32] = [5; 32];
181 random_indices.shuffle(&mut SmallRng::from_seed(seed));
182
183 let mut randomized_points = ordered_points.clone();
184 for (i, rand_index) in random_indices.iter().enumerate().take(N) {
185 randomized_points
186 .slice_mut(s![*rand_index, ..])
187 .assign(&ordered_points.slice(s![i, ..]).view());
188 }
189 randomized_points
190 .slice_move(s![0..5000, ..])
191 .unflatten_vector3()
192 .unwrap()
193 };
194
195 let tree = R3dTree::new(&randomized_points.view());
196
197 let mut sum_millis = 0;
198 const M: usize = 10;
199 for _ in 0..M {
200 let start = Instant::now();
201 for point in ordered_points.outer_iter() {
202 let point = Vector3::new(point[0], point[1], point[2]);
203 tree.nearest(&point);
204 }
205 sum_millis += start.elapsed().as_millis();
206 }
207
208 println!("Mean time: {}", sum_millis as f64 / M as f64);
209 }
210}