fast_kd/
kdtree.rs

1use std::collections::{BinaryHeap};
2
3use num_traits::{Float, One, Zero};
4
5
6use crate::heap_element::HeapElement;
7use crate::util::distance_to_space;
8
9
10#[derive(Clone, Debug)]
11pub struct KdTree<A, const K: usize> {
12    size: usize,
13
14
15    min_bounds: [A; K],
16    max_bounds: [A; K],
17    content: Node<A, K>,
18}
19
20#[derive(Clone, Debug)]
21pub enum Node<A, const K: usize> {
22    Stem {
23        left: Box<KdTree<A, K>>,
24        right: Box<KdTree<A, K>>,
25        split_value: A,
26        split_dimension: usize,
27    },
28    Leaf {
29        //points: Vec<[A; K]>,
30        //bucket: Vec<T>,
31        bucket: Vec<[A; K]>,
32        capacity: usize,
33    },
34}
35
36#[derive(Debug, PartialEq)]
37pub enum ErrorKind {
38    NonFiniteCoordinate,
39    ZeroCapacity,
40    Empty,
41}
42
43impl<A: Float + Zero + One, const K: usize> KdTree<A, K> {
44    pub fn new() -> Self {
45        KdTree::with_per_node_capacity(16).unwrap()
46    }
47
48    pub fn with_per_node_capacity(capacity: usize) -> Result<Self, ErrorKind> {
49        if capacity == 0 {
50            return Err(ErrorKind::ZeroCapacity);
51        }
52
53        Ok(KdTree {
54            size: 0,
55            min_bounds: [A::infinity(); K],
56            max_bounds: [A::neg_infinity(); K],
57            content: Node::Leaf {
58                bucket: Vec::with_capacity(capacity + 1),
59                capacity,
60            },
61        })
62    }
63
64    pub fn size(&self) -> usize {
65        self.size
66    }
67
68    pub fn is_leaf(&self) -> bool {
69        match &self.content {
70            Node::Leaf { .. } => true,
71            Node::Stem { .. } => false,
72        }
73    }
74
75
76    pub fn best_n_within<F>(
77        &self,
78        point: &[A; K],
79        radius: A,
80        max_qty: usize,
81        distance: &F,
82    ) -> Result<Vec<&[A;K]>, ErrorKind>
83    where
84        F: Fn(&[A; K], &[A; K]) -> A,
85    {
86        if self.size == 0 {
87            return Ok(vec![]);
88        }
89
90        self.check_point(point)?;
91
92        let mut pending = BinaryHeap::new();
93        let mut evaluated = BinaryHeap::<HeapElement<A, &[A;K]>>::with_capacity(self.size().min(max_qty + 1));
94        let mut max_ev_dist = A::infinity();
95        pending.push(HeapElement {
96            distance: A::zero(),
97            element: self,
98        });
99
100        while !pending.is_empty() {
101            let curr = pending.pop().unwrap();
102            if evaluated.len() == max_qty && -curr.distance > max_ev_dist {
103                break;
104            }
105            let curr = curr.element;
106            match curr.content {
107                Node::Leaf {
108                    ref bucket,
109                    ..
110                } => {
111                    for p in bucket.iter() {
112                        let d : A = distance(point, p);
113                        let heap_elem = HeapElement {
114                            distance: d,
115                            element: p,
116                        };
117
118                        if evaluated.len() < max_qty {
119                            evaluated.push(heap_elem);
120                            max_ev_dist = evaluated.peek().unwrap().distance;
121                        } else if max_ev_dist > heap_elem.distance {
122                            evaluated.push(heap_elem);
123                            evaluated.pop();
124                            max_ev_dist = evaluated.peek().unwrap().distance;
125                        }
126                    }
127                }
128                Node::Stem {
129                    ref left,
130                    ref right,
131                    ..
132                } => {
133                    let d_left :A = distance_to_space(
134                        point,
135                        &left.min_bounds,
136                        &left.max_bounds,
137                        distance
138                    );
139                    if d_left < radius {
140                        pending.push(HeapElement {
141                            distance: -d_left,
142                            element: left,
143                        });
144                    }
145                    let d_right:A = distance_to_space(
146                        point,
147                        &right.min_bounds,
148                        &right.max_bounds,
149                        distance
150                    );
151                    if d_right < radius {
152                        pending.push(HeapElement {
153                            distance: -d_right,
154                            element: right,
155                        });
156                    }
157                }
158            }
159        }
160
161        Ok(evaluated.iter().map(|e| e.element).collect())
162    }
163
164
165    /// Add an element to the tree. The first argument specifies the location in kd space
166    /// at which the element is located. The second argument is the data associated with
167    /// that point in space.
168    ///
169    /// # Examples
170    ///
171    /// ```rust
172    /// use kiddo::KdTree;
173    ///
174    /// let mut tree: KdTree<f64, 3> = KdTree::new();
175    ///
176    /// tree.add(&[1.0, 2.0, 5.0])?;
177    /// tree.add(&[1.1, 2.1, 5.1])?;
178    ///
179    /// assert_eq!(tree.size(), 2);
180    /// ```
181    pub fn add(&mut self, point: &[A; K]) -> Result<(), ErrorKind> {
182        self.check_point(point)?;
183        self.add_unchecked(point);
184        Ok(())
185    }
186
187    fn add_unchecked(&mut self, point: &[A; K]) {
188        match &mut self.content {
189            Node::Leaf { .. } => {
190                self.add_to_bucket(point);
191            }
192
193            Node::Stem {
194                left,
195                right,
196                split_dimension,
197                split_value,
198            } => {
199                if point[*split_dimension] < *split_value {
200                    // belongs_in_left
201                    left.add_unchecked(point)
202                } else {
203                    right.add_unchecked(point)
204                }
205            }
206        };
207
208        self.extend(point);
209        self.size += 1;
210    }
211
212    fn add_to_bucket(&mut self, point: &[A; K]) {
213        self.extend(point);
214        let cap;
215        match &mut self.content {
216            Node::Leaf {
217                bucket,
218                capacity,
219            } => {
220                bucket.push(*point);
221                cap = *capacity;
222            }
223            Node::Stem { .. } => unreachable!(),
224        }
225
226        self.size += 1;
227        if self.size > cap {
228            self.split();
229        }
230    }
231
232    fn split(&mut self) {
233        match &mut self.content {
234            Node::Leaf {
235                bucket,
236                capacity,
237            } => {
238                let mut split_dimension:usize = 0;
239                let mut max = A::zero();
240                for dim in 0..K {
241                    let diff = self.max_bounds[dim] - self.min_bounds[dim];
242                    if !diff.is_nan() && diff > max {
243                        max = diff;
244                        split_dimension = dim;
245                    }
246                }
247
248                let split_value = self.min_bounds[split_dimension] + max / A::from(2.0).unwrap();
249                let mut left = Box::new(KdTree::with_per_node_capacity(*capacity).unwrap());
250                let mut right = Box::new(KdTree::with_per_node_capacity(*capacity).unwrap());
251
252                while !bucket.is_empty() {
253                    let point= bucket.pop().unwrap();
254                    if point[split_dimension] < split_value {
255                        // belongs_in_left
256                        left.add_to_bucket(&point);
257                    } else {
258                        right.add_to_bucket(&point);
259                    }
260                }
261
262                self.content = Node::Stem {
263                    left,
264                    right,
265                    split_value,
266                    split_dimension,
267                }
268            }
269            Node::Stem { .. } => unreachable!(),
270        }
271    }
272
273    fn extend(&mut self, point: &[A; K]) {
274        let min = self.min_bounds.iter_mut();
275        let max = self.max_bounds.iter_mut();
276        for ((l, h), v) in min.zip(max).zip(point.iter()) {
277            if v < l {
278                *l = *v
279            }
280            if v > h {
281                *h = *v
282            }
283        }
284    }
285
286    fn check_point(&self, point: &[A; K]) -> Result<(), ErrorKind> {
287        if point.iter().all(|n| n.is_finite()) {
288            Ok(())
289        } else {
290            Err(ErrorKind::NonFiniteCoordinate)
291        }
292    }
293}
294
295
296impl std::fmt::Display for ErrorKind {
297    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
298        write!(f, "KdTree error: {}", self)
299    }
300}