linfa_nn/
kdtree.rs

1use linfa::Float;
2use ndarray::{aview1, ArrayBase, Data, Ix2};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6use crate::{
7    distance::Distance, BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex,
8    NnError, Point,
9};
10
11/// Spatial indexing structure created by [`KdTree`]
12#[derive(Debug)]
13pub struct KdTreeIndex<'a, F: Float, D: Distance<F>>(
14    kdtree::KdTree<F, (Point<'a, F>, usize), &'a [F]>,
15    D,
16);
17
18impl<'a, F: Float, D: Distance<F>> KdTreeIndex<'a, F, D> {
19    /// Creates a new `KdTreeIndex`
20    pub fn new<DT: Data<Elem = F>>(
21        batch: &'a ArrayBase<DT, Ix2>,
22        leaf_size: usize,
23        dist_fn: D,
24    ) -> Result<Self, BuildError> {
25        if leaf_size == 0 {
26            Err(BuildError::EmptyLeaf)
27        } else if batch.ncols() == 0 {
28            Err(BuildError::ZeroDimension)
29        } else {
30            let mut tree = kdtree::KdTree::with_capacity(batch.ncols().max(1), leaf_size);
31            for (i, point) in batch.rows().into_iter().enumerate() {
32                tree.add(
33                    point.to_slice().expect("views should be contiguous"),
34                    (point, i),
35                )
36                .unwrap();
37            }
38            Ok(Self(tree, dist_fn))
39        }
40    }
41}
42
43impl From<kdtree::ErrorKind> for NnError {
44    fn from(err: kdtree::ErrorKind) -> Self {
45        match err {
46            kdtree::ErrorKind::WrongDimension => NnError::WrongDimension,
47            kdtree::ErrorKind::NonFiniteCoordinate => panic!("infinite value found"),
48            _ => unreachable!(),
49        }
50    }
51}
52
53impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for KdTreeIndex<'_, F, D> {
54    fn k_nearest(
55        &self,
56        point: Point<'_, F>,
57        k: usize,
58    ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
59        Ok(self
60            .0
61            .nearest(
62                point.to_slice().expect("views should be contiguous"),
63                k,
64                &|a, b| self.1.rdistance(aview1(a), aview1(b)),
65            )?
66            .into_iter()
67            .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
68            .collect())
69    }
70
71    fn within_range(
72        &self,
73        point: Point<'_, F>,
74        range: F,
75    ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
76        let range = self.1.dist_to_rdist(range);
77        Ok(self
78            .0
79            .within(
80                point.to_slice().expect("views should be contiguous"),
81                range,
82                &|a, b| self.1.rdistance(aview1(a), aview1(b)),
83            )?
84            .into_iter()
85            .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
86            .collect())
87    }
88}
89
90/// Implementation of K-D tree, a fast space-partitioning data structure.  For each parent node,
91/// the indexed points are split with a hyperplane into two child nodes. Due to its tree-like
92/// structure, the K-D tree performs spatial queries in `O(k * logN)` time, where `k` is the number
93/// of points returned by the query. Calling `from_batch` returns a [`KdTree`].
94///
95/// More details can be found [here](https://en.wikipedia.org/wiki/K-d_tree).
96///
97/// Unlike other `NearestNeighbour` implementations, `KdTree` requires that points be laid out
98/// contiguously in memory and will panic otherwise.
99#[derive(Default, Clone, Debug, PartialEq, Eq)]
100#[cfg_attr(
101    feature = "serde",
102    derive(Serialize, Deserialize),
103    serde(crate = "serde_crate")
104)]
105pub struct KdTree;
106
107impl KdTree {
108    /// Creates an instance of `KdTree`
109    pub fn new() -> Self {
110        Self
111    }
112}
113
114impl NearestNeighbour for KdTree {
115    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
116        &self,
117        batch: &'a ArrayBase<DT, Ix2>,
118        leaf_size: usize,
119        dist_fn: D,
120    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
121        KdTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
122    }
123}