kdtree/kdtree/
mod.rs

1pub mod test_common;
2pub mod distance;
3
4mod partition;
5mod bounds;
6
7use self::bounds::*;
8use self::distance::*;
9
10use std::cmp;
11
12pub trait KdtreePointTrait: Copy + PartialEq {
13    fn dims(&self) -> &[f64];
14}
15
16pub struct Kdtree<KdtreePoint> {
17    nodes: Vec<KdtreeNode<KdtreePoint>>,
18
19    node_adding_dimension: usize,
20    node_depth_during_last_rebuild: usize,
21    current_node_depth: usize,
22}
23
24impl<KdtreePoint: KdtreePointTrait> Kdtree<KdtreePoint> {
25    pub fn new(mut points: &mut [KdtreePoint]) -> Kdtree<KdtreePoint> {
26        if points.len() == 0 {
27            panic!("empty vector point not allowed");
28        }
29
30
31
32        let mut tree = Kdtree {
33            nodes: vec![],
34            node_adding_dimension: 0,
35            node_depth_during_last_rebuild: 0,
36            current_node_depth: 0,
37        };
38
39        tree.rebuild_tree(&mut points);
40
41        tree
42    }
43
44    pub fn rebuild_tree(&mut self, points : &mut [KdtreePoint]) {
45        self.nodes.clear();
46
47        self.node_depth_during_last_rebuild = 0;
48        self.current_node_depth = 0;
49
50        let rect = Bounds::new_from_points(points);
51        self.build_tree(points, &rect, 1);
52    }
53
54    /// Can be used if you are sure that the tree is degenerated or if you will never again insert the nodes into the tree.
55    pub fn gather_points_and_rebuild(&mut self) {
56        let mut points : Vec<KdtreePoint> = vec![];
57        self.gather_points(0,&mut points);
58
59        self.rebuild_tree(&mut points);
60    }
61
62    pub fn nearest_search(&self, node: &KdtreePoint) -> KdtreePoint
63    {
64        let mut nearest_neighbor = 0usize;
65        let mut best_distance = squared_euclidean(node.dims(), &self.nodes[0].point.dims());
66        self.nearest_search_impl(node, 0usize, &mut best_distance, &mut nearest_neighbor);
67
68        self.nodes[nearest_neighbor].point
69    }
70
71    pub fn has_neighbor_in_range(&self, node: &KdtreePoint, range: f64) -> bool {
72        let squared_range = range * range;
73
74        self.distance_squared_to_nearest(node) <= squared_range
75    }
76
77    pub fn distance_squared_to_nearest(&self, node: &KdtreePoint) -> f64 {
78        squared_euclidean(&self.nearest_search(node).dims(), node.dims())
79    }
80
81    pub fn insert_nodes_and_rebuild(&mut self, nodes_to_add : &mut [KdtreePoint]) {
82        let mut pts : Vec<KdtreePoint> = vec![];
83        self.gather_points(0, &mut pts);
84        pts.extend(nodes_to_add.iter());
85
86        self.rebuild_tree(&mut pts);
87    }
88
89    pub fn insert_node(&mut self, node_to_add : KdtreePoint) {
90
91        let mut current_index = 0;
92        let dimension = self.node_adding_dimension;
93        let index_of_new_node = self.add_node(node_to_add,dimension,node_to_add.dims()[dimension]);
94        self.node_adding_dimension = ( dimension + 1) % node_to_add.dims().len();
95        let mut should_pop_node = false;
96
97        let mut depth = 0;
98        loop {
99
100            depth +=1 ;
101            let current_node = &mut self.nodes[current_index];
102
103            if node_to_add.dims()[current_node.dimension] <= current_node.split_on {
104                if let Some(left_node_index) = current_node.left_node {
105                    current_index = left_node_index
106                } else {
107                    if current_node.point.eq(&node_to_add) {
108                        should_pop_node = true;
109                    } else {
110                        current_node.left_node = Some(index_of_new_node);
111                    }
112                    break;
113                }
114            } else {
115                if let Some(right_node_index) = current_node.right_node {
116                    current_index = right_node_index
117                } else {
118                    if current_node.point.eq(&node_to_add) {
119                        should_pop_node = true;
120                    } else {
121                        current_node.right_node = Some(index_of_new_node);
122                    }
123                    break;
124                }
125            }
126        }
127
128        if should_pop_node {
129            self.nodes.pop();
130        }
131
132        if self.node_depth_during_last_rebuild as f64 * 4.0 < depth as f64  {
133            self.gather_points_and_rebuild();
134        }
135    }
136
137    fn nearest_search_impl(&self, p: &KdtreePoint, searched_index: usize, best_distance_squared: &mut f64, best_leaf_found: &mut usize) {
138        let node = &self.nodes[searched_index];
139
140        let splitting_value = node.split_on;
141        let point_splitting_dim_value = p.dims()[node.dimension];
142
143        let (closer_node, farther_node) = if point_splitting_dim_value <= splitting_value {
144            (node.left_node, node.right_node)
145        } else {
146            (node.right_node, node.left_node)
147        };
148
149        if let Some(closer_node) = closer_node {
150            self.nearest_search_impl(p, closer_node, best_distance_squared, best_leaf_found);
151        }
152
153        let distance = squared_euclidean(p.dims(), node.point.dims());
154        if distance < *best_distance_squared {
155            *best_distance_squared = distance;
156            *best_leaf_found = searched_index;
157        }
158
159        if let Some(farther_node) = farther_node {
160            let distance_on_single_dimension = squared_euclidean(&[splitting_value], &[point_splitting_dim_value]);
161
162            if distance_on_single_dimension <= *best_distance_squared {
163                self.nearest_search_impl(p, farther_node, best_distance_squared, best_leaf_found);
164            }
165        }
166    }
167
168    fn add_node(&mut self, p: KdtreePoint, dimension: usize, split_on: f64) -> usize {
169        let node = KdtreeNode::new(p, dimension, split_on);
170
171        self.nodes.push(node);
172        self.nodes.len() - 1
173    }
174
175    fn build_tree(&mut self, nodes: &mut [KdtreePoint], bounds: &Bounds, depth : usize) -> usize {
176        let splitting_index = partition::partition_sliding_midpoint(nodes, bounds.get_midvalue_of_widest_dim(), bounds.get_widest_dim());
177        let pivot_value = nodes[splitting_index].dims()[bounds.get_widest_dim()];
178
179        let node_id = self.add_node(nodes[splitting_index], bounds.get_widest_dim(), pivot_value);
180        let nodes_len = nodes.len();
181
182        if splitting_index > 0 {
183            let left_rect = bounds.clone_moving_max(pivot_value, bounds.get_widest_dim());
184            let left_child_id = self.build_tree(&mut nodes[0..splitting_index], &left_rect, depth+1);
185            self.nodes[node_id].left_node = Some(left_child_id);
186        }
187
188        if splitting_index < nodes.len() - 1 {
189            let right_rect = bounds.clone_moving_min(pivot_value, bounds.get_widest_dim());
190
191            let right_child_id = self.build_tree(&mut nodes[splitting_index + 1..nodes_len], &right_rect, depth+1);
192            self.nodes[node_id].right_node = Some(right_child_id);
193        }
194
195        self.node_depth_during_last_rebuild =  cmp::max(self.node_depth_during_last_rebuild,depth);
196
197        node_id
198    }
199
200    fn gather_points(&self, current_index: usize, points : &mut Vec<KdtreePoint>){
201        points.push(self.nodes[current_index].point);
202        if let Some(left_index) = self.nodes[current_index].left_node {
203            self.gather_points(left_index, points);
204        }
205
206        if let Some(right_index) = self.nodes[current_index].right_node {
207            self.gather_points(right_index, points);
208        }
209    }
210}
211
212pub struct KdtreeNode<T> {
213    left_node: Option<usize>,
214    right_node: Option<usize>,
215
216    point: T,
217    dimension: usize,
218    split_on: f64
219}
220
221impl<T: KdtreePointTrait> KdtreeNode<T> {
222    fn new(p: T, splitting_dimension: usize, split_on_value: f64) -> KdtreeNode<T> {
223        KdtreeNode {
224            left_node: None,
225            right_node: None,
226
227            point: p,
228            dimension: splitting_dimension,
229            split_on: split_on_value
230        }
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use ::kdtree::test_common::Point2WithId;
237
238    use super::*;
239
240    #[test]
241    #[should_panic(expected = "empty vector point not allowed")]
242    fn should_panic_given_empty_vector() {
243        let mut empty_vec: Vec<Point2WithId> = vec![];
244
245        Kdtree::new(&mut empty_vec);
246    }
247
248    quickcheck! {
249        fn tree_build_creates_tree_with_as_many_leafs_as_there_is_points(xs : Vec<f64>) -> bool {
250            if xs.len() == 0 {
251                return true;
252            }
253            let mut vec : Vec<Point2WithId> = vec![];
254            for i in 0 .. xs.len() {
255                let p = Point2WithId::new(i as i32, xs[i], xs[i]);
256
257                vec.push(p);
258            }
259
260            let tree = Kdtree::new(&mut qc_value_vec_to_2d_points_vec(&xs));
261
262            let mut to_iterate : Vec<usize> = vec![];
263            to_iterate.push(0);
264
265            while to_iterate.len() > 0 {
266                let last_index = to_iterate.last().unwrap().clone();
267                let ref x = tree.nodes.get(last_index).unwrap();
268                to_iterate.pop();
269                if x.left_node.is_some() {
270                    to_iterate.push(x.left_node.unwrap());
271                }
272                if x.right_node.is_some() {
273                    to_iterate.push(x.right_node.unwrap());
274                }
275            }
276            xs.len() == tree.nodes.len()
277        }
278    }
279
280    quickcheck! {
281        fn nearest_neighbor_search_using_qc(xs : Vec<f64>) -> bool {
282            if xs.len() == 0 {
283                return true;
284            }
285
286            let point_vec = qc_value_vec_to_2d_points_vec(&xs);
287            let tree = Kdtree::new(&mut point_vec.clone());
288
289            for p in &point_vec {
290                let found_nn = tree.nearest_search(p);
291
292                assert_eq!(p.id,found_nn.id);
293            }
294
295            true
296        }
297    }
298
299    #[test]
300    fn has_neighbor_in_range() {
301        let mut vec: Vec<Point2WithId> = vec![Point2WithId::new(0,2.,0.)];
302
303        let tree = Kdtree::new(&mut vec);
304
305        assert_eq!(false,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 0.));
306        assert_eq!(false,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 1.));
307        assert_eq!(true,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 2.));
308        assert_eq!(true,tree.has_neighbor_in_range(&Point2WithId::new(0,0.,0.), 300.));
309    }
310
311    #[test]
312    fn incremental_add_adds_as_expected() {
313        //this test is tricky because it can have problems with the automatic tree rebuild.
314
315        let mut vec = vec![Point2WithId::new(0,0.,0.)];
316
317        let mut tree = Kdtree::new(&mut vec);
318
319        tree.insert_node(Point2WithId::new(0,1.,0.));
320        tree.insert_node(Point2WithId::new(0,-1.,0.));
321
322        assert_eq!(tree.nodes.len(), 3);
323        assert_eq!(tree.nodes[0].dimension, 0);
324
325        assert_eq!(tree.nodes[0].left_node.is_some(), true);
326        assert_eq!(tree.nodes[1].point.dims()[0], 1.);
327        assert_eq!(tree.nodes[2].point.dims()[0], -1.);
328
329        assert_eq!(tree.nodes[0].right_node.is_some(), true);
330    }
331
332    #[test]
333    fn incremental_add_filters_duplicates() {
334        let mut vec = vec![Point2WithId::new(0,0.,0.)];
335
336        let mut tree = Kdtree::new(&mut vec);
337
338        let node = Point2WithId::new(0,1.,0.);
339        tree.insert_node(node);
340        tree.insert_node(node);
341
342        assert_eq!(tree.nodes.len(), 2);
343    }
344
345    fn qc_value_vec_to_2d_points_vec(xs: &Vec<f64>) -> Vec<Point2WithId> {
346        let mut vec: Vec<Point2WithId> = vec![];
347        for i in 0..xs.len() {
348            let mut is_duplicated_value = false;
349            for j in 0..i {
350                if xs[i] == xs[j] {
351                    is_duplicated_value = true;
352                    break;
353                }
354            }
355            if !is_duplicated_value {
356                let p = Point2WithId::new(i as i32, xs[i], xs[i]);
357                vec.push(p);
358            }
359        }
360
361        vec
362    }
363}