kn0sys_nn/lib.rs
1//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
2//! trait-based interface for performing nearest-neighbour and range queries using these
3//! algorithms.
4//!
5//! ## The big picture
6//!
7//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
8//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
9//! kin in spirit to Python's `scikit-learn`.
10//!
11//! You can find a roadmap (and a selection of good first issues)
12//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
13//!
14//! ## Current state
15//!
16//! Right now `linfa-nn` provides the following algorithms:
17//! * [Linear Scan](LinearSearch)
18//! * [KD Tree](KdTree)
19//! * [Ball Tree](BallTree)
20//!
21//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
22//! between all of the above algorithms flexibly.
23
24use distance::Distance;
25use linfa::Float;
26use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
27#[cfg(feature = "serde")]
28use serde_crate::{Deserialize, Serialize};
29use thiserror::Error;
30
31mod balltree;
32mod heap_elem;
33mod kdtree;
34mod linear;
35
36pub mod distance;
37
38pub use crate::{balltree::*, kdtree::*, linear::*};
39
40pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
41pub(crate) type NearestNeighbourBox<'a, F> = Box<dyn 'a + Send + Sync + NearestNeighbourIndex<F>>;
42
43/// Error returned when building nearest neighbour indices
44#[derive(Error, Debug)]
45pub enum BuildError {
46 #[error("points have dimension of 0")]
47 ZeroDimension,
48 #[error("leaf size is 0")]
49 EmptyLeaf,
50}
51
52/// Error returned when performing spatial queries on nearest neighbour indices
53#[derive(Error, Debug)]
54pub enum NnError {
55 #[error("dimensions of query point and stored points are different")]
56 WrongDimension,
57}
58
59/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
60/// distance between points is calculated using a provided distance function. The index implements
61/// the [`NearestNeighbourIndex`] trait and allows for efficient
62/// computing of nearest neighbour and range queries.
63pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
64 /// Builds a spatial index using a MxN two-dimensional array representing M points with N
65 /// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
66 /// nodes of tree-like index structures.
67 ///
68 /// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
69 /// value in the batch is NaN or infinite, the behaviour is unspecified.
70 fn batch_with_leaf_size<'a, F: Float + ndarray::ScalarOperand, DT: Data<Elem = F>, D: 'a + Distance<F>>(
71 &self,
72 batch: &'a ArrayBase<DT, Ix2>,
73 leaf_size: usize,
74 dist_fn: D,
75 ) -> Result<NearestNeighbourBox<'a, F>, BuildError>;
76
77 /// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
78 /// information.
79 fn batch<'a, F: Float + ndarray::ScalarOperand, DT: Data<Elem = F>, D: 'a + Distance<F>>(
80 &self,
81 batch: &'a ArrayBase<DT, Ix2>,
82 dist_fn: D,
83 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
84 self.batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
85 }
86}
87
88/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
89/// computation of nearest neighbour and range queries over the set of points. Individual points
90/// are represented as one-dimensional array views.
91pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
92 /// Returns the `k` points in the index that are the closest to the provided point, along with
93 /// their positions in the original dataset. Points are returned in ascending order of the
94 /// distance away from the provided points, and less than `k` points will be returned if the
95 /// index contains fewer than `k`.
96 ///
97 /// Returns an error if the provided point has different dimensionality than the index's
98 /// points.
99 fn k_nearest(
100 &self,
101 point: Point<'_, F>,
102 k: usize,
103 ) -> Result<Vec<(Point<F>, usize)>, NnError>;
104
105 /// Returns all the points in the index that are within the specified distance to the provided
106 /// point, along with their positions in the original dataset. The points are not guaranteed to
107 /// be in any order, though many algorithms return the points in order of distance.
108 ///
109 /// Returns an error if the provided point has different dimensionality than the index's
110 /// points.
111 fn within_range(
112 &self,
113 point: Point<'_, F>,
114 range: F,
115 ) -> Result<Vec<(Point<F>, usize)>, NnError>;
116}
117
118/// Enum that dispatches to one of the crate's [`NearestNeighbour`]
119/// implementations based on value. This enum should be used instead of using types like
120/// `LinearSearch` and `KdTree` directly.
121///
122/// ## Example
123///
124/// ```rust
125/// use rand_xoshiro::Xoshiro256Plus;
126/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
127/// use ndarray::{Array1, Array2};
128/// use kn0sys_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
129///
130/// // Use seedable RNG for generating points
131/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
132/// let n_features = 3;
133/// let distr = Uniform::new(-500., 500.).unwrap();
134/// // Randomly generate points for building the index
135/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
136///
137/// // Build a K-D tree with Euclidean distance as the distance function
138/// let nn = CommonNearestNeighbour::KdTree.batch(&points, L2Dist).unwrap();
139///
140/// let pt = Array1::random_using(n_features, distr, &mut rng);
141/// // Compute the 10 nearest points to `pt` in the index
142/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
143/// // Compute all points within 100 units of `pt`
144/// let range = nn.within_range(pt.view(), 100.0).unwrap();
145/// ```
146#[non_exhaustive]
147#[derive(Debug, Clone, PartialEq, Eq)]
148#[cfg_attr(
149 feature = "serde",
150 derive(Serialize, Deserialize),
151 serde(crate = "serde_crate")
152)]
153pub enum CommonNearestNeighbour {
154 /// Linear search
155 LinearSearch,
156 /// KD Tree
157 KdTree,
158 /// Ball Tree
159 BallTree,
160}
161
162impl NearestNeighbour for CommonNearestNeighbour {
163 fn batch_with_leaf_size<'a, F: Float + ndarray::ScalarOperand, DT: Data<Elem = F>, D: 'a + Distance<F>>(
164 &self,
165 batch: &'a ArrayBase<DT, Ix2>,
166 leaf_size: usize,
167 dist_fn: D,
168 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
169 match self {
170 Self::LinearSearch => LinearSearch.batch_with_leaf_size(batch, leaf_size, dist_fn),
171 Self::KdTree => KdTree.batch_with_leaf_size(batch, leaf_size, dist_fn),
172 Self::BallTree => BallTree.batch_with_leaf_size(batch, leaf_size, dist_fn),
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
182 fn autotraits() {
183 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
184 has_autotraits::<CommonNearestNeighbour>();
185 has_autotraits::<NearestNeighbourBox<'static, f64>>();
186 has_autotraits::<BuildError>();
187 has_autotraits::<NnError>();
188 }
189}