1#![warn(missing_docs)]
2
3#![allow(clippy::deprecated_cfg_attr)]
47
48#[cfg(any(test, feature = "dummy_point"))]
49pub mod dummy_point;
50mod heap;
51mod infinite;
52mod internal_neighbour;
53mod internal_parameters;
54mod node;
55
56use internal_parameters::InternalParameters;
57use node::Node;
58use num_traits::{clamp_max, clamp_min, Bounded, Zero, Signed, FromPrimitive};
59use ordered_float::Float;
60pub use ordered_float::NotNan;
61use std::{collections::BinaryHeap, ops::AddAssign};
62use std::cmp::{Ordering, Ord};
63use std::fmt::Debug;
64use heap::CandidateHeap;
65use internal_neighbour::InternalNeighbour;
66
67pub trait Scalar: Float + AddAssign + FromPrimitive + std::fmt::Debug {}
69impl<T: Float + AddAssign + FromPrimitive + std::fmt::Debug> Scalar for T {}
70
71pub trait Point<T: Scalar>: Default + Clone + Debug + Copy {
73 fn set(&mut self, i: u32, value: NotNan<T>);
75 fn get(&self, i: u32) -> NotNan<T>;
77 const DIM: u32;
79 const DIM_BIT_COUNT: u32 = 32 - Self::DIM.leading_zeros();
81 const DIM_MASK: u32 = (1 << Self::DIM_BIT_COUNT) - 1;
83 const MAX_NODE_COUNT: u32 = ((1u64 << (32 - Self::DIM_BIT_COUNT)) - 1) as u32;
85}
86
87#[inline]
89fn point_slice_dist2<T: Scalar, P: Point<T>>(lhs: &[NotNan<T>], rhs: &[NotNan<T>]) -> NotNan<T> {
90 let mut dist2 = NotNan::<T>::zero();
91 for index in 0..P::DIM {
92 let index = index as usize;
93 let diff = lhs[index] - rhs[index];
94 dist2 += diff * diff;
95 }
96 dist2
97}
98
99pub type Index = u32;
101
102#[derive(Debug)]
104pub struct Neighbour<T: Scalar, P: Point<T>> {
105 pub point: P,
107 pub dist2: NotNan<T>,
109 pub index: Index,
111}
112
113impl<T: Scalar, P: Point<T>> PartialOrd for Neighbour<T,P> {
114 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
115 self.dist2.partial_cmp(&other.dist2)
116 }
117}
118
119impl<T: Scalar, P: Point<T>> Ord for Neighbour<T, P>
120where
121 NotNan<T>: Eq + Ord,
122{
123 fn cmp(&self, other: &Self) -> Ordering {
124 self.dist2.cmp(&other.dist2)
125 }
126}
127
128impl<T: Scalar, P: Point<T>> PartialEq for Neighbour<T,P> {
129 fn eq(&self, other: &Self) -> bool {
130 self.dist2 == other.dist2
131 }
132}
133
134impl<T: Scalar, P: Point<T>> Eq for Neighbour<T,P> { }
135
136#[derive(Clone, Copy)]
138pub enum CandidateContainer {
139 Linear,
141 BinaryHeap,
143}
144
145pub struct Parameters<T: Scalar> {
147 pub epsilon: T,
149 pub max_radius: T,
151 pub allow_self_match: bool,
153 pub sort_results: bool,
155}
156
157impl<T: Scalar> Default for Parameters<T> {
158 fn default() -> Parameters<T> {
159 Parameters {
160 epsilon: T::zero(),
161 max_radius: T::infinity(),
162 allow_self_match: true,
163 sort_results: true,
164 }
165 }
166}
167
168type Nodes<T, P> = Vec<Node<T, P>>;
170
171#[derive(Debug)]
177pub struct KDTree<T: Scalar, P: Point<T>, const K: usize> {
178 bucket_size: u32,
180 nodes: Nodes<T, P>,
182 points: Vec<NotNan<T>>,
184 indices: Vec<Index>,
186}
187
188impl<T: Scalar + Signed, P: Point<T>, const K: usize> KDTree<T, P, K> {
189 pub fn new(cloud: &[P]) -> Self {
191 KDTree::new_with_bucket_size(cloud, 8)
192 }
193 pub fn new_with_bucket_size(cloud: &[P], bucket_size: u32) -> Self {
197 if bucket_size < 2 {
199 panic!(
200 "Bucket size must be at least 2, but {} was passed",
201 bucket_size
202 );
203 }
204 if cloud.len() > u32::MAX as usize {
205 panic!(
206 "Point cloud is larger than maximum possible size {}",
207 u32::MAX
208 );
209 }
210 let estimated_node_count = (cloud.len() / (bucket_size as usize / 2)) as u32;
211 if estimated_node_count > P::MAX_NODE_COUNT {
212 panic!("Point cloud has a risk to have more nodes {} than the kd-tree allows {}. The kd-tree has {} bits for dimensions and {} bits for node indices", estimated_node_count, P::MAX_NODE_COUNT, P::DIM_BIT_COUNT, 32 - P::DIM_BIT_COUNT);
213 }
214
215 let mut build_points: Vec<_> = (0..cloud.len()).collect();
217
218 let mut tree = KDTree {
220 bucket_size,
221 nodes: Vec::with_capacity(estimated_node_count as usize),
222 points: Vec::with_capacity(cloud.len() * P::DIM as usize),
223 indices: Vec::with_capacity(cloud.len()),
224 };
225 tree.build_nodes(cloud, &mut build_points);
226 tree
227 }
228
229 pub fn knn(&self, k: u32, query: &P) -> Vec<Neighbour<T, P>> {
236 let candidate_container = if k <= 16 {
237 CandidateContainer::Linear
238 } else {
239 CandidateContainer::BinaryHeap
240 };
241 #[cfg_attr(rustfmt, rustfmt_skip)]
242 self.knn_advanced(
243 k, query,
244 candidate_container,
245 &Parameters::default(),
246 None,
247 )
248 }
249
250
251 pub fn knn_periodic(&self, k: u32, query: &P, periodic: &[NotNan<T>; K]) -> Vec<Neighbour<T, P>> {
259 let candidate_container = if k <= 16 {
260 CandidateContainer::Linear
261 } else {
262 CandidateContainer::BinaryHeap
263 };
264
265 let mut real_image_knns: Vec<Neighbour<T, P>> = self.knn_advanced(
267 k, query,
268 candidate_container,
269 &Parameters::default(),
270 None,
271 );
272
273 let max_dist2 = real_image_knns.iter().max().unwrap().dist2.into_inner();
275
276 let mut closest_side_dist2: [T; K] = [T::zero(); K];
278 for side in 0..K {
279
280 let query_component: NotNan<T> = query.get(side as u32);
282
283 let upper = periodic[side] - query_component;
285
286 debug_assert!(!upper.is_negative());
288 debug_assert!(!query_component.is_negative());
289
290 closest_side_dist2[side] = upper.min(query_component).powi(2);
292 }
293
294 let mut images_to_check = Vec::with_capacity(2_usize.pow(K as u32)-1);
297 for image in 1..2_usize.pow(K as u32) {
298
299 let closest_image = (0..K)
301 .map(|idx| ((image / 2_usize.pow(idx as u32)) % 2) == 1);
302
303 let dist_to_side_edge_or_other: T = closest_image
305 .clone()
306 .enumerate()
307 .flat_map(|(side, flag)| if flag {
308
309 Some(closest_side_dist2[side])
311 } else { None })
312 .fold(T::zero(), |acc, x| acc + x);
313
314 if dist_to_side_edge_or_other < max_dist2 {
315
316 let mut image_to_check = query.clone();
317
318 for (idx, flag) in closest_image.enumerate() {
319
320 if flag {
322 let query_component: NotNan<T> = query.get(idx as u32);
324 let periodic_component = periodic[idx];
326
327 if query_component < periodic_component / T::from(2_u8).unwrap() {
328 image_to_check.set(idx as u32, query_component + periodic_component)
330 } else {
331 image_to_check.set(idx as u32, query_component - periodic_component)
333 }
334
335 }
336 }
337
338 images_to_check.push(image_to_check);
339 }
340 }
341
342 for image in &images_to_check {
344
345 real_image_knns.append(&mut self.knn_advanced(
347 k, image,
348 candidate_container,
349 &Parameters::default(),
350 None,
351 ))
352 }
353
354 real_image_knns.sort();
356 real_image_knns.dedup();
357 real_image_knns.truncate(k as usize);
358
359 real_image_knns
360 }
361
362 pub fn knn_advanced(
371 &self,
372 k: u32,
373 query: &P,
374 candidate_container: CandidateContainer,
375 parameters: &Parameters<T>,
376 touch_statistics: Option<&mut u32>,
377 ) -> Vec<Neighbour<T, P>> {
378 #[cfg_attr(rustfmt, rustfmt_skip)]
379 (match candidate_container {
380 CandidateContainer::Linear => Self::knn_generic_heap::<Vec<InternalNeighbour<T>>>,
381 CandidateContainer::BinaryHeap => Self::knn_generic_heap::<BinaryHeap<InternalNeighbour<T>>>
382 })(
383 self,
384 k, query,
385 parameters, touch_statistics
386 )
387 }
388
389 fn knn_generic_heap<H: CandidateHeap<T>>(
390 &self,
391 k: u32,
392 query: &P,
393 parameters: &Parameters<T>,
394 touch_statistics: Option<&mut u32>,
395 ) -> Vec<Neighbour<T, P>> {
396 let query_as_vec: Vec<_> = (0..P::DIM).map(|i| query.get(i)).collect();
397 let Parameters {
398 epsilon,
399 max_radius,
400 allow_self_match,
401 sort_results,
402 } = *parameters;
403 let max_error = epsilon + T::one();
404 let max_error2 = NotNan::new(max_error * max_error).unwrap();
405 let max_radius2 = NotNan::new(max_radius * max_radius).unwrap();
406 #[cfg_attr(rustfmt, rustfmt_skip)]
407 self.knn_internal::<H>(
408 k, &query_as_vec,
409 &InternalParameters { max_error2, max_radius2, allow_self_match },
410 sort_results, touch_statistics,
411 )
412 .into_iter()
413 .map(|n| self.externalise_neighbour(n))
414 .collect()
415 }
416
417 fn knn_internal<H: CandidateHeap<T>>(
418 &self,
419 k: u32,
420 query: &[NotNan<T>],
421 internal_parameters: &InternalParameters<T>,
422 sort_results: bool,
423 touch_statistics: Option<&mut u32>,
424 ) -> Vec<InternalNeighbour<T>> {
425 let mut off = [NotNan::<T>::zero(); K];
427 let mut heap = H::new_with_k(k);
428 #[cfg_attr(rustfmt, rustfmt_skip)]
429 let leaf_touched_count = self.recurse_knn(
430 k, query,
431 0, NotNan::<T>::zero(),
432 &mut heap, &mut off,
433 internal_parameters,
434 );
435 if let Some(touch_statistics) = touch_statistics {
436 *touch_statistics = leaf_touched_count;
437 }
438 if sort_results {
439 heap.into_sorted_vec()
440 } else {
441 heap.into_vec()
442 }
443 }
444
445 #[allow(clippy::too_many_arguments)]
446 fn recurse_knn<H: CandidateHeap<T>>(
447 &self,
448 k: u32,
449 query: &[NotNan<T>],
450 node: usize,
451 rd: NotNan<T>,
452 heap: &mut H,
453 off: &mut [NotNan<T>],
454 internal_parameters: &InternalParameters<T>,
455 ) -> u32 {
456 self.nodes[node].dispatch_on_type(
457 heap,
458 |heap, split_dim, split_val, right_child| {
459 let mut rd = rd;
461 let split_dim = split_dim as usize;
462 let old_off = off[split_dim];
463 let new_off = query[split_dim] - split_val;
464 let left_child = node + 1;
465 let right_child = right_child as usize;
466 let InternalParameters {
467 max_radius2,
468 max_error2,
469 ..
470 } = *internal_parameters;
471 if new_off > NotNan::<T>::zero() {
472 #[cfg_attr(rustfmt, rustfmt_skip)]
473 let mut leaf_visited_count = self.recurse_knn(
474 k, query,
475 right_child, rd,
476 heap, off,
477 internal_parameters,
478 );
479 rd += new_off * new_off - old_off * old_off;
480 if rd <= max_radius2 && rd * max_error2 <= heap.furthest_dist2() {
481 off[split_dim] = new_off;
482 #[cfg_attr(rustfmt, rustfmt_skip)]
483 let new_visits= self.recurse_knn(
484 k, query,
485 left_child, rd,
486 heap, off,
487 internal_parameters,
488 );
489 leaf_visited_count += new_visits;
490 off[split_dim] = old_off;
491 }
492 leaf_visited_count
493 } else {
494 #[cfg_attr(rustfmt, rustfmt_skip)]
495 let mut leaf_visited_count = self.recurse_knn(
496 k, query,
497 left_child, rd,
498 heap, off,
499 internal_parameters,
500 );
501 rd += new_off * new_off - old_off * old_off;
502 if rd <= max_radius2 && rd * max_error2 <= heap.furthest_dist2() {
503 off[split_dim] = new_off;
504 #[cfg_attr(rustfmt, rustfmt_skip)]
505 let new_visits = self.recurse_knn(
506 k, query,
507 right_child, rd,
508 heap, off,
509 internal_parameters,
510 );
511 leaf_visited_count += new_visits;
512 off[split_dim] = old_off;
513 }
514 leaf_visited_count
515 }
516 },
517 |heap, bucket_start_index, bucket_size| {
518 let bucket_end_index = bucket_start_index + bucket_size;
520 for bucket_index in bucket_start_index..bucket_end_index {
521 let point_index = (bucket_index * P::DIM) as usize;
522 let point = &self.points[point_index..point_index + (P::DIM as usize)];
523 let dist2 = point_slice_dist2::<T, P>(query, point);
524 let epsilon = NotNan::new(T::epsilon()).unwrap();
525 let InternalParameters {
526 max_radius2,
527 allow_self_match,
528 ..
529 } = *internal_parameters;
530 if dist2 <= max_radius2 && (allow_self_match || (dist2 > epsilon)) {
531 heap.add(dist2, bucket_index);
532 }
533 }
534 bucket_size
535 },
536 )
537 }
538
539 fn build_nodes(&mut self, cloud: &[P], build_points: &mut [usize]) -> usize {
540 let count = build_points.len() as u32;
541 let pos = self.nodes.len();
542
543 if count <= self.bucket_size {
545 let bucket_start_index = self.indices.len() as u32;
546 self.points.reserve(build_points.len() * P::DIM as usize);
547 self.indices.reserve(build_points.len());
548 for point_index in build_points {
549 let point_index = *point_index;
550 self.indices.push(point_index as u32);
551 for i in 0..P::DIM {
552 self.points.push(cloud[point_index].get(i));
553 }
554 }
555 self.nodes
556 .push(Node::new_leaf_node(bucket_start_index, count));
557 return pos;
558 }
559
560 let (min_bounds, max_bounds) = Self::get_build_points_bounds(cloud, build_points);
562
563 let split_dim = Self::max_delta_index(&min_bounds, &max_bounds);
565 let split_dim_u = split_dim as usize;
566
567 let split_val = (max_bounds[split_dim_u] + min_bounds[split_dim_u]) * T::from(0.5).unwrap();
569 let range = max_bounds[split_dim_u] - min_bounds[split_dim_u];
570 let (left_points, right_points) = if range == T::from(0).unwrap() {
571 build_points.split_at_mut(build_points.len() / 2)
573 } else {
574 partition::partition(build_points, |index| {
576 cloud[*index].get(split_dim) < split_val
577 })
578 };
579 debug_assert_ne!(left_points.len(), 0);
580 debug_assert_ne!(right_points.len(), 0);
581
582 self.nodes.push(Node::new_split_node(split_dim, split_val));
584
585 let left_child = self.build_nodes(cloud, left_points);
587 debug_assert_eq!(left_child, pos + 1);
588 let right_child = self.build_nodes(cloud, right_points);
589
590 self.nodes[pos].set_child_index(right_child as u32);
592 pos
593 }
594
595 fn get_build_points_bounds(
596 cloud: &[P],
597 build_points: &[usize],
598 ) -> (Vec<NotNan<T>>, Vec<NotNan<T>>) {
599 let mut min_bounds = vec![NotNan::<T>::max_value(); P::DIM as usize];
600 let mut max_bounds = vec![NotNan::<T>::min_value(); P::DIM as usize];
601 for p_index in build_points {
602 let p = &cloud[*p_index];
603 for index in 0..P::DIM {
604 let index_u = index as usize;
605 min_bounds[index_u] = clamp_max(p.get(index), min_bounds[index_u]);
606 max_bounds[index_u] = clamp_min(p.get(index), max_bounds[index_u]);
607 }
608 }
609 (min_bounds, max_bounds)
610 }
611
612 fn max_delta_index(lower_bound: &[NotNan<T>], upper_bound: &[NotNan<T>]) -> u32 {
613 lower_bound
614 .iter()
615 .zip(upper_bound.iter())
616 .enumerate()
617 .max_by_key(|(_, (l, u))| *u - *l)
618 .unwrap()
619 .0 as u32
620 }
621
622 fn externalise_neighbour(&self, neighbour: InternalNeighbour<T>) -> Neighbour<T, P> {
623 let mut point = P::default();
624 let base_index = neighbour.index * P::DIM;
625 for i in 0..P::DIM {
626 point.set(i, self.points[(base_index + i) as usize]);
627 }
628 Neighbour {
629 point,
630 dist2: neighbour.dist2,
631 index: self.indices[neighbour.index as usize],
632 }
633 }
634}
635
636#[cfg(test)]
637mod tests {
638 use crate::*;
639 use dummy_point::{random_point, random_point_cloud, P2};
640 use float_cmp::approx_eq;
641
642 fn cloud3() -> Vec<P2> {
644 vec![P2::new(0., 0.), P2::new(-1., 3.), P2::new(2., -4.)]
645 }
646
647 fn point_dist2<T: Scalar, P: Point<T>>(lhs: &P, rhs: &P) -> NotNan<T> {
649 let mut dist2 = NotNan::<T>::zero();
650 for index in 0..P::DIM {
651 let diff = lhs.get(index) - rhs.get(index);
652 dist2 += diff * diff;
653 }
654 dist2
655 }
656
657 fn brute_force_1nn(cloud: &[P2], query: &P2) -> Neighbour<f32, P2> {
659 let mut best_dist2 = f32::infinity();
660 let mut best_index = 0;
661 for (index, point) in cloud.iter().enumerate() {
662 let dist2 = point_dist2(point, query).into_inner();
663 if dist2 < best_dist2 {
664 best_dist2 = dist2;
665 best_index = index;
666 }
667 }
668 Neighbour {
669 point: cloud[best_index],
670 dist2: NotNan::new(best_dist2).unwrap(),
671 index: best_index as u32,
672 }
673 }
674
675 fn brute_force_knn<H: CandidateHeap<f32>>(
676 cloud: &[P2],
677 query: &P2,
678 k: u32,
679 ) -> Vec<Neighbour<f32, P2>> {
680 let mut h = H::new_with_k(k);
681 for (index, point) in cloud.iter().enumerate() {
682 let dist2 = point_dist2(point, query);
683 h.add(dist2, index as u32);
684 }
685 h.into_sorted_vec()
686 .into_iter()
687 .map(|n| {
688 let index = n.index as usize;
689 Neighbour {
690 point: cloud[index],
691 dist2: n.dist2,
692 index: n.index,
693 }
694 })
695 .collect()
696 }
697
698 #[test]
701 fn get_build_points_bounds() {
702 const K: usize = 2;
703 let cloud = cloud3();
704 let indices = vec![0, 1, 2];
705 let bounds = KDTree::<_, _, K>::get_build_points_bounds(&cloud, &indices);
706 assert_eq!(bounds.0, vec![-1., -4.]);
707 assert_eq!(bounds.1, vec![2., 3.]);
708 }
709
710 #[test]
711 fn max_delta_index() {
712 const K: usize = 2;
713 let b = |x: f32, y: f32| {
714 [
715 NotNan::<f32>::new(x).unwrap(),
716 NotNan::<f32>::new(y).unwrap(),
717 ]
718 };
719 assert_eq!(
720 KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(0., 1.)),
721 1
722 );
723 assert_eq!(
724 KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(-1., 1.)),
725 1
726 );
727 assert_eq!(
728 KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(-1., -2.)),
729 0
730 );
731 }
732
733 #[test]
734 fn new_tree() {
735 const K: usize = 2;
736 let cloud = cloud3();
737 let tree = KDTree::<_,_,K>::new_with_bucket_size(&cloud, 2);
738 dbg!(tree);
739 }
740
741 #[test]
742 fn query_1nn_allow_self() {
743 const K: usize = 2;
744 let mut touch_sum = 0;
745 const PASS_COUNT: u32 = 20;
746 const QUERY_COUNT: u32 = 100;
747 const CLOUD_SIZE: u32 = 1000;
748 const PARAMETERS: Parameters<f32> = Parameters {
749 epsilon: 0.0,
750 max_radius: f32::INFINITY,
751 allow_self_match: true,
752 sort_results: true,
753 };
754 for _ in 0..PASS_COUNT {
755 let cloud = random_point_cloud(CLOUD_SIZE);
756 let tree = KDTree::<_,_,K>::new(&cloud);
757 for _ in 0..QUERY_COUNT {
758 let query = random_point();
759 let mut touch_statistics = 0;
760
761 let nns_lin = tree.knn_advanced(
763 1,
764 &query,
765 CandidateContainer::Linear,
766 &PARAMETERS,
767 Some(&mut touch_statistics),
768 );
769 assert_eq!(nns_lin.len(), 1);
770 let nn_lin = &nns_lin[0];
771 assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
772 touch_sum += touch_statistics;
773 let nns_bin =
775 tree.knn_advanced(1, &query, CandidateContainer::BinaryHeap, &PARAMETERS, None);
776 assert_eq!(nns_bin.len(), 1);
777 let nn_bin = &nns_bin[0];
778 assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
779 let nn_bf = brute_force_1nn(&cloud, &query);
781 assert_eq!(nn_bf.point, cloud[nn_bf.index as usize]);
782 assert_eq!(
784 nn_bin.index, nn_bf.index,
785 "KDTree binary heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}",
786 query, nn_bin.dist2, nn_bin.point, nn_bf.dist2, nn_bf.point
787 );
788 assert_eq!(nn_lin.index, nn_bf.index, "\nKDTree linear heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", query, nn_lin.dist2, nn_lin.point, nn_bf.dist2, nn_bf.point);
789 assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf.dist2, ulps = 2));
790 assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf.dist2, ulps = 2));
791 }
792 }
793 let touch_pct = (touch_sum * 100) as f32 / (PASS_COUNT * QUERY_COUNT * CLOUD_SIZE) as f32;
794 println!("Average tree point touched: {} %", touch_pct);
795 }
796
797 #[test]
798 fn query_knn_allow_self() {
799 const K: usize = 2;
800 const QUERY_COUNT: u32 = 100;
801 const CLOUD_SIZE: u32 = 1000;
802 const PARAMETERS: Parameters<f32> = Parameters {
803 epsilon: 0.0,
804 max_radius: f32::INFINITY,
805 allow_self_match: true,
806 sort_results: true,
807 };
808 let cloud = random_point_cloud(CLOUD_SIZE);
809 let tree = KDTree::<_,_,K>::new(&cloud);
810 for k in [1, 2, 3, 5, 7, 13] {
811 for _ in 0..QUERY_COUNT {
812 let query = random_point();
813 let nns_bf_lin = brute_force_knn::<Vec<InternalNeighbour<f32>>>(&cloud, &query, k);
815 assert_eq!(nns_bf_lin.len(), k as usize);
816 let nns_bf_bin =
817 brute_force_knn::<BinaryHeap<InternalNeighbour<f32>>>(&cloud, &query, k);
818 assert_eq!(nns_bf_bin.len(), k as usize);
819 #[cfg_attr(rustfmt, rustfmt_skip)]
821 let nns_bin = tree.knn_advanced(
822 k, &query,
823 CandidateContainer::BinaryHeap,
824 &PARAMETERS,
825 None,
826 );
827 assert_eq!(nns_bin.len(), k as usize);
828 #[cfg_attr(rustfmt, rustfmt_skip)]
829 let nns_lin = tree.knn_advanced(
830 k, &query,
831 CandidateContainer::Linear,
832 &PARAMETERS,
833 None,
834 );
835 assert_eq!(nns_lin.len(), k as usize);
836 for i in 0..k as usize {
838 let nn_bf_lin = &nns_bf_lin[i];
840 let nn_bf_bin = &nns_bf_bin[i];
841 let nn_lin = &nns_lin[i];
842 let nn_bin = &nns_bin[i];
843 assert_eq!(nn_bf_lin.point, cloud[nn_bf_lin.index as usize]);
845 assert_eq!(nn_bf_bin.point, cloud[nn_bf_bin.index as usize]);
846 assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
847 assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
848 assert_eq!(nn_bf_bin.index, nn_bf_lin.index, "BF binary heap: mismatch indexes at {} on {}\nquery: {}\n bf bin {}, {}\nvs bf lin {}, {}\n", i, k, query, nn_bf_bin.dist2, nn_bf_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
850 assert_eq!(nn_lin.index, nn_bf_lin.index, "\nKDTree linear heap: mismatch indexes at {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_lin.dist2, nn_lin.point, nn_bf_lin.dist2, nn_bf_lin.point);
851 assert_eq!(nn_bin.index, nn_bf_lin.index, "\nKDTree binary heap: mismatch indexes {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_bin.dist2, nn_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
852 assert!(approx_eq!(
854 f32,
855 *nn_bf_bin.dist2,
856 *nn_bf_lin.dist2,
857 ulps = 2
858 ));
859 assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf_lin.dist2, ulps = 2));
860 assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf_lin.dist2, ulps = 2));
861 }
862 }
863 }
864 }
865
866 #[test]
867 fn small_clouds_can_lead_to_neighbours() {
868 const K: usize = 2;
869 let cloud = vec![P2::new(0.0, 0.0), P2::new(1.0, 0.0)];
870 let tree = KDTree::<_,_,K>::new(&cloud);
871 let query = P2::new(0.5, 0.0);
872 for _ in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
873 let nns = tree.knn(3, &query);
874 assert_eq!(nns.len(), 2);
875 }
876 }
877
878 #[test]
879 fn max_radius_can_lead_to_neighbours() {
880 const K: usize = 2;
881 let cloud = vec![P2::new(0.0, 0.0), P2::new(1.0, 0.0)];
882 let tree = KDTree::<_,_,K>::new(&cloud);
883 let query = P2::new(0.1, 0.0);
884 let parameters = Parameters {
885 epsilon: 0.0,
886 max_radius: 0.5,
887 allow_self_match: false,
888 sort_results: false,
889 };
890 for container in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
891 let nns = tree.knn_advanced(2, &query, container, ¶meters, None);
892 assert_eq!(nns.len(), 1);
893 }
894 }
895}