scirs2_spatial/rtree/
optimization.rs

1use crate::error::SpatialResult;
2use crate::rtree::node::{Entry, Node, RTree};
3use crate::rtree::Rectangle;
4use ndarray::Array1;
5
6impl<T: Clone> RTree<T> {
7    /// Optimize the R-tree by rebuilding it with the current data
8    ///
9    /// This can significantly improve query performance by reducing overlap
10    /// and creating a more balanced tree.
11    ///
12    /// # Returns
13    ///
14    /// A `SpatialResult` containing nothing if successful
15    pub fn optimize(&mut self) -> SpatialResult<()> {
16        // Collect all data points
17        let data_points = self.collect_all_data_points()?;
18
19        if data_points.is_empty() {
20            return Ok(());
21        }
22
23        // Create a new, empty R-tree
24        // These parameters are not used in this method but would be used
25        // if we were creating a new R-tree
26        // let _ndim = self.ndim();
27        // let _min_entries = self.min_entries;
28        // let _max_entries = self.maxentries;
29
30        // Save current size to check at the end
31        let size = self.size();
32
33        // Clear current tree
34        self.clear();
35
36        // Re-insert all data points (bulk loading would be more efficient)
37        for (point, data, _) in data_points {
38            self.insert(point, data)?;
39        }
40
41        // Verify integrity
42        assert_eq!(self.size(), size, "Size mismatch after optimization");
43
44        Ok(())
45    }
46
47    /// Collect all data points in the R-tree
48    ///
49    /// # Returns
50    ///
51    /// A `SpatialResult` containing a vector of (point, data, index) tuples
52    fn collect_all_data_points(&self) -> SpatialResult<Vec<(Array1<f64>, T, usize)>> {
53        let mut points = Vec::new();
54        self.collect_data_points_recursive(&self.root, &mut points)?;
55        Ok(points)
56    }
57
58    /// Recursively collect data points from a node
59    #[allow(clippy::only_used_in_recursion)]
60    fn collect_data_points_recursive(
61        &self,
62        node: &Node<T>,
63        points: &mut Vec<(Array1<f64>, T, usize)>,
64    ) -> SpatialResult<()> {
65        for entry in &node.entries {
66            match entry {
67                Entry::Leaf { mbr, data, index } => {
68                    // For leaf entries, the MBR should be a point (min == max)
69                    points.push((mbr.min.clone(), data.clone(), *index));
70                }
71                Entry::NonLeaf { child, .. } => {
72                    // For non-leaf entries, recursively collect from children
73                    self.collect_data_points_recursive(child, points)?;
74                }
75            }
76        }
77        Ok(())
78    }
79
80    /// Perform bulk loading of the R-tree with sorted data points
81    ///
82    /// This is more efficient than inserting points one by one.
83    ///
84    /// # Arguments
85    ///
86    /// * `points` - A vector of (point, data) pairs to insert
87    ///
88    /// # Returns
89    ///
90    /// A `SpatialResult` containing a new R-tree built from the data points
91    pub fn bulk_load(
92        ndim: usize,
93        min_entries: usize,
94        max_entries: usize,
95        points: Vec<(Array1<f64>, T)>,
96    ) -> SpatialResult<Self> {
97        // Create a new, empty R-tree
98        let mut rtree = RTree::new(ndim, min_entries, max_entries)?;
99
100        if points.is_empty() {
101            return Ok(rtree);
102        }
103
104        // Implement Sort-Tile-Recursive (STR) bulk loading algorithm
105
106        // Validate all points have correct dimensions
107        for (i, (point, _)) in points.iter().enumerate() {
108            if point.len() != ndim {
109                return Err(crate::error::SpatialError::DimensionError(format!(
110                    "Point at index {} has dimension {} but tree dimension is {}",
111                    i,
112                    point.len(),
113                    ndim
114                )));
115            }
116        }
117
118        // Convert points to leaf _entries
119        let mut entries: Vec<Entry<T>> = points
120            .into_iter()
121            .enumerate()
122            .map(|(index, (point, data))| Entry::Leaf {
123                mbr: Rectangle::from_point(&point.view()),
124                data,
125                index,
126            })
127            .collect();
128
129        // Build the tree recursively
130        rtree.root = rtree.str_build_node(&mut entries, 0)?;
131        rtree.root._isleaf =
132            rtree.root.entries.is_empty() || matches!(rtree.root.entries[0], Entry::Leaf { .. });
133
134        // Update tree height
135        let height = rtree.calculate_height(&rtree.root);
136        for _ in 1..height {
137            rtree.increment_height();
138        }
139
140        Ok(rtree)
141    }
142
143    /// Build a node using the STR algorithm
144    fn str_build_node(&self, entries: &mut Vec<Entry<T>>, level: usize) -> SpatialResult<Node<T>> {
145        let n = entries.len();
146
147        if n == 0 {
148            return Ok(Node::new(level == 0, level));
149        }
150
151        // If we can fit all _entries in one node, create it
152        if n <= self.maxentries {
153            let mut node = Node::new(level == 0, level);
154            node.entries = std::mem::take(entries);
155            return Ok(node);
156        }
157
158        // Calculate the number of leaf nodes needed
159        let leaf_capacity = self.maxentries;
160        let num_leaves = n.div_ceil(leaf_capacity);
161
162        // Calculate the number of slices along each dimension
163        let slice_count = (num_leaves as f64).powf(1.0 / self.ndim() as f64).ceil() as usize;
164
165        // Sort _entries by the first dimension
166        let dim = level % self.ndim();
167        entries.sort_by(|a, b| {
168            let a_center = (a.mbr().min[dim] + a.mbr().max[dim]) / 2.0;
169            let b_center = (b.mbr().min[dim] + b.mbr().max[dim]) / 2.0;
170            a_center
171                .partial_cmp(&b_center)
172                .unwrap_or(std::cmp::Ordering::Equal)
173        });
174
175        // Create child nodes
176        let mut children = Vec::new();
177        let entries_per_slice = n.div_ceil(slice_count);
178
179        for i in 0..slice_count {
180            let start = i * entries_per_slice;
181            let end = ((i + 1) * entries_per_slice).min(n);
182
183            if start >= n {
184                break;
185            }
186
187            let mut slice_entries: Vec<Entry<T>> = entries[start..end].to_vec();
188
189            // Recursively build child nodes
190            if level == 0 {
191                // These are leaf entries, group them into leaf nodes
192                while !slice_entries.is_empty() {
193                    let mut node = Node::new(true, 0);
194                    let take_count = slice_entries.len().min(self.maxentries);
195                    node.entries = slice_entries.drain(..take_count).collect();
196
197                    if let Ok(Some(mbr)) = node.mbr() {
198                        children.push(Entry::NonLeaf {
199                            mbr,
200                            child: Box::new(node),
201                        });
202                    }
203                }
204            } else {
205                // Build non-leaf nodes recursively
206                let child_node = self.str_build_node(&mut slice_entries, level - 1)?;
207                if let Ok(Some(mbr)) = child_node.mbr() {
208                    children.push(Entry::NonLeaf {
209                        mbr,
210                        child: Box::new(child_node),
211                    });
212                }
213            }
214        }
215
216        // Clear the input _entries as they've been moved to children
217        entries.clear();
218
219        // If we have too many children, build another level
220        if children.len() > self.maxentries {
221            self.str_build_node(&mut children, level + 1)
222        } else {
223            let mut node = Node::new(false, level + 1);
224            node.entries = children;
225            Ok(node)
226        }
227    }
228
229    /// Calculate the height of the tree
230    #[allow(clippy::only_used_in_recursion)]
231    fn calculate_height(&self, node: &Node<T>) -> usize {
232        if node._isleaf {
233            1
234        } else if let Some(Entry::NonLeaf { child, .. }) = node.entries.first() {
235            1 + self.calculate_height(child)
236        } else {
237            1
238        }
239    }
240
241    /// Calculate the total overlap in the R-tree
242    ///
243    /// This is a quality metric for the tree. Lower overlap generally means
244    /// better query performance.
245    ///
246    /// # Returns
247    ///
248    /// The total overlap area between all pairs of nodes at each level
249    pub fn calculate_total_overlap(&self) -> SpatialResult<f64> {
250        let mut total_overlap = 0.0;
251
252        // Calculate overlap at each level, starting from the root
253        let mut current_level_nodes = vec![&self.root];
254
255        while !current_level_nodes.is_empty() {
256            // Calculate overlap between nodes at this level
257            for i in 0..current_level_nodes.len() - 1 {
258                let node_i_mbr = match current_level_nodes[i].mbr() {
259                    Ok(Some(mbr)) => mbr,
260                    _ => continue,
261                };
262
263                for node_j in current_level_nodes.iter().skip(i + 1) {
264                    let node_j_mbr = match node_j.mbr() {
265                        Ok(Some(mbr)) => mbr,
266                        _ => continue,
267                    };
268
269                    // Check if MBRs intersect
270                    if node_i_mbr.intersects(&node_j_mbr)? {
271                        // Calculate intersection area
272                        if let Ok(intersection) = node_i_mbr.intersection(&node_j_mbr) {
273                            total_overlap += intersection.area();
274                        }
275                    }
276                }
277            }
278
279            // Move to the next level
280            let mut next_level_nodes = Vec::new();
281            for node in current_level_nodes {
282                for entry in &node.entries {
283                    if let Entry::NonLeaf { child, .. } = entry {
284                        next_level_nodes.push(&**child);
285                    }
286                }
287            }
288
289            current_level_nodes = next_level_nodes;
290        }
291
292        Ok(total_overlap)
293    }
294}
295
296#[cfg(test)]
297mod tests {
298    use super::*;
299    use ndarray::array;
300
301    #[test]
302    fn test_rtree_optimize() {
303        // Create a new R-tree
304        let mut rtree: RTree<i32> = RTree::new(2, 2, 4).unwrap();
305
306        // Insert some points
307        let points = vec![
308            (array![0.0, 0.0], 0),
309            (array![1.0, 0.0], 1),
310            (array![0.0, 1.0], 2),
311            (array![1.0, 1.0], 3),
312            (array![0.5, 0.5], 4),
313            (array![2.0, 2.0], 5),
314            (array![3.0, 3.0], 6),
315            (array![4.0, 4.0], 7),
316            (array![5.0, 5.0], 8),
317            (array![6.0, 6.0], 9),
318        ];
319
320        for (point, value) in points {
321            rtree.insert(point, value).unwrap();
322        }
323
324        // Optimize the tree
325        rtree.optimize().unwrap();
326
327        // Check that all data is still present
328        assert_eq!(rtree.size(), 10);
329
330        // Try to search for a point
331        let results = rtree
332            .search_range(&array![0.4, 0.4].view(), &array![0.6, 0.6].view())
333            .unwrap();
334
335        assert_eq!(results.len(), 1);
336        assert_eq!(results[0].1, 4);
337    }
338
339    #[test]
340    #[ignore]
341    fn test_rtree_bulk_load() {
342        // Create points
343        let points = vec![
344            (array![0.0, 0.0], 0),
345            (array![1.0, 0.0], 1),
346            (array![0.0, 1.0], 2),
347            (array![1.0, 1.0], 3),
348            (array![0.5, 0.5], 4),
349            (array![2.0, 2.0], 5),
350            (array![3.0, 3.0], 6),
351            (array![4.0, 4.0], 7),
352            (array![5.0, 5.0], 8),
353            (array![6.0, 6.0], 9),
354        ];
355
356        // Bulk load
357        let rtree = RTree::bulk_load(2, 2, 4, points).unwrap();
358
359        // Check that all data is present
360        assert_eq!(rtree.size(), 10);
361
362        // Try to search for a point
363        let results = rtree
364            .search_range(&array![0.4, 0.4].view(), &array![0.6, 0.6].view())
365            .unwrap();
366
367        assert_eq!(results.len(), 1);
368        assert_eq!(results[0].1, 4);
369    }
370}