radius_nn/
radius_nn.rs

1use vpsearch::{BestCandidate, MetricSpace};
2
3use std::collections::HashSet;
4
5#[derive(Clone, Debug)]
6struct PointN {
7    data: Vec<f32>,
8}
9
10/// Point structure that will end up in the tree
11impl PointN {
12    pub fn new(data: impl Into<Vec<f32>>) -> Self {
13        Self { data: data.into() }
14    }
15}
16
17/// The searching function
18impl MetricSpace for PointN {
19    type UserData = ();
20    type Distance = f32;
21
22    fn distance(&self, other: &Self, _: &Self::UserData) -> Self::Distance {
23        self.data
24            .iter()
25            .zip(other.data.iter())
26            .map(|(s, o)| (s - o).powi(2))
27            .sum::<f32>()
28            .sqrt()
29    }
30}
31
32/// Add custom search for finding the index of multiple points in a radius
33/// The index of all point with a euclidean distance strictly less than
34/// `max_distance` will be returned.
35struct RadiusBasedNeighborhood<Item: MetricSpace<Impl>, Impl> {
36    max_distance: Item::Distance,
37    ids: HashSet<usize>,
38}
39
40impl<Item: MetricSpace<Impl>, Impl> RadiusBasedNeighborhood<Item, Impl> {
41    /// Helper function for creating the `RadiusBasedNeighborhood` struct.
42    /// Here `max_distance` is an exclusive upper bound to the euclidean distance.
43    fn new(max_distance: Item::Distance) -> Self {
44        Self {
45            max_distance,
46            ids: HashSet::<usize>::new(),
47        }
48    }
49}
50
51/// Best candidate definitions that tracks of the index all the points
52/// within the radius of `distance` as specified in the `RadiusBasedNeighborhood`.
53impl<Item: MetricSpace<Impl> + Clone, Impl> BestCandidate<Item, Impl>
54    for RadiusBasedNeighborhood<Item, Impl>
55{
56    type Output = HashSet<usize>;
57
58    #[inline]
59    fn consider(
60        &mut self,
61        _: &Item,
62        distance: Item::Distance,
63        candidate_index: usize,
64        _: &Item::UserData,
65    ) {
66        // If the distance is lower than the bound we include the index
67        // in the result.
68        if distance < self.max_distance {
69            self.ids.insert(candidate_index);
70        }
71    }
72
73    #[inline]
74    fn distance(&self) -> Item::Distance {
75        self.max_distance
76    }
77    fn result(self, _: &Item::UserData) -> Self::Output {
78        self.ids
79    }
80}
81
82fn main() {
83    let points = vec![
84        PointN::new([2.0, 3.0]),
85        PointN::new([0.0, 1.0]),
86        PointN::new([4.0, 5.0]),
87    ];
88    let tree = vpsearch::Tree::new(&points);
89
90    // Search with a distance of 0, expect no points to be returned
91    let expected = HashSet::new();
92    let actual = tree.find_nearest_custom(
93        &PointN::new([1.0, 2.0]),
94        &(),
95        RadiusBasedNeighborhood::new(0.0f32),
96    );
97    assert_eq!(actual, expected);
98
99    // Search with a distance of 100, expect all points to be returned
100    let expected = [0, 1, 2].iter().copied().collect::<HashSet<usize>>();
101    let actual = tree.find_nearest_custom(
102        &PointN::new([1.0, 2.0]),
103        &(),
104        RadiusBasedNeighborhood::new(100.0f32),
105    );
106    assert_eq!(actual, expected);
107}