1use std::{cmp::Reverse, collections::BinaryHeap};
2
3use linfa::Float;
4use ndarray::{ArrayBase, ArrayView2, Data, Ix2};
5use noisy_float::NoisyFloat;
6#[cfg(feature = "serde")]
7use serde_crate::{Deserialize, Serialize};
8
9use crate::{
10 distance::Distance, heap_elem::MinHeapElem, BuildError, NearestNeighbour, NearestNeighbourBox,
11 NearestNeighbourIndex, NnError, Point,
12};
13
14#[derive(Debug, Clone, PartialEq)]
16pub struct LinearSearchIndex<'a, F: Float, D: Distance<F>>(ArrayView2<'a, F>, D);
17
18impl<'a, F: Float, D: Distance<F>> LinearSearchIndex<'a, F, D> {
19 pub fn new<DT: Data<Elem = F>>(
21 batch: &'a ArrayBase<DT, Ix2>,
22 dist_fn: D,
23 ) -> Result<Self, BuildError> {
24 if batch.ncols() == 0 {
25 Err(BuildError::ZeroDimension)
26 } else {
27 Ok(Self(batch.view(), dist_fn))
28 }
29 }
30}
31
32impl<F: Float, D: Distance<F>> NearestNeighbourIndex<F> for LinearSearchIndex<'_, F, D> {
33 fn k_nearest(
34 &self,
35 point: Point<'_, F>,
36 k: usize,
37 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
38 if self.0.ncols() != point.len() {
39 Err(NnError::WrongDimension)
40 } else {
41 let mut heap = BinaryHeap::with_capacity(self.0.nrows());
42 for (i, pt) in self.0.rows().into_iter().enumerate() {
43 let dist = self.1.rdistance(point.reborrow(), pt.reborrow());
44 heap.push(MinHeapElem {
45 elem: (pt.reborrow(), i),
46 dist: Reverse(NoisyFloat::new(dist)),
47 });
48 }
49
50 Ok((0..k.min(heap.len()))
51 .map(|_| heap.pop().unwrap().elem)
52 .collect())
53 }
54 }
55
56 fn within_range(
57 &self,
58 point: Point<'_, F>,
59 range: F,
60 ) -> Result<Vec<(Point<'_, F>, usize)>, NnError> {
61 if self.0.ncols() != point.len() {
62 Err(NnError::WrongDimension)
63 } else {
64 let range = self.1.dist_to_rdist(range);
65 Ok(self
66 .0
67 .rows()
68 .into_iter()
69 .enumerate()
70 .filter(|(_, pt)| self.1.rdistance(point.reborrow(), pt.reborrow()) < range)
71 .map(|(i, pt)| (pt, i))
72 .collect())
73 }
74 }
75}
76
77#[derive(Default, Clone, Debug, PartialEq, Eq)]
81#[cfg_attr(
82 feature = "serde",
83 derive(Serialize, Deserialize),
84 serde(crate = "serde_crate")
85)]
86pub struct LinearSearch;
87
88impl LinearSearch {
89 pub fn new() -> Self {
91 Self
92 }
93}
94
95impl NearestNeighbour for LinearSearch {
96 fn from_batch_with_leaf_size<'a, F: Float, DT: Data<Elem = F>, D: 'a + Distance<F>>(
97 &self,
98 batch: &'a ArrayBase<DT, Ix2>,
99 leaf_size: usize,
100 dist_fn: D,
101 ) -> Result<NearestNeighbourBox<'a, F>, BuildError> {
102 if leaf_size == 0 {
103 return Err(BuildError::EmptyLeaf);
104 }
105 LinearSearchIndex::new(batch, dist_fn).map(|v| Box::new(v) as NearestNeighbourBox<F>)
106 }
107}