1use az::Cast;
2
3use crate::float::kdtree::{Axis, KdTree};
4use crate::nearest_neighbour::NearestNeighbour;
5use crate::traits::DistanceMetric;
6use crate::traits::{Content, Index};
7
8use crate::generate_within;
9
10macro_rules! generate_float_within {
11 ($doctest_build_tree:tt) => {
12 generate_within!((
13 "Finds all elements within `dist` of `query`, using the specified
14distance metric function.
15
16Results are returned sorted nearest-first
17
18# Examples
19
20```rust
21 use kiddo::KdTree;
22 use kiddo::SquaredEuclidean;
23 ",
24 $doctest_build_tree,
25 "
26
27 let within = tree.within::<SquaredEuclidean>(&[1.0, 2.0, 5.0], 10f64);
28
29 assert_eq!(within.len(), 2);
30```"
31 ));
32 };
33}
34
35impl<A: Axis, T: Content, const K: usize, const B: usize, IDX: Index<T = IDX>>
36 KdTree<A, T, K, B, IDX>
37where
38 usize: Cast<IDX>,
39{
40 generate_float_within!(
41 "
42let mut tree: KdTree<f64, 3> = KdTree::new();
43tree.add(&[1.0, 2.0, 5.0], 100);
44tree.add(&[2.0, 3.0, 6.0], 101);"
45 );
46}
47
48#[cfg(feature = "rkyv")]
49use crate::float::kdtree::ArchivedKdTree;
50#[cfg(feature = "rkyv")]
51impl<
52 A: Axis + rkyv::Archive<Archived = A>,
53 T: Content + rkyv::Archive<Archived = T>,
54 const K: usize,
55 const B: usize,
56 IDX: Index<T = IDX> + rkyv::Archive<Archived = IDX>,
57 > ArchivedKdTree<A, T, K, B, IDX>
58where
59 usize: Cast<IDX>,
60{
61 generate_float_within!(
62 "use std::fs::File;
63use memmap::MmapOptions;
64
65let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree.rkyv\").expect(\"./examples/float-doctest-tree.rkyv missing\")).unwrap() };
66let tree = unsafe { rkyv::archived_root::<KdTree<f64, 3>>(&mmap) };"
67 );
68}
69
70#[cfg(feature = "rkyv_08")]
71use crate::float::kdtree::ArchivedR8KdTree;
72#[cfg(feature = "rkyv_08")]
73impl<
74 A: Axis + rkyv_08::Archive,
75 T: Content + rkyv_08::Archive,
76 const K: usize,
77 const B: usize,
78 IDX: Index<T = IDX>,
79 > ArchivedR8KdTree<A, T, K, B, IDX>
80where
81 usize: Cast<IDX>,
82 IDX: rkyv_08::Archive,
83{
84 generate_float_within!(
85 "use std::fs::File;
86 use memmap::MmapOptions;
87 use kiddo::float::kdtree::ArchivedR8KdTree;
88
89 let mmap = unsafe { MmapOptions::new().map(&File::open(\"./examples/float-doctest-tree-rkyv_08.rkyv\").expect(\"./examples/float-doctest-tree-rkyv_08.rkyv missing\")).unwrap() };
90 let tree = unsafe { rkyv_08::access_unchecked::<ArchivedR8KdTree<f64, u64, 3, 32, u32>>(&mmap) };"
91 );
92}
93
94#[cfg(test)]
95mod tests {
96 use crate::float::distance::Manhattan;
97 use crate::float::kdtree::{Axis, KdTree};
98 use crate::nearest_neighbour::NearestNeighbour;
99 use crate::traits::DistanceMetric;
100 use rand::Rng;
101 use std::cmp::Ordering;
102
103 type AX = f32;
104
105 #[test]
106 fn can_query_items_within_radius() {
107 let mut tree: KdTree<AX, u32, 4, 5, u32> = KdTree::new();
108
109 let content_to_add: [([AX; 4], u32); 16] = [
110 ([0.9f32, 0.0f32, 0.9f32, 0.0f32], 9),
111 ([0.4f32, 0.5f32, 0.4f32, 0.5f32], 4),
112 ([0.12f32, 0.3f32, 0.12f32, 0.3f32], 12),
113 ([0.7f32, 0.2f32, 0.7f32, 0.2f32], 7),
114 ([0.13f32, 0.4f32, 0.13f32, 0.4f32], 13),
115 ([0.6f32, 0.3f32, 0.6f32, 0.3f32], 6),
116 ([0.2f32, 0.7f32, 0.2f32, 0.7f32], 2),
117 ([0.14f32, 0.5f32, 0.14f32, 0.5f32], 14),
118 ([0.3f32, 0.6f32, 0.3f32, 0.6f32], 3),
119 ([0.10f32, 0.1f32, 0.10f32, 0.1f32], 10),
120 ([0.16f32, 0.7f32, 0.16f32, 0.7f32], 16),
121 ([0.1f32, 0.8f32, 0.1f32, 0.8f32], 1),
122 ([0.15f32, 0.6f32, 0.15f32, 0.6f32], 15),
123 ([0.5f32, 0.4f32, 0.5f32, 0.4f32], 5),
124 ([0.8f32, 0.1f32, 0.8f32, 0.1f32], 8),
125 ([0.11f32, 0.2f32, 0.11f32, 0.2f32], 11),
126 ];
127
128 for (point, item) in content_to_add {
129 tree.add(&point, item);
130 }
131
132 assert_eq!(tree.size(), 16);
133
134 let query_point = [0.78f32, 0.55f32, 0.78f32, 0.55f32];
135
136 let radius = 0.2;
137 let expected = linear_search(&content_to_add, &query_point, radius);
138
139 let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, radius);
140 stabilize_sort(&mut result);
141 assert_eq!(result, expected);
142
143 let mut rng = rand::rng();
144 for _i in 0..1000 {
145 let query_point = [
146 rng.random_range(0f32..1f32),
147 rng.random_range(0f32..1f32),
148 rng.random_range(0f32..1f32),
149 rng.random_range(0f32..1f32),
150 ];
151 let radius: f32 = 2.0;
152 let expected = linear_search(&content_to_add, &query_point, radius);
153
154 let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, radius);
155 stabilize_sort(&mut result);
156
157 assert_eq!(result, expected);
158 }
159 }
160
161 #[test]
162 fn can_query_items_within_radius_large_scale() {
163 const TREE_SIZE: usize = 100_000;
164 const NUM_QUERIES: usize = 100;
165 const RADIUS: f32 = 0.2;
166
167 let content_to_add: Vec<([f32; 4], u32)> = (0..TREE_SIZE)
168 .map(|_| rand::random::<([f32; 4], u32)>())
169 .collect();
170
171 let mut tree: KdTree<AX, u32, 4, 32, u32> = KdTree::with_capacity(TREE_SIZE);
172 content_to_add
173 .iter()
174 .for_each(|(point, content)| tree.add(point, *content));
175 assert_eq!(tree.size(), TREE_SIZE as u32);
176
177 let query_points: Vec<[f32; 4]> = (0..NUM_QUERIES)
178 .map(|_| rand::random::<[f32; 4]>())
179 .collect();
180
181 for query_point in query_points {
182 let expected = linear_search(&content_to_add, &query_point, RADIUS);
183
184 let mut result: Vec<_> = tree.within::<Manhattan>(&query_point, RADIUS);
185
186 stabilize_sort(&mut result);
190
191 assert_eq!(result, expected);
192 }
193 }
194
195 fn linear_search<A: Axis, const K: usize>(
196 content: &[([A; K], u32)],
197 query_point: &[A; K],
198 radius: A,
199 ) -> Vec<NearestNeighbour<A, u32>> {
200 let mut matching_items = vec![];
201
202 for &(p, item) in content {
203 let distance = Manhattan::dist(query_point, &p);
204 if distance < radius {
205 matching_items.push(NearestNeighbour { distance, item });
206 }
207 }
208
209 stabilize_sort(&mut matching_items);
210
211 matching_items
212 }
213
214 fn stabilize_sort<A: Axis>(matching_items: &mut [NearestNeighbour<A, u32>]) {
215 matching_items.sort_unstable_by(|a, b| {
216 let dist_cmp = a.distance.partial_cmp(&b.distance).unwrap();
217 if dist_cmp == Ordering::Equal {
218 a.item.cmp(&b.item)
219 } else {
220 dist_cmp
221 }
222 });
223 }
224}