1use std::collections::BinaryHeap;
2
3use num_traits::{Float, One, Zero};
4
5use crate::heap_element::HeapElement;
6use crate::util;
7
8#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
9#[derive(Clone, Debug)]
10pub struct KdTree<A, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> {
11 left: Option<Box<KdTree<A, T, U>>>,
13 right: Option<Box<KdTree<A, T, U>>>,
14 dimensions: usize,
16 capacity: usize,
17 size: usize,
18 min_bounds: Box<[A]>,
19 max_bounds: Box<[A]>,
20 split_value: Option<A>,
22 split_dimension: Option<usize>,
23 points: Option<Vec<U>>,
25 bucket: Option<Vec<T>>,
26}
27
28#[derive(Debug, PartialEq)]
29pub enum ErrorKind {
30 WrongDimension,
31 NonFiniteCoordinate,
32 ZeroCapacity,
33}
34
35impl<A: Float + Zero + One, T: std::cmp::PartialEq, U: AsRef<[A]> + std::cmp::PartialEq> KdTree<A, T, U> {
36 pub fn new(dims: usize) -> Self {
38 KdTree::with_capacity(dims, 2_usize.pow(4))
39 }
40
41 pub fn with_capacity(dimensions: usize, capacity: usize) -> Self {
43 let min_bounds = vec![A::infinity(); dimensions];
44 let max_bounds = vec![A::neg_infinity(); dimensions];
45 KdTree {
46 left: None,
47 right: None,
48 dimensions,
49 capacity,
50 size: 0,
51 min_bounds: min_bounds.into_boxed_slice(),
52 max_bounds: max_bounds.into_boxed_slice(),
53 split_value: None,
54 split_dimension: None,
55 points: Some(vec![]),
56 bucket: Some(vec![]),
57 }
58 }
59
60 pub fn size(&self) -> usize {
61 self.size
62 }
63
64 pub fn nearest<F>(
65 &self,
66 point: &[A],
67 num: usize,
68 distance: &F,
69 ) -> Result<Vec<(A, &T)>, ErrorKind>
70 where
71 F: Fn(&[A], &[A]) -> A,
72 {
73 if let Err(err) = self.check_point(point) {
74 return Err(err);
75 }
76 let num = std::cmp::min(num, self.size);
77 if num == 0 {
78 return Ok(vec![]);
79 }
80 let mut pending = BinaryHeap::new();
81 let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
82 pending.push(HeapElement {
83 distance: A::zero(),
84 element: self,
85 });
86 while !pending.is_empty()
87 && (evaluated.len() < num
88 || (-pending.peek().unwrap().distance <= evaluated.peek().unwrap().distance))
89 {
90 self.nearest_step(
91 point,
92 num,
93 A::infinity(),
94 distance,
95 &mut pending,
96 &mut evaluated,
97 );
98 }
99 Ok(evaluated
100 .into_sorted_vec()
101 .into_iter()
102 .take(num)
103 .map(Into::into)
104 .collect())
105 }
106
107 pub fn within<F>(&self, point: &[A], radius: A, distance: &F) -> Result<Vec<(A, &T)>, ErrorKind>
108 where
109 F: Fn(&[A], &[A]) -> A,
110 {
111 if let Err(err) = self.check_point(point) {
112 return Err(err);
113 }
114 if self.size == 0 {
115 return Ok(vec![]);
116 }
117 let mut pending = BinaryHeap::new();
118 let mut evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
119 pending.push(HeapElement {
120 distance: A::zero(),
121 element: self,
122 });
123 while !pending.is_empty() && (-pending.peek().unwrap().distance <= radius) {
124 self.nearest_step(
125 point,
126 self.size,
127 radius,
128 distance,
129 &mut pending,
130 &mut evaluated,
131 );
132 }
133 Ok(evaluated
134 .into_sorted_vec()
135 .into_iter()
136 .map(Into::into)
137 .collect())
138 }
139
140 fn nearest_step<'b, F>(
141 &self,
142 point: &[A],
143 num: usize,
144 max_dist: A,
145 distance: &F,
146 pending: &mut BinaryHeap<HeapElement<A, &'b Self>>,
147 evaluated: &mut BinaryHeap<HeapElement<A, &'b T>>,
148 ) where
149 F: Fn(&[A], &[A]) -> A,
150 {
151 let mut curr = &*pending.pop().unwrap().element;
152 debug_assert!(evaluated.len() <= num);
153 let evaluated_dist = if evaluated.len() == num {
154 max_dist.min(evaluated.peek().unwrap().distance)
158 } else {
159 max_dist
160 };
161
162 while !curr.is_leaf() {
163 let candidate;
164 if curr.belongs_in_left(point) {
165 candidate = curr.right.as_ref().unwrap();
166 curr = curr.left.as_ref().unwrap();
167 } else {
168 candidate = curr.left.as_ref().unwrap();
169 curr = curr.right.as_ref().unwrap();
170 }
171 let candidate_to_space = util::distance_to_space(
172 point,
173 &*candidate.min_bounds,
174 &*candidate.max_bounds,
175 distance,
176 );
177 if candidate_to_space <= evaluated_dist {
178 pending.push(HeapElement {
179 distance: candidate_to_space * -A::one(),
180 element: &**candidate,
181 });
182 }
183 }
184
185 let points = curr.points.as_ref().unwrap().iter();
186 let bucket = curr.bucket.as_ref().unwrap().iter();
187 let iter = points.zip(bucket).map(|(p, d)| HeapElement {
188 distance: distance(point, p.as_ref()),
189 element: d,
190 });
191 for element in iter {
192 if element <= max_dist {
193 if evaluated.len() < num {
194 evaluated.push(element);
195 } else if element < *evaluated.peek().unwrap() {
196 evaluated.pop();
197 evaluated.push(element);
198 }
199 }
200 }
201 }
202
203 pub fn iter_nearest<'a, 'b, F>(
204 &'b self,
205 point: &'a [A],
206 distance: &'a F,
207 ) -> Result<NearestIter<'a, 'b, A, T, U, F>, ErrorKind>
208 where
209 F: Fn(&[A], &[A]) -> A,
210 {
211 if let Err(err) = self.check_point(point) {
212 return Err(err);
213 }
214 let mut pending = BinaryHeap::new();
215 let evaluated = BinaryHeap::<HeapElement<A, &T>>::new();
216 pending.push(HeapElement {
217 distance: A::zero(),
218 element: self,
219 });
220 Ok(NearestIter {
221 point,
222 pending,
223 evaluated,
224 distance,
225 })
226 }
227
228 pub fn iter_nearest_mut<'a, 'b, F>(
229 &'b mut self,
230 point: &'a [A],
231 distance: &'a F,
232 ) -> Result<NearestIterMut<'a, 'b, A, T, U, F>, ErrorKind>
233 where
234 F: Fn(&[A], &[A]) -> A,
235 {
236 if let Err(err) = self.check_point(point) {
237 return Err(err);
238 }
239 let mut pending = BinaryHeap::new();
240 let evaluated = BinaryHeap::<HeapElement<A, &mut T>>::new();
241 pending.push(HeapElement {
242 distance: A::zero(),
243 element: self,
244 });
245 Ok(NearestIterMut {
246 point,
247 pending,
248 evaluated,
249 distance,
250 })
251 }
252
253 pub fn add(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
254 if self.capacity == 0 {
255 return Err(ErrorKind::ZeroCapacity);
256 }
257 if let Err(err) = self.check_point(point.as_ref()) {
258 return Err(err);
259 }
260 self.add_unchecked(point, data)
261 }
262
263 fn add_unchecked(&mut self, point: U, data: T) -> Result<(), ErrorKind> {
264 if self.is_leaf() {
265 self.add_to_bucket(point, data);
266 return Ok(());
267 }
268 self.extend(point.as_ref());
269 self.size += 1;
270 let next = if self.belongs_in_left(point.as_ref()) {
271 self.left.as_mut()
272 } else {
273 self.right.as_mut()
274 };
275 next.unwrap().add_unchecked(point, data)
276 }
277
278 fn add_to_bucket(&mut self, point: U, data: T) {
279 self.extend(point.as_ref());
280 let mut points = self.points.take().unwrap();
281 let mut bucket = self.bucket.take().unwrap();
282 points.push(point);
283 bucket.push(data);
284 self.size += 1;
285 if self.size > self.capacity {
286 self.split(points, bucket);
287 } else {
288 self.points = Some(points);
289 self.bucket = Some(bucket);
290 }
291 }
292
293 pub fn remove(&mut self, point: &U, data: &T) -> Result<usize, ErrorKind> {
294 let mut removed = 0;
295 if let Err(err) = self.check_point(point.as_ref()) {
296 return Err(err);
297 }
298 if let (Some(mut points), Some(mut bucket)) = (self.points.take(), self.bucket.take()) {
299 while let Some(p_index) = points.iter().position(|x| x == point) {
300 if &bucket[p_index] == data {
301 points.remove(p_index);
302 bucket.remove(p_index);
303 removed += 1;
304 self.size -= 1;
305 }
306 }
307 self.points = Some(points);
308 self.bucket = Some(bucket);
309 } else {
310 if let Some(right) = self.right.as_mut() {
311 let right_removed = right.remove(point, data)?;
312 if right_removed > 0 {
313 self.size -= right_removed;
314 removed += right_removed;
315 }
316 }
317 if let Some(left) = self.left.as_mut() {
318 let left_removed = left.remove(point, data)?;
319 if left_removed > 0 {
320 self.size -= left_removed;
321 removed += left_removed;
322 }
323 }
324 }
325 Ok(removed)
326 }
327
328 fn split(&mut self, mut points: Vec<U>, mut bucket: Vec<T>) {
329 let mut max = A::zero();
330 for dim in 0..self.dimensions {
331 let diff = self.max_bounds[dim] - self.min_bounds[dim];
332 if !diff.is_nan() && diff > max {
333 max = diff;
334 self.split_dimension = Some(dim);
335 }
336 }
337 match self.split_dimension {
338 None => {
339 self.points = Some(points);
340 self.bucket = Some(bucket);
341 return;
342 }
343 Some(dim) => {
344 let min = self.min_bounds[dim];
345 let max = self.max_bounds[dim];
346 self.split_value = Some(min + (max - min) / A::from(2.0).unwrap());
347 }
348 };
349 let mut left = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
350 let mut right = Box::new(KdTree::with_capacity(self.dimensions, self.capacity));
351 while !points.is_empty() {
352 let point = points.swap_remove(0);
353 let data = bucket.swap_remove(0);
354 if self.belongs_in_left(point.as_ref()) {
355 left.add_to_bucket(point, data);
356 } else {
357 right.add_to_bucket(point, data);
358 }
359 }
360 self.left = Some(left);
361 self.right = Some(right);
362 }
363
364 fn belongs_in_left(&self, point: &[A]) -> bool {
365 point[self.split_dimension.unwrap()] < self.split_value.unwrap()
366 }
367
368 fn extend(&mut self, point: &[A]) {
369 let min = self.min_bounds.iter_mut();
370 let max = self.max_bounds.iter_mut();
371 for ((l, h), v) in min.zip(max).zip(point.iter()) {
372 if v < l {
373 *l = *v
374 }
375 if v > h {
376 *h = *v
377 }
378 }
379 }
380
381 fn is_leaf(&self) -> bool {
382 self.bucket.is_some()
383 && self.points.is_some()
384 && self.split_value.is_none()
385 && self.split_dimension.is_none()
386 && self.left.is_none()
387 && self.right.is_none()
388 }
389
390 fn check_point(&self, point: &[A]) -> Result<(), ErrorKind> {
391 if self.dimensions != point.len() {
392 return Err(ErrorKind::WrongDimension);
393 }
394 for n in point {
395 if !n.is_finite() {
396 return Err(ErrorKind::NonFiniteCoordinate);
397 }
398 }
399 Ok(())
400 }
401}
402
403pub struct NearestIter<
404 'a,
405 'b,
406 A: 'a + 'b + Float,
407 T: 'b + PartialEq,
408 U: 'b + AsRef<[A]> + std::cmp::PartialEq,
409 F: 'a + Fn(&[A], &[A]) -> A,
410> {
411 point: &'a [A],
412 pending: BinaryHeap<HeapElement<A, &'b KdTree<A, T, U>>>,
413 evaluated: BinaryHeap<HeapElement<A, &'b T>>,
414 distance: &'a F,
415}
416
417impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator
418 for NearestIter<'a, 'b, A, T, U, F>
419where
420 F: Fn(&[A], &[A]) -> A,
421 U: PartialEq,
422 T: PartialEq,
423{
424 type Item = (A, &'b T);
425 fn next(&mut self) -> Option<(A, &'b T)> {
426 use util::distance_to_space;
427
428 let distance = self.distance;
429 let point = self.point;
430 while !self.pending.is_empty()
431 && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance)
432 >= -self.pending.peek().unwrap().distance)
433 {
434 let mut curr = &*self.pending.pop().unwrap().element;
435 while !curr.is_leaf() {
436 let candidate;
437 if curr.belongs_in_left(point) {
438 candidate = curr.right.as_ref().unwrap();
439 curr = curr.left.as_ref().unwrap();
440 } else {
441 candidate = curr.left.as_ref().unwrap();
442 curr = curr.right.as_ref().unwrap();
443 }
444 self.pending.push(HeapElement {
445 distance: -distance_to_space(
446 point,
447 &*candidate.min_bounds,
448 &*candidate.max_bounds,
449 distance,
450 ),
451 element: &**candidate,
452 });
453 }
454 let points = curr.points.as_ref().unwrap().iter();
455 let bucket = curr.bucket.as_ref().unwrap().iter();
456 self.evaluated
457 .extend(points.zip(bucket).map(|(p, d)| HeapElement {
458 distance: -distance(point, p.as_ref()),
459 element: d,
460 }));
461 }
462 self.evaluated.pop().map(|x| (-x.distance, x.element))
463 }
464}
465
466pub struct NearestIterMut<
467 'a,
468 'b,
469 A: 'a + 'b + Float,
470 T: 'b + PartialEq,
471 U: 'b + AsRef<[A]> + PartialEq,
472 F: 'a + Fn(&[A], &[A]) -> A,
473> {
474 point: &'a [A],
475 pending: BinaryHeap<HeapElement<A, &'b mut KdTree<A, T, U>>>,
476 evaluated: BinaryHeap<HeapElement<A, &'b mut T>>,
477 distance: &'a F,
478}
479
480impl<'a, 'b, A: Float + Zero + One, T: 'b, U: 'b + AsRef<[A]>, F: 'a> Iterator
481 for NearestIterMut<'a, 'b, A, T, U, F>
482where
483 F: Fn(&[A], &[A]) -> A,
484 U: PartialEq,
485 T: PartialEq,
486{
487 type Item = (A, &'b mut T);
488 fn next(&mut self) -> Option<(A, &'b mut T)> {
489 use util::distance_to_space;
490
491 let distance = self.distance;
492 let point = self.point;
493 while !self.pending.is_empty()
494 && (self.evaluated.peek().map_or(A::infinity(), |x| -x.distance)
495 >= -self.pending.peek().unwrap().distance)
496 {
497 let mut curr = &mut *self.pending.pop().unwrap().element;
498 while !curr.is_leaf() {
499 let candidate;
500 if curr.belongs_in_left(point) {
501 candidate = curr.right.as_mut().unwrap();
502 curr = curr.left.as_mut().unwrap();
503 } else {
504 candidate = curr.left.as_mut().unwrap();
505 curr = curr.right.as_mut().unwrap();
506 }
507 self.pending.push(HeapElement {
508 distance: -distance_to_space(
509 point,
510 &*candidate.min_bounds,
511 &*candidate.max_bounds,
512 distance,
513 ),
514 element: &mut **candidate,
515 });
516 }
517 let points = curr.points.as_ref().unwrap().iter();
518 let bucket = curr.bucket.as_mut().unwrap().iter_mut();
519 self.evaluated
520 .extend(points.zip(bucket).map(|(p, d)| HeapElement {
521 distance: -distance(point, p.as_ref()),
522 element: d,
523 }));
524 }
525 self.evaluated.pop().map(|x| (-x.distance, x.element))
526 }
527}
528
529impl std::error::Error for ErrorKind {
530 fn description(&self) -> &str {
531 match *self {
532 ErrorKind::WrongDimension => "wrong dimension",
533 ErrorKind::NonFiniteCoordinate => "non-finite coordinate",
534 ErrorKind::ZeroCapacity => "zero capacity",
535 }
536 }
537}
538
539impl std::fmt::Display for ErrorKind {
540 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
541 use std::error::Error;
542 write!(f, "KdTree error: {}", self.description())
543 }
544}
545
546#[cfg(test)]
547mod tests {
548 extern crate rand;
549 use super::KdTree;
550
551 fn random_point() -> ([f64; 2], i32) {
552 rand::random::<([f64; 2], i32)>()
553 }
554
555 #[test]
556 fn it_has_default_capacity() {
557 let tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
558 assert_eq!(tree.capacity, 2_usize.pow(4));
559 }
560
561 #[test]
562 fn it_can_be_cloned() {
563 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
564 let (pos, data) = random_point();
565 tree.add(pos, data).unwrap();
566 let mut cloned_tree = tree.clone();
567 cloned_tree.add(pos, data).unwrap();
568 assert_eq!(tree.size(), 1);
569 assert_eq!(cloned_tree.size(), 2);
570 }
571
572 #[test]
573 fn it_holds_on_to_its_capacity_before_splitting() {
574 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::new(2);
575 let capacity = 2_usize.pow(4);
576 for _ in 0..capacity {
577 let (pos, data) = random_point();
578 tree.add(pos, data).unwrap();
579 }
580 assert_eq!(tree.size, capacity);
581 assert_eq!(tree.size(), capacity);
582 assert!(tree.left.is_none() && tree.right.is_none());
583 {
584 let (pos, data) = random_point();
585 tree.add(pos, data).unwrap();
586 }
587 assert_eq!(tree.size, capacity + 1);
588 assert_eq!(tree.size(), capacity + 1);
589 assert!(tree.left.is_some() && tree.right.is_some());
590 }
591
592 #[test]
593 fn no_items_can_be_added_to_a_zero_capacity_kdtree() {
594 let mut tree: KdTree<f64, i32, [f64; 2]> = KdTree::with_capacity(2, 0);
595 let (pos, data) = random_point();
596 let res = tree.add(pos, data);
597 assert!(res.is_err());
598 }
599}