1use std::num::NonZeroUsize;
2use std::ops::ControlFlow;
3
4use num_traits::Zero;
5
6use crate::{iter::branch_for_each, Distance, Node, Object, Point, RTree, ROOT_IDX};
7
8impl<O, S> RTree<O, S>
9where
10 O: Object,
11 S: AsRef<[Node<O>]>,
12{
13 pub fn look_up_aabb_contains<'a, V, R>(
15 &'a self,
16 query: &(O::Point, O::Point),
17 visitor: V,
18 ) -> ControlFlow<R>
19 where
20 V: FnMut(&'a O) -> ControlFlow<R>,
21 {
22 let query = |node: &Node<O>| match node {
23 Node::Branch { aabb, .. } => intersects(query, aabb),
24 Node::Twig(_) => unreachable!(),
25 Node::Leaf(obj) => contains(query, &obj.aabb()),
26 };
27
28 self.look_up(query, visitor)
29 }
30
31 pub fn look_up_aabb_intersects<'a, V, R>(
33 &'a self,
34 query: &(O::Point, O::Point),
35 visitor: V,
36 ) -> ControlFlow<R>
37 where
38 V: FnMut(&'a O) -> ControlFlow<R>,
39 {
40 let query = |node: &Node<O>| match node {
41 Node::Branch { aabb, .. } => intersects(query, aabb),
42 Node::Twig(_) => unreachable!(),
43 Node::Leaf(obj) => intersects(query, &obj.aabb()),
44 };
45
46 self.look_up(query, visitor)
47 }
48
49 pub fn look_up_at_point<'a, V, R>(&'a self, query: &O::Point, visitor: V) -> ControlFlow<R>
51 where
52 O: Distance<O::Point>,
53 O::Point: Distance<O::Point>,
54 V: FnMut(&'a O) -> ControlFlow<R>,
55 {
56 let query = |node: &Node<O>| match node {
57 Node::Branch { aabb, .. } => aabb.contains(query),
58 Node::Twig(_) => unreachable!(),
59 Node::Leaf(obj) => obj.contains(query),
60 };
61
62 self.look_up(query, visitor)
63 }
64
65 pub fn look_up_within_distance_of_point<'a, V, R>(
67 &'a self,
68 center: &O::Point,
69 distance: <O::Point as Point>::Coord,
70 visitor: V,
71 ) -> ControlFlow<R>
72 where
73 O: Distance<O::Point>,
74 O::Point: Distance<O::Point>,
75 V: FnMut(&'a O) -> ControlFlow<R>,
76 {
77 let distance_2 = distance * distance;
78
79 let query = |node: &Node<O>| match node {
80 Node::Branch { aabb, .. } => aabb.distance_2(center) <= distance_2,
81 Node::Twig(_) => unreachable!(),
82 Node::Leaf(obj) => obj.distance_2(center) <= distance_2,
83 };
84
85 self.look_up(query, visitor)
86 }
87
88 fn look_up<'a, Q, V, R>(&'a self, query: Q, visitor: V) -> ControlFlow<R>
89 where
90 Q: FnMut(&'a Node<O>) -> bool,
91 V: FnMut(&'a O) -> ControlFlow<R>,
92 {
93 let mut args = LookUpArgs {
94 nodes: self.nodes.as_ref(),
95 query,
96 visitor,
97 };
98
99 let [node, rest @ ..] = &args.nodes[ROOT_IDX..] else {
100 unreachable!()
101 };
102
103 if (args.query)(node) {
104 match node {
105 Node::Branch { len, .. } => look_up(&mut args, len, rest)?,
106 Node::Twig(_) | Node::Leaf(_) => unreachable!(),
107 }
108 }
109
110 ControlFlow::Continue(())
111 }
112}
113
114struct LookUpArgs<'a, O, Q, V>
115where
116 O: Object,
117{
118 nodes: &'a [Node<O>],
119 query: Q,
120 visitor: V,
121}
122
123fn look_up<'a, O, Q, V, R>(
124 args: &mut LookUpArgs<'a, O, Q, V>,
125 mut len: &'a NonZeroUsize,
126 mut twigs: &'a [Node<O>],
127) -> ControlFlow<R>
128where
129 O: Object,
130 Q: FnMut(&'a Node<O>) -> bool,
131 V: FnMut(&'a O) -> ControlFlow<R>,
132{
133 loop {
134 let mut branch = None;
135
136 branch_for_each(len, twigs, |idx| {
137 let [node, rest @ ..] = &args.nodes[idx..] else {
138 unreachable!()
139 };
140
141 if (args.query)(node) {
142 match node {
143 Node::Branch { len, .. } => {
144 if let Some((len1, twigs1)) = branch.replace((len, rest)) {
145 look_up(args, len1, twigs1)?;
146 }
147 }
148 Node::Twig(_) => unreachable!(),
149 Node::Leaf(obj) => (args.visitor)(obj)?,
150 }
151 }
152
153 ControlFlow::Continue(())
154 })?;
155
156 if let Some((len1, twigs1)) = branch {
157 len = len1;
158 twigs = twigs1;
159 } else {
160 return ControlFlow::Continue(());
161 }
162 }
163}
164
165fn intersects<P>(lhs: &(P, P), rhs: &(P, P)) -> bool
166where
167 P: Point,
168{
169 (0..P::DIM).all(|axis| {
170 lhs.0.coord(axis) <= rhs.1.coord(axis) && lhs.1.coord(axis) >= rhs.0.coord(axis)
171 })
172}
173
174fn contains<P>(lhs: &(P, P), rhs: &(P, P)) -> bool
175where
176 P: Point,
177{
178 (0..P::DIM).all(|axis| {
179 lhs.0.coord(axis) <= rhs.0.coord(axis) && lhs.1.coord(axis) >= rhs.1.coord(axis)
180 })
181}
182
183impl<P> Distance<P> for (P, P)
184where
185 P: Point + Distance<P>,
186{
187 fn distance_2(&self, point: &P) -> P::Coord {
188 if !self.contains(point) {
189 let min = self.1.min(&self.0.max(point));
190
191 min.distance_2(point)
192 } else {
193 P::Coord::zero()
194 }
195 }
196
197 fn contains(&self, point: &P) -> bool {
198 (0..P::DIM).all(|axis| {
199 self.0.coord(axis) <= point.coord(axis) && point.coord(axis) <= self.1.coord(axis)
200 })
201 }
202}
203
204#[cfg(test)]
205mod tests {
206 use super::*;
207
208 use proptest::{collection::vec, test_runner::TestRunner};
209
210 use crate::{
211 tests::{random_objects, random_points},
212 DEF_NODE_LEN,
213 };
214
215 #[test]
216 fn random_look_up_aabb_contains() {
217 TestRunner::default()
218 .run(
219 &(random_objects(100), random_objects(10)),
220 |(objects, queries)| {
221 let index = RTree::new(DEF_NODE_LEN, objects);
222
223 for query in queries {
224 let mut results1 = index
225 .objects()
226 .filter(|obj| contains(&query.aabb(), &obj.aabb()))
227 .collect::<Vec<_>>();
228
229 let mut results2 = Vec::new();
230 index
231 .look_up_aabb_contains(&query.aabb(), |obj| {
232 results2.push(obj);
233 ControlFlow::<()>::Continue(())
234 })
235 .continue_value()
236 .unwrap();
237
238 results1.sort_unstable();
239 results2.sort_unstable();
240 assert_eq!(results1, results2);
241 }
242
243 Ok(())
244 },
245 )
246 .unwrap();
247 }
248
249 #[test]
250 fn random_look_up_aabb_intersects() {
251 TestRunner::default()
252 .run(
253 &(random_objects(100), random_objects(10)),
254 |(objects, queries)| {
255 let index = RTree::new(DEF_NODE_LEN, objects);
256
257 for query in queries {
258 let mut results1 = index
259 .objects()
260 .filter(|obj| intersects(&query.aabb(), &obj.aabb()))
261 .collect::<Vec<_>>();
262
263 let mut results2 = Vec::new();
264 index
265 .look_up_aabb_intersects(&query.aabb(), |obj| {
266 results2.push(obj);
267 ControlFlow::<()>::Continue(())
268 })
269 .continue_value()
270 .unwrap();
271
272 results1.sort_unstable();
273 results2.sort_unstable();
274 assert_eq!(results1, results2);
275 }
276
277 Ok(())
278 },
279 )
280 .unwrap();
281 }
282
283 #[test]
284 fn random_look_up_at_point() {
285 TestRunner::default()
286 .run(
287 &(random_objects(100), random_points(10)),
288 |(objects, queries)| {
289 let index = RTree::new(DEF_NODE_LEN, objects);
290
291 for query in queries {
292 let mut results1 = index
293 .objects()
294 .filter(|obj| obj.contains(&query))
295 .collect::<Vec<_>>();
296
297 let mut results2 = Vec::new();
298 index
299 .look_up_at_point(&query, |obj| {
300 results2.push(obj);
301 ControlFlow::<()>::Continue(())
302 })
303 .continue_value()
304 .unwrap();
305
306 results1.sort_unstable();
307 results2.sort_unstable();
308 assert_eq!(results1, results2);
309 }
310
311 Ok(())
312 },
313 )
314 .unwrap();
315 }
316
317 #[test]
318 fn random_look_up_within_distance_of_point() {
319 TestRunner::default()
320 .run(
321 &(
322 random_objects(100),
323 random_points(10),
324 vec(0.0_f32..=1.0, 10),
325 ),
326 |(objects, centers, distances)| {
327 let index = RTree::new(DEF_NODE_LEN, objects);
328
329 for (center, distance) in centers.iter().zip(distances) {
330 let mut results1 = index
331 .objects()
332 .filter(|obj| obj.distance_2(center) <= distance * distance)
333 .collect::<Vec<_>>();
334
335 let mut results2 = Vec::new();
336 index
337 .look_up_within_distance_of_point(center, distance, |obj| {
338 results2.push(obj);
339 ControlFlow::<()>::Continue(())
340 })
341 .continue_value()
342 .unwrap();
343
344 results1.sort_unstable();
345 results2.sort_unstable();
346 assert_eq!(results1, results2);
347 }
348
349 Ok(())
350 },
351 )
352 .unwrap();
353 }
354}