nabo_pbc/
lib.rs

1#![warn(missing_docs)]
2
3//! A fast K Nearest Neighbour library for low-dimensional spaces.
4//!
5//! This crate is a  re-implementation in pure Rust of the [C++ library of the same name](https://github.com/ethz-asl/libnabo).
6//! This work has been sponsored by [Enlightware GmbH](https://enlightware.ch).
7//!
8//! # Example
9//! ```
10//! use nabo_pbc::dummy_point::*;
11//! use nabo_pbc::KDTree;
12//! const K: usize = 2;
13//! let cloud = random_point_cloud(10000);
14//! let tree = KDTree::<_,_,K>::new(&cloud);
15//! let query = random_point();
16//! let neighbour = tree.knn(3, &query);
17//! ```
18//!
19//! If you want to have more control on the search, you can use the advanced API:
20//! ```
21//! use nabo_pbc::dummy_point::*;
22//! use nabo_pbc::KDTree;
23//! use nabo_pbc::CandidateContainer;
24//! use nabo_pbc::Parameters;
25//! const K: usize = 2;
26//! let cloud = random_point_cloud(10000);
27//! let tree = KDTree::<_,_,K>::new(&cloud);
28//! let query = random_point();
29//! let mut touch_count = 0;
30//! let neighbour = tree.knn_advanced(
31//!     3,
32//!     &query,
33//!     CandidateContainer::BinaryHeap,
34//!     &Parameters {
35//!         epsilon: 0.0,
36//!         max_radius: 10.0,
37//!         allow_self_match: true,
38//!         sort_results: false,
39//!     },
40//!     Some(&mut touch_count) // statistics
41//! );
42//! ```
43
44// We forbid the clippy lint here because it suggests to use #[rustfmt::skip],
45// which is experimental. See: https://github.com/rust-lang/rust/issues/88591
46#![allow(clippy::deprecated_cfg_attr)]
47
48#[cfg(any(test, feature = "dummy_point"))]
49pub mod dummy_point;
50mod heap;
51mod infinite;
52mod internal_neighbour;
53mod internal_parameters;
54mod node;
55
56use internal_parameters::InternalParameters;
57use node::Node;
58use num_traits::{clamp_max, clamp_min, Bounded, Zero, Signed, FromPrimitive};
59use ordered_float::Float;
60pub use ordered_float::NotNan;
61use std::{collections::BinaryHeap, ops::AddAssign};
62use std::cmp::{Ordering, Ord};
63use std::fmt::Debug;
64use heap::CandidateHeap;
65use internal_neighbour::InternalNeighbour;
66
67/// The scalar type for points in the space to be searched
68pub trait Scalar: Float + AddAssign + FromPrimitive + std::fmt::Debug {}
69impl<T: Float + AddAssign + FromPrimitive + std::fmt::Debug> Scalar for T {}
70
71/// A point in the space to be searched
72pub trait Point<T: Scalar>: Default + Clone + Debug + Copy {
73    /// Sets the value for the `i`-th component, `i` must be within `0..DIM`.
74    fn set(&mut self, i: u32, value: NotNan<T>);
75    /// Gets the value for the `i`-th component, `i` must be within `0..DIM`.
76    fn get(&self, i: u32) -> NotNan<T>;
77    /// The number of dimension of the space this point lies in.
78    const DIM: u32;
79    /// Derived from `DIM`, do not reimplement, use the default!
80    const DIM_BIT_COUNT: u32 = 32 - Self::DIM.leading_zeros();
81    /// Derived from `DIM`, do not reimplement, use the default!
82    const DIM_MASK: u32 = (1 << Self::DIM_BIT_COUNT) - 1;
83    /// Derived from `DIM`, do not reimplement, use the default!
84    const MAX_NODE_COUNT: u32 = ((1u64 << (32 - Self::DIM_BIT_COUNT)) - 1) as u32;
85}
86
87/// Helper function to compute the square distance between two points given as slice
88#[inline]
89fn point_slice_dist2<T: Scalar, P: Point<T>>(lhs: &[NotNan<T>], rhs: &[NotNan<T>]) -> NotNan<T> {
90    let mut dist2 = NotNan::<T>::zero();
91    for index in 0..P::DIM {
92        let index = index as usize;
93        let diff = lhs[index] - rhs[index];
94        dist2 += diff * diff;
95    }
96    dist2
97}
98
99/// The index of a point in the original point cloud
100pub type Index = u32;
101
102/// A neighbour resulting from the search
103#[derive(Debug)]
104pub struct Neighbour<T: Scalar, P: Point<T>> {
105    /// the point itself
106    pub point: P,
107    /// the squared-distance to the point
108    pub dist2: NotNan<T>,
109    /// the index of the point in the original point cloud
110    pub index: Index,
111}
112
113impl<T: Scalar, P: Point<T>> PartialOrd for Neighbour<T,P> {
114    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
115        self.dist2.partial_cmp(&other.dist2)
116    }
117}
118
119impl<T: Scalar, P: Point<T>> Ord for Neighbour<T, P>
120where
121    NotNan<T>: Eq + Ord,
122{
123    fn cmp(&self, other: &Self) -> Ordering {
124        self.dist2.cmp(&other.dist2)
125    }
126}
127
128impl<T: Scalar, P: Point<T>> PartialEq for Neighbour<T,P> {
129    fn eq(&self, other: &Self) -> bool {
130        self.dist2 == other.dist2
131    }
132}
133
134impl<T: Scalar, P: Point<T>> Eq for Neighbour<T,P> { }
135
136/// The type of container to keep candidates
137#[derive(Clone, Copy)]
138pub enum CandidateContainer {
139    /// use a linear vector to keep candidates, good for small k
140    Linear,
141    /// use a binary heap to keep candidates, good for large k
142    BinaryHeap,
143}
144
145/// Advanced search parameters
146pub struct Parameters<T: Scalar> {
147    /// maximal ratio of error for approximate search, 0 for exact search; has no effect if the number of neighbours found is smaller than the number requested
148    pub epsilon: T,
149    /// maximum radius in which to search, can be used to prune search, is not affected by `epsilon`
150    pub max_radius: T,
151    /// allows the return of the same point as the query, if this point is in the point cloud
152    pub allow_self_match: bool,
153    /// sort points by distances, when `k` > 1
154    pub sort_results: bool,
155}
156
157impl<T: Scalar> Default for Parameters<T> {
158    fn default() -> Parameters<T> {
159        Parameters { 
160            epsilon: T::zero(), 
161            max_radius: T::infinity(),
162            allow_self_match: true,
163            sort_results: true,
164        }
165    }
166}
167
168/// A dense vector of search nodes, provides better memory performances than many small objects
169type Nodes<T, P> = Vec<Node<T, P>>;
170
171/// A KD-Tree to perform NN-search queries
172///
173/// This implementation is inspired of the variant `KDTreeUnbalancedPtInLeavesImplicitBoundsStackOpt` in libnabo C++.
174/// Contrary to the latter, it does not keep a reference to the point cloud but copies the point.
175/// It retains their index though.
176#[derive(Debug)]
177pub struct KDTree<T: Scalar, P: Point<T>, const K: usize> {
178    /// size of a bucket
179    bucket_size: u32,
180    /// Search nodes
181    nodes: Nodes<T, P>,
182    /// point data, size cloud.len() * P::DIM
183    points: Vec<NotNan<T>>,
184    /// indices in cloud , size cloud.len()
185    indices: Vec<Index>,
186}
187
188impl<T: Scalar + Signed, P: Point<T>, const K: usize> KDTree<T, P, K> {
189    /// Creates a new KD-Tree from a point cloud.
190    pub fn new(cloud: &[P]) -> Self {
191        KDTree::new_with_bucket_size(cloud, 8)
192    }
193    /// Creates a new KD-Tree from a point cloud.
194    ///
195    /// The `bucket_size` can be chosen freely, but must be at least 2.
196    pub fn new_with_bucket_size(cloud: &[P], bucket_size: u32) -> Self {
197        // validate input
198        if bucket_size < 2 {
199            panic!(
200                "Bucket size must be at least 2, but {} was passed",
201                bucket_size
202            );
203        }
204        if cloud.len() > u32::MAX as usize {
205            panic!(
206                "Point cloud is larger than maximum possible size {}",
207                u32::MAX
208            );
209        }
210        let estimated_node_count = (cloud.len() / (bucket_size as usize / 2)) as u32;
211        if estimated_node_count > P::MAX_NODE_COUNT {
212            panic!("Point cloud has a risk to have more nodes {} than the kd-tree allows {}. The kd-tree has {} bits for dimensions and {} bits for node indices", estimated_node_count, P::MAX_NODE_COUNT, P::DIM_BIT_COUNT, 32 - P::DIM_BIT_COUNT);
213        }
214
215        // build point vector and compute bounds
216        let mut build_points: Vec<_> = (0..cloud.len()).collect();
217
218        // create and populate tree
219        let mut tree = KDTree {
220            bucket_size,
221            nodes: Vec::with_capacity(estimated_node_count as usize),
222            points: Vec::with_capacity(cloud.len() * P::DIM as usize),
223            indices: Vec::with_capacity(cloud.len()),
224        };
225        tree.build_nodes(cloud, &mut build_points);
226        tree
227    }
228
229    /// Finds the `k` nearest neighbour of `query`, using reasonable default parameters.
230    ///
231    /// If there are less than `k` points in the point cloud, the returned vector will be smaller than `k`.
232    /// The default parameters are:
233    /// Exact search, no max. radius, allowing self matching, sorting results, and not collecting statistics.
234    /// If `k` <= 16, a linear vector is used to keep track of candidates, otherwise a binary heap is used.
235    pub fn knn(&self, k: u32, query: &P) -> Vec<Neighbour<T, P>> {
236        let candidate_container = if k <= 16 {
237            CandidateContainer::Linear
238        } else {
239            CandidateContainer::BinaryHeap
240        };
241        #[cfg_attr(rustfmt, rustfmt_skip)]
242        self.knn_advanced(
243            k, query,
244            candidate_container,
245            &Parameters::default(),
246            None,
247        )
248    }
249
250
251    /// Finds the `k` nearest neighbour of `query` with periodic boundary conditions, using reasonable
252    /// default parameters.
253    ///
254    /// If there are less than `k` points in the point cloud, the returned vector will be smaller than `k`.
255    /// The default parameters are:
256    /// Exact search, no max. radius, allowing self matching, sorting results, and not collecting statistics.
257    /// If `k` <= 16, a linear vector is used to keep track of candidates, otherwise a binary heap is used.
258    pub fn knn_periodic(&self, k: u32, query: &P, periodic: &[NotNan<T>; K]) -> Vec<Neighbour<T, P>> {
259        let candidate_container = if k <= 16 {
260            CandidateContainer::Linear
261        } else {
262            CandidateContainer::BinaryHeap
263        };
264
265        // First get real images
266        let mut real_image_knns: Vec<Neighbour<T, P>> = self.knn_advanced(
267            k, query,
268            candidate_container,
269            &Parameters::default(),
270            None,
271        );
272
273        // Find max dist2
274        let max_dist2 = real_image_knns.iter().max().unwrap().dist2.into_inner();
275
276        // Find closest dist2 to every side
277        let mut closest_side_dist2: [T; K] = [T::zero(); K];
278        for side in 0..K {
279
280            // Do a single index here. This is equal to distance to lower side
281            let query_component: NotNan<T> = query.get(side as u32);
282
283            // Get distance to upper half
284            let upper = periodic[side] - query_component;
285
286            // !negative includes zero
287            debug_assert!(!upper.is_negative()); 
288            debug_assert!(!query_component.is_negative());
289
290            // Choose lesser of two and then square
291            closest_side_dist2[side] = upper.min(query_component).powi(2);
292        }
293
294        // Find which images we need to check.
295        // Initialize vector with real image (which we will remove later)
296        let mut images_to_check = Vec::with_capacity(2_usize.pow(K as u32)-1);
297        for image in 1..2_usize.pow(K as u32) {
298            
299            // Closest image in the form of bool array
300            let closest_image = (0..K)
301                .map(|idx| ((image / 2_usize.pow(idx as u32)) % 2) == 1);
302
303            // Find distance to corresponding side, edge, vertex or other higher dimensional equivalent
304            let dist_to_side_edge_or_other: T = closest_image
305                .clone()
306                .enumerate()
307                .flat_map(|(side, flag)| if flag {
308                    
309                    // Get minimum of dist2 to lower and upper side
310                    Some(closest_side_dist2[side])
311                } else { None })
312                .fold(T::zero(), |acc, x| acc + x);
313
314            if dist_to_side_edge_or_other < max_dist2 {
315
316                let mut image_to_check = query.clone();
317                
318                for (idx, flag) in closest_image.enumerate() {
319
320                    // If moving image along this dimension
321                    if flag {
322                        // Do a single index here. This is equal to distance to lower side
323                        let query_component: NotNan<T> = query.get(idx as u32);
324                        // Single index here as well
325                        let periodic_component = periodic[idx];
326
327                        if query_component < periodic_component / T::from(2_u8).unwrap() {
328                            // Add if in lower half of box
329                            image_to_check.set(idx as u32, query_component + periodic_component)
330                        } else {
331                            // Subtract if in upper half of box
332                            image_to_check.set(idx as u32, query_component - periodic_component)
333                        }
334                        
335                    }
336                }
337
338                images_to_check.push(image_to_check);
339            }
340        }
341
342        // Then check all images
343        for image in &images_to_check {
344
345            // Append it to real images, we will clean up later.
346            real_image_knns.append(&mut self.knn_advanced(
347                k, image,
348                candidate_container,
349                &Parameters::default(),
350                None,
351            ))
352        }
353
354        // Perform cleanup
355        real_image_knns.sort();
356        real_image_knns.dedup();
357        real_image_knns.truncate(k as usize);
358
359        real_image_knns
360    }
361
362    /// Finds the `k` nearest neighbour of `query`, with user-provided parameters.
363    ///
364    /// If there are less than `k` points in the point cloud or in the ball around `query`
365    /// defined by `parameters.max_radius`, the returned vector will be smaller than `k`.
366    /// The parameters are:
367    /// * `candidate_container` which container to use to collect candidates,
368    /// * `parameters` the advanced search parameters,
369    /// * `touch_statistics`, if `Some(&mut u32)`, return the number of point touched in the provided `u32` reference.
370    pub fn knn_advanced(
371        &self,
372        k: u32,
373        query: &P,
374        candidate_container: CandidateContainer,
375        parameters: &Parameters<T>,
376        touch_statistics: Option<&mut u32>,
377    ) -> Vec<Neighbour<T, P>> {
378        #[cfg_attr(rustfmt, rustfmt_skip)]
379        (match candidate_container {
380            CandidateContainer::Linear => Self::knn_generic_heap::<Vec<InternalNeighbour<T>>>,
381            CandidateContainer::BinaryHeap => Self::knn_generic_heap::<BinaryHeap<InternalNeighbour<T>>>
382        })(
383            self,
384            k, query,
385            parameters, touch_statistics
386        )
387    }
388
389    fn knn_generic_heap<H: CandidateHeap<T>>(
390        &self,
391        k: u32,
392        query: &P,
393        parameters: &Parameters<T>,
394        touch_statistics: Option<&mut u32>,
395    ) -> Vec<Neighbour<T, P>> {
396        let query_as_vec: Vec<_> = (0..P::DIM).map(|i| query.get(i)).collect();
397        let Parameters {
398            epsilon,
399            max_radius,
400            allow_self_match,
401            sort_results,
402        } = *parameters;
403        let max_error = epsilon + T::one();
404        let max_error2 = NotNan::new(max_error * max_error).unwrap();
405        let max_radius2 = NotNan::new(max_radius * max_radius).unwrap();
406        #[cfg_attr(rustfmt, rustfmt_skip)]
407        self.knn_internal::<H>(
408            k, &query_as_vec,
409            &InternalParameters { max_error2, max_radius2, allow_self_match },
410            sort_results, touch_statistics,
411        )
412            .into_iter()
413            .map(|n| self.externalise_neighbour(n))
414            .collect()
415    }
416
417    fn knn_internal<H: CandidateHeap<T>>(
418        &self,
419        k: u32,
420        query: &[NotNan<T>],
421        internal_parameters: &InternalParameters<T>,
422        sort_results: bool,
423        touch_statistics: Option<&mut u32>,
424    ) -> Vec<InternalNeighbour<T>> {
425        // TODO Const generics: once available, remove `vec!` below. update: done but leaving note
426        let mut off = [NotNan::<T>::zero(); K];
427        let mut heap = H::new_with_k(k);
428        #[cfg_attr(rustfmt, rustfmt_skip)]
429        let leaf_touched_count = self.recurse_knn(
430            k, query,
431            0, NotNan::<T>::zero(),
432            &mut heap, &mut off,
433            internal_parameters,
434        );
435        if let Some(touch_statistics) = touch_statistics {
436            *touch_statistics = leaf_touched_count;
437        }
438        if sort_results {
439            heap.into_sorted_vec()
440        } else {
441            heap.into_vec()
442        }
443    }
444
445    #[allow(clippy::too_many_arguments)]
446    fn recurse_knn<H: CandidateHeap<T>>(
447        &self,
448        k: u32,
449        query: &[NotNan<T>],
450        node: usize,
451        rd: NotNan<T>,
452        heap: &mut H,
453        off: &mut [NotNan<T>],
454        internal_parameters: &InternalParameters<T>,
455    ) -> u32 {
456        self.nodes[node].dispatch_on_type(
457            heap,
458            |heap, split_dim, split_val, right_child| {
459                // split node, see whether we have to recurse
460                let mut rd = rd;
461                let split_dim = split_dim as usize;
462                let old_off = off[split_dim];
463                let new_off = query[split_dim] - split_val;
464                let left_child = node + 1;
465                let right_child = right_child as usize;
466                let InternalParameters {
467                    max_radius2,
468                    max_error2,
469                    ..
470                } = *internal_parameters;
471                if new_off > NotNan::<T>::zero() {
472                    #[cfg_attr(rustfmt, rustfmt_skip)]
473                    let mut leaf_visited_count = self.recurse_knn(
474                        k, query,
475                        right_child, rd,
476                        heap, off,
477                        internal_parameters,
478                    );
479                    rd += new_off * new_off - old_off * old_off;
480                    if rd <= max_radius2 && rd * max_error2 <= heap.furthest_dist2() {
481                        off[split_dim] = new_off;
482                        #[cfg_attr(rustfmt, rustfmt_skip)]
483                        let new_visits= self.recurse_knn(
484                            k, query,
485                            left_child, rd,
486                            heap, off,
487                            internal_parameters,
488                        );
489                        leaf_visited_count += new_visits;
490                        off[split_dim] = old_off;
491                    }
492                    leaf_visited_count
493                } else {
494                    #[cfg_attr(rustfmt, rustfmt_skip)]
495                    let mut leaf_visited_count = self.recurse_knn(
496                        k, query,
497                        left_child, rd,
498                        heap, off,
499                        internal_parameters,
500                    );
501                    rd += new_off * new_off - old_off * old_off;
502                    if rd <= max_radius2 && rd * max_error2 <= heap.furthest_dist2() {
503                        off[split_dim] = new_off;
504                        #[cfg_attr(rustfmt, rustfmt_skip)]
505                        let new_visits = self.recurse_knn(
506                            k, query,
507                            right_child, rd,
508                            heap, off,
509                            internal_parameters,
510                        );
511                        leaf_visited_count += new_visits;
512                        off[split_dim] = old_off;
513                    }
514                    leaf_visited_count
515                }
516            },
517            |heap, bucket_start_index, bucket_size| {
518                // leaf node, go through the buckets and check elements
519                let bucket_end_index = bucket_start_index + bucket_size;
520                for bucket_index in bucket_start_index..bucket_end_index {
521                    let point_index = (bucket_index * P::DIM) as usize;
522                    let point = &self.points[point_index..point_index + (P::DIM as usize)];
523                    let dist2 = point_slice_dist2::<T, P>(query, point);
524                    let epsilon = NotNan::new(T::epsilon()).unwrap();
525                    let InternalParameters {
526                        max_radius2,
527                        allow_self_match,
528                        ..
529                    } = *internal_parameters;
530                    if dist2 <= max_radius2 && (allow_self_match || (dist2 > epsilon)) {
531                        heap.add(dist2, bucket_index);
532                    }
533                }
534                bucket_size
535            },
536        )
537    }
538
539    fn build_nodes(&mut self, cloud: &[P], build_points: &mut [usize]) -> usize {
540        let count = build_points.len() as u32;
541        let pos = self.nodes.len();
542
543        // if remaining points fit in a single bucket, add a node and this bucket
544        if count <= self.bucket_size {
545            let bucket_start_index = self.indices.len() as u32;
546            self.points.reserve(build_points.len() * P::DIM as usize);
547            self.indices.reserve(build_points.len());
548            for point_index in build_points {
549                let point_index = *point_index;
550                self.indices.push(point_index as u32);
551                for i in 0..P::DIM {
552                    self.points.push(cloud[point_index].get(i));
553                }
554            }
555            self.nodes
556                .push(Node::new_leaf_node(bucket_start_index, count));
557            return pos;
558        }
559
560        // compute bounds
561        let (min_bounds, max_bounds) = Self::get_build_points_bounds(cloud, build_points);
562
563        // find the largest dimension of the box
564        let split_dim = Self::max_delta_index(&min_bounds, &max_bounds);
565        let split_dim_u = split_dim as usize;
566
567        // split along this dimension
568        let split_val = (max_bounds[split_dim_u] + min_bounds[split_dim_u]) * T::from(0.5).unwrap();
569        let range = max_bounds[split_dim_u] - min_bounds[split_dim_u];
570        let (left_points, right_points) = if range == T::from(0).unwrap() {
571            // degenerate data, split in half and iterate
572            build_points.split_at_mut(build_points.len() / 2)
573        } else {
574            // partition data around split_val on split_dim
575            partition::partition(build_points, |index| {
576                cloud[*index].get(split_dim) < split_val
577            })
578        };
579        debug_assert_ne!(left_points.len(), 0);
580        debug_assert_ne!(right_points.len(), 0);
581
582        // add this split
583        self.nodes.push(Node::new_split_node(split_dim, split_val));
584
585        // recurse
586        let left_child = self.build_nodes(cloud, left_points);
587        debug_assert_eq!(left_child, pos + 1);
588        let right_child = self.build_nodes(cloud, right_points);
589
590        // write right child index and return
591        self.nodes[pos].set_child_index(right_child as u32);
592        pos
593    }
594
595    fn get_build_points_bounds(
596        cloud: &[P],
597        build_points: &[usize],
598    ) -> (Vec<NotNan<T>>, Vec<NotNan<T>>) {
599        let mut min_bounds = vec![NotNan::<T>::max_value(); P::DIM as usize];
600        let mut max_bounds = vec![NotNan::<T>::min_value(); P::DIM as usize];
601        for p_index in build_points {
602            let p = &cloud[*p_index];
603            for index in 0..P::DIM {
604                let index_u = index as usize;
605                min_bounds[index_u] = clamp_max(p.get(index), min_bounds[index_u]);
606                max_bounds[index_u] = clamp_min(p.get(index), max_bounds[index_u]);
607            }
608        }
609        (min_bounds, max_bounds)
610    }
611
612    fn max_delta_index(lower_bound: &[NotNan<T>], upper_bound: &[NotNan<T>]) -> u32 {
613        lower_bound
614            .iter()
615            .zip(upper_bound.iter())
616            .enumerate()
617            .max_by_key(|(_, (l, u))| *u - *l)
618            .unwrap()
619            .0 as u32
620    }
621
622    fn externalise_neighbour(&self, neighbour: InternalNeighbour<T>) -> Neighbour<T, P> {
623        let mut point = P::default();
624        let base_index = neighbour.index * P::DIM;
625        for i in 0..P::DIM {
626            point.set(i, self.points[(base_index + i) as usize]);
627        }
628        Neighbour {
629            point,
630            dist2: neighbour.dist2,
631            index: self.indices[neighbour.index as usize],
632        }
633    }
634}
635
636#[cfg(test)]
637mod tests {
638    use crate::*;
639    use dummy_point::{random_point, random_point_cloud, P2};
640    use float_cmp::approx_eq;
641
642    // helpers to create cloud
643    fn cloud3() -> Vec<P2> {
644        vec![P2::new(0., 0.), P2::new(-1., 3.), P2::new(2., -4.)]
645    }
646
647    // helper to compute the square distance between two points
648    fn point_dist2<T: Scalar, P: Point<T>>(lhs: &P, rhs: &P) -> NotNan<T> {
649        let mut dist2 = NotNan::<T>::zero();
650        for index in 0..P::DIM {
651            let diff = lhs.get(index) - rhs.get(index);
652            dist2 += diff * diff;
653        }
654        dist2
655    }
656
657    // brute force search implementations
658    fn brute_force_1nn(cloud: &[P2], query: &P2) -> Neighbour<f32, P2> {
659        let mut best_dist2 = f32::infinity();
660        let mut best_index = 0;
661        for (index, point) in cloud.iter().enumerate() {
662            let dist2 = point_dist2(point, query).into_inner();
663            if dist2 < best_dist2 {
664                best_dist2 = dist2;
665                best_index = index;
666            }
667        }
668        Neighbour {
669            point: cloud[best_index],
670            dist2: NotNan::new(best_dist2).unwrap(),
671            index: best_index as u32,
672        }
673    }
674
675    fn brute_force_knn<H: CandidateHeap<f32>>(
676        cloud: &[P2],
677        query: &P2,
678        k: u32,
679    ) -> Vec<Neighbour<f32, P2>> {
680        let mut h = H::new_with_k(k);
681        for (index, point) in cloud.iter().enumerate() {
682            let dist2 = point_dist2(point, query);
683            h.add(dist2, index as u32);
684        }
685        h.into_sorted_vec()
686            .into_iter()
687            .map(|n| {
688                let index = n.index as usize;
689                Neighbour {
690                    point: cloud[index],
691                    dist2: n.dist2,
692                    index: n.index,
693                }
694            })
695            .collect()
696    }
697
698    // tests themselves
699
700    #[test]
701    fn get_build_points_bounds() {
702        const K: usize = 2;
703        let cloud = cloud3();
704        let indices = vec![0, 1, 2];
705        let bounds = KDTree::<_, _, K>::get_build_points_bounds(&cloud, &indices);
706        assert_eq!(bounds.0, vec![-1., -4.]);
707        assert_eq!(bounds.1, vec![2., 3.]);
708    }
709
710    #[test]
711    fn max_delta_index() {
712        const K: usize = 2;
713        let b = |x: f32, y: f32| {
714            [
715                NotNan::<f32>::new(x).unwrap(),
716                NotNan::<f32>::new(y).unwrap(),
717            ]
718        };
719        assert_eq!(
720            KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(0., 1.)),
721            1
722        );
723        assert_eq!(
724            KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(-1., 1.)),
725            1
726        );
727        assert_eq!(
728            KDTree::<f32, P2, K>::max_delta_index(&b(0., 0.), &b(-1., -2.)),
729            0
730        );
731    }
732
733    #[test]
734    fn new_tree() {
735        const K: usize = 2;
736        let cloud = cloud3();
737        let tree = KDTree::<_,_,K>::new_with_bucket_size(&cloud, 2);
738        dbg!(tree);
739    }
740
741    #[test]
742    fn query_1nn_allow_self() {
743        const K: usize = 2;
744        let mut touch_sum = 0;
745        const PASS_COUNT: u32 = 20;
746        const QUERY_COUNT: u32 = 100;
747        const CLOUD_SIZE: u32 = 1000;
748        const PARAMETERS: Parameters<f32> = Parameters {
749            epsilon: 0.0,
750            max_radius: f32::INFINITY,
751            allow_self_match: true,
752            sort_results: true,
753        };
754        for _ in 0..PASS_COUNT {
755            let cloud = random_point_cloud(CLOUD_SIZE);
756            let tree = KDTree::<_,_,K>::new(&cloud);
757            for _ in 0..QUERY_COUNT {
758                let query = random_point();
759                let mut touch_statistics = 0;
760
761                // linear search
762                let nns_lin = tree.knn_advanced(
763                    1,
764                    &query,
765                    CandidateContainer::Linear,
766                    &PARAMETERS,
767                    Some(&mut touch_statistics),
768                );
769                assert_eq!(nns_lin.len(), 1);
770                let nn_lin = &nns_lin[0];
771                assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
772                touch_sum += touch_statistics;
773                // binary
774                let nns_bin =
775                    tree.knn_advanced(1, &query, CandidateContainer::BinaryHeap, &PARAMETERS, None);
776                assert_eq!(nns_bin.len(), 1);
777                let nn_bin = &nns_bin[0];
778                assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
779                // brute force
780                let nn_bf = brute_force_1nn(&cloud, &query);
781                assert_eq!(nn_bf.point, cloud[nn_bf.index as usize]);
782                // assertion
783                assert_eq!(
784                    nn_bin.index, nn_bf.index,
785                    "KDTree binary heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}",
786                    query, nn_bin.dist2, nn_bin.point, nn_bf.dist2, nn_bf.point
787                );
788                assert_eq!(nn_lin.index, nn_bf.index, "\nKDTree linear heap: mismatch indexes\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", query, nn_lin.dist2, nn_lin.point, nn_bf.dist2, nn_bf.point);
789                assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf.dist2, ulps = 2));
790                assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf.dist2, ulps = 2));
791            }
792        }
793        let touch_pct = (touch_sum * 100) as f32 / (PASS_COUNT * QUERY_COUNT * CLOUD_SIZE) as f32;
794        println!("Average tree point touched: {} %", touch_pct);
795    }
796
797    #[test]
798    fn query_knn_allow_self() {
799        const K: usize = 2;
800        const QUERY_COUNT: u32 = 100;
801        const CLOUD_SIZE: u32 = 1000;
802        const PARAMETERS: Parameters<f32> = Parameters {
803            epsilon: 0.0,
804            max_radius: f32::INFINITY,
805            allow_self_match: true,
806            sort_results: true,
807        };
808        let cloud = random_point_cloud(CLOUD_SIZE);
809        let tree = KDTree::<_,_,K>::new(&cloud);
810        for k in [1, 2, 3, 5, 7, 13] {
811            for _ in 0..QUERY_COUNT {
812                let query = random_point();
813                // brute force
814                let nns_bf_lin = brute_force_knn::<Vec<InternalNeighbour<f32>>>(&cloud, &query, k);
815                assert_eq!(nns_bf_lin.len(), k as usize);
816                let nns_bf_bin =
817                    brute_force_knn::<BinaryHeap<InternalNeighbour<f32>>>(&cloud, &query, k);
818                assert_eq!(nns_bf_bin.len(), k as usize);
819                // kd-tree
820                #[cfg_attr(rustfmt, rustfmt_skip)]
821                let nns_bin = tree.knn_advanced(
822                    k, &query,
823                    CandidateContainer::BinaryHeap,
824                    &PARAMETERS,
825                    None,
826                );
827                assert_eq!(nns_bin.len(), k as usize);
828                #[cfg_attr(rustfmt, rustfmt_skip)]
829                let nns_lin = tree.knn_advanced(
830                    k, &query,
831                    CandidateContainer::Linear,
832                    &PARAMETERS,
833                    None,
834                );
835                assert_eq!(nns_lin.len(), k as usize);
836                // assertion
837                for i in 0..k as usize {
838                    // get neighbour
839                    let nn_bf_lin = &nns_bf_lin[i];
840                    let nn_bf_bin = &nns_bf_bin[i];
841                    let nn_lin = &nns_lin[i];
842                    let nn_bin = &nns_bin[i];
843                    // ensure their point data are consistent with the cloud
844                    assert_eq!(nn_bf_lin.point, cloud[nn_bf_lin.index as usize]);
845                    assert_eq!(nn_bf_bin.point, cloud[nn_bf_bin.index as usize]);
846                    assert_eq!(nn_lin.point, cloud[nn_lin.index as usize]);
847                    assert_eq!(nn_bin.point, cloud[nn_bin.index as usize]);
848                    // ensure their indices are consistent
849                    assert_eq!(nn_bf_bin.index, nn_bf_lin.index, "BF binary heap: mismatch indexes at {} on {}\nquery: {}\n   bf bin {}, {}\nvs bf lin {}, {}\n", i, k, query, nn_bf_bin.dist2, nn_bf_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
850                    assert_eq!(nn_lin.index, nn_bf_lin.index, "\nKDTree linear heap: mismatch indexes at {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_lin.dist2, nn_lin.point, nn_bf_lin.dist2, nn_bf_lin.point);
851                    assert_eq!(nn_bin.index, nn_bf_lin.index, "\nKDTree binary heap: mismatch indexes {} on {}\nquery: {}\npoint {}, {}\nvs bf {}, {}\n", i, k, query, nn_bin.dist2, nn_bin.point, nn_bf_lin.dist2, nn_bf_lin.point);
852                    // ensure their dist2 are consistent
853                    assert!(approx_eq!(
854                        f32,
855                        *nn_bf_bin.dist2,
856                        *nn_bf_lin.dist2,
857                        ulps = 2
858                    ));
859                    assert!(approx_eq!(f32, *nn_lin.dist2, *nn_bf_lin.dist2, ulps = 2));
860                    assert!(approx_eq!(f32, *nn_bin.dist2, *nn_bf_lin.dist2, ulps = 2));
861                }
862            }
863        }
864    }
865
866    #[test]
867    fn small_clouds_can_lead_to_neighbours() {
868        const K: usize = 2;
869        let cloud = vec![P2::new(0.0, 0.0), P2::new(1.0, 0.0)];
870        let tree = KDTree::<_,_,K>::new(&cloud);
871        let query = P2::new(0.5, 0.0);
872        for _ in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
873            let nns = tree.knn(3, &query);
874            assert_eq!(nns.len(), 2);
875        }
876    }
877
878    #[test]
879    fn max_radius_can_lead_to_neighbours() {
880        const K: usize = 2;
881        let cloud = vec![P2::new(0.0, 0.0), P2::new(1.0, 0.0)];
882        let tree = KDTree::<_,_,K>::new(&cloud);
883        let query = P2::new(0.1, 0.0);
884        let parameters = Parameters {
885            epsilon: 0.0,
886            max_radius: 0.5,
887            allow_self_match: false,
888            sort_results: false,
889        };
890        for container in [CandidateContainer::Linear, CandidateContainer::BinaryHeap] {
891            let nns = tree.knn_advanced(2, &query, container, &parameters, None);
892            assert_eq!(nns.len(), 1);
893        }
894    }
895}