Skip to main content

kiddo/float/query/
within.rs

1use az::Cast;
2
3use crate::float::kdtree::{Axis, KdTree};
4use crate::nearest_neighbour::NearestNeighbour;
5use crate::traits::DistanceMetric;
6use crate::traits::{Content, Index};
7
8use crate::generate_within;
9
10macro_rules! generate_float_within {
11    ($doctest_build_tree:tt) => {
12        generate_within!((
13            "Finds all elements within `dist` of `query`, using the specified
14distance metric function.
15
16Results are returned sorted nearest-first
17
18# Examples
19
20```rust
21    use kiddo::KdTree;
22    use kiddo::SquaredEuclidean;
23    ",
24            $doctest_build_tree,
25            "
26
27    let within = tree.within::<SquaredEuclidean>(&[1.0, 2.0, 5.0], 10f64);
28
29    assert_eq!(within.len(), 2);
30```"
31        ));
32    };
33}
34
35impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
36    KdTree<A, T, K, B, IDX>
37where
38    usize: Cast<IDX>,
39{
40    generate_float_within!(
41        "
42let mut tree: KdTree<f64, 3> = KdTree::new();
43tree.add(&[1.0, 2.0, 5.0], 100);
44tree.add(&[2.0, 3.0, 6.0], 101);"
45    );
46}
47
48#[cfg(feature = "rkyv")]
49use crate::float::kdtree::ArchivedKdTree;
50#[cfg(feature = "rkyv")]
51impl<
52        A: Axis + rkyv::Archive<Archived = A>,
53        T: Content + rkyv::Archive<Archived = T>,
54        const K: usize,
55        const B: usize,
56        IDX: Index<T = IDX> + rkyv::Archive<Archived = IDX>,
57    > ArchivedKdTree<A, T, K, B, IDX>
58where
59    usize: Cast<IDX>,
60{
61    generate_float_within!(
62        "use std::fs::File;
63use memmap::MmapOptions;
64
65let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree.rkyv\").expect(\"./examples/float-doctest-tree.rkyv missing\")).unwrap() };
66let tree = unsafe { rkyv::archived_root::<KdTree<f64, 3>>(&mmap) };"
67    );
68}
69
70#[cfg(feature = "rkyv_08")]
71use crate::float::kdtree::ArchivedR8KdTree;
72#[cfg(feature = "rkyv_08")]
73impl<
74        A: Axis + rkyv_08::Archive,
75        T: Content + rkyv_08::Archive,
76        const K: usize,
77        const B: usize,
78        IDX: Index<T = IDX>,
79    > ArchivedR8KdTree<A, T, K, B, IDX>
80where
81    usize: Cast<IDX>,
82    IDX: rkyv_08::Archive,
83{
84    generate_float_within!(
85        "use std::fs::File;
86    use memmap::MmapOptions;
87    use kiddo::float::kdtree::ArchivedR8KdTree;
88
89    let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree-rkyv_08.rkyv\").expect(\"./examples/float-doctest-tree-rkyv_08.rkyv missing\")).unwrap() };
90    let tree = unsafe { rkyv_08::access_unchecked::<ArchivedR8KdTree<f64, u64, 3, 32, u32>>(&mmap) };"
91    );
92}
93
94#[cfg(test)]
95mod tests {
96    use crate::float::distance::Manhattan;
97    use crate::float::kdtree::{Axis, KdTree};
98    use crate::nearest_neighbour::NearestNeighbour;
99    use crate::traits::DistanceMetric;
100    use rand::Rng;
101    use std::cmp::Ordering;
102
103    type AX = f32;
104
105    #[test]
106    fn can_query_items_within_radius() {
107        let mut tree: KdTree<AX, u32, 4, 5, u32> = KdTree::new();
108
109        let content_to_add: [([AX; 4], u32); 16] = [
110            ([0.9f32, 0.0f32, 0.9f32, 0.0f32], 9),
111            ([0.4f32, 0.5f32, 0.4f32, 0.5f32], 4),
112            ([0.12f32, 0.3f32, 0.12f32, 0.3f32], 12),
113            ([0.7f32, 0.2f32, 0.7f32, 0.2f32], 7),
114            ([0.13f32, 0.4f32, 0.13f32, 0.4f32], 13),
115            ([0.6f32, 0.3f32, 0.6f32, 0.3f32], 6),
116            ([0.2f32, 0.7f32, 0.2f32, 0.7f32], 2),
117            ([0.14f32, 0.5f32, 0.14f32, 0.5f32], 14),
118            ([0.3f32, 0.6f32, 0.3f32, 0.6f32], 3),
119            ([0.10f32, 0.1f32, 0.10f32, 0.1f32], 10),
120            ([0.16f32, 0.7f32, 0.16f32, 0.7f32], 16),
121            ([0.1f32, 0.8f32, 0.1f32, 0.8f32], 1),
122            ([0.15f32, 0.6f32, 0.15f32, 0.6f32], 15),
123            ([0.5f32, 0.4f32, 0.5f32, 0.4f32], 5),
124            ([0.8f32, 0.1f32, 0.8f32, 0.1f32], 8),
125            ([0.11f32, 0.2f32, 0.11f32, 0.2f32], 11),
126        ];
127
128        for (point, item) in content_to_add {
129            tree.add(&point, item);
130        }
131
132        assert_eq!(tree.size(), 16);
133
134        let query_point = [0.78f32, 0.55f32, 0.78f32, 0.55f32];
135
136        let radius = 0.2;
137        let expected = linear_search(&content_to_add, &query_point, radius);
138
139        let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, radius);
140        stabilize_sort(&mut result);
141        assert_eq!(result, expected);
142
143        let mut rng = rand::rng();
144        for _i in 0..1000 {
145            let query_point = [
146                rng.random_range(0f32..1f32),
147                rng.random_range(0f32..1f32),
148                rng.random_range(0f32..1f32),
149                rng.random_range(0f32..1f32),
150            ];
151            let radius: f32 = 2.0;
152            let expected = linear_search(&content_to_add, &query_point, radius);
153
154            let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, radius);
155            stabilize_sort(&mut result);
156
157            assert_eq!(result, expected);
158        }
159    }
160
161    #[test]
162    fn can_query_items_within_radius_large_scale() {
163        const TREE_SIZE: usize = 100_000;
164        const NUM_QUERIES: usize = 100;
165        const RADIUS: f32 = 0.2;
166
167        let content_to_add: Vec<([f32; 4], u32)> = (0..TREE_SIZE)
168            .map(|_| rand::random::<([f32; 4], u32)>())
169            .collect();
170
171        let mut tree: KdTree<AX, u32, 4, 32, u32> = KdTree::with_capacity(TREE_SIZE);
172        content_to_add
173            .iter()
174            .for_each(|(point, content)| tree.add(point, *content));
175        assert_eq!(tree.size(), TREE_SIZE as u32);
176
177        let query_points: Vec<[f32; 4]> = (0..NUM_QUERIES)
178            .map(|_| rand::random::<[f32; 4]>())
179            .collect();
180
181        for query_point in query_points {
182            let expected = linear_search(&content_to_add, &query_point, RADIUS);
183
184            let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, RADIUS);
185
186            // TODO: ensure that adjacent results with the same dist are sorted in order of item val
187            //       to prevent occasional test failures due to the linear search returning items
188            //       with the same dist in a different order to the query
189            stabilize_sort(&mut result);
190
191            assert_eq!(result, expected);
192        }
193    }
194
195    fn linear_search<A: Axis, const K: usize>(
196        content: &[([A; K], u32)],
197        query_point: &[A; K],
198        radius: A,
199    ) -> Vec<NearestNeighbour<A, u32>> {
200        let mut matching_items = vec![];
201
202        for &(p, item) in content {
203            let distance = Manhattan::dist(query_point, &p);
204            if distance < radius {
205                matching_items.push(NearestNeighbour { distance, item });
206            }
207        }
208
209        stabilize_sort(&mut matching_items);
210
211        matching_items
212    }
213
214    fn stabilize_sort<A: Axis>(matching_items: &mut [NearestNeighbour<A, u32>]) {
215        matching_items.sort_unstable_by(|a, b| {
216            let dist_cmp = a.distance.partial_cmp(&b.distance).unwrap();
217            if dist_cmp == Ordering::Equal {
218                a.item.cmp(&b.item)
219            } else {
220                dist_cmp
221            }
222        });
223    }
224}