sif_rtree/
build.rs

1use std::marker::PhantomData;
2use std::num::NonZeroUsize;
3
4use crate::{iter::twig_len_pad, Node, Object, Point, RTree, TWIG_LEN};
5
6/// A sensible default value for the node length, balancing query efficency against memory overhead
7pub const DEF_NODE_LEN: usize = 6;
8
9impl<O> RTree<O>
10where
11    O: Object,
12{
13    /// Builds a new [R-tree](https://en.wikipedia.org/wiki/R-tree) from a given set of `objects`
14    ///
15    /// The `node_len` parameter determines the length of branch nodes and thereby the three depth. It must be larger than one. [`DEF_NODE_LEN`] provides a sensible default.
16    ///
17    /// The `objects` parameter must not be empty.
18    pub fn new(node_len: usize, objects: Vec<O>) -> Self {
19        assert!(node_len > 1);
20        assert!(!objects.is_empty());
21
22        let mut nodes = Vec::new();
23        let mut next_nodes = Vec::new();
24
25        let root_idx = build(node_len, objects, &mut nodes, &mut next_nodes);
26        debug_assert_eq!(root_idx, nodes.len() - 1);
27
28        // The whole tree is reversed, so that iteration visits increasing memory addresses which measurably improves performance.
29        nodes.reverse();
30
31        for node in &mut nodes {
32            if let Node::Twig(twig) = node {
33                for idx in twig {
34                    *idx = root_idx - *idx;
35                }
36            }
37        }
38
39        Self {
40            nodes: nodes.into_boxed_slice(),
41            _marker: PhantomData,
42        }
43    }
44}
45
46/// A reimplementation of the overlap-minimizing top-down bulk loading algorithm used by the [`rstar`] crate
47///
48/// For a given value of `node_len` (which is equivalent to [`rstar::RTreeParams::MAX_SIZE`]) and a given list of `objects`, it should produce the same tree structure.
49fn build<O>(
50    node_len: usize,
51    objects: Vec<O>,
52    nodes: &mut Vec<Node<O>>,
53    next_nodes: &mut Vec<usize>,
54) -> usize
55where
56    O: Object,
57{
58    let next_nodes_len = next_nodes.len();
59
60    if objects.len() > node_len {
61        let num_clusters = num_clusters(node_len, O::Point::DIM, objects.len()).max(2);
62
63        struct State<O> {
64            objects: Vec<O>,
65            axis: usize,
66        }
67
68        let mut state = vec![State {
69            objects,
70            axis: O::Point::DIM,
71        }];
72
73        while let Some(State {
74            mut objects,
75            mut axis,
76        }) = state.pop()
77        {
78            if axis != 0 {
79                axis -= 1;
80
81                let cluster_len = (objects.len() + num_clusters - 1) / num_clusters;
82
83                while objects.len() > cluster_len {
84                    objects.select_nth_unstable_by(cluster_len, |lhs, rhs| {
85                        let lhs = lhs.aabb().0.coord(axis);
86                        let rhs = rhs.aabb().0.coord(axis);
87                        lhs.partial_cmp(&rhs).unwrap()
88                    });
89
90                    let next_objects = objects.split_off(cluster_len);
91                    state.push(State { objects, axis });
92                    objects = next_objects;
93                }
94
95                if !objects.is_empty() {
96                    state.push(State { objects, axis });
97                }
98            } else {
99                let node = build(node_len, objects, nodes, next_nodes);
100                next_nodes.push(node);
101            }
102        }
103    } else {
104        next_nodes.extend(nodes.len()..nodes.len() + objects.len());
105        nodes.extend(objects.into_iter().map(Node::Leaf));
106    }
107
108    let node = add_branch(nodes, &next_nodes[next_nodes_len..]);
109    next_nodes.truncate(next_nodes_len);
110    node
111}
112
113fn num_clusters(node_len: usize, point_dim: usize, num_objects: usize) -> usize {
114    let node_len = node_len as f32;
115    let point_dim = point_dim as f32;
116    let num_objects = num_objects as f32;
117
118    let depth = num_objects.log(node_len).ceil() as i32;
119
120    let subtree_len = node_len.powi(depth - 1);
121    let num_subtree = (num_objects / subtree_len).ceil();
122
123    num_subtree.powf(point_dim.recip()).ceil() as usize
124}
125
126fn add_branch<O>(nodes: &mut Vec<Node<O>>, next_nodes: &[usize]) -> usize
127where
128    O: Object,
129{
130    let len = NonZeroUsize::new(next_nodes.len()).unwrap();
131
132    let aabb = merge_aabb(nodes, next_nodes);
133
134    {
135        // Padding is inserted into the first twig, so that iteration is uniform over the following twigs.
136        let (len, pad) = twig_len_pad(&len);
137
138        nodes.reserve(len + 1);
139
140        let mut twig = [0; TWIG_LEN];
141        let mut pos = TWIG_LEN;
142
143        // The twigs in the branch are reversed, so that after reversing the whole tree, they will follow the branch in ascending order.
144        for next_node in next_nodes.iter().rev() {
145            pos -= 1;
146            twig[pos] = *next_node;
147
148            if pos == 0 {
149                nodes.push(Node::Twig(twig));
150                pos = TWIG_LEN;
151            }
152        }
153
154        if pos != TWIG_LEN {
155            debug_assert_eq!(pos, pad);
156            nodes.push(Node::Twig(twig));
157        }
158    }
159
160    let node = nodes.len();
161    nodes.push(Node::Branch { len, aabb });
162    node
163}
164
165fn merge_aabb<O>(nodes: &[Node<O>], next_nodes: &[usize]) -> (O::Point, O::Point)
166where
167    O: Object,
168{
169    next_nodes
170        .iter()
171        .map(|idx| match &nodes[*idx] {
172            Node::Branch { aabb, .. } => aabb.clone(),
173            Node::Twig(_) => unreachable!(),
174            Node::Leaf(obj) => obj.aabb(),
175        })
176        .reduce(|mut res, aabb| {
177            res.0 = res.0.min(&aabb.0);
178            res.1 = res.1.max(&aabb.1);
179
180            res
181        })
182        .unwrap()
183}
184
185#[cfg(test)]
186mod tests {
187    use super::*;
188
189    use std::ops::ControlFlow;
190
191    use proptest::test_runner::TestRunner;
192
193    use crate::{
194        iter::branch_for_each,
195        tests::{random_objects, RandomObject},
196    };
197
198    impl rstar::RTreeObject for RandomObject {
199        type Envelope = rstar::AABB<[f32; 3]>;
200
201        fn envelope(&self) -> Self::Envelope {
202            rstar::AABB::from_corners(self.0, self.1)
203        }
204    }
205
206    fn collect_index<'a>(
207        nodes: &'a [Node<RandomObject>],
208        idx: usize,
209        branches: &mut Vec<usize>,
210        leaves: &mut Vec<&'a RandomObject>,
211    ) {
212        let [node, rest @ ..] = &nodes[idx..] else {
213            unreachable!()
214        };
215        let len = match node {
216            Node::Branch { len, .. } => len,
217            Node::Twig(_) | Node::Leaf(_) => unreachable!(),
218        };
219        branches.push(len.get());
220        branch_for_each(len, rest, |idx| {
221            match &nodes[idx] {
222                Node::Branch { .. } => collect_index(nodes, idx, branches, leaves),
223                Node::Twig(_) => unreachable!(),
224                Node::Leaf(obj) => {
225                    branches.push(0);
226                    leaves.push(obj);
227                }
228            }
229            ControlFlow::<()>::Continue(())
230        })
231        .continue_value()
232        .unwrap();
233    }
234
235    fn collect_rstar_index<'a>(
236        node: &'a rstar::ParentNode<RandomObject>,
237        branches: &mut Vec<usize>,
238        leaves: &mut Vec<&'a RandomObject>,
239    ) {
240        let children = node.children();
241        branches.push(children.len());
242        for child in children {
243            match child {
244                rstar::RTreeNode::Parent(node) => collect_rstar_index(node, branches, leaves),
245                rstar::RTreeNode::Leaf(obj) => {
246                    branches.push(0);
247                    leaves.push(obj);
248                }
249            }
250        }
251    }
252
253    #[test]
254    fn random_trees() {
255        TestRunner::default()
256            .run(&random_objects(100), |objects| {
257                let index = RTree::new(DEF_NODE_LEN, objects.clone());
258
259                let mut branches = Vec::new();
260                let mut leaves = Vec::new();
261
262                collect_index(&index, 0, &mut branches, &mut leaves);
263
264                let rstar_index = rstar::RTree::bulk_load(objects);
265
266                let mut rstar_branches = Vec::new();
267                let mut rstar_leaves = Vec::new();
268
269                collect_rstar_index(rstar_index.root(), &mut rstar_branches, &mut rstar_leaves);
270
271                assert_eq!(branches, rstar_branches);
272                assert_eq!(leaves, rstar_leaves);
273
274                Ok(())
275            })
276            .unwrap();
277    }
278}