knn/
knn.rs

1use vpsearch::{BestCandidate, MetricSpace};
2
3use std::collections::HashSet;
4use num_traits::Bounded;
5
6#[derive(Clone, Debug)]
7struct PointN {
8    data: Vec<f32>,
9}
10
11/// Point structure that will end up in the tree
12impl PointN {
13    pub fn new(data: impl Into<Vec<f32>>) -> Self {
14        Self { data: data.into() }
15    }
16}
17
18/// The searching function
19impl MetricSpace for PointN {
20    type UserData = ();
21    type Distance = f32;
22
23    fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
24        self.data
25            .iter()
26            .zip(other.data.iter())
27            .map(|(s, o)| (s - o).powi(2))
28            .sum::<f32>()
29            .sqrt()
30    }
31}
32
33/// Add custom search for finding the index of the N nearest points
34struct CountBasedNeighborhood<Item: MetricSpace<Impl>, Impl> {
35    // Max amount of items
36    max_item_count: usize,
37    // The max distance we have observed so far
38    max_observed_distance: Item::Distance,
39    // A list of indexes no longer than max_item_count sorted by distance
40    distance_x_index: Vec<(Item::Distance, usize)>,
41}
42
43impl<Item: MetricSpace<Impl>, Impl> CountBasedNeighborhood<Item, Impl> {
44    /// Helper function for creating the `CountBasedNeighborhood` struct.
45    /// Here `item_count` is the amount of items returned, the k in knn.
46    fn new(item_count: usize) -> Self {
47        Self {
48            max_item_count: item_count,
49            max_observed_distance: <Item::Distance as Bounded>::min_value(),
50            distance_x_index: Vec::<(Item::Distance, usize)>::new(),
51        }
52    }
53
54    /// Insert a single index in the correct possition given that the
55    /// `distance_x_index` is already sorted.
56    fn insert_index(&mut self, index: usize, distance: Item::Distance) {
57        // Add the new item at the end of the list.
58        self.distance_x_index.push((distance, index));
59        // We only need to sort lists with more than one entry
60        if self.distance_x_index.len() > 1 {
61            // Start indexing at the end of the vector. Note that len() is 1 indexed.
62            let mut n = self.distance_x_index.len() - 1;
63            // at n is further than n -1 we swap the two.
64            // Prefrom a single insertion sort pass. If the distance of the element
65            while n > 0 && self.distance_x_index[n].0 < self.distance_x_index[n - 1].0 {
66                self.distance_x_index.swap(n, n - 1);
67                n -= 1;
68            }
69            self.distance_x_index.truncate(self.max_item_count);
70        }
71        // Update the max observed distance, unwrap is safe because this function
72        // inserts a point and the `max_item_count` is more then 0.
73        self.max_observed_distance = self.distance_x_index.last().unwrap().0;
74    }
75}
76
77/// Best candidate definitions that tracks of the index all the points
78/// within the radius of `distance` as specified in the `RadiusBasedNeighborhood`.
79impl<Item: MetricSpace<Impl> + Clone, Impl> BestCandidate<Item, Impl>
80    for CountBasedNeighborhood<Item, Impl>
81{
82    type Output = HashSet<usize>;
83
84    #[inline]
85    fn consider(
86        &mut self,
87        _: &Item,
88        distance: Item::Distance,
89        candidate_index: usize,
90        _: &Item::UserData,
91    ) {
92        // Early out, no need to do track any points if the max return size is 0
93        if self.max_item_count == 0 {
94            return;
95        }
96
97        // If the distance is lower than the max_observed distance we
98        // need to add that index into the sorted_ids and update the
99        // `max_observed_distance`. If the sorted_ids is already at max
100        // capacity we drop the point with the max distance and find
101        // out what the new `max_observed_distance` is by looking at
102        // the last entry in the `distance_x_index` vector. We also
103        // include the point if the `distance_x_index` is not full yet.
104        if distance < self.max_observed_distance
105            || self.distance_x_index.len() < self.max_item_count
106        {
107            self.insert_index(candidate_index, distance);
108        }
109    }
110
111    #[inline]
112    fn distance(&self) -> Item::Distance {
113        // return distance of the Nth farthest as we have currently observed it.
114        // All other points currently in the state will be closer than this.
115        self.max_observed_distance
116    }
117
118    fn result(self, _: &Item::UserData) -> Self::Output {
119        // Convert the sorted indexes into a hash set droping the distance value.
120        self.distance_x_index
121            .into_iter()
122            .map(|(_, index)| index)
123            .collect::<HashSet<usize>>()
124    }
125}
126
127fn main() {
128    let points = vec![
129        PointN::new([2.0, 3.0]),
130        PointN::new([0.0, 1.0]),
131        PointN::new([4.0, 5.0]),
132    ];
133    let tree = vpsearch::Tree::new(&points);
134
135    // Search with a neigboord size of 1, expect a single points to be returned
136    let actual = tree.find_nearest_custom(
137        &PointN::new([1.0, 2.0]),
138        &(),
139        CountBasedNeighborhood::new(1),
140    );
141    assert_eq!(actual.len(), 1);
142
143    // Search with a neigboord size of 2, expect a two points to be returned
144    let expected = [0, 1].iter().copied().collect::<HashSet<usize>>();
145    let actual = tree.find_nearest_custom(
146        &PointN::new([1.0, 2.0]),
147        &(),
148        CountBasedNeighborhood::new(2),
149    );
150    assert_eq!(actual, expected);
151
152    // Search with a neigboord size of 10, expect all points to be returned
153    let expected = [0, 1, 2].iter().copied().collect::<HashSet<usize>>();
154    let actual = tree.find_nearest_custom(
155        &PointN::new([1.0, 2.0]),
156        &(),
157        CountBasedNeighborhood::new(10),
158    );
159    assert_eq!(actual, expected);
160}