kiddo/float/query/
nearest_n.rs

1use az::{Az, Cast};
2use std::collections::BinaryHeap;
3use std::ops::Rem;
4
5use crate::float::kdtree::{Axis, KdTree};
6use crate::nearest_neighbour::NearestNeighbour;
7use crate::rkyv_utils::transform;
8use crate::traits::DistanceMetric;
9use crate::traits::{is_stem_index, Content, Index};
10
11use crate::generate_nearest_n;
12
13macro_rules! generate_float_nearest_n {
14    ($doctest_build_tree:tt) => {
15        generate_nearest_n!((
16            "Finds the nearest `qty` elements to `query`, using the specified
17distance metric function.
18# Examples
19
20```rust
21    use kiddo::KdTree;
22    use kiddo::SquaredEuclidean;
23
24    ",
25            $doctest_build_tree,
26            "
27
28    let nearest: Vec<_> = tree.nearest_n::<SquaredEuclidean>(&[1.0, 2.0, 5.1], 1);
29
30    assert_eq!(nearest.len(), 1);
31    assert!((nearest[0].distance - 0.01f64).abs() < f64::EPSILON);
32    assert_eq!(nearest[0].item, 100);
33```"
34        ));
35    };
36}
37
38impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
39    KdTree<A, T, K, B, IDX>
40where
41    usize: Cast<IDX>,
42{
43    generate_float_nearest_n!(
44        "let mut tree: KdTree<f64, 3> = KdTree::new();
45    tree.add(&[1.0, 2.0, 5.0], 100);
46    tree.add(&[2.0, 3.0, 6.0], 101);"
47    );
48}
49
50#[cfg(feature = "rkyv")]
51use crate::float::kdtree::ArchivedKdTree;
52#[cfg(feature = "rkyv")]
53impl<
54        A: Axis + rkyv::Archive<Archived = A>,
55        T: Content + rkyv::Archive<Archived = T>,
56        const K: usize,
57        const B: usize,
58        IDX: Index<T = IDX> + rkyv::Archive<Archived = IDX>,
59    > ArchivedKdTree<A, T, K, B, IDX>
60where
61    usize: Cast<IDX>,
62{
63    generate_float_nearest_n!(
64        "use std::fs::File;
65    use memmap::MmapOptions;
66
67    let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree.rkyv\").expect(\"./examples/float-doctest-tree.rkyv missing\")).unwrap() };
68    let tree = unsafe { rkyv::archived_root::<KdTree<f64, 3>>(&mmap) };"
69    );
70}
71
72#[cfg(feature = "rkyv_08")]
73use crate::float::kdtree::ArchivedR8KdTree;
74#[cfg(feature = "rkyv_08")]
75impl<
76        A: Axis + rkyv_08::Archive,
77        T: Content + rkyv_08::Archive,
78        const K: usize,
79        const B: usize,
80        IDX: Index<T = IDX>,
81    > ArchivedR8KdTree<A, T, K, B, IDX>
82where
83    usize: Cast<IDX>,
84    IDX: rkyv_08::Archive,
85{
86    generate_float_nearest_n!(
87        "use std::fs::File;
88    use memmap::MmapOptions;
89    use kiddo::float::kdtree::ArchivedR8KdTree;
90
91    let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree-rkyv_08.rkyv\").expect(\"./examples/float-doctest-tree-rkyv_08.rkyv missing\")).unwrap() };
92    let tree = unsafe { rkyv_08::access_unchecked::<ArchivedR8KdTree<f64, u64, 3, 32, u32>>(&mmap) };"
93    );
94}
95
96#[cfg(test)]
97mod tests {
98    use crate::float::distance::SquaredEuclidean;
99    use crate::float::kdtree::{Axis, KdTree};
100    use crate::traits::DistanceMetric;
101    use rand::Rng;
102
103    type AX = f32;
104
105    #[test]
106    fn can_query_nearest_n_item() {
107        let mut tree: KdTree<AX, u32, 4, 8, u32> = KdTree::new();
108
109        let content_to_add: [([AX; 4], u32); 16] = [
110            ([0.9f32, 0.0f32, 0.9f32, 0.0f32], 9),    // 1.34
111            ([0.4f32, 0.5f32, 0.4f32, 0.51f32], 4),   // 0.86
112            ([0.12f32, 0.3f32, 0.12f32, 0.3f32], 12), // 1.82
113            ([0.7f32, 0.2f32, 0.7f32, 0.22f32], 7),   // 0.86
114            ([0.13f32, 0.4f32, 0.13f32, 0.4f32], 13), // 1.56
115            ([0.6f32, 0.3f32, 0.6f32, 0.33f32], 6),   // 0.86
116            ([0.2f32, 0.7f32, 0.2f32, 0.7f32], 2),    // 1.46
117            ([0.14f32, 0.5f32, 0.14f32, 0.5f32], 14), // 1.38
118            ([0.3f32, 0.6f32, 0.3f32, 0.6f32], 3),    // 1.06
119            ([0.10f32, 0.1f32, 0.10f32, 0.1f32], 10), // 2.26
120            ([0.16f32, 0.7f32, 0.16f32, 0.7f32], 16), // 1.54
121            ([0.1f32, 0.8f32, 0.1f32, 0.8f32], 1),    // 1.86
122            ([0.15f32, 0.6f32, 0.15f32, 0.6f32], 15), // 1.36
123            ([0.5f32, 0.4f32, 0.5f32, 0.44f32], 5),   // 0.86
124            ([0.8f32, 0.1f32, 0.8f32, 0.15f32], 8),   // 0.86
125            ([0.11f32, 0.2f32, 0.11f32, 0.2f32], 11), // 2.04
126        ];
127
128        for (point, item) in content_to_add {
129            tree.add(&point, item);
130        }
131
132        assert_eq!(tree.size(), 16);
133
134        let query_point = [0.78f32, 0.55f32, 0.78f32, 0.55f32];
135
136        let expected = vec![(0.17569996, 6), (0.19139998, 5), (0.24420004, 7)];
137
138        let result: Vec<_> = tree
139            .nearest_n::<SquaredEuclidean>(&query_point, 3)
140            .into_iter()
141            .map(|n| (n.distance, n.item))
142            .collect();
143        assert_eq!(result, expected);
144
145        let qty = 10;
146        let mut rng = rand::rng();
147        for _i in 0..1000 {
148            let query_point = [
149                rng.random_range(0f32..1f32),
150                rng.random_range(0f32..1f32),
151                rng.random_range(0f32..1f32),
152                rng.random_range(0f32..1f32),
153            ];
154            let expected = linear_search(&content_to_add, qty, &query_point);
155
156            let result: Vec<_> = tree
157                .nearest_n::<SquaredEuclidean>(&query_point, qty)
158                .into_iter()
159                .map(|n| (n.distance, n.item))
160                .collect();
161
162            let result_dists: Vec<_> = result.iter().map(|(d, _)| d).collect();
163            let expected_dists: Vec<_> = expected.iter().map(|(d, _)| d).collect();
164
165            assert_eq!(result_dists, expected_dists);
166        }
167    }
168
169    #[test]
170    fn can_query_nearest_10_items_large_scale() {
171        const TREE_SIZE: usize = 100_000;
172        const NUM_QUERIES: usize = 100;
173        const N: usize = 10;
174
175        let content_to_add: Vec<([f32; 4], u32)> = (0..TREE_SIZE)
176            .map(|_| rand::random::<([f32; 4], u32)>())
177            .collect();
178
179        let mut tree: KdTree<AX, u32, 4, 32, u32> = KdTree::with_capacity(TREE_SIZE);
180        content_to_add
181            .iter()
182            .for_each(|(point, content)| tree.add(point, *content));
183        assert_eq!(tree.size(), TREE_SIZE as u32);
184
185        let query_points: Vec<[f32; 4]> = (0..NUM_QUERIES)
186            .map(|_| rand::random::<[f32; 4]>())
187            .collect();
188
189        for query_point in query_points {
190            let expected = linear_search(&content_to_add, N, &query_point);
191
192            let result: Vec<_> = tree
193                .nearest_n::<SquaredEuclidean>(&query_point, N)
194                .into_iter()
195                .map(|n| (n.distance, n.item))
196                .collect();
197
198            let result_dists: Vec<_> = result.iter().map(|(d, _)| d).collect();
199            let expected_dists: Vec<_> = expected.iter().map(|(d, _)| d).collect();
200
201            assert_eq!(result_dists, expected_dists);
202        }
203    }
204
205    fn linear_search<A: Axis, const K: usize>(
206        content: &[([A; K], u32)],
207        qty: usize,
208        query_point: &[A; K],
209    ) -> Vec<(A, u32)> {
210        let mut results = vec![];
211
212        for &(p, item) in content {
213            let dist = SquaredEuclidean::dist(query_point, &p);
214            if results.len() < qty {
215                results.push((dist, item));
216                results.sort_by(|(a_dist, _), (b_dist, _)| a_dist.partial_cmp(b_dist).unwrap());
217            } else if dist < results[qty - 1].0 {
218                results[qty - 1] = (dist, item);
219                results.sort_by(|(a_dist, _), (b_dist, _)| a_dist.partial_cmp(b_dist).unwrap());
220            }
221        }
222
223        results
224    }
225}