1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188
//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
//! trait-based interface for performing nearest-neighbour and range queries using these
//! algorithms.
//!
//! ## The big picture
//!
//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
//! kin in spirit to Python's `scikit-learn`.
//!
//! You can find a roadmap (and a selection of good first issues)
//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
//!
//! ## Current state
//!
//! Right now `linfa-nn` provides the following algorithms:
//! * [Linear Scan](struct.LinearSearch.html)
//! * [KD Tree](struct.KdTree.html)
//! * [Ball Tree](struct.BallTree.html)
//!
//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
//! between all of the above algorithms flexibly.
use distance::Distance;
use linfa::Float;
use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use thiserror::Error;
mod balltree;
mod heap_elem;
mod kdtree;
mod linear;
pub mod distance;
pub use crate::{balltree::*, kdtree::*, linear::*};
pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
/// Error returned when building nearest neighbour indices
#[derive(Error, Debug)]
pub enum BuildError {
#[error("points have dimension of 0")]
ZeroDimension,
#[error("leaf size is 0")]
EmptyLeaf,
}
/// Error returned when performing spatial queries on nearest neighbour indices
#[derive(Error, Debug)]
pub enum NnError {
#[error("dimensions of query point and stored points are different")]
WrongDimension,
}
/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
/// distance between points is calculated using a provided distance function. The index implements
/// the [`NearestNeighbourIndex`](trait.NearestNeighbourIndex.html) trait and allows for efficient
/// computing of nearest neighbour and range queries.
pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
/// Builds a spatial index using a MxN two-dimensional array representing M points with N
/// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
/// nodes of tree-like index structures.
///
/// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
/// value in the batch is NaN or infinite, the behaviour is unspecified.
fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
leaf_size: usize,
dist_fn: D,
) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError>;
/// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
/// information.
fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
dist_fn: D,
) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError> {
self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
}
}
/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
/// computation of nearest neighbour and range queries over the set of points. Individual points
/// are represented as one-dimensional array views.
pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
/// Returns the `k` points in the index that are the closest to the provided point, along with
/// their positions in the original dataset. Points are returned in ascending order of the
/// distance away from the provided points, and less than `k` points will be returned if the
/// index contains fewer than `k`.
///
/// Returns an error if the provided point has different dimensionality than the index's
/// points.
fn k_nearest<'b>(
&self,
point: Point<'b, F>,
k: usize,
) -> Result<Vec<(Point<F>, usize)>, NnError>;
/// Returns all the points in the index that are within the specified distance to the provided
/// point, along with their positions in the original dataset. The points are not guaranteed to
/// be in any order, though many algorithms return the points in order of distance.
///
/// Returns an error if the provided point has different dimensionality than the index's
/// points.
fn within_range<'b>(
&self,
point: Point<'b, F>,
range: F,
) -> Result<Vec<(Point<F>, usize)>, NnError>;
}
/// Enum that dispatches to one of the crate's [`NearestNeighbour`](trait.NearestNeighbour.html)
/// implementations based on value. This enum should be used instead of using types like
/// `LinearSearch` and `KdTree` directly.
///
/// ## Example
///
/// ```rust
/// use rand_xoshiro::Xoshiro256Plus;
/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
/// use ndarray::{Array1, Array2};
/// use linfa_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
///
/// // Use seedable RNG for generating points
/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
/// let n_features = 3;
/// let distr = Uniform::new(-500., 500.);
/// // Randomly generate points for building the index
/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
///
/// // Build a K-D tree with Euclidean distance as the distance function
/// let nn = CommonNearestNeighbour::KdTree.from_batch(&points, L2Dist).unwrap();
///
/// let pt = Array1::random_using(n_features, distr, &mut rng);
/// // Compute the 10 nearest points to `pt` in the index
/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
/// // Compute all points within 100 units of `pt`
/// let range = nn.within_range(pt.view(), 100.0).unwrap();
/// ```
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum CommonNearestNeighbour {
/// Linear search
LinearSearch,
/// KD Tree
KdTree,
/// Ball Tree
BallTree,
}
impl NearestNeighbour for CommonNearestNeighbour {
fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
leaf_size: usize,
dist_fn: D,
) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError> {
match self {
Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<CommonNearestNeighbour>();
has_autotraits::<Box<dyn NearestNeighbourIndex<f64>>>();
has_autotraits::<BuildError>();
has_autotraits::<NnError>();
}
}