nabo/
lib.rs

1#![no_std]
2#![warn(missing_docs)]
3
4//! A fast K Nearest Neighbour library for low-dimensional spaces.
5//!
6//! This crate is a  re-implementation in pure Rust of the [C++ library of the same name](https://github.com/ethz-asl/libnabo).
7//! This work has been sponsored by [Enlightware GmbH](https://enlightware.ch).
8//!
9//! # Example
10//! ```
11//! use nabo::simple_point::*;
12//! use nabo::KDTree;
13//! let cloud = random_point_cloud::<2>(10000);
14//! let tree = KDTree::new(&cloud);
15//! let query = random_point();
16//! let neighbour = tree.knn(3, &query);
17//! ```
18//!
19//! If you want to have more control on the search, you can use the advanced API:
20//! ```
21//! use nabo::simple_point::*;
22//! use nabo::KDTree;
23//! use nabo::CandidateContainer;
24//! use nabo::Parameters;
25//! let cloud = random_point_cloud::<2>(10000);
26//! let tree = KDTree::new(&cloud);
27//! let query = random_point();
28//! let mut touch_count = 0;
29//! let neighbour = tree.knn_advanced(
30//!     3,
31//!     &query,
32//!     CandidateContainer::BinaryHeap,
33//!     &Parameters {
34//!         epsilon: 0.0,
35//!         max_radius: 10.0,
36//!         allow_self_match: true,
37//!         sort_results: false,
38//!     },
39//!     Some(&mut touch_count) // statistics
40//! );
41//! ```
42
43// We forbid the clippy lint here because it suggests to use #[rustfmt::skip],
44// which is experimental. See: https://github.com/rust-lang/rust/issues/88591
45#![allow(clippy::deprecated_cfg_attr)]
46
47extern crate alloc;
48
49mod heap;
50mod infinite;
51mod internal_neighbour;
52mod internal_parameters;
53mod node;
54pub mod simple_point;
55
56use alloc::{collections::BinaryHeap, vec, vec::Vec};
57use core::ops::AddAssign;
58use internal_parameters::InternalParameters;
59use node::Node;
60use num_traits::{clamp_max, clamp_min, Bounded, Zero};
61use ordered_float::FloatCore;
62pub use ordered_float::{FloatIsNan, NotNan};
63
64use heap::CandidateHeap;
65use internal_neighbour::InternalNeighbour;
66
67/// The scalar type for points in the space to be searched
68pub trait Scalar: FloatCore + AddAssign + core::fmt::Debug {}
69impl<T: FloatCore + AddAssign + core::fmt::Debug> Scalar for T {}
70
71/// A point in the space to be searched
72pub trait Point<T: Scalar>: Default {
73    /// Sets the value for the `i`-th component, `i` must be within `0..DIM`.
74    fn set(&mut self, i: u32, value: NotNan<T>);
75    /// Gets the value for the `i`-th component, `i` must be within `0..DIM`.
76    fn get(&self, i: u32) -> NotNan<T>;
77    /// The number of dimension of the space this point lies in.
78    const DIM: u32;
79    /// Derived from `DIM`, do not reimplement, use the default!
80    const DIM_BIT_COUNT: u32 = 32 - Self::DIM.leading_zeros();
81    /// Derived from `DIM`, do not reimplement, use the default!
82    const DIM_MASK: u32 = (1 << Self::DIM_BIT_COUNT) - 1;
83    /// Derived from `DIM`, do not reimplement, use the default!
84    const MAX_NODE_COUNT: u32 = ((1u64 << (32 - Self::DIM_BIT_COUNT)) - 1) as u32;
85
86    /// Construct from a slice of valid axis values.
87    ///
88    /// If the slice is too short, the point will be right-filled as `Point::default()`.
89    /// if it is too long, the extra elements will be ignored.
90    fn from_slice(values: &[NotNan<T>]) -> Self {
91        let mut p = Self::default();
92        for (idx, v) in values.iter().take(Self::DIM as usize).enumerate() {
93            p.set(idx as u32, *v);
94        }
95        p
96    }
97
98    /// Construct from a slice of raw axis values.
99    ///
100    /// If the slice is too short, the point will be right-filled as `Point::default()`.
101    /// if it is too long, the extra elements will be ignored.
102    fn from_raw(values: &[T]) -> Result<Self, FloatIsNan> {
103        let mut p = Self::default();
104        for (idx, v) in values.iter().take(Self::DIM as usize).enumerate() {
105            p.set(idx as u32, NotNan::new(*v)?);
106        }
107        Ok(p)
108    }
109}
110
111/// Helper function to compute the square distance between two points given as slice
112#[inline]
113fn point_slice_dist2<T: Scalar, P: Point<T>>(lhs: &[NotNan<T>], rhs: &[NotNan<T>]) -> NotNan<T> {
114    let mut dist2 = NotNan::<T>::zero();
115    for index in 0..P::DIM {
116        let index = index as usize;
117        let diff = lhs[index] - rhs[index];
118        dist2 += diff * diff;
119    }
120    dist2
121}
122
123/// The index of a point in the original point cloud
124pub type Index = u32;
125
126/// A neighbour resulting from the search
127#[derive(Debug)]
128pub struct Neighbour<T: Scalar, P: Point<T>> {
129    /// the point itself
130    pub point: P,
131    /// the squared-distance to the point
132    pub dist2: NotNan<T>,
133    /// the index of the point in the original point cloud
134    pub index: Index,
135}
136
137/// The type of container to keep candidates
138pub enum CandidateContainer {
139    /// use a linear vector to keep candidates, good for small k
140    Linear,
141    /// use a binary heap to keep candidates, good for large k
142    BinaryHeap,
143}
144
145/// Advanced search parameters
146pub struct Parameters<T: Scalar> {
147    /// maximal ratio of error for approximate search, 0 for exact search; has no effect if the number of neighbours found is smaller than the number requested
148    pub epsilon: T,
149    /// maximum radius in which to search, can be used to prune search, is not affected by `epsilon`
150    pub max_radius: T,
151    /// allows the return of the same point as the query, if this point is in the point cloud
152    pub allow_self_match: bool,
153    /// sort points by distances, when `k` > 1
154    pub sort_results: bool,
155}
156
157/// A dense vector of search nodes, provides better memory performances than many small objects
158type Nodes<T, P> = Vec<Node<T, P>>;
159
160/// A KD-Tree to perform NN-search queries
161///
162/// This implementation is inspired of the variant `KDTreeUnbalancedPtInLeavesImplicitBoundsStackOpt` in libnabo C++.
163/// Contrary to the latter, it does not keep a reference to the point cloud but copies the point.
164/// It retains their index though.
165#[derive(Clone, Debug)]
166pub struct KDTree<T: Scalar, P: Point<T>> {
167    /// size of a bucket
168    bucket_size: u32,
169    /// search nodes
170    nodes: Nodes<T, P>,
171    /// point data, size cloud.len() * P::DIM
172    points: Vec<NotNan<T>>,
173    /// indices in cloud , size cloud.len()
174    indices: Vec<Index>,
175}
176
177impl<T: Scalar, P: Point<T>> KDTree<T, P> {
178    /// Creates a new KD-Tree from a point cloud.
179    pub fn new(cloud: &[P]) -> Self {
180        KDTree::new_with_bucket_size(cloud, 8)
181    }
182    /// Creates a new KD-Tree from a point cloud.
183    ///
184    /// The `bucket_size` can be chosen freely, but must be at least 2.
185    pub fn new_with_bucket_size(cloud: &[P], bucket_size: u32) -> Self {
186        // validate input
187        if bucket_size < 2 {
188            panic!(
189                "Bucket size must be at least 2, but {} was passed",
190                bucket_size
191            );
192        }
193        if cloud.len() > u32::MAX as usize {
194            panic!(
195                "Point cloud is larger than maximum possible size {}",
196                u32::MAX
197            );
198        }
199        let estimated_node_count = (cloud.len() / (bucket_size as usize / 2)) as u32;
200        if estimated_node_count > P::MAX_NODE_COUNT {
201            panic!("Point cloud has a risk to have more nodes {} than the kd-tree allows {}. The kd-tree has {} bits for dimensions and {} bits for node indices", estimated_node_count, P::MAX_NODE_COUNT, P::DIM_BIT_COUNT, 32 - P::DIM_BIT_COUNT);
202        }
203
204        // build point vector and compute bounds
205        let mut build_points: Vec<_> = (0..cloud.len()).collect();
206
207        // create and populate tree
208        let mut tree = KDTree {
209            bucket_size,
210            nodes: Vec::with_capacity(estimated_node_count as usize),
211            points: Vec::with_capacity(cloud.len() * P::DIM as usize),
212            indices: Vec::with_capacity(cloud.len()),
213        };
214        tree.build_nodes(cloud, &mut build_points);
215        tree
216    }
217
218    /// Finds the `k` nearest neighbour of `query`, using reasonable default parameters.
219    ///
220    /// If there are less than `k` points in the point cloud, the returned vector will be smaller than `k`.
221    /// The default parameters are:
222    /// Exact search, no max. radius, allowing self matching, sorting results, and not collecting statistics.
223    /// If `k` <= 16, a linear vector is used to keep track of candidates, otherwise a binary heap is used.
224    pub fn knn(&self, k: u32, query: &P) -> Vec<Neighbour<T, P>> {
225        let candidate_container = if k <= 16 {
226            CandidateContainer::Linear
227        } else {
228            CandidateContainer::BinaryHeap
229        };
230        #[cfg_attr(rustfmt, rustfmt_skip)]
231        self.knn_advanced(
232            k, query,
233            candidate_container,
234            &Parameters {
235                epsilon: T::from(0.0).unwrap(),
236                max_radius: T::infinity(),
237                allow_self_match: true,
238                sort_results: true,
239            },
240            None,
241        )
242    }
243
244    /// Finds the `k` nearest neighbour of `query`, with user-provided parameters.
245    ///
246    /// If there are less than `k` points in the point cloud or in the ball around `query`
247    /// defined by `parameters.max_radius`, the returned vector will be smaller than `k`.
248    /// The parameters are:
249    /// * `candidate_container` which container to use to collect candidates,
250    /// * `parameters` the advanced search parameters,
251    /// * `touch_statistics`, if `Some(&mut u32)`, return the number of point touched in the provided `u32` reference.
252    pub fn knn_advanced(
253        &self,
254        k: u32,
255        query: &P,
256        candidate_container: CandidateContainer,
257        parameters: &Parameters<T>,
258        touch_statistics: Option<&mut u32>,
259    ) -> Vec<Neighbour<T, P>> {
260        #[cfg_attr(rustfmt, rustfmt_skip)]
261        (match candidate_container {
262            CandidateContainer::Linear => Self::knn_generic_heap::<Vec<InternalNeighbour<T>>>,
263            CandidateContainer::BinaryHeap => Self::knn_generic_heap::<BinaryHeap<InternalNeighbour<T>>>
264        })(
265            self,
266            k, query,
267            parameters, touch_statistics
268        )
269    }
270
271    fn knn_generic_heap<H: CandidateHeap<T>>(
272        &self,
273        k: u32,
274        query: &P,
275        parameters: &Parameters<T>,
276        touch_statistics: Option<&mut u32>,
277    ) -> Vec<Neighbour<T, P>> {
278        let query_as_vec: Vec<_> = (0..P::DIM).map(|i| query.get(i)).collect();
279        let Parameters {
280            epsilon,
281            max_radius,
282            allow_self_match,
283            sort_results,
284        } = *parameters;
285        let max_error = epsilon + T::from(1).unwrap();
286        let max_error2 = NotNan::new(max_error * max_error).unwrap();
287        let max_radius2 = NotNan::new(max_radius * max_radius).unwrap();
288        #[cfg_attr(rustfmt, rustfmt_skip)]
289        self.knn_internal::<H>(
290            k, &query_as_vec,
291            &InternalParameters { max_error2, max_radius2, allow_self_match },
292            sort_results, touch_statistics,
293        )
294            .into_iter()
295            .map(|n| self.externalise_neighbour(n))
296            .collect()
297    }
298
299    fn knn_internal<H: CandidateHeap<T>>(
300        &self,
301        k: u32,
302        query: &[NotNan<T>],
303        internal_parameters: &InternalParameters<T>,
304        sort_results: bool,
305        touch_statistics: Option<&mut u32>,
306    ) -> Vec<InternalNeighbour<T>> {
307        // TODO Const generics: once available, remove `vec!` below.
308        let mut off = vec![NotNan::<T>::zero(); P::DIM as usize];
309        let mut heap = H::new_with_k(k);
310        #[cfg_attr(rustfmt, rustfmt_skip)]
311        let leaf_touched_count = self.recurse_knn(
312            query,
313            0, NotNan::<T>::zero(),
314            &mut heap, &mut off,
315            internal_parameters,
316        );
317        if let Some(touch_statistics) = touch_statistics {
318            *touch_statistics = leaf_touched_count;
319        }
320        if sort_results {
321            heap.into_sorted_vec()
322        } else {
323            heap.into_vec()
324        }
325    }
326
327    #[allow(clippy::too_many_arguments)]
328    fn recurse_knn<H: CandidateHeap<T>>(
329        &self,
330        query: &[NotNan<T>],
331        node: usize,
332        rd: NotNan<T>,
333        heap: &mut H,
334        off: &mut [NotNan<T>],
335        internal_parameters: &InternalParameters<T>,
336    ) -> u32 {
337        self.nodes[node].dispatch_on_type(
338            heap,
339            |heap, split_dim, split_val, right_child| {
340                // split node, see whether we have to recurse
341                let mut rd = rd;
342                let split_dim = split_dim as usize;
343                let old_off = off[split_dim];
344                let new_off = query[split_dim] - split_val;
345                let left_child = node + 1;
346                let right_child = right_child as usize;
347                let InternalParameters {
348                    max_radius2,
349                    max_error2,
350                    ..
351                } = *internal_parameters;
352                if new_off > NotNan::<T>::zero() {
353                    #[cfg_attr(rustfmt, rustfmt_skip)]
354                    let mut leaf_visited_count = self.recurse_knn(
355                        query,
356                        right_child, rd,
357                        heap, off,
358                        internal_parameters,
359                    );
360                    rd += new_off * new_off - old_off * old_off;
361                    if rd <= max_radius2 && rd * max_error2 < heap.furthest_dist2() {
362                        off[split_dim] = new_off;
363                        #[cfg_attr(rustfmt, rustfmt_skip)]
364                        let new_visits= self.recurse_knn(
365                            query,
366                            left_child, rd,
367                            heap, off,
368                            internal_parameters,
369                        );
370                        leaf_visited_count += new_visits;
371                        off[split_dim] = old_off;
372                    }
373                    leaf_visited_count
374                } else {
375                    #[cfg_attr(rustfmt, rustfmt_skip)]
376                    let mut leaf_visited_count = self.recurse_knn(
377                        query,
378                        left_child, rd,
379                        heap, off,
380                        internal_parameters,
381                    );
382                    rd += new_off * new_off - old_off * old_off;
383                    if rd <= max_radius2 && rd * max_error2 < heap.furthest_dist2() {
384                        off[split_dim] = new_off;
385                        #[cfg_attr(rustfmt, rustfmt_skip)]
386                        let new_visits = self.recurse_knn(
387                            query,
388                            right_child, rd,
389                            heap, off,
390                            internal_parameters,
391                        );
392                        leaf_visited_count += new_visits;
393                        off[split_dim] = old_off;
394                    }
395                    leaf_visited_count
396                }
397            },
398            |heap, bucket_start_index, bucket_size| {
399                // leaf node, go through the buckets and check elements
400                let bucket_end_index = bucket_start_index + bucket_size;
401                for bucket_index in bucket_start_index..bucket_end_index {
402                    let point_index = (bucket_index * P::DIM) as usize;
403                    let point = &self.points[point_index..point_index + (P::DIM as usize)];
404                    let dist2 = point_slice_dist2::<T, P>(query, point);
405                    let epsilon = NotNan::new(T::epsilon()).unwrap();
406                    let InternalParameters {
407                        max_radius2,
408                        allow_self_match,
409                        ..
410                    } = *internal_parameters;
411                    if dist2 < max_radius2 && (allow_self_match || (dist2 > epsilon)) {
412                        heap.add(dist2, bucket_index);
413                    }
414                }
415                bucket_size
416            },
417        )
418    }
419
420    fn build_nodes(&mut self, cloud: &[P], build_points: &mut [usize]) -> usize {
421        let count = build_points.len() as u32;
422        let pos = self.nodes.len();
423
424        // if remaining points fit in a single bucket, add a node and this bucket
425        if count <= self.bucket_size {
426            let bucket_start_index = self.indices.len() as u32;
427            self.points.reserve(build_points.len() * P::DIM as usize);
428            self.indices.reserve(build_points.len());
429            for point_index in build_points {
430                let point_index = *point_index;
431                self.indices.push(point_index as u32);
432                for i in 0..P::DIM {
433                    self.points.push(cloud[point_index].get(i));
434                }
435            }
436            self.nodes
437                .push(Node::new_leaf_node(bucket_start_index, count));
438            return pos;
439        }
440
441        // compute bounds
442        let (min_bounds, max_bounds) = Self::get_build_points_bounds(cloud, build_points);
443
444        // find the largest dimension of the box
445        let split_dim = Self::max_delta_index(&min_bounds, &max_bounds);
446        let split_dim_u = split_dim as usize;
447
448        // split along this dimension
449        let split_val = (max_bounds[split_dim_u] + min_bounds[split_dim_u]) * T::from(0.5).unwrap();
450        let range = max_bounds[split_dim_u] - min_bounds[split_dim_u];
451        let (left_points, right_points) = if range == T::from(0).unwrap() {
452            // degenerate data, split in half and iterate
453            build_points.split_at_mut(build_points.len() / 2)
454        } else {
455            // partition data around split_val on split_dim
456            partition::partition(build_points, |index| {
457                cloud[*index].get(split_dim) < split_val
458            })
459        };
460        debug_assert_ne!(left_points.len(), 0);
461        debug_assert_ne!(right_points.len(), 0);
462
463        // add this split
464        self.nodes.push(Node::new_split_node(split_dim, split_val));
465
466        // recurse
467        let left_child = self.build_nodes(cloud, left_points);
468        debug_assert_eq!(left_child, pos + 1);
469        let right_child = self.build_nodes(cloud, right_points);
470
471        // write right child index and return
472        self.nodes[pos].set_child_index(right_child as u32);
473        pos
474    }
475
476    fn get_build_points_bounds(
477        cloud: &[P],
478        build_points: &[usize],
479    ) -> (Vec<NotNan<T>>, Vec<NotNan<T>>) {
480        let mut min_bounds = vec![NotNan::<T>::max_value(); P::DIM as usize];
481        let mut max_bounds = vec![NotNan::<T>::min_value(); P::DIM as usize];
482        for p_index in build_points {
483            let p = &cloud[*p_index];
484            for index in 0..P::DIM {
485                let index_u = index as usize;
486                min_bounds[index_u] = clamp_max(p.get(index), min_bounds[index_u]);
487                max_bounds[index_u] = clamp_min(p.get(index), max_bounds[index_u]);
488            }
489        }
490        (min_bounds, max_bounds)
491    }
492
493    fn max_delta_index(lower_bound: &[NotNan<T>], upper_bound: &[NotNan<T>]) -> u32 {
494        lower_bound
495            .iter()
496            .zip(upper_bound.iter())
497            .enumerate()
498            .max_by_key(|(_, (l, u))| *u - *l)
499            .unwrap()
500            .0 as u32
501    }
502
503    fn externalise_neighbour(&self, neighbour: InternalNeighbour<T>) -> Neighbour<T, P> {
504        let mut point = P::default();
505        let base_index = neighbour.index * P::DIM;
506        for i in 0..P::DIM {
507            point.set(i, self.points[(base_index + i) as usize]);
508        }
509        Neighbour {
510            point,
511            dist2: neighbour.dist2,
512            index: self.indices[neighbour.index as usize],
513        }
514    }
515
516    /// Iterate over the indices and points in this KDTree.
517    /// The order is arbitrary;
518    /// the indices are the point's location in the slice
519    /// from which the tree was built.
520    pub fn iter_idx_points(&self) -> impl Iterator<Item = (u32, P)> + '_ {
521        self.indices.iter().cloned().zip(self.iter_points())
522    }
523
524    /// Iterate over the points in this KDTree in arbitrary order.
525    pub fn iter_points(&self) -> impl Iterator<Item = P> + '_ {
526        self.points
527            .as_slice()
528            .chunks(P::DIM as usize)
529            .map(P::from_slice)
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    extern crate std;
536    use crate::*;
537    use float_cmp::approx_eq;
538    use simple_point::{random_point, random_point_cloud, P2};
539    use std::{dbg, println};
540
541    // helpers to create cloud
542    fn cloud3() -> Vec<P2> {
543        vec![P2::new2d(0., 0.), P2::new2d(-1., 3.), P2::new2d(2., -4.)]
544    }
545
546    // helper to compute the square distance between two points
547    fn point_dist2<T: Scalar, P: Point<T>>(lhs: &P, rhs: &P) -> NotNan<T> {
548        let mut dist2 = NotNan::<T>::zero();
549        for index in 0..P::DIM {
550            let diff = lhs.get(index) - rhs.get(index);
551            dist2 += diff * diff;
552        }
553        dist2
554    }
555
556    // brute force search implementations
557    fn brute_force_1nn(cloud: &[P2], query: &P2) -> Neighbour<f32, P2> {
558        let mut best_dist2 = f32::infinity();
559        let mut best_index = 0;
560        for (index, point) in cloud.iter().enumerate() {
561            let dist2 = point_dist2(point, query).into_inner();
562            if dist2 < best_dist2 {
563                best_dist2 = dist2;
564                best_index = index;
565            }
566        }
567        Neighbour {
568            point: cloud[best_index],
569            dist2: NotNan::new(best_dist2).unwrap(),
570            index: best_index as u32,
571        }
572    }
573
574    fn brute_force_knn<H: CandidateHeap<f32>>(
575        cloud: &[P2],
576        query: &P2,
577        k: u32,
578    ) -> Vec<Neighbour<f32, P2>> {
579        let mut h = H::new_with_k(k);
580        for (index, point) in cloud.iter().enumerate() {
581            let dist2 = point_dist2(point, query);
582            h.add(dist2, index as u32);
583        }
584        h.into_sorted_vec()
585            .into_iter()
586            .map(|n| {
587                let index = n.index as usize;
588                Neighbour {
589                    point: cloud[index],
590                    dist2: n.dist2,
591                    index: n.index,
592                }
593            })
594            .collect()
595    }
596
597    // tests themselves
598
599    #[test]
600    fn get_build_points_bounds() {
601        let cloud = cloud3();
602        let indices = vec![0, 1, 2];
603        let bounds = KDTree::get_build_points_bounds(&cloud, &indices);
604        assert_eq!(bounds.0, vec![-1., -4.]);
605        assert_eq!(bounds.1, vec![2., 3.]);
606    }
607
608    #[test]
609    fn max_delta_index() {
610        let b = |x: f32, y: f32| {
611            [
612                NotNan::<f32>::new(x).unwrap(),
613                NotNan::<f32>::new(y).unwrap(),
614            ]
615        };
616        assert_eq!(
617            KDTree::<f32, P2>::max_delta_index(&b(0., 0.), &b(0., 1.)),
618            1
619        );
620        assert_eq!(
621            KDTree::<f32, P2>::max_delta_index(&b(0., 0.), &b(-1., 1.)),
622            1
623        );
624        assert_eq!(
625            KDTree::<f32, P2>::max_delta_index(&b(0., 0.), &b(-1., -2.)),
626            0
627        );
628    }
629
630    #[test]
631    fn new_tree() {
632        let cloud = cloud3();
633        let tree = KDTree::new_with_bucket_size(&cloud, 2);
634        dbg!(tree);
635    }
636
637    #[test]
638    fn query_1nn_allow_self() {
639        let mut touch_sum = 0;
640        const PASS_COUNT: u32 = 20;
641        const QUERY_COUNT: u32 = 100;
642        const CLOUD_SIZE: u32 = 1000;
643        const PARAMETERS: Parameters<f32> = Parameters {
644            epsilon: 0.0,
645            max_radius: f32::INFINITY,
646            allow_self_match: true,
647            sort_results: true,
648        };
649        for _ in 0..PASS_COUNT {
650            let cloud = random_point_cloud(CLOUD_SIZE);
651            let tree = KDTree::new(&cloud);
652            for _ in 0..QUERY_COUNT {
653                let query = random_point();
654                let mut touch_statistics = 0;
655
656                // linear search
657                let nns_lin = tree.knn_advanced(
658                    1,
659                    &query,
660                    CandidateContainer::Linear,
661                    &PARAMETERS,
662                    Some(&mut touch_statistics),
663                );
664                assert_eq!(nns_lin.len(), 1);
665                let nn_lin = &nns_lin[0];
666                assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
667                touch_sum += touch_statistics;
668                // binary
669                let nns_bin =
670                    tree.knn_advanced(1, &query, CandidateContainer::BinaryHeap, &PARAMETERS, None);
671                assert_eq!(nns_bin.len(), 1);
672                let nn_bin = &nns_bin[0];
673                assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
674                // brute force
675                let nn_bf = brute_force_1nn(&cloud, &query);
676                assert_eq!(nn_bf.point, cloud[nn_bf.index as usize]);
677                // assertion
678                assert_eq!(
679                    nn_bin.index, nn_bf.index,
680                    "KDTree binary heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}",
681                    query, nn_bin.dist2, nn_bin.point, nn_bf.dist2, nn_bf.point
682                );
683                assert_eq!(nn_lin.index, nn_bf.index, "\nKDTree linear heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", query, nn_lin.dist2, nn_lin.point, nn_bf.dist2, nn_bf.point);
684                assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf.dist2, ulps = 2));
685                assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf.dist2, ulps = 2));
686            }
687        }
688        let touch_pct = (touch_sum * 100) as f32 / (PASS_COUNT * QUERY_COUNT * CLOUD_SIZE) as f32;
689        println!("Average tree point touched: {} %", touch_pct);
690    }
691
692    #[test]
693    fn query_knn_allow_self() {
694        const QUERY_COUNT: u32 = 100;
695        const CLOUD_SIZE: u32 = 1000;
696        const PARAMETERS: Parameters<f32> = Parameters {
697            epsilon: 0.0,
698            max_radius: f32::INFINITY,
699            allow_self_match: true,
700            sort_results: true,
701        };
702        let cloud = random_point_cloud(CLOUD_SIZE);
703        let tree = KDTree::new(&cloud);
704        for k in [1, 2, 3, 5, 7, 13] {
705            for _ in 0..QUERY_COUNT {
706                let query = random_point();
707                // brute force
708                let nns_bf_lin = brute_force_knn::<Vec<InternalNeighbour<f32>>>(&cloud, &query, k);
709                assert_eq!(nns_bf_lin.len(), k as usize);
710                let nns_bf_bin =
711                    brute_force_knn::<BinaryHeap<InternalNeighbour<f32>>>(&cloud, &query, k);
712                assert_eq!(nns_bf_bin.len(), k as usize);
713                // kd-tree
714                #[cfg_attr(rustfmt, rustfmt_skip)]
715                let nns_bin = tree.knn_advanced(
716                    k, &query,
717                    CandidateContainer::BinaryHeap,
718                    &PARAMETERS,
719                    None,
720                );
721                assert_eq!(nns_bin.len(), k as usize);
722                #[cfg_attr(rustfmt, rustfmt_skip)]
723                let nns_lin = tree.knn_advanced(
724                    k, &query,
725                    CandidateContainer::Linear,
726                    &PARAMETERS,
727                    None,
728                );
729                assert_eq!(nns_lin.len(), k as usize);
730                // assertion
731                for i in 0..k as usize {
732                    // get neighbour
733                    let nn_bf_lin = &nns_bf_lin[i];
734                    let nn_bf_bin = &nns_bf_bin[i];
735                    let nn_lin = &nns_lin[i];
736                    let nn_bin = &nns_bin[i];
737                    // ensure their point data are consistent with the cloud
738                    assert_eq!(nn_bf_lin.point, cloud[nn_bf_lin.index as usize]);
739                    assert_eq!(nn_bf_bin.point, cloud[nn_bf_bin.index as usize]);
740                    assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
741                    assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
742                    // ensure their indices are consistent
743                    assert_eq!(nn_bf_bin.index, nn_bf_lin.index, "BF binary heap: mismatch indexes at {} on {}\nquery: {}\n   bf bin {}, {}\nvs bf lin {}, {}\n", i, k, query, nn_bf_bin.dist2, nn_bf_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
744                    assert_eq!(nn_lin.index, nn_bf_lin.index, "\nKDTree linear heap: mismatch indexes at {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_lin.dist2, nn_lin.point, nn_bf_lin.dist2, nn_bf_lin.point);
745                    assert_eq!(nn_bin.index, nn_bf_lin.index, "\nKDTree binary heap: mismatch indexes {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_bin.dist2, nn_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
746                    // ensure their dist2 are consistent
747                    assert!(approx_eq!(
748                        f32,
749                        *nn_bf_bin.dist2,
750                        *nn_bf_lin.dist2,
751                        ulps = 2
752                    ));
753                    assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf_lin.dist2, ulps = 2));
754                    assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf_lin.dist2, ulps = 2));
755                }
756            }
757        }
758    }
759
760    #[test]
761    fn small_clouds_can_lead_to_neighbours() {
762        let cloud = vec![P2::new2d(0.0, 0.0), P2::new2d(1.0, 0.0)];
763        let tree = KDTree::new(&cloud);
764        let query = P2::new2d(0.5, 0.0);
765        for _ in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
766            let nns = tree.knn(3, &query);
767            assert_eq!(nns.len(), 2);
768        }
769    }
770
771    #[test]
772    fn max_radius_can_lead_to_neighbours() {
773        let cloud = vec![P2::new2d(0.0, 0.0), P2::new2d(1.0, 0.0)];
774        let tree = KDTree::new(&cloud);
775        let query = P2::new2d(0.1, 0.0);
776        let parameters = Parameters {
777            epsilon: 0.0,
778            max_radius: 0.5,
779            allow_self_match: false,
780            sort_results: false,
781        };
782        for container in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
783            let nns = tree.knn_advanced(2, &query, container, &parameters, None);
784            assert_eq!(nns.len(), 1);
785        }
786    }
787}