kdtree_simd/
kdtree.rs

1use std::collections::BinaryHeap;
2
3use num_traits::{Float, One, Zero};
4
5use crate::heap_element::HeapElement;
6use crate::util;
7
8#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
9#[derive(Clone, Debug)]
10pub struct KdTree<A, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> {
11    // node
12    left: Option<Box<KdTree<A, T, U>>>,
13    right: Option<Box<KdTree<A, T, U>>>,
14    // common
15    dimensions: usize,
16    capacity: usize,
17    size: usize,
18    min_bounds: Box<[A]>,
19    max_bounds: Box<[A]>,
20    // stem
21    split_value: Option<A>,
22    split_dimension: Option<usize>,
23    // leaf
24    points: Option<Vec<U>>,
25    bucket: Option<Vec<T>>,
26}
27
28#[derive(Debug, PartialEq)]
29pub enum ErrorKind {
30    WrongDimension,
31    NonFiniteCoordinate,
32    ZeroCapacity,
33}
34
35impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> KdTree<A, T, U> {
36    /// Create a new KD tree, specifying the dimension size of each point
37    pub fn new(dims: usize) -> Self {
38        KdTree::with_capacity(dims, 2_usize.pow(4))
39    }
40
41    /// Create a new KD tree, specifying the dimension size of each point and the capacity of leaf nodes
42    pub fn with_capacity(dimensions: usize, capacity: usize) -> Self {
43        let min_bounds = vec![A::infinity(); dimensions];
44        let max_bounds = vec![A::neg_infinity(); dimensions];
45        KdTree {
46            left: None,
47            right: None,
48            dimensions,
49            capacity,
50            size: 0,
51            min_bounds: min_bounds.into_boxed_slice(),
52            max_bounds: max_bounds.into_boxed_slice(),
53            split_value: None,
54            split_dimension: None,
55            points: Some(vec![]),
56            bucket: Some(vec![]),
57        }
58    }
59
60    pub fn size(&self) -> usize {
61        self.size
62    }
63
64    pub fn nearest<F>(
65        &self,
66        point: &[A],
67        num: usize,
68        distance: &F,
69    ) -> Result<Vec<(A, &T)>, ErrorKind>
70    where
71        F: Fn(&[A], &[A]) -> A,
72    {
73        if let Err(err) = self.check_point(point) {
74            return Err(err);
75        }
76        let num = std::cmp::min(num, self.size);
77        if num == 0 {
78            return Ok(vec![]);
79        }
80        let mut pending = BinaryHeap::new();
81        let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
82        pending.push(HeapElement {
83            distance: A::zero(),
84            element: self,
85        });
86        while !pending.is_empty()
87            && (evaluated.len() < num
88                || (-pending.peek().unwrap().distance <= evaluated.peek().unwrap().distance))
89        {
90            self.nearest_step(
91                point,
92                num,
93                A::infinity(),
94                distance,
95                &mut pending,
96                &mut evaluated,
97            );
98        }
99        Ok(evaluated
100            .into_sorted_vec()
101            .into_iter()
102            .take(num)
103            .map(Into::into)
104            .collect())
105    }
106
107    pub fn within<F>(&self, point: &[A], radius: A, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
108    where
109        F: Fn(&[A], &[A]) -> A,
110    {
111        if let Err(err) = self.check_point(point) {
112            return Err(err);
113        }
114        if self.size == 0 {
115            return Ok(vec![]);
116        }
117        let mut pending = BinaryHeap::new();
118        let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
119        pending.push(HeapElement {
120            distance: A::zero(),
121            element: self,
122        });
123        while !pending.is_empty() && (-pending.peek().unwrap().distance <= radius) {
124            self.nearest_step(
125                point,
126                self.size,
127                radius,
128                distance,
129                &mut pending,
130                &mut evaluated,
131            );
132        }
133        Ok(evaluated
134            .into_sorted_vec()
135            .into_iter()
136            .map(Into::into)
137            .collect())
138    }
139
140    fn nearest_step<'b, F>(
141        &self,
142        point: &[A],
143        num: usize,
144        max_dist: A,
145        distance: &F,
146        pending: &mut BinaryHeap<HeapElement<A, &'b Self>>,
147        evaluated: &mut BinaryHeap<HeapElement<A, &'b T>>,
148    ) where
149        F: Fn(&[A], &[A]) -> A,
150    {
151        let mut curr = &*pending.pop().unwrap().element;
152        debug_assert!(evaluated.len() <= num);
153        let evaluated_dist = if evaluated.len() == num {
154            // We only care about the nearest `num` points, so if we already have `num` points,
155            // any more point we add to `evaluated` must be nearer then one of the point already in
156            // `evaluated`.
157            max_dist.min(evaluated.peek().unwrap().distance)
158        } else {
159            max_dist
160        };
161
162        while !curr.is_leaf() {
163            let candidate;
164            if curr.belongs_in_left(point) {
165                candidate = curr.right.as_ref().unwrap();
166                curr = curr.left.as_ref().unwrap();
167            } else {
168                candidate = curr.left.as_ref().unwrap();
169                curr = curr.right.as_ref().unwrap();
170            }
171            let candidate_to_space = util::distance_to_space(
172                point,
173                &*candidate.min_bounds,
174                &*candidate.max_bounds,
175                distance,
176            );
177            if candidate_to_space <= evaluated_dist {
178                pending.push(HeapElement {
179                    distance: candidate_to_space * -A::one(),
180                    element: &**candidate,
181                });
182            }
183        }
184
185        let points = curr.points.as_ref().unwrap().iter();
186        let bucket = curr.bucket.as_ref().unwrap().iter();
187        let iter = points.zip(bucket).map(|(p, d)| HeapElement {
188            distance: distance(point, p.as_ref()),
189            element: d,
190        });
191        for element in iter {
192            if element <= max_dist {
193                if evaluated.len() < num {
194                    evaluated.push(element);
195                } else if element < *evaluated.peek().unwrap() {
196                    evaluated.pop();
197                    evaluated.push(element);
198                }
199            }
200        }
201    }
202
203    pub fn iter_nearest<'a, 'b, F>(
204        &'b self,
205        point: &'a [A],
206        distance: &'a F,
207    ) -> Result<NearestIter<'a, 'b, A, T, U, F>, ErrorKind>
208    where
209        F: Fn(&[A], &[A]) -> A,
210    {
211        if let Err(err) = self.check_point(point) {
212            return Err(err);
213        }
214        let mut pending = BinaryHeap::new();
215        let evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
216        pending.push(HeapElement {
217            distance: A::zero(),
218            element: self,
219        });
220        Ok(NearestIter {
221            point,
222            pending,
223            evaluated,
224            distance,
225        })
226    }
227
228    pub fn iter_nearest_mut<'a, 'b, F>(
229        &'b mut self,
230        point: &'a [A],
231        distance: &'a F,
232    ) -> Result<NearestIterMut<'a, 'b, A, T, U, F>, ErrorKind>
233    where
234        F: Fn(&[A], &[A]) -> A,
235    {
236        if let Err(err) = self.check_point(point) {
237            return Err(err);
238        }
239        let mut pending = BinaryHeap::new();
240        let evaluated = BinaryHeap::<HeapElement<A, &mut T>>::new();
241        pending.push(HeapElement {
242            distance: A::zero(),
243            element: self,
244        });
245        Ok(NearestIterMut {
246            point,
247            pending,
248            evaluated,
249            distance,
250        })
251    }
252
253    pub fn add(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
254        if self.capacity == 0 {
255            return Err(ErrorKind::ZeroCapacity);
256        }
257        if let Err(err) = self.check_point(point.as_ref()) {
258            return Err(err);
259        }
260        self.add_unchecked(point, data)
261    }
262
263    fn add_unchecked(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
264        if self.is_leaf() {
265            self.add_to_bucket(point, data);
266            return Ok(());
267        }
268        self.extend(point.as_ref());
269        self.size += 1;
270        let next = if self.belongs_in_left(point.as_ref()) {
271            self.left.as_mut()
272        } else {
273            self.right.as_mut()
274        };
275        next.unwrap().add_unchecked(point, data)
276    }
277
278    fn add_to_bucket(&mut self, point: U, data: T) {
279        self.extend(point.as_ref());
280        let mut points = self.points.take().unwrap();
281        let mut bucket = self.bucket.take().unwrap();
282        points.push(point);
283        bucket.push(data);
284        self.size += 1;
285        if self.size > self.capacity {
286            self.split(points, bucket);
287        } else {
288            self.points = Some(points);
289            self.bucket = Some(bucket);
290        }
291    }
292
293    pub fn remove(&mut self, point: &U, data: &T) -> Result<usize, ErrorKind> {
294        let mut removed = 0;
295        if let Err(err) = self.check_point(point.as_ref()) {
296            return Err(err);
297        }
298        if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
299            while let Some(p_index) = points.iter().position(|x| x == point) {
300                if &bucket[p_index] == data {
301                    points.remove(p_index);
302                    bucket.remove(p_index);
303                    removed += 1;
304                    self.size -= 1;
305                }
306            }
307            self.points = Some(points);
308            self.bucket = Some(bucket);
309        } else {
310            if let Some(right) = self.right.as_mut() {
311                let right_removed = right.remove(point, data)?;
312                if right_removed > 0 {
313                    self.size -= right_removed;
314                    removed += right_removed;
315                }
316            }
317            if let Some(left) = self.left.as_mut() {
318                let left_removed = left.remove(point, data)?;
319                if left_removed > 0 {
320                    self.size -= left_removed;
321                    removed += left_removed;
322                }
323            }
324        }
325        Ok(removed)
326    }
327
328    fn split(&mut self, mut points: Vec<U>, mut bucket: Vec<T>) {
329        let mut max = A::zero();
330        for dim in 0..self.dimensions {
331            let diff = self.max_bounds[dim] - self.min_bounds[dim];
332            if !diff.is_nan() && diff > max {
333                max = diff;
334                self.split_dimension = Some(dim);
335            }
336        }
337        match self.split_dimension {
338            None => {
339                self.points = Some(points);
340                self.bucket = Some(bucket);
341                return;
342            }
343            Some(dim) => {
344                let min = self.min_bounds[dim];
345                let max = self.max_bounds[dim];
346                self.split_value = Some(min + (max - min) / A::from(2.0).unwrap());
347            }
348        };
349        let mut left = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
350        let mut right = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
351        while !points.is_empty() {
352            let point = points.swap_remove(0);
353            let data = bucket.swap_remove(0);
354            if self.belongs_in_left(point.as_ref()) {
355                left.add_to_bucket(point, data);
356            } else {
357                right.add_to_bucket(point, data);
358            }
359        }
360        self.left = Some(left);
361        self.right = Some(right);
362    }
363
364    fn belongs_in_left(&self, point: &[A]) -> bool {
365        point[self.split_dimension.unwrap()] < self.split_value.unwrap()
366    }
367
368    fn extend(&mut self, point: &[A]) {
369        let min = self.min_bounds.iter_mut();
370        let max = self.max_bounds.iter_mut();
371        for ((l, h), v) in min.zip(max).zip(point.iter()) {
372            if v < l {
373                *l = *v
374            }
375            if v > h {
376                *h = *v
377            }
378        }
379    }
380
381    fn is_leaf(&self) -> bool {
382        self.bucket.is_some()
383            && self.points.is_some()
384            && self.split_value.is_none()
385            && self.split_dimension.is_none()
386            && self.left.is_none()
387            && self.right.is_none()
388    }
389
390    fn check_point(&self, point: &[A]) -> Result<(), ErrorKind> {
391        if self.dimensions != point.len() {
392            return Err(ErrorKind::WrongDimension);
393        }
394        for n in point {
395            if !n.is_finite() {
396                return Err(ErrorKind::NonFiniteCoordinate);
397            }
398        }
399        Ok(())
400    }
401}
402
403pub struct NearestIter<
404    'a,
405    'b,
406    A: 'a + 'b + Float,
407    T: 'b + PartialEq,
408    U: 'b + AsRef<[A]> + std::cmp::PartialEq,
409    F: 'a + Fn(&[A], &[A]) -> A,
410> {
411    point: &'a [A],
412    pending: BinaryHeap<HeapElement<A, &'b KdTree<A, T, U>>>,
413    evaluated: BinaryHeap<HeapElement<A, &'b T>>,
414    distance: &'a F,
415}
416
417impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator
418    for NearestIter<'a, 'b, A, T, U, F>
419where
420    F: Fn(&[A], &[A]) -> A,
421    U: PartialEq,
422    T: PartialEq,
423{
424    type Item = (A, &'b T);
425    fn next(&mut self) -> Option<(A, &'b T)> {
426        use util::distance_to_space;
427
428        let distance = self.distance;
429        let point = self.point;
430        while !self.pending.is_empty()
431            && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance)
432                >= -self.pending.peek().unwrap().distance)
433        {
434            let mut curr = &*self.pending.pop().unwrap().element;
435            while !curr.is_leaf() {
436                let candidate;
437                if curr.belongs_in_left(point) {
438                    candidate = curr.right.as_ref().unwrap();
439                    curr = curr.left.as_ref().unwrap();
440                } else {
441                    candidate = curr.left.as_ref().unwrap();
442                    curr = curr.right.as_ref().unwrap();
443                }
444                self.pending.push(HeapElement {
445                    distance: -distance_to_space(
446                        point,
447                        &*candidate.min_bounds,
448                        &*candidate.max_bounds,
449                        distance,
450                    ),
451                    element: &**candidate,
452                });
453            }
454            let points = curr.points.as_ref().unwrap().iter();
455            let bucket = curr.bucket.as_ref().unwrap().iter();
456            self.evaluated
457                .extend(points.zip(bucket).map(|(p, d)| HeapElement {
458                    distance: -distance(point, p.as_ref()),
459                    element: d,
460                }));
461        }
462        self.evaluated.pop().map(|x| (-x.distance, x.element))
463    }
464}
465
466pub struct NearestIterMut<
467    'a,
468    'b,
469    A: 'a + 'b + Float,
470    T: 'b + PartialEq,
471    U: 'b + AsRef<[A]> + PartialEq,
472    F: 'a + Fn(&[A], &[A]) -> A,
473> {
474    point: &'a [A],
475    pending: BinaryHeap<HeapElement<A, &'b mut KdTree<A, T, U>>>,
476    evaluated: BinaryHeap<HeapElement<A, &'b mut T>>,
477    distance: &'a F,
478}
479
480impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator
481    for NearestIterMut<'a, 'b, A, T, U, F>
482where
483    F: Fn(&[A], &[A]) -> A,
484    U: PartialEq,
485    T: PartialEq,
486{
487    type Item = (A, &'b mut T);
488    fn next(&mut self) -> Option<(A, &'b mut T)> {
489        use util::distance_to_space;
490
491        let distance = self.distance;
492        let point = self.point;
493        while !self.pending.is_empty()
494            && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance)
495                >= -self.pending.peek().unwrap().distance)
496        {
497            let mut curr = &mut *self.pending.pop().unwrap().element;
498            while !curr.is_leaf() {
499                let candidate;
500                if curr.belongs_in_left(point) {
501                    candidate = curr.right.as_mut().unwrap();
502                    curr = curr.left.as_mut().unwrap();
503                } else {
504                    candidate = curr.left.as_mut().unwrap();
505                    curr = curr.right.as_mut().unwrap();
506                }
507                self.pending.push(HeapElement {
508                    distance: -distance_to_space(
509                        point,
510                        &*candidate.min_bounds,
511                        &*candidate.max_bounds,
512                        distance,
513                    ),
514                    element: &mut **candidate,
515                });
516            }
517            let points = curr.points.as_ref().unwrap().iter();
518            let bucket = curr.bucket.as_mut().unwrap().iter_mut();
519            self.evaluated
520                .extend(points.zip(bucket).map(|(p, d)| HeapElement {
521                    distance: -distance(point, p.as_ref()),
522                    element: d,
523                }));
524        }
525        self.evaluated.pop().map(|x| (-x.distance, x.element))
526    }
527}
528
529impl std::error::Error for ErrorKind {
530    fn description(&self) -> &str {
531        match *self {
532            ErrorKind::WrongDimension => "wrong dimension",
533            ErrorKind::NonFiniteCoordinate => "non-finite coordinate",
534            ErrorKind::ZeroCapacity => "zero capacity",
535        }
536    }
537}
538
539impl std::fmt::Display for ErrorKind {
540    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
541        use std::error::Error;
542        write!(f, "KdTree error: {}", self.description())
543    }
544}
545
546#[cfg(test)]
547mod tests {
548    extern crate rand;
549    use super::KdTree;
550
551    fn random_point() -> ([f64; 2], i32) {
552        rand::random::<([f64; 2], i32)>()
553    }
554
555    #[test]
556    fn it_has_default_capacity() {
557        let tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
558        assert_eq!(tree.capacity, 2_usize.pow(4));
559    }
560
561    #[test]
562    fn it_can_be_cloned() {
563        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
564        let (pos, data) = random_point();
565        tree.add(pos, data).unwrap();
566        let mut cloned_tree = tree.clone();
567        cloned_tree.add(pos, data).unwrap();
568        assert_eq!(tree.size(), 1);
569        assert_eq!(cloned_tree.size(), 2);
570    }
571
572    #[test]
573    fn it_holds_on_to_its_capacity_before_splitting() {
574        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
575        let capacity = 2_usize.pow(4);
576        for _ in 0..capacity {
577            let (pos, data) = random_point();
578            tree.add(pos, data).unwrap();
579        }
580        assert_eq!(tree.size, capacity);
581        assert_eq!(tree.size(), capacity);
582        assert!(tree.left.is_none() && tree.right.is_none());
583        {
584            let (pos, data) = random_point();
585            tree.add(pos, data).unwrap();
586        }
587        assert_eq!(tree.size, capacity + 1);
588        assert_eq!(tree.size(), capacity + 1);
589        assert!(tree.left.is_some() && tree.right.is_some());
590    }
591
592    #[test]
593    fn no_items_can_be_added_to_a_zero_capacity_kdtree() {
594        let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::with_capacity(2, 0);
595        let (pos, data) = random_point();
596        let res = tree.add(pos, data);
597        assert!(res.is_err());
598    }
599}