1use linfa::Float;
2use ndarray::{aview1, ArrayBase, Data, Ix2};
3#[cfg(feature = "serde")]
4use serde_crate::{Deserialize, Serialize};
5
6use crate::{
7 distance::Distance, BuildError, NearestNeighbour, NearestNeighbourBox, NearestNeighbourIndex,
8 NnError, Point,
9};
10
11#[derive(Debug)]
13pub struct KdTreeIndex<'a, F: Float, D: Distance<F>>(
14 kdtree::KdTree<F, (Point<'a, F>, usize), &'a [F]>,
15 D,
16);
17
18impl<'a, F: Float, D: Distance<F>> KdTreeIndex<'a, F, D> {
19 pub fn new<DT: Data<Elem = F>>(
21 batch: &'a ArrayBase<DT, Ix2>,
22 leaf_size: usize,
23 dist_fn: D,
24 ) -> Result<Self, BuildError> {
25 if leaf_size == 0 {
26 Err(BuildError::EmptyLeaf)
27 } else if batch.ncols() == 0 {
28 Err(BuildError::ZeroDimension)
29 } else {
30 let mut tree = kdtree::KdTree::with_capacity(batch.ncols().max(1), leaf_size);
31 for (i, point) in batch.rows().into_iter().enumerate() {
32 tree.add(
33 point.to_slice().expect("views should be contiguous"),
34 (point, i),
35 )
36 .unwrap();
37 }
38 Ok(Self(tree, dist_fn))
39 }
40 }
41}
42
43impl From<kdtree::ErrorKind> for NnError {
44 fn from(err: kdtree::ErrorKind) -> Self {
45 match err {
46 kdtree::ErrorKind::WrongDimension => NnError::WrongDimension,
47 kdtree::ErrorKind::NonFiniteCoordinate => panic!("infinite value found"),
48 _ => unreachable!(),
49 }
50 }
51}
52
53impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for KdTreeIndex<'_, F, D> {
54 fn k_nearest(
55 &self,
56 point: Point<'_, F>,
57 k: usize,
58 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
59 Ok(self
60 .0
61 .nearest(
62 point.to_slice().expect("views should be contiguous"),
63 k,
64 &|a, b| self.1.rdistance(aview1(a), aview1(b)),
65 )?
66 .into_iter()
67 .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
68 .collect())
69 }
70
71 fn within_range(
72 &self,
73 point: Point<'_, F>,
74 range: F,
75 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
76 let range = self.1.dist_to_rdist(range);
77 Ok(self
78 .0
79 .within(
80 point.to_slice().expect("views should be contiguous"),
81 range,
82 &|a, b| self.1.rdistance(aview1(a), aview1(b)),
83 )?
84 .into_iter()
85 .map(|(_, (pt, pos))| (pt.reborrow(), *pos))
86 .collect())
87 }
88}
89
90#[derive(Default, Clone, Debug, PartialEq, Eq)]
100#[cfg_attr(
101 feature = "serde",
102 derive(Serialize, Deserialize),
103 serde(crate = "serde_crate")
104)]
105pub struct KdTree;
106
107impl KdTree {
108 pub fn new() -> Self {
110 Self
111 }
112}
113
114impl NearestNeighbour for KdTree {
115 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
116 &self,
117 batch: &'a ArrayBase<DT, Ix2>,
118 leaf_size: usize,
119 dist_fn: D,
120 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
121 KdTreeIndex::new(batch, leaf_size, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
122 }
123}