1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
//! `linfa-nn` provides Rust implementations of common spatial indexing algorithms, as well as a
//! trait-based interface for performing nearest-neighbour and range queries using these
//! algorithms.
//!
//! ## The big picture
//!
//! `linfa-nn` is a crate in the `linfa` ecosystem, a wider effort to
//! bootstrap a toolkit for classical Machine Learning implemented in pure Rust,
//! kin in spirit to Python's `scikit-learn`.
//!
//! You can find a roadmap (and a selection of good first issues)
//! [here](https://github.com/LukeMathWalker/linfa/issues) - contributors are more than welcome!
//!
//! ## Current state
//!
//! Right now `linfa-nn` provides the following algorithms:
//! * [Linear Scan](struct.LinearSearch.html)
//! * [KD Tree](struct.KdTree.html)
//! * [Ball Tree](struct.BallTree.html)
//!
//! The [`CommonNearestNeighbour`](struct.CommonNearestNeighbour) enum should be used to dispatch
//! between all of the above algorithms flexibly.

use distance::Distance;
use linfa::Float;
use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use thiserror::Error;

mod balltree;
mod heap_elem;
mod kdtree;
mod linear;

pub mod distance;

pub use crate::{balltree::*, kdtree::*, linear::*};

pub(crate) type Point<'a, F> = ArrayView1<'a, F>;

/// Error returned when building nearest neighbour indices
#[derive(Error, Debug)]
pub enum BuildError {
    #[error("points have dimension of 0")]
    ZeroDimension,
    #[error("leaf size is 0")]
    EmptyLeaf,
}

/// Error returned when performing spatial queries on nearest neighbour indices
#[derive(Error, Debug)]
pub enum NnError {
    #[error("dimensions of query point and stored points are different")]
    WrongDimension,
}

/// Nearest neighbour algorithm builds a spatial index structure out of a batch of points. The
/// distance between points is calculated using a provided distance function. The index implements
/// the [`NearestNeighbourIndex`](trait.NearestNeighbourIndex.html) trait and allows for efficient
/// computing of nearest neighbour and range queries.
pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
    /// Builds a spatial index using a MxN two-dimensional array representing M points with N
    /// dimensions. Also takes `leaf_size`, which specifies the number of elements in the leaf
    /// nodes of tree-like index structures.
    ///
    /// Returns an error if the points have dimensionality of 0 or if the leaf size is 0. If any
    /// value in the batch is NaN or infinite, the behaviour is unspecified.
    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
        &self,
        batch: &'a ArrayBase<DT, Ix2>,
        leaf_size: usize,
        dist_fn: D,
    ) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError>;

    /// Builds a spatial index using a default leaf size. See `from_batch_with_leaf_size` for more
    /// information.
    fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
        &self,
        batch: &'a ArrayBase<DT, Ix2>,
        dist_fn: D,
    ) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError> {
        self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
    }
}

/// A spatial index structure over a set of points, created by `NearestNeighbour`. Allows efficient
/// computation of nearest neighbour and range queries over the set of points. Individual points
/// are represented as one-dimensional array views.
pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
    /// Returns the `k` points in the index that are the closest to the provided point, along with
    /// their positions in the original dataset. Points are returned in ascending order of the
    /// distance away from the provided points, and less than `k` points will be returned if the
    /// index contains fewer than `k`.
    ///
    /// Returns an error if the provided point has different dimensionality than the index's
    /// points.
    fn k_nearest<'b>(
        &self,
        point: Point<'b, F>,
        k: usize,
    ) -> Result<Vec<(Point<F>, usize)>, NnError>;

    /// Returns all the points in the index that are within the specified distance to the provided
    /// point, along with their positions in the original dataset. The points are not guaranteed to
    /// be in any order, though many algorithms return the points in order of distance.
    ///
    /// Returns an error if the provided point has different dimensionality than the index's
    /// points.
    fn within_range<'b>(
        &self,
        point: Point<'b, F>,
        range: F,
    ) -> Result<Vec<(Point<F>, usize)>, NnError>;
}

/// Enum that dispatches to one of the crate's [`NearestNeighbour`](trait.NearestNeighbour.html)
/// implementations based on value. This enum should be used instead of using types like
/// `LinearSearch` and `KdTree` directly.
///
/// ## Example
///
/// ```rust
/// use rand_xoshiro::Xoshiro256Plus;
/// use ndarray_rand::{rand::SeedableRng, rand_distr::Uniform, RandomExt};
/// use ndarray::{Array1, Array2};
/// use linfa_nn::{distance::*, CommonNearestNeighbour, NearestNeighbour};
///
/// // Use seedable RNG for generating points
/// let mut rng = Xoshiro256Plus::seed_from_u64(40);
/// let n_features = 3;
/// let distr = Uniform::new(-500., 500.);
/// // Randomly generate points for building the index
/// let points = Array2::random_using((5000, n_features), distr, &mut rng);
///
/// // Build a K-D tree with Euclidean distance as the distance function
/// let nn = CommonNearestNeighbour::KdTree.from_batch(&points, L2Dist).unwrap();
///
/// let pt = Array1::random_using(n_features, distr, &mut rng);
/// // Compute the 10 nearest points to `pt` in the index
/// let nearest = nn.k_nearest(pt.view(), 10).unwrap();
/// // Compute all points within 100 units of `pt`
/// let range = nn.within_range(pt.view(), 100.0).unwrap();
/// ```
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
    feature = "serde",
    derive(Serialize, Deserialize),
    serde(crate = "serde_crate")
)]
pub enum CommonNearestNeighbour {
    /// Linear search
    LinearSearch,
    /// KD Tree
    KdTree,
    /// Ball Tree
    BallTree,
}

impl NearestNeighbour for CommonNearestNeighbour {
    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
        &self,
        batch: &'a ArrayBase<DT, Ix2>,
        leaf_size: usize,
        dist_fn: D,
    ) -> Result<Box<dyn 'a + NearestNeighbourIndex<F>>, BuildError> {
        match self {
            Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
            Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
            Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn autotraits() {
        fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
        has_autotraits::<CommonNearestNeighbour>();
        has_autotraits::<Box<dyn NearestNeighbourIndex<f64>>>();
        has_autotraits::<BuildError>();
        has_autotraits::<NnError>();
    }
}