rectutils/
quadtree.rs

1//! Quadrilateral (quad) tree is used for space partitioning and fast spatial queries.
2
3use crate::Rect;
4use arrayvec::ArrayVec;
5use nalgebra::Vector2;
6
7enum QuadTreeNode<T> {
8    Leaf {
9        bounds: Rect<f32>,
10        ids: Vec<T>,
11    },
12    Branch {
13        bounds: Rect<f32>,
14        leaves: [usize; 4],
15    },
16}
17
18fn split_rect(rect: &Rect<f32>) -> [Rect<f32>; 4] {
19    let half_size = rect.size.scale(0.5);
20    [
21        Rect {
22            position: rect.position,
23            size: half_size,
24        },
25        Rect {
26            position: Vector2::new(rect.position.x + half_size.x, rect.position.y),
27            size: half_size,
28        },
29        Rect {
30            position: rect.position + half_size,
31            size: half_size,
32        },
33        Rect {
34            position: Vector2::new(rect.position.x, rect.position.y + half_size.y),
35            size: half_size,
36        },
37    ]
38}
39
40/// Quadrilateral (quad) tree is used for space partitioning and fast spatial queries.
41pub struct QuadTree<T> {
42    nodes: Vec<QuadTreeNode<T>>,
43    root: usize,
44    split_threshold: usize,
45}
46
47impl<T: 'static> Default for QuadTree<T> {
48    fn default() -> Self {
49        Self {
50            nodes: Default::default(),
51            root: Default::default(),
52            split_threshold: 16,
53        }
54    }
55}
56
57/// A trait for anything that has rectangular bounds.
58pub trait BoundsProvider {
59    /// Identifier of the bounds provider.
60    type Id: Clone;
61
62    /// Returns bounds of the bounds provider.
63    fn bounds(&self) -> Rect<f32>;
64
65    /// Returns id of the bounds provider.
66    fn id(&self) -> Self::Id;
67}
68
69/// An error, that may occur during the build of the quad tree.
70pub enum QuadTreeBuildError {
71    /// It means that given split threshold is too low for an algorithm to build quad tree.
72    /// Make it larger and try again. Also this might mean that your initial bounds are too small.
73    ReachedRecursionLimit,
74}
75
76#[derive(Clone)]
77struct Entry<I: Clone> {
78    id: I,
79    bounds: Rect<f32>,
80}
81
82fn build_recursive<I>(
83    nodes: &mut Vec<QuadTreeNode<I>>,
84    bounds: Rect<f32>,
85    entries: &[Entry<I>],
86    split_threshold: usize,
87    depth: usize,
88) -> Result<usize, QuadTreeBuildError>
89where
90    I: Clone + 'static,
91{
92    if depth >= 64 {
93        Err(QuadTreeBuildError::ReachedRecursionLimit)
94    } else if entries.len() <= split_threshold {
95        let index = nodes.len();
96        nodes.push(QuadTreeNode::Leaf {
97            bounds,
98            ids: entries.iter().map(|e| e.id.clone()).collect::<Vec<_>>(),
99        });
100        Ok(index)
101    } else {
102        let leaf_bounds = split_rect(&bounds);
103        let mut leaves = [usize::MAX; 4];
104
105        for (leaf, &leaf_bounds) in leaves.iter_mut().zip(leaf_bounds.iter()) {
106            let leaf_entries = entries
107                .iter()
108                .filter_map(|e| {
109                    if leaf_bounds.intersects(e.bounds) {
110                        Some(e.clone())
111                    } else {
112                        None
113                    }
114                })
115                .collect::<Vec<_>>();
116
117            *leaf = build_recursive(
118                nodes,
119                leaf_bounds,
120                &leaf_entries,
121                split_threshold,
122                depth + 1,
123            )?;
124        }
125
126        let index = nodes.len();
127        nodes.push(QuadTreeNode::Branch { bounds, leaves });
128        Ok(index)
129    }
130}
131
132impl<I> QuadTree<I>
133where
134    I: Clone + 'static,
135{
136    /// Creates new quad tree from the given initial bounds and the set of objects.
137    pub fn new<T>(
138        root_bounds: Rect<f32>,
139        objects: impl Iterator<Item = T>,
140        split_threshold: usize,
141    ) -> Result<Self, QuadTreeBuildError>
142    where
143        T: BoundsProvider<Id = I>,
144    {
145        let entries = objects
146            .filter_map(|o| {
147                if root_bounds.intersects(o.bounds()) {
148                    Some(Entry {
149                        id: o.id(),
150                        bounds: o.bounds(),
151                    })
152                } else {
153                    None
154                }
155            })
156            .collect::<Vec<_>>();
157
158        let mut nodes = Vec::new();
159        let root = build_recursive(&mut nodes, root_bounds, &entries, split_threshold, 0)?;
160        Ok(Self {
161            nodes,
162            root,
163            split_threshold,
164        })
165    }
166
167    /// Searches for a leaf node in the tree, that contains the given point and writes ids of the
168    /// entities stored in the leaf node to the output storage.
169    pub fn point_query<S>(&self, point: Vector2<f32>, storage: &mut S)
170    where
171        S: QueryStorage<Id = I>,
172    {
173        self.point_query_recursive(self.root, point, storage)
174    }
175
176    fn point_query_recursive<S>(&self, node: usize, point: Vector2<f32>, storage: &mut S)
177    where
178        S: QueryStorage<Id = I>,
179    {
180        if let Some(node) = self.nodes.get(node) {
181            match node {
182                QuadTreeNode::Leaf { bounds, ids } => {
183                    if bounds.contains(point) {
184                        for id in ids {
185                            if !storage.try_push(id.clone()) {
186                                return;
187                            }
188                        }
189                    }
190                }
191                QuadTreeNode::Branch { bounds, leaves } => {
192                    if bounds.contains(point) {
193                        for &leaf in leaves {
194                            self.point_query_recursive(leaf, point, storage)
195                        }
196                    }
197                }
198            }
199        }
200    }
201
202    /// Returns current split threshold, that was used to build the quad tree.
203    pub fn split_threshold(&self) -> usize {
204        self.split_threshold
205    }
206}
207
208/// Arbitrary storage for query results.
209pub trait QueryStorage {
210    /// Id of an entity in the storage.
211    type Id;
212
213    /// Tries to push a new id in the storage.
214    fn try_push(&mut self, id: Self::Id) -> bool;
215
216    /// Clears the storage.
217    fn clear(&mut self);
218}
219
220impl<I> QueryStorage for Vec<I> {
221    type Id = I;
222
223    fn try_push(&mut self, intersection: I) -> bool {
224        self.push(intersection);
225        true
226    }
227
228    fn clear(&mut self) {
229        self.clear()
230    }
231}
232
233impl<I, const CAP: usize> QueryStorage for ArrayVec<I, CAP> {
234    type Id = I;
235
236    fn try_push(&mut self, intersection: I) -> bool {
237        self.try_push(intersection).is_ok()
238    }
239
240    fn clear(&mut self) {
241        self.clear()
242    }
243}
244
245#[cfg(test)]
246mod test {
247    use super::*;
248    use crate::Rect;
249
250    struct TestObject {
251        bounds: Rect<f32>,
252        id: usize,
253    }
254
255    impl BoundsProvider for &TestObject {
256        type Id = usize;
257
258        fn bounds(&self) -> Rect<f32> {
259            self.bounds
260        }
261
262        fn id(&self) -> Self::Id {
263            self.id
264        }
265    }
266
267    #[test]
268    fn test_quad_tree() {
269        let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
270        let objects = vec![
271            TestObject {
272                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
273                id: 0,
274            },
275            TestObject {
276                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
277                id: 1,
278            },
279        ];
280        // Infinite recursion prevention check (when there are multiple objects share same location).
281        assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_err());
282
283        let objects = vec![
284            TestObject {
285                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
286                id: 0,
287            },
288            TestObject {
289                bounds: Rect::new(20.0, 20.0, 10.0, 10.0),
290                id: 1,
291            },
292        ];
293        assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_ok());
294    }
295
296    #[test]
297    fn default_for_quad_tree() {
298        let tree = QuadTree::<u32>::default();
299
300        assert_eq!(tree.split_threshold, 16);
301        assert_eq!(tree.root, 0);
302    }
303
304    #[test]
305    fn quad_tree_point_query() {
306        // empty
307        let tree = QuadTree::<f32>::default();
308        let mut s = Vec::<f32>::new();
309
310        tree.point_query(Vector2::new(0.0, 0.0), &mut s);
311        assert_eq!(s, vec![]);
312
313        let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
314
315        // leaf
316        let mut s = Vec::<usize>::new();
317        let mut pool = Vec::new();
318        pool.push(QuadTreeNode::Leaf {
319            bounds: root_bounds,
320            ids: vec![0, 1],
321        });
322
323        let tree = QuadTree {
324            root: 0,
325            nodes: pool,
326            ..Default::default()
327        };
328
329        tree.point_query(Vector2::new(10.0, 10.0), &mut s);
330        assert_eq!(s, vec![0, 1]);
331
332        // branch
333        let mut s = Vec::<usize>::new();
334        let mut pool = Vec::new();
335        let a = 0;
336        pool.push(QuadTreeNode::Leaf {
337            bounds: root_bounds,
338            ids: vec![0, 1],
339        });
340        let b = 1;
341        pool.push(QuadTreeNode::Branch {
342            bounds: root_bounds,
343            leaves: [a, a, a, a],
344        });
345
346        let tree = QuadTree {
347            root: b,
348            nodes: pool,
349            ..Default::default()
350        };
351
352        tree.point_query(Vector2::new(10.0, 10.0), &mut s);
353        assert_eq!(s, vec![0, 1, 0, 1, 0, 1, 0, 1]);
354    }
355
356    #[test]
357    fn quad_tree_split_threshold() {
358        let tree = QuadTree::<u32>::default();
359
360        assert_eq!(tree.split_threshold(), tree.split_threshold);
361    }
362
363    #[test]
364    fn query_storage_for_vec() {
365        let mut s = vec![1];
366
367        let res = QueryStorage::try_push(&mut s, 2);
368        assert!(res);
369        assert_eq!(s, vec![1, 2]);
370
371        QueryStorage::clear(&mut s);
372        assert!(s.is_empty());
373    }
374
375    #[test]
376    fn query_storage_for_array_vec() {
377        let mut s = ArrayVec::<i32, 3>::new();
378
379        let res = QueryStorage::try_push(&mut s, 1);
380        assert!(res);
381        assert!(!s.is_empty());
382
383        QueryStorage::clear(&mut s);
384        assert!(s.is_empty());
385    }
386}