linfa_nn/
linear.rs

1use std::{cmp::Reverse, collections::BinaryHeap};
2
3use linfa::Float;
4use ndarray::{ArrayBase, ArrayView2, Data, Ix2};
5use noisy_float::NoisyFloat;
6#[cfg(feature = "serde")]
7use serde_crate::{Deserialize, Serialize};
8
9use crate::{
10    distance::Distance, heap_elem::MinHeapElem, BuildError, NearestNeighbour, NearestNeighbourBox,
11    NearestNeighbourIndex, NnError, Point,
12};
13
14/// Spatial indexing structure created by [`LinearSearch`]
15#[derive(Debug, Clone, PartialEq)]
16pub struct LinearSearchIndex<'a, F: Float, D: Distance<F>>(ArrayView2<'a, F>, D);
17
18impl<'a, F: Float, D: Distance<F>> LinearSearchIndex<'a, F, D> {
19    /// Creates a new `LinearSearchIndex`
20    pub fn new<DT: Data<Elem = F>>(
21        batch: &'a ArrayBase<DT, Ix2>,
22        dist_fn: D,
23    ) -> Result<Self, BuildError> {
24        if batch.ncols() == 0 {
25            Err(BuildError::ZeroDimension)
26        } else {
27            Ok(Self(batch.view(), dist_fn))
28        }
29    }
30}
31
32impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for LinearSearchIndex<'_, F, D> {
33    fn k_nearest(
34        &self,
35        point: Point<'_, F>,
36        k: usize,
37    ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
38        if self.0.ncols() != point.len() {
39            Err(NnError::WrongDimension)
40        } else {
41            let mut heap = BinaryHeap::with_capacity(self.0.nrows());
42            for (i, pt) in self.0.rows().into_iter().enumerate() {
43                let dist = self.1.rdistance(point.reborrow(), pt.reborrow());
44                heap.push(MinHeapElem {
45                    elem: (pt.reborrow(), i),
46                    dist: Reverse(NoisyFloat::new(dist)),
47                });
48            }
49
50            Ok((0..k.min(heap.len()))
51                .map(|_| heap.pop().unwrap().elem)
52                .collect())
53        }
54    }
55
56    fn within_range(
57        &self,
58        point: Point<'_, F>,
59        range: F,
60    ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
61        if self.0.ncols() != point.len() {
62            Err(NnError::WrongDimension)
63        } else {
64            let range = self.1.dist_to_rdist(range);
65            Ok(self
66                .0
67                .rows()
68                .into_iter()
69                .enumerate()
70                .filter(|(_, pt)| self.1.rdistance(point.reborrow(), pt.reborrow()) < range)
71                .map(|(i, pt)| (pt, i))
72                .collect())
73        }
74    }
75}
76
77/// Implementation of linear search, which is the simplest nearest neighbour algorithm. All queries
78/// are implemented by scanning through every point, so all of them are `O(N)`. Calling
79/// `from_batch` returns a [`LinearSearchIndex`].
80#[derive(Default, Clone, Debug, PartialEq, Eq)]
81#[cfg_attr(
82    feature = "serde",
83    derive(Serialize, Deserialize),
84    serde(crate = "serde_crate")
85)]
86pub struct LinearSearch;
87
88impl LinearSearch {
89    /// Creates an instance of `LinearSearch`
90    pub fn new() -> Self {
91        Self
92    }
93}
94
95impl NearestNeighbour for LinearSearch {
96    fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
97        &self,
98        batch: &'a ArrayBase<DT, Ix2>,
99        leaf_size: usize,
100        dist_fn: D,
101    ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
102        if leaf_size == 0 {
103            return Err(BuildError::EmptyLeaf);
104        }
105        LinearSearchIndex::new(batch, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
106    }
107}