use distance::Distance;
use linfa::Float;
use ndarray::{ArrayBase, ArrayView1, Data, Ix2};
#[cfg(feature = "serde")]
use serde_crate::{Deserialize, Serialize};
use thiserror::Error;
mod balltree;
mod heap_elem;
mod kdtree;
mod linear;
pub mod distance;
pub use crate::{balltree::*, kdtree::*, linear::*};
pub(crate) type Point<'a, F> = ArrayView1<'a, F>;
pub(crate) type NearestNeighbourBox<'a, F> = Box<dyn 'a + Send + Sync + NearestNeighbourIndex<F>>;
#[derive(Error, Debug)]
pub enum BuildError {
#[error("points have dimension of 0")]
ZeroDimension,
#[error("leaf size is 0")]
EmptyLeaf,
}
#[derive(Error, Debug)]
pub enum NnError {
#[error("dimensions of query point and stored points are different")]
WrongDimension,
}
pub trait NearestNeighbour: std::fmt::Debug + Send + Sync + Unpin {
#[allow(clippy::wrong_self_convention)]
fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
leaf_size: usize,
dist_fn: D,
) -> Result<NearestNeighbourBox<'a, F>, BuildError>;
#[allow(clippy::wrong_self_convention)]
fn from_batch<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
dist_fn: D,
) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
self.from_batch_with_leaf_size(batch, 2usize.pow(4), dist_fn)
}
}
pub trait NearestNeighbourIndex<F: Float>: Send + Sync + Unpin {
fn k_nearest(
&self,
point: Point<'_, F>,
k: usize,
) -> Result<Vec<(Point<'_, F>, usize)>, NnError>;
fn within_range(
&self,
point: Point<'_, F>,
range: F,
) -> Result<Vec<(Point<'_, F>, usize)>, NnError>;
}
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq, Eq)]
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
pub enum CommonNearestNeighbour {
LinearSearch,
KdTree,
BallTree,
}
impl NearestNeighbour for CommonNearestNeighbour {
fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
&self,
batch: &'a ArrayBase<DT, Ix2>,
leaf_size: usize,
dist_fn: D,
) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
match self {
Self::LinearSearch => LinearSearch.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
Self::KdTree => KdTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
Self::BallTree => BallTree.from_batch_with_leaf_size(batch, leaf_size, dist_fn),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn autotraits() {
fn has_autotraits<T: Send + Sync + Sized + Unpin>() {}
has_autotraits::<CommonNearestNeighbour>();
has_autotraits::<NearestNeighbourBox<'static, f64>>();
has_autotraits::<BuildError>();
has_autotraits::<NnError>();
}
}