1use crate::Rect;
4use arrayvec::ArrayVec;
5use nalgebra::Vector2;
6
7#[derive(Clone)]
8enum QuadTreeNode<T: Clone> {
9 Leaf {
10 bounds: Rect<f32>,
11 ids: Vec<T>,
12 },
13 Branch {
14 bounds: Rect<f32>,
15 leaves: [usize; 4],
16 },
17}
18
19fn split_rect(rect: &Rect<f32>) -> [Rect<f32>; 4] {
20 let half_size = rect.size.scale(0.5);
21 [
22 Rect {
23 position: rect.position,
24 size: half_size,
25 },
26 Rect {
27 position: Vector2::new(rect.position.x + half_size.x, rect.position.y),
28 size: half_size,
29 },
30 Rect {
31 position: rect.position + half_size,
32 size: half_size,
33 },
34 Rect {
35 position: Vector2::new(rect.position.x, rect.position.y + half_size.y),
36 size: half_size,
37 },
38 ]
39}
40
41#[derive(Clone)]
43pub struct QuadTree<T: Clone> {
44 nodes: Vec<QuadTreeNode<T>>,
45 root: usize,
46 split_threshold: usize,
47}
48
49impl<T: Clone + 'static> Default for QuadTree<T> {
50 fn default() -> Self {
51 Self {
52 nodes: Default::default(),
53 root: Default::default(),
54 split_threshold: 16,
55 }
56 }
57}
58
59pub trait BoundsProvider {
61 type Id: Clone;
63
64 fn bounds(&self) -> Rect<f32>;
66
67 fn id(&self) -> Self::Id;
69}
70
71pub enum QuadTreeBuildError {
73 ReachedRecursionLimit,
76}
77
78#[derive(Clone)]
79struct Entry<I: Clone> {
80 id: I,
81 bounds: Rect<f32>,
82}
83
84fn build_recursive<I>(
85 nodes: &mut Vec<QuadTreeNode<I>>,
86 bounds: Rect<f32>,
87 entries: &[Entry<I>],
88 split_threshold: usize,
89 depth: usize,
90) -> Result<usize, QuadTreeBuildError>
91where
92 I: Clone + 'static,
93{
94 if depth >= 64 {
95 Err(QuadTreeBuildError::ReachedRecursionLimit)
96 } else if entries.len() <= split_threshold {
97 let index = nodes.len();
98 nodes.push(QuadTreeNode::Leaf {
99 bounds,
100 ids: entries.iter().map(|e| e.id.clone()).collect::<Vec<_>>(),
101 });
102 Ok(index)
103 } else {
104 let leaf_bounds = split_rect(&bounds);
105 let mut leaves = [usize::MAX; 4];
106
107 for (leaf, &leaf_bounds) in leaves.iter_mut().zip(leaf_bounds.iter()) {
108 let leaf_entries = entries
109 .iter()
110 .filter_map(|e| {
111 if leaf_bounds.intersects(e.bounds) {
112 Some(e.clone())
113 } else {
114 None
115 }
116 })
117 .collect::<Vec<_>>();
118
119 *leaf = build_recursive(
120 nodes,
121 leaf_bounds,
122 &leaf_entries,
123 split_threshold,
124 depth + 1,
125 )?;
126 }
127
128 let index = nodes.len();
129 nodes.push(QuadTreeNode::Branch { bounds, leaves });
130 Ok(index)
131 }
132}
133
134impl<I> QuadTree<I>
135where
136 I: Clone + 'static,
137{
138 pub fn new<T>(
140 root_bounds: Rect<f32>,
141 objects: impl Iterator<Item = T>,
142 split_threshold: usize,
143 ) -> Result<Self, QuadTreeBuildError>
144 where
145 T: BoundsProvider<Id = I>,
146 {
147 let entries = objects
148 .filter_map(|o| {
149 if root_bounds.intersects(o.bounds()) {
150 Some(Entry {
151 id: o.id(),
152 bounds: o.bounds(),
153 })
154 } else {
155 None
156 }
157 })
158 .collect::<Vec<_>>();
159
160 let mut nodes = Vec::new();
161 let root = build_recursive(&mut nodes, root_bounds, &entries, split_threshold, 0)?;
162 Ok(Self {
163 nodes,
164 root,
165 split_threshold,
166 })
167 }
168
169 pub fn point_query<S>(&self, point: Vector2<f32>, storage: &mut S)
172 where
173 S: QueryStorage<Id = I>,
174 {
175 self.point_query_recursive(self.root, point, storage)
176 }
177
178 fn point_query_recursive<S>(&self, node: usize, point: Vector2<f32>, storage: &mut S)
179 where
180 S: QueryStorage<Id = I>,
181 {
182 if let Some(node) = self.nodes.get(node) {
183 match node {
184 QuadTreeNode::Leaf { bounds, ids } => {
185 if bounds.contains(point) {
186 for id in ids {
187 if !storage.try_push(id.clone()) {
188 return;
189 }
190 }
191 }
192 }
193 QuadTreeNode::Branch { bounds, leaves } => {
194 if bounds.contains(point) {
195 for &leaf in leaves {
196 self.point_query_recursive(leaf, point, storage)
197 }
198 }
199 }
200 }
201 }
202 }
203
204 pub fn split_threshold(&self) -> usize {
206 self.split_threshold
207 }
208}
209
210pub trait QueryStorage {
212 type Id;
214
215 fn try_push(&mut self, id: Self::Id) -> bool;
217
218 fn clear(&mut self);
220}
221
222impl<I> QueryStorage for Vec<I> {
223 type Id = I;
224
225 fn try_push(&mut self, intersection: I) -> bool {
226 self.push(intersection);
227 true
228 }
229
230 fn clear(&mut self) {
231 self.clear()
232 }
233}
234
235impl<I, const CAP: usize> QueryStorage for ArrayVec<I, CAP> {
236 type Id = I;
237
238 fn try_push(&mut self, intersection: I) -> bool {
239 self.try_push(intersection).is_ok()
240 }
241
242 fn clear(&mut self) {
243 self.clear()
244 }
245}
246
247#[cfg(test)]
248mod test {
249 use super::*;
250 use crate::Rect;
251
252 struct TestObject {
253 bounds: Rect<f32>,
254 id: usize,
255 }
256
257 impl BoundsProvider for &TestObject {
258 type Id = usize;
259
260 fn bounds(&self) -> Rect<f32> {
261 self.bounds
262 }
263
264 fn id(&self) -> Self::Id {
265 self.id
266 }
267 }
268
269 #[test]
270 fn test_quad_tree() {
271 let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
272 let objects = vec![
273 TestObject {
274 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
275 id: 0,
276 },
277 TestObject {
278 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
279 id: 1,
280 },
281 ];
282 assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_err());
284
285 let objects = vec![
286 TestObject {
287 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
288 id: 0,
289 },
290 TestObject {
291 bounds: Rect::new(20.0, 20.0, 10.0, 10.0),
292 id: 1,
293 },
294 ];
295 assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_ok());
296 }
297
298 #[test]
299 fn default_for_quad_tree() {
300 let tree = QuadTree::<u32>::default();
301
302 assert_eq!(tree.split_threshold, 16);
303 assert_eq!(tree.root, 0);
304 }
305
306 #[test]
307 fn quad_tree_point_query() {
308 let tree = QuadTree::<f32>::default();
310 let mut s = Vec::<f32>::new();
311
312 tree.point_query(Vector2::new(0.0, 0.0), &mut s);
313 assert_eq!(s, vec![]);
314
315 let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
316
317 let mut s = Vec::<usize>::new();
319 let mut pool = Vec::new();
320 pool.push(QuadTreeNode::Leaf {
321 bounds: root_bounds,
322 ids: vec![0, 1],
323 });
324
325 let tree = QuadTree {
326 root: 0,
327 nodes: pool,
328 ..Default::default()
329 };
330
331 tree.point_query(Vector2::new(10.0, 10.0), &mut s);
332 assert_eq!(s, vec![0, 1]);
333
334 let mut s = Vec::<usize>::new();
336 let mut pool = Vec::new();
337 let a = 0;
338 pool.push(QuadTreeNode::Leaf {
339 bounds: root_bounds,
340 ids: vec![0, 1],
341 });
342 let b = 1;
343 pool.push(QuadTreeNode::Branch {
344 bounds: root_bounds,
345 leaves: [a, a, a, a],
346 });
347
348 let tree = QuadTree {
349 root: b,
350 nodes: pool,
351 ..Default::default()
352 };
353
354 tree.point_query(Vector2::new(10.0, 10.0), &mut s);
355 assert_eq!(s, vec![0, 1, 0, 1, 0, 1, 0, 1]);
356 }
357
358 #[test]
359 fn quad_tree_split_threshold() {
360 let tree = QuadTree::<u32>::default();
361
362 assert_eq!(tree.split_threshold(), tree.split_threshold);
363 }
364
365 #[test]
366 fn query_storage_for_vec() {
367 let mut s = vec![1];
368
369 let res = QueryStorage::try_push(&mut s, 2);
370 assert!(res);
371 assert_eq!(s, vec![1, 2]);
372
373 QueryStorage::clear(&mut s);
374 assert!(s.is_empty());
375 }
376
377 #[test]
378 fn query_storage_for_array_vec() {
379 let mut s = ArrayVec::<i32, 3>::new();
380
381 let res = QueryStorage::try_push(&mut s, 1);
382 assert!(res);
383 assert!(!s.is_empty());
384
385 QueryStorage::clear(&mut s);
386 assert!(s.is_empty());
387 }
388}