1use crate::Rect;
4use arrayvec::ArrayVec;
5use nalgebra::Vector2;
6
7enum QuadTreeNode<T> {
8 Leaf {
9 bounds: Rect<f32>,
10 ids: Vec<T>,
11 },
12 Branch {
13 bounds: Rect<f32>,
14 leaves: [usize; 4],
15 },
16}
17
18fn split_rect(rect: &Rect<f32>) -> [Rect<f32>; 4] {
19 let half_size = rect.size.scale(0.5);
20 [
21 Rect {
22 position: rect.position,
23 size: half_size,
24 },
25 Rect {
26 position: Vector2::new(rect.position.x + half_size.x, rect.position.y),
27 size: half_size,
28 },
29 Rect {
30 position: rect.position + half_size,
31 size: half_size,
32 },
33 Rect {
34 position: Vector2::new(rect.position.x, rect.position.y + half_size.y),
35 size: half_size,
36 },
37 ]
38}
39
40pub struct QuadTree<T> {
42 nodes: Vec<QuadTreeNode<T>>,
43 root: usize,
44 split_threshold: usize,
45}
46
47impl<T: 'static> Default for QuadTree<T> {
48 fn default() -> Self {
49 Self {
50 nodes: Default::default(),
51 root: Default::default(),
52 split_threshold: 16,
53 }
54 }
55}
56
57pub trait BoundsProvider {
59 type Id: Clone;
61
62 fn bounds(&self) -> Rect<f32>;
64
65 fn id(&self) -> Self::Id;
67}
68
69pub enum QuadTreeBuildError {
71 ReachedRecursionLimit,
74}
75
76#[derive(Clone)]
77struct Entry<I: Clone> {
78 id: I,
79 bounds: Rect<f32>,
80}
81
82fn build_recursive<I>(
83 nodes: &mut Vec<QuadTreeNode<I>>,
84 bounds: Rect<f32>,
85 entries: &[Entry<I>],
86 split_threshold: usize,
87 depth: usize,
88) -> Result<usize, QuadTreeBuildError>
89where
90 I: Clone + 'static,
91{
92 if depth >= 64 {
93 Err(QuadTreeBuildError::ReachedRecursionLimit)
94 } else if entries.len() <= split_threshold {
95 let index = nodes.len();
96 nodes.push(QuadTreeNode::Leaf {
97 bounds,
98 ids: entries.iter().map(|e| e.id.clone()).collect::<Vec<_>>(),
99 });
100 Ok(index)
101 } else {
102 let leaf_bounds = split_rect(&bounds);
103 let mut leaves = [usize::MAX; 4];
104
105 for (leaf, &leaf_bounds) in leaves.iter_mut().zip(leaf_bounds.iter()) {
106 let leaf_entries = entries
107 .iter()
108 .filter_map(|e| {
109 if leaf_bounds.intersects(e.bounds) {
110 Some(e.clone())
111 } else {
112 None
113 }
114 })
115 .collect::<Vec<_>>();
116
117 *leaf = build_recursive(
118 nodes,
119 leaf_bounds,
120 &leaf_entries,
121 split_threshold,
122 depth + 1,
123 )?;
124 }
125
126 let index = nodes.len();
127 nodes.push(QuadTreeNode::Branch { bounds, leaves });
128 Ok(index)
129 }
130}
131
132impl<I> QuadTree<I>
133where
134 I: Clone + 'static,
135{
136 pub fn new<T>(
138 root_bounds: Rect<f32>,
139 objects: impl Iterator<Item = T>,
140 split_threshold: usize,
141 ) -> Result<Self, QuadTreeBuildError>
142 where
143 T: BoundsProvider<Id = I>,
144 {
145 let entries = objects
146 .filter_map(|o| {
147 if root_bounds.intersects(o.bounds()) {
148 Some(Entry {
149 id: o.id(),
150 bounds: o.bounds(),
151 })
152 } else {
153 None
154 }
155 })
156 .collect::<Vec<_>>();
157
158 let mut nodes = Vec::new();
159 let root = build_recursive(&mut nodes, root_bounds, &entries, split_threshold, 0)?;
160 Ok(Self {
161 nodes,
162 root,
163 split_threshold,
164 })
165 }
166
167 pub fn point_query<S>(&self, point: Vector2<f32>, storage: &mut S)
170 where
171 S: QueryStorage<Id = I>,
172 {
173 self.point_query_recursive(self.root, point, storage)
174 }
175
176 fn point_query_recursive<S>(&self, node: usize, point: Vector2<f32>, storage: &mut S)
177 where
178 S: QueryStorage<Id = I>,
179 {
180 if let Some(node) = self.nodes.get(node) {
181 match node {
182 QuadTreeNode::Leaf { bounds, ids } => {
183 if bounds.contains(point) {
184 for id in ids {
185 if !storage.try_push(id.clone()) {
186 return;
187 }
188 }
189 }
190 }
191 QuadTreeNode::Branch { bounds, leaves } => {
192 if bounds.contains(point) {
193 for &leaf in leaves {
194 self.point_query_recursive(leaf, point, storage)
195 }
196 }
197 }
198 }
199 }
200 }
201
202 pub fn split_threshold(&self) -> usize {
204 self.split_threshold
205 }
206}
207
208pub trait QueryStorage {
210 type Id;
212
213 fn try_push(&mut self, id: Self::Id) -> bool;
215
216 fn clear(&mut self);
218}
219
220impl<I> QueryStorage for Vec<I> {
221 type Id = I;
222
223 fn try_push(&mut self, intersection: I) -> bool {
224 self.push(intersection);
225 true
226 }
227
228 fn clear(&mut self) {
229 self.clear()
230 }
231}
232
233impl<I, const CAP: usize> QueryStorage for ArrayVec<I, CAP> {
234 type Id = I;
235
236 fn try_push(&mut self, intersection: I) -> bool {
237 self.try_push(intersection).is_ok()
238 }
239
240 fn clear(&mut self) {
241 self.clear()
242 }
243}
244
245#[cfg(test)]
246mod test {
247 use super::*;
248 use crate::Rect;
249
250 struct TestObject {
251 bounds: Rect<f32>,
252 id: usize,
253 }
254
255 impl BoundsProvider for &TestObject {
256 type Id = usize;
257
258 fn bounds(&self) -> Rect<f32> {
259 self.bounds
260 }
261
262 fn id(&self) -> Self::Id {
263 self.id
264 }
265 }
266
267 #[test]
268 fn test_quad_tree() {
269 let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
270 let objects = vec![
271 TestObject {
272 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
273 id: 0,
274 },
275 TestObject {
276 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
277 id: 1,
278 },
279 ];
280 assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_err());
282
283 let objects = vec![
284 TestObject {
285 bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
286 id: 0,
287 },
288 TestObject {
289 bounds: Rect::new(20.0, 20.0, 10.0, 10.0),
290 id: 1,
291 },
292 ];
293 assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_ok());
294 }
295
296 #[test]
297 fn default_for_quad_tree() {
298 let tree = QuadTree::<u32>::default();
299
300 assert_eq!(tree.split_threshold, 16);
301 assert_eq!(tree.root, 0);
302 }
303
304 #[test]
305 fn quad_tree_point_query() {
306 let tree = QuadTree::<f32>::default();
308 let mut s = Vec::<f32>::new();
309
310 tree.point_query(Vector2::new(0.0, 0.0), &mut s);
311 assert_eq!(s, vec![]);
312
313 let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
314
315 let mut s = Vec::<usize>::new();
317 let mut pool = Vec::new();
318 pool.push(QuadTreeNode::Leaf {
319 bounds: root_bounds,
320 ids: vec![0, 1],
321 });
322
323 let tree = QuadTree {
324 root: 0,
325 nodes: pool,
326 ..Default::default()
327 };
328
329 tree.point_query(Vector2::new(10.0, 10.0), &mut s);
330 assert_eq!(s, vec![0, 1]);
331
332 let mut s = Vec::<usize>::new();
334 let mut pool = Vec::new();
335 let a = 0;
336 pool.push(QuadTreeNode::Leaf {
337 bounds: root_bounds,
338 ids: vec![0, 1],
339 });
340 let b = 1;
341 pool.push(QuadTreeNode::Branch {
342 bounds: root_bounds,
343 leaves: [a, a, a, a],
344 });
345
346 let tree = QuadTree {
347 root: b,
348 nodes: pool,
349 ..Default::default()
350 };
351
352 tree.point_query(Vector2::new(10.0, 10.0), &mut s);
353 assert_eq!(s, vec![0, 1, 0, 1, 0, 1, 0, 1]);
354 }
355
356 #[test]
357 fn quad_tree_split_threshold() {
358 let tree = QuadTree::<u32>::default();
359
360 assert_eq!(tree.split_threshold(), tree.split_threshold);
361 }
362
363 #[test]
364 fn query_storage_for_vec() {
365 let mut s = vec![1];
366
367 let res = QueryStorage::try_push(&mut s, 2);
368 assert!(res);
369 assert_eq!(s, vec![1, 2]);
370
371 QueryStorage::clear(&mut s);
372 assert!(s.is_empty());
373 }
374
375 #[test]
376 fn query_storage_for_array_vec() {
377 let mut s = ArrayVec::<i32, 3>::new();
378
379 let res = QueryStorage::try_push(&mut s, 1);
380 assert!(res);
381 assert!(!s.is_empty());
382
383 QueryStorage::clear(&mut s);
384 assert!(s.is_empty());
385 }
386}