linfa_nn/lib.rs
1//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
2//! trait-based interface for performing nearest-neighbour and range queries using these
3//! algorithms.
4//!
5//! ## The big picture
6//!
7//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
8//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
9//! kin in spirit to Python's `scikit-learn`.
10//!
11//! You can find a roadmap (and a selection of good first issues)
12//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
13//!
14//! ## Current state
15//!
16//! Right now `linfa-nn` provides the following algorithms:
17//! * [Linear Scan](LinearSearch)
18//! * [KD Tree](KdTree)
19//! * [Ball Tree](BallTree)
20//!
21//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
22//! between all of the above algorithms flexibly.
23
24use distance::Distance;
25use linfa::Float;
26use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
27#[cfg(feature = "serde")]
28use serde_crate::{Deserialize, Serialize};
29use thiserror::Error;
30
31mod balltree;
32mod heap_elem;
33mod kdtree;
34mod linear;
35
36pub mod distance;
37
38pub use crate::{balltree::*, kdtree::*, linear::*};
39
40pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
41pub(crate) type NearestNeighbourBox<'a, F> = Box<dyn 'a + Send + Sync + NearestNeighbourIndex<F>>;
42
43/// Error returned when building nearest neighbour indices
44#[derive(Error, Debug)]
45pub enum BuildError {
46 #[error("points have dimension of 0")]
47 ZeroDimension,
48 #[error("leaf size is 0")]
49 EmptyLeaf,
50}
51
52/// Error returned when performing spatial queries on nearest neighbour indices
53#[derive(Error, Debug)]
54pub enum NnError {
55 #[error("dimensions of query point and stored points are different")]
56 WrongDimension,
57}
58
59/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
60/// distance between points is calculated using a provided distance function. The index implements
61/// the [`NearestNeighbourIndex`] trait and allows for efficient
62/// computing of nearest neighbour and range queries.
63pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
64 /// Builds a spatial index using a MxN two-dimensional array representing M points with N
65 /// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
66 /// nodes of tree-like index structures.
67 ///
68 /// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
69 /// value in the batch is NaN or infinite, the behaviour is unspecified.
70 #[allow(clippy::wrong_self_convention)]
71 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
72 &self,
73 batch: &'a ArrayBase<DT, Ix2>,
74 leaf_size: usize,
75 dist_fn: D,
76 ) -> Result<NearestNeighbourBox<'a, F>, BuildError>;
77
78 /// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
79 /// information.
80 #[allow(clippy::wrong_self_convention)]
81 fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
82 &self,
83 batch: &'a ArrayBase<DT, Ix2>,
84 dist_fn: D,
85 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
86 self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
87 }
88}
89
90/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
91/// computation of nearest neighbour and range queries over the set of points. Individual points
92/// are represented as one-dimensional array views.
93pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
94 /// Returns the `k` points in the index that are the closest to the provided point, along with
95 /// their positions in the original dataset. Points are returned in ascending order of the
96 /// distance away from the provided points, and less than `k` points will be returned if the
97 /// index contains fewer than `k`.
98 ///
99 /// Returns an error if the provided point has different dimensionality than the index's
100 /// points.
101 fn k_nearest(
102 &self,
103 point: Point<'_, F>,
104 k: usize,
105 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError>;
106
107 /// Returns all the points in the index that are within the specified distance to the provided
108 /// point, along with their positions in the original dataset. The points are not guaranteed to
109 /// be in any order, though many algorithms return the points in order of distance.
110 ///
111 /// Returns an error if the provided point has different dimensionality than the index's
112 /// points.
113 fn within_range(
114 &self,
115 point: Point<'_, F>,
116 range: F,
117 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError>;
118}
119
120/// Enum that dispatches to one of the crate's [`NearestNeighbour`]
121/// implementations based on value. This enum should be used instead of using types like
122/// `LinearSearch` and `KdTree` directly.
123///
124/// ## Example
125///
126/// ```rust
127/// use rand_xoshiro::Xoshiro256Plus;
128/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
129/// use ndarray::{Array1, Array2};
130/// use linfa_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
131///
132/// // Use seedable RNG for generating points
133/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
134/// let n_features = 3;
135/// let distr = Uniform::new(-500., 500.);
136/// // Randomly generate points for building the index
137/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
138///
139/// // Build a K-D tree with Euclidean distance as the distance function
140/// let nn = CommonNearestNeighbour::KdTree.from_batch(&points, L2Dist).unwrap();
141///
142/// let pt = Array1::random_using(n_features, distr, &mut rng);
143/// // Compute the 10 nearest points to `pt` in the index
144/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
145/// // Compute all points within 100 units of `pt`
146/// let range = nn.within_range(pt.view(), 100.0).unwrap();
147/// ```
148#[non_exhaustive]
149#[derive(Debug, Clone, PartialEq, Eq)]
150#[cfg_attr(
151 feature = "serde",
152 derive(Serialize, Deserialize),
153 serde(crate = "serde_crate")
154)]
155pub enum CommonNearestNeighbour {
156 /// Linear search
157 LinearSearch,
158 /// KD Tree
159 KdTree,
160 /// Ball Tree
161 BallTree,
162}
163
164impl NearestNeighbour for CommonNearestNeighbour {
165 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
166 &self,
167 batch: &'a ArrayBase<DT, Ix2>,
168 leaf_size: usize,
169 dist_fn: D,
170 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
171 match self {
172 Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
173 Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
174 Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
175 }
176 }
177}
178
179#[cfg(test)]
180mod tests {
181 use super::*;
182
183 #[test]
184 fn autotraits() {
185 fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
186 has_autotraits::<CommonNearestNeighbour>();
187 has_autotraits::<NearestNeighbourBox<'static, f64>>();
188 has_autotraits::<BuildError>();
189 has_autotraits::<NnError>();
190 }
191}