1use az::{Az, Cast};
2use std::collections::BinaryHeap;
3use std::ops::Rem;
4
5use crate::float::kdtree::{Axis, KdTree};
6use crate::nearest_neighbour::NearestNeighbour;
7use crate::rkyv_utils::transform;
8use crate::traits::DistanceMetric;
9use crate::traits::{is_stem_index, Content, Index};
10
11use crate::generate_nearest_n;
12
13macro_rules! generate_float_nearest_n {
14 ($doctest_build_tree:tt) => {
15 generate_nearest_n!((
16 "Finds the nearest `qty` elements to `query`, using the specified
17distance metric function.
18# Examples
19
20```rust
21 use kiddo::KdTree;
22 use kiddo::SquaredEuclidean;
23
24 ",
25 $doctest_build_tree,
26 "
27
28 let nearest: Vec<_> = tree.nearest_n::<SquaredEuclidean>(&[1.0, 2.0, 5.1], 1);
29
30 assert_eq!(nearest.len(), 1);
31 assert!((nearest[0].distance - 0.01f64).abs() < f64::EPSILON);
32 assert_eq!(nearest[0].item, 100);
33```"
34 ));
35 };
36}
37
38impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
39 KdTree<A, T, K, B, IDX>
40where
41 usize: Cast<IDX>,
42{
43 generate_float_nearest_n!(
44 "let mut tree: KdTree<f64, 3> = KdTree::new();
45 tree.add(&[1.0, 2.0, 5.0], 100);
46 tree.add(&[2.0, 3.0, 6.0], 101);"
47 );
48}
49
50#[cfg(feature = "rkyv")]
51use crate::float::kdtree::ArchivedKdTree;
52#[cfg(feature = "rkyv")]
53impl<
54 A: Axis + rkyv::Archive<Archived = A>,
55 T: Content + rkyv::Archive<Archived = T>,
56 const K: usize,
57 const B: usize,
58 IDX: Index<T = IDX> + rkyv::Archive<Archived = IDX>,
59 > ArchivedKdTree<A, T, K, B, IDX>
60where
61 usize: Cast<IDX>,
62{
63 generate_float_nearest_n!(
64 "use std::fs::File;
65 use memmap::MmapOptions;
66
67 let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree.rkyv\").expect(\"./examples/float-doctest-tree.rkyv missing\")).unwrap() };
68 let tree = unsafe { rkyv::archived_root::<KdTree<f64, 3>>(&mmap) };"
69 );
70}
71
72#[cfg(feature = "rkyv_08")]
73use crate::float::kdtree::ArchivedR8KdTree;
74#[cfg(feature = "rkyv_08")]
75impl<
76 A: Axis + rkyv_08::Archive,
77 T: Content + rkyv_08::Archive,
78 const K: usize,
79 const B: usize,
80 IDX: Index<T = IDX>,
81 > ArchivedR8KdTree<A, T, K, B, IDX>
82where
83 usize: Cast<IDX>,
84 IDX: rkyv_08::Archive,
85{
86 generate_float_nearest_n!(
87 "use std::fs::File;
88 use memmap::MmapOptions;
89 use kiddo::float::kdtree::ArchivedR8KdTree;
90
91 let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree-rkyv_08.rkyv\").expect(\"./examples/float-doctest-tree-rkyv_08.rkyv missing\")).unwrap() };
92 let tree = unsafe { rkyv_08::access_unchecked::<ArchivedR8KdTree<f64, u64, 3, 32, u32>>(&mmap) };"
93 );
94}
95
96#[cfg(test)]
97mod tests {
98 use crate::float::distance::SquaredEuclidean;
99 use crate::float::kdtree::{Axis, KdTree};
100 use crate::traits::DistanceMetric;
101 use rand::Rng;
102
103 type AX = f32;
104
105 #[test]
106 fn can_query_nearest_n_item() {
107 let mut tree: KdTree<AX, u32, 4, 8, u32> = KdTree::new();
108
109 let content_to_add: [([AX; 4], u32); 16] = [
110 ([0.9f32, 0.0f32, 0.9f32, 0.0f32], 9), ([0.4f32, 0.5f32, 0.4f32, 0.51f32], 4), ([0.12f32, 0.3f32, 0.12f32, 0.3f32], 12), ([0.7f32, 0.2f32, 0.7f32, 0.22f32], 7), ([0.13f32, 0.4f32, 0.13f32, 0.4f32], 13), ([0.6f32, 0.3f32, 0.6f32, 0.33f32], 6), ([0.2f32, 0.7f32, 0.2f32, 0.7f32], 2), ([0.14f32, 0.5f32, 0.14f32, 0.5f32], 14), ([0.3f32, 0.6f32, 0.3f32, 0.6f32], 3), ([0.10f32, 0.1f32, 0.10f32, 0.1f32], 10), ([0.16f32, 0.7f32, 0.16f32, 0.7f32], 16), ([0.1f32, 0.8f32, 0.1f32, 0.8f32], 1), ([0.15f32, 0.6f32, 0.15f32, 0.6f32], 15), ([0.5f32, 0.4f32, 0.5f32, 0.44f32], 5), ([0.8f32, 0.1f32, 0.8f32, 0.15f32], 8), ([0.11f32, 0.2f32, 0.11f32, 0.2f32], 11), ];
127
128 for (point, item) in content_to_add {
129 tree.add(&point, item);
130 }
131
132 assert_eq!(tree.size(), 16);
133
134 let query_point = [0.78f32, 0.55f32, 0.78f32, 0.55f32];
135
136 let expected = vec![(0.17569996, 6), (0.19139998, 5), (0.24420004, 7)];
137
138 let result: Vec<_> = tree
139 .nearest_n::<SquaredEuclidean>(&query_point, 3)
140 .into_iter()
141 .map(|n| (n.distance, n.item))
142 .collect();
143 assert_eq!(result, expected);
144
145 let qty = 10;
146 let mut rng = rand::rng();
147 for _i in 0..1000 {
148 let query_point = [
149 rng.random_range(0f32..1f32),
150 rng.random_range(0f32..1f32),
151 rng.random_range(0f32..1f32),
152 rng.random_range(0f32..1f32),
153 ];
154 let expected = linear_search(&content_to_add, qty, &query_point);
155
156 let result: Vec<_> = tree
157 .nearest_n::<SquaredEuclidean>(&query_point, qty)
158 .into_iter()
159 .map(|n| (n.distance, n.item))
160 .collect();
161
162 let result_dists: Vec<_> = result.iter().map(|(d, _)| d).collect();
163 let expected_dists: Vec<_> = expected.iter().map(|(d, _)| d).collect();
164
165 assert_eq!(result_dists, expected_dists);
166 }
167 }
168
169 #[test]
170 fn can_query_nearest_10_items_large_scale() {
171 const TREE_SIZE: usize = 100_000;
172 const NUM_QUERIES: usize = 100;
173 const N: usize = 10;
174
175 let content_to_add: Vec<([f32; 4], u32)> = (0..TREE_SIZE)
176 .map(|_| rand::random::<([f32; 4], u32)>())
177 .collect();
178
179 let mut tree: KdTree<AX, u32, 4, 32, u32> = KdTree::with_capacity(TREE_SIZE);
180 content_to_add
181 .iter()
182 .for_each(|(point, content)| tree.add(point, *content));
183 assert_eq!(tree.size(), TREE_SIZE as u32);
184
185 let query_points: Vec<[f32; 4]> = (0..NUM_QUERIES)
186 .map(|_| rand::random::<[f32; 4]>())
187 .collect();
188
189 for query_point in query_points {
190 let expected = linear_search(&content_to_add, N, &query_point);
191
192 let result: Vec<_> = tree
193 .nearest_n::<SquaredEuclidean>(&query_point, N)
194 .into_iter()
195 .map(|n| (n.distance, n.item))
196 .collect();
197
198 let result_dists: Vec<_> = result.iter().map(|(d, _)| d).collect();
199 let expected_dists: Vec<_> = expected.iter().map(|(d, _)| d).collect();
200
201 assert_eq!(result_dists, expected_dists);
202 }
203 }
204
205 fn linear_search<A: Axis, const K: usize>(
206 content: &[([A; K], u32)],
207 qty: usize,
208 query_point: &[A; K],
209 ) -> Vec<(A, u32)> {
210 let mut results = vec![];
211
212 for &(p, item) in content {
213 let dist = SquaredEuclidean::dist(query_point, &p);
214 if results.len() < qty {
215 results.push((dist, item));
216 results.sort_by(|(a_dist, _), (b_dist, _)| a_dist.partial_cmp(b_dist).unwrap());
217 } else if dist < results[qty - 1].0 {
218 results[qty - 1] = (dist, item);
219 results.sort_by(|(a_dist, _), (b_dist, _)| a_dist.partial_cmp(b_dist).unwrap());
220 }
221 }
222
223 results
224 }
225}