1use std::collections::{BinaryHeap};
2
3use num_traits::{Float, One, Zero};
4
5
6use crate::heap_element::HeapElement;
7use crate::util::distance_to_space;
8
9
10#[derive(Clone, Debug)]
11pub struct KdTree<A, const K: usize> {
12 size: usize,
13
14
15 min_bounds: [A; K],
16 max_bounds: [A; K],
17 content: Node<A, K>,
18}
19
20#[derive(Clone, Debug)]
21pub enum Node<A, const K: usize> {
22 Stem {
23 left: Box<KdTree<A, K>>,
24 right: Box<KdTree<A, K>>,
25 split_value: A,
26 split_dimension: usize,
27 },
28 Leaf {
29 bucket: Vec<[A; K]>,
32 capacity: usize,
33 },
34}
35
36#[derive(Debug, PartialEq)]
37pub enum ErrorKind {
38 NonFiniteCoordinate,
39 ZeroCapacity,
40 Empty,
41}
42
43impl<A: Float + Zero + One, const K: usize> KdTree<A, K> {
44 pub fn new() -> Self {
45 KdTree::with_per_node_capacity(16).unwrap()
46 }
47
48 pub fn with_per_node_capacity(capacity: usize) -> Result<Self, ErrorKind> {
49 if capacity == 0 {
50 return Err(ErrorKind::ZeroCapacity);
51 }
52
53 Ok(KdTree {
54 size: 0,
55 min_bounds: [A::infinity(); K],
56 max_bounds: [A::neg_infinity(); K],
57 content: Node::Leaf {
58 bucket: Vec::with_capacity(capacity + 1),
59 capacity,
60 },
61 })
62 }
63
64 pub fn size(&self) -> usize {
65 self.size
66 }
67
68 pub fn is_leaf(&self) -> bool {
69 match &self.content {
70 Node::Leaf { .. } => true,
71 Node::Stem { .. } => false,
72 }
73 }
74
75
76 pub fn best_n_within<F>(
77 &self,
78 point: &[A; K],
79 radius: A,
80 max_qty: usize,
81 distance: &F,
82 ) -> Result<Vec<&[A;K]>, ErrorKind>
83 where
84 F: Fn(&[A; K], &[A; K]) -> A,
85 {
86 if self.size == 0 {
87 return Ok(vec![]);
88 }
89
90 self.check_point(point)?;
91
92 let mut pending = BinaryHeap::new();
93 let mut evaluated = BinaryHeap::<HeapElement<A, &[A;K]>>::with_capacity(self.size().min(max_qty + 1));
94 let mut max_ev_dist = A::infinity();
95 pending.push(HeapElement {
96 distance: A::zero(),
97 element: self,
98 });
99
100 while !pending.is_empty() {
101 let curr = pending.pop().unwrap();
102 if evaluated.len() == max_qty && -curr.distance > max_ev_dist {
103 break;
104 }
105 let curr = curr.element;
106 match curr.content {
107 Node::Leaf {
108 ref bucket,
109 ..
110 } => {
111 for p in bucket.iter() {
112 let d : A = distance(point, p);
113 let heap_elem = HeapElement {
114 distance: d,
115 element: p,
116 };
117
118 if evaluated.len() < max_qty {
119 evaluated.push(heap_elem);
120 max_ev_dist = evaluated.peek().unwrap().distance;
121 } else if max_ev_dist > heap_elem.distance {
122 evaluated.push(heap_elem);
123 evaluated.pop();
124 max_ev_dist = evaluated.peek().unwrap().distance;
125 }
126 }
127 }
128 Node::Stem {
129 ref left,
130 ref right,
131 ..
132 } => {
133 let d_left :A = distance_to_space(
134 point,
135 &left.min_bounds,
136 &left.max_bounds,
137 distance
138 );
139 if d_left < radius {
140 pending.push(HeapElement {
141 distance: -d_left,
142 element: left,
143 });
144 }
145 let d_right:A = distance_to_space(
146 point,
147 &right.min_bounds,
148 &right.max_bounds,
149 distance
150 );
151 if d_right < radius {
152 pending.push(HeapElement {
153 distance: -d_right,
154 element: right,
155 });
156 }
157 }
158 }
159 }
160
161 Ok(evaluated.iter().map(|e| e.element).collect())
162 }
163
164
165 pub fn add(&mut self, point: &[A; K]) -> Result<(), ErrorKind> {
182 self.check_point(point)?;
183 self.add_unchecked(point);
184 Ok(())
185 }
186
187 fn add_unchecked(&mut self, point: &[A; K]) {
188 match &mut self.content {
189 Node::Leaf { .. } => {
190 self.add_to_bucket(point);
191 }
192
193 Node::Stem {
194 left,
195 right,
196 split_dimension,
197 split_value,
198 } => {
199 if point[*split_dimension] < *split_value {
200 left.add_unchecked(point)
202 } else {
203 right.add_unchecked(point)
204 }
205 }
206 };
207
208 self.extend(point);
209 self.size += 1;
210 }
211
212 fn add_to_bucket(&mut self, point: &[A; K]) {
213 self.extend(point);
214 let cap;
215 match &mut self.content {
216 Node::Leaf {
217 bucket,
218 capacity,
219 } => {
220 bucket.push(*point);
221 cap = *capacity;
222 }
223 Node::Stem { .. } => unreachable!(),
224 }
225
226 self.size += 1;
227 if self.size > cap {
228 self.split();
229 }
230 }
231
232 fn split(&mut self) {
233 match &mut self.content {
234 Node::Leaf {
235 bucket,
236 capacity,
237 } => {
238 let mut split_dimension:usize = 0;
239 let mut max = A::zero();
240 for dim in 0..K {
241 let diff = self.max_bounds[dim] - self.min_bounds[dim];
242 if !diff.is_nan() && diff > max {
243 max = diff;
244 split_dimension = dim;
245 }
246 }
247
248 let split_value = self.min_bounds[split_dimension] + max / A::from(2.0).unwrap();
249 let mut left = Box::new(KdTree::with_per_node_capacity(*capacity).unwrap());
250 let mut right = Box::new(KdTree::with_per_node_capacity(*capacity).unwrap());
251
252 while !bucket.is_empty() {
253 let point= bucket.pop().unwrap();
254 if point[split_dimension] < split_value {
255 left.add_to_bucket(&point);
257 } else {
258 right.add_to_bucket(&point);
259 }
260 }
261
262 self.content = Node::Stem {
263 left,
264 right,
265 split_value,
266 split_dimension,
267 }
268 }
269 Node::Stem { .. } => unreachable!(),
270 }
271 }
272
273 fn extend(&mut self, point: &[A; K]) {
274 let min = self.min_bounds.iter_mut();
275 let max = self.max_bounds.iter_mut();
276 for ((l, h), v) in min.zip(max).zip(point.iter()) {
277 if v < l {
278 *l = *v
279 }
280 if v > h {
281 *h = *v
282 }
283 }
284 }
285
286 fn check_point(&self, point: &[A; K]) -> Result<(), ErrorKind> {
287 if point.iter().all(|n| n.is_finite()) {
288 Ok(())
289 } else {
290 Err(ErrorKind::NonFiniteCoordinate)
291 }
292 }
293}
294
295
296impl std::fmt::Display for ErrorKind {
297 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
298 write!(f, "KdTree error: {}", self)
299 }
300}