rectutils/
quadtree.rs

1//! Quadrilateral (quad) tree is used for space partitioning and fast spatial queries.
2
3use crate::Rect;
4use arrayvec::ArrayVec;
5use nalgebra::Vector2;
6
7#[derive(Clone)]
8enum QuadTreeNode<T: Clone> {
9    Leaf {
10        bounds: Rect<f32>,
11        ids: Vec<T>,
12    },
13    Branch {
14        bounds: Rect<f32>,
15        leaves: [usize; 4],
16    },
17}
18
19fn split_rect(rect: &Rect<f32>) -> [Rect<f32>; 4] {
20    let half_size = rect.size.scale(0.5);
21    [
22        Rect {
23            position: rect.position,
24            size: half_size,
25        },
26        Rect {
27            position: Vector2::new(rect.position.x + half_size.x, rect.position.y),
28            size: half_size,
29        },
30        Rect {
31            position: rect.position + half_size,
32            size: half_size,
33        },
34        Rect {
35            position: Vector2::new(rect.position.x, rect.position.y + half_size.y),
36            size: half_size,
37        },
38    ]
39}
40
41/// Quadrilateral (quad) tree is used for space partitioning and fast spatial queries.
42#[derive(Clone)]
43pub struct QuadTree<T: Clone> {
44    nodes: Vec<QuadTreeNode<T>>,
45    root: usize,
46    split_threshold: usize,
47}
48
49impl<T: Clone + 'static> Default for QuadTree<T> {
50    fn default() -> Self {
51        Self {
52            nodes: Default::default(),
53            root: Default::default(),
54            split_threshold: 16,
55        }
56    }
57}
58
59/// A trait for anything that has rectangular bounds.
60pub trait BoundsProvider {
61    /// Identifier of the bounds provider.
62    type Id: Clone;
63
64    /// Returns bounds of the bounds provider.
65    fn bounds(&self) -> Rect<f32>;
66
67    /// Returns id of the bounds provider.
68    fn id(&self) -> Self::Id;
69}
70
71/// An error, that may occur during the build of the quad tree.
72pub enum QuadTreeBuildError {
73    /// It means that given split threshold is too low for an algorithm to build quad tree.
74    /// Make it larger and try again. Also this might mean that your initial bounds are too small.
75    ReachedRecursionLimit,
76}
77
78#[derive(Clone)]
79struct Entry<I: Clone> {
80    id: I,
81    bounds: Rect<f32>,
82}
83
84fn build_recursive<I>(
85    nodes: &mut Vec<QuadTreeNode<I>>,
86    bounds: Rect<f32>,
87    entries: &[Entry<I>],
88    split_threshold: usize,
89    depth: usize,
90) -> Result<usize, QuadTreeBuildError>
91where
92    I: Clone + 'static,
93{
94    if depth >= 64 {
95        Err(QuadTreeBuildError::ReachedRecursionLimit)
96    } else if entries.len() <= split_threshold {
97        let index = nodes.len();
98        nodes.push(QuadTreeNode::Leaf {
99            bounds,
100            ids: entries.iter().map(|e| e.id.clone()).collect::<Vec<_>>(),
101        });
102        Ok(index)
103    } else {
104        let leaf_bounds = split_rect(&bounds);
105        let mut leaves = [usize::MAX; 4];
106
107        for (leaf, &leaf_bounds) in leaves.iter_mut().zip(leaf_bounds.iter()) {
108            let leaf_entries = entries
109                .iter()
110                .filter_map(|e| {
111                    if leaf_bounds.intersects(e.bounds) {
112                        Some(e.clone())
113                    } else {
114                        None
115                    }
116                })
117                .collect::<Vec<_>>();
118
119            *leaf = build_recursive(
120                nodes,
121                leaf_bounds,
122                &leaf_entries,
123                split_threshold,
124                depth + 1,
125            )?;
126        }
127
128        let index = nodes.len();
129        nodes.push(QuadTreeNode::Branch { bounds, leaves });
130        Ok(index)
131    }
132}
133
134impl<I> QuadTree<I>
135where
136    I: Clone + 'static,
137{
138    /// Creates new quad tree from the given initial bounds and the set of objects.
139    pub fn new<T>(
140        root_bounds: Rect<f32>,
141        objects: impl Iterator<Item = T>,
142        split_threshold: usize,
143    ) -> Result<Self, QuadTreeBuildError>
144    where
145        T: BoundsProvider<Id = I>,
146    {
147        let entries = objects
148            .filter_map(|o| {
149                if root_bounds.intersects(o.bounds()) {
150                    Some(Entry {
151                        id: o.id(),
152                        bounds: o.bounds(),
153                    })
154                } else {
155                    None
156                }
157            })
158            .collect::<Vec<_>>();
159
160        let mut nodes = Vec::new();
161        let root = build_recursive(&mut nodes, root_bounds, &entries, split_threshold, 0)?;
162        Ok(Self {
163            nodes,
164            root,
165            split_threshold,
166        })
167    }
168
169    /// Searches for a leaf node in the tree, that contains the given point and writes ids of the
170    /// entities stored in the leaf node to the output storage.
171    pub fn point_query<S>(&self, point: Vector2<f32>, storage: &mut S)
172    where
173        S: QueryStorage<Id = I>,
174    {
175        self.point_query_recursive(self.root, point, storage)
176    }
177
178    fn point_query_recursive<S>(&self, node: usize, point: Vector2<f32>, storage: &mut S)
179    where
180        S: QueryStorage<Id = I>,
181    {
182        if let Some(node) = self.nodes.get(node) {
183            match node {
184                QuadTreeNode::Leaf { bounds, ids } => {
185                    if bounds.contains(point) {
186                        for id in ids {
187                            if !storage.try_push(id.clone()) {
188                                return;
189                            }
190                        }
191                    }
192                }
193                QuadTreeNode::Branch { bounds, leaves } => {
194                    if bounds.contains(point) {
195                        for &leaf in leaves {
196                            self.point_query_recursive(leaf, point, storage)
197                        }
198                    }
199                }
200            }
201        }
202    }
203
204    /// Returns current split threshold, that was used to build the quad tree.
205    pub fn split_threshold(&self) -> usize {
206        self.split_threshold
207    }
208}
209
210/// Arbitrary storage for query results.
211pub trait QueryStorage {
212    /// Id of an entity in the storage.
213    type Id;
214
215    /// Tries to push a new id in the storage.
216    fn try_push(&mut self, id: Self::Id) -> bool;
217
218    /// Clears the storage.
219    fn clear(&mut self);
220}
221
222impl<I> QueryStorage for Vec<I> {
223    type Id = I;
224
225    fn try_push(&mut self, intersection: I) -> bool {
226        self.push(intersection);
227        true
228    }
229
230    fn clear(&mut self) {
231        self.clear()
232    }
233}
234
235impl<I, const CAP: usize> QueryStorage for ArrayVec<I, CAP> {
236    type Id = I;
237
238    fn try_push(&mut self, intersection: I) -> bool {
239        self.try_push(intersection).is_ok()
240    }
241
242    fn clear(&mut self) {
243        self.clear()
244    }
245}
246
247#[cfg(test)]
248mod test {
249    use super::*;
250    use crate::Rect;
251
252    struct TestObject {
253        bounds: Rect<f32>,
254        id: usize,
255    }
256
257    impl BoundsProvider for &TestObject {
258        type Id = usize;
259
260        fn bounds(&self) -> Rect<f32> {
261            self.bounds
262        }
263
264        fn id(&self) -> Self::Id {
265            self.id
266        }
267    }
268
269    #[test]
270    fn test_quad_tree() {
271        let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
272        let objects = vec![
273            TestObject {
274                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
275                id: 0,
276            },
277            TestObject {
278                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
279                id: 1,
280            },
281        ];
282        // Infinite recursion prevention check (when there are multiple objects share same location).
283        assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_err());
284
285        let objects = vec![
286            TestObject {
287                bounds: Rect::new(10.0, 10.0, 10.0, 10.0),
288                id: 0,
289            },
290            TestObject {
291                bounds: Rect::new(20.0, 20.0, 10.0, 10.0),
292                id: 1,
293            },
294        ];
295        assert!(QuadTree::new(root_bounds, objects.iter(), 1).is_ok());
296    }
297
298    #[test]
299    fn default_for_quad_tree() {
300        let tree = QuadTree::<u32>::default();
301
302        assert_eq!(tree.split_threshold, 16);
303        assert_eq!(tree.root, 0);
304    }
305
306    #[test]
307    fn quad_tree_point_query() {
308        // empty
309        let tree = QuadTree::<f32>::default();
310        let mut s = Vec::<f32>::new();
311
312        tree.point_query(Vector2::new(0.0, 0.0), &mut s);
313        assert_eq!(s, vec![]);
314
315        let root_bounds = Rect::new(0.0, 0.0, 200.0, 200.0);
316
317        // leaf
318        let mut s = Vec::<usize>::new();
319        let mut pool = Vec::new();
320        pool.push(QuadTreeNode::Leaf {
321            bounds: root_bounds,
322            ids: vec![0, 1],
323        });
324
325        let tree = QuadTree {
326            root: 0,
327            nodes: pool,
328            ..Default::default()
329        };
330
331        tree.point_query(Vector2::new(10.0, 10.0), &mut s);
332        assert_eq!(s, vec![0, 1]);
333
334        // branch
335        let mut s = Vec::<usize>::new();
336        let mut pool = Vec::new();
337        let a = 0;
338        pool.push(QuadTreeNode::Leaf {
339            bounds: root_bounds,
340            ids: vec![0, 1],
341        });
342        let b = 1;
343        pool.push(QuadTreeNode::Branch {
344            bounds: root_bounds,
345            leaves: [a, a, a, a],
346        });
347
348        let tree = QuadTree {
349            root: b,
350            nodes: pool,
351            ..Default::default()
352        };
353
354        tree.point_query(Vector2::new(10.0, 10.0), &mut s);
355        assert_eq!(s, vec![0, 1, 0, 1, 0, 1, 0, 1]);
356    }
357
358    #[test]
359    fn quad_tree_split_threshold() {
360        let tree = QuadTree::<u32>::default();
361
362        assert_eq!(tree.split_threshold(), tree.split_threshold);
363    }
364
365    #[test]
366    fn query_storage_for_vec() {
367        let mut s = vec![1];
368
369        let res = QueryStorage::try_push(&mut s, 2);
370        assert!(res);
371        assert_eq!(s, vec![1, 2]);
372
373        QueryStorage::clear(&mut s);
374        assert!(s.is_empty());
375    }
376
377    #[test]
378    fn query_storage_for_array_vec() {
379        let mut s = ArrayVec::<i32, 3>::new();
380
381        let res = QueryStorage::try_push(&mut s, 1);
382        assert!(res);
383        assert!(!s.is_empty());
384
385        QueryStorage::clear(&mut s);
386        assert!(s.is_empty());
387    }
388}