scirs2_spatial/rtree/
optimization.rs

1use crate::error::SpatialResult;
2use crate::rtree::node::{Entry, Node, RTree};
3use crate::rtree::Rectangle;
4use scirs2_core::ndarray::Array1;
5
6impl<T: Clone> RTree<T> {
7    /// Optimize the R-tree by rebuilding it with the current data
8    ///
9    /// This can significantly improve query performance by reducing overlap
10    /// and creating a more balanced tree.
11    ///
12    /// # Returns
13    ///
14    /// A `SpatialResult` containing nothing if successful
15    pub fn optimize(&mut self) -> SpatialResult<()> {
16        // Collect all data points
17        let data_points = self.collect_all_data_points()?;
18
19        if data_points.is_empty() {
20            return Ok(());
21        }
22
23        // Create a new, empty R-tree
24        // These parameters are not used in this method but would be used
25        // if we were creating a new R-tree
26        // let _ndim = self.ndim();
27        // let _min_entries = self.min_entries;
28        // let _max_entries = self.maxentries;
29
30        // Save current size to check at the end
31        let size = self.size();
32
33        // Clear current tree
34        self.clear();
35
36        // Re-insert all data points (bulk loading would be more efficient)
37        for (point, data, _) in data_points {
38            self.insert(point, data)?;
39        }
40
41        // Verify integrity
42        assert_eq!(self.size(), size, "Size mismatch after optimization");
43
44        Ok(())
45    }
46
47    /// Collect all data points in the R-tree
48    ///
49    /// # Returns
50    ///
51    /// A `SpatialResult` containing a vector of (point, data, index) tuples
52    fn collect_all_data_points(&self) -> SpatialResult<Vec<(Array1<f64>, T, usize)>> {
53        let mut points = Vec::new();
54        self.collect_data_points_recursive(&self.root, &mut points)?;
55        Ok(points)
56    }
57
58    /// Recursively collect data points from a node
59    #[allow(clippy::only_used_in_recursion)]
60    fn collect_data_points_recursive(
61        &self,
62        node: &Node<T>,
63        points: &mut Vec<(Array1<f64>, T, usize)>,
64    ) -> SpatialResult<()> {
65        for entry in &node.entries {
66            match entry {
67                Entry::Leaf { mbr, data, index } => {
68                    // For leaf entries, the MBR should be a point (min == max)
69                    points.push((mbr.min.clone(), data.clone(), *index));
70                }
71                Entry::NonLeaf { child, .. } => {
72                    // For non-leaf entries, recursively collect from children
73                    self.collect_data_points_recursive(child, points)?;
74                }
75            }
76        }
77        Ok(())
78    }
79
80    /// Perform bulk loading of the R-tree with sorted data points
81    ///
82    /// This is more efficient than inserting points one by one.
83    ///
84    /// # Arguments
85    ///
86    /// * `points` - A vector of (point, data) pairs to insert
87    ///
88    /// # Returns
89    ///
90    /// A `SpatialResult` containing a new R-tree built from the data points
91    pub fn bulk_load(
92        ndim: usize,
93        min_entries: usize,
94        max_entries: usize,
95        points: Vec<(Array1<f64>, T)>,
96    ) -> SpatialResult<Self> {
97        // Create a new, empty R-tree
98        let mut rtree = RTree::new(ndim, min_entries, max_entries)?;
99
100        if points.is_empty() {
101            return Ok(rtree);
102        }
103
104        // Implement Sort-Tile-Recursive (STR) bulk loading algorithm
105
106        // Validate all points have correct dimensions
107        for (i, (point, _)) in points.iter().enumerate() {
108            if point.len() != ndim {
109                return Err(crate::error::SpatialError::DimensionError(format!(
110                    "Point at index {} has dimension {} but tree dimension is {}",
111                    i,
112                    point.len(),
113                    ndim
114                )));
115            }
116        }
117
118        // Convert points to leaf _entries
119        let mut entries: Vec<Entry<T>> = points
120            .into_iter()
121            .enumerate()
122            .map(|(index, (point, data))| Entry::Leaf {
123                mbr: Rectangle::from_point(&point.view()),
124                data,
125                index,
126            })
127            .collect();
128
129        // Store the number of points for size tracking
130        let num_points = entries.len();
131
132        // Build the tree recursively
133        rtree.root = rtree.str_build_node(&mut entries, 0)?;
134        rtree.root._isleaf =
135            rtree.root.entries.is_empty() || matches!(rtree.root.entries[0], Entry::Leaf { .. });
136
137        // Update tree height
138        let height = rtree.calculate_height(&rtree.root);
139        for _ in 1..height {
140            rtree.increment_height();
141        }
142
143        // Update the tree size to reflect the number of data points loaded
144        for _ in 0..num_points {
145            rtree.increment_size();
146        }
147
148        Ok(rtree)
149    }
150
151    /// Build a node using the STR algorithm
152    fn str_build_node(&self, entries: &mut Vec<Entry<T>>, level: usize) -> SpatialResult<Node<T>> {
153        let n = entries.len();
154
155        if n == 0 {
156            return Ok(Node::new(level == 0, level));
157        }
158
159        // If we can fit all _entries in one node, create it
160        if n <= self.maxentries {
161            let mut node = Node::new(level == 0, level);
162            node.entries = std::mem::take(entries);
163            return Ok(node);
164        }
165
166        // Calculate the number of leaf nodes needed
167        let leaf_capacity = self.maxentries;
168        let num_leaves = n.div_ceil(leaf_capacity);
169
170        // Calculate the number of slices along each dimension
171        let slice_count = (num_leaves as f64).powf(1.0 / self.ndim() as f64).ceil() as usize;
172
173        // Sort _entries by the first dimension
174        let dim = level % self.ndim();
175        entries.sort_by(|a, b| {
176            let a_center = (a.mbr().min[dim] + a.mbr().max[dim]) / 2.0;
177            let b_center = (b.mbr().min[dim] + b.mbr().max[dim]) / 2.0;
178            a_center
179                .partial_cmp(&b_center)
180                .unwrap_or(std::cmp::Ordering::Equal)
181        });
182
183        // Create child nodes
184        let mut children = Vec::new();
185        let entries_per_slice = n.div_ceil(slice_count);
186
187        for i in 0..slice_count {
188            let start = i * entries_per_slice;
189            let end = ((i + 1) * entries_per_slice).min(n);
190
191            if start >= n {
192                break;
193            }
194
195            let mut slice_entries: Vec<Entry<T>> = entries[start..end].to_vec();
196
197            // Recursively build child nodes
198            if level == 0 {
199                // These are leaf entries, group them into leaf nodes
200                while !slice_entries.is_empty() {
201                    let mut node = Node::new(true, 0);
202                    let take_count = slice_entries.len().min(self.maxentries);
203                    node.entries = slice_entries.drain(..take_count).collect();
204
205                    if let Ok(Some(mbr)) = node.mbr() {
206                        children.push(Entry::NonLeaf {
207                            mbr,
208                            child: Box::new(node),
209                        });
210                    }
211                }
212            } else {
213                // Build non-leaf nodes recursively
214                let child_node = self.str_build_node(&mut slice_entries, level - 1)?;
215                if let Ok(Some(mbr)) = child_node.mbr() {
216                    children.push(Entry::NonLeaf {
217                        mbr,
218                        child: Box::new(child_node),
219                    });
220                }
221            }
222        }
223
224        // Clear the input _entries as they've been moved to children
225        entries.clear();
226
227        // If we have too many children, build another level
228        if children.len() > self.maxentries {
229            self.str_build_node(&mut children, level + 1)
230        } else {
231            let mut node = Node::new(false, level + 1);
232            node.entries = children;
233            Ok(node)
234        }
235    }
236
237    /// Calculate the height of the tree
238    #[allow(clippy::only_used_in_recursion)]
239    fn calculate_height(&self, node: &Node<T>) -> usize {
240        if node._isleaf {
241            1
242        } else if let Some(Entry::NonLeaf { child, .. }) = node.entries.first() {
243            1 + self.calculate_height(child)
244        } else {
245            1
246        }
247    }
248
249    /// Calculate the total overlap in the R-tree
250    ///
251    /// This is a quality metric for the tree. Lower overlap generally means
252    /// better query performance.
253    ///
254    /// # Returns
255    ///
256    /// The total overlap area between all pairs of nodes at each level
257    pub fn calculate_total_overlap(&self) -> SpatialResult<f64> {
258        let mut total_overlap = 0.0;
259
260        // Calculate overlap at each level, starting from the root
261        let mut current_level_nodes = vec![&self.root];
262
263        while !current_level_nodes.is_empty() {
264            // Calculate overlap between nodes at this level
265            for i in 0..current_level_nodes.len() - 1 {
266                let node_i_mbr = match current_level_nodes[i].mbr() {
267                    Ok(Some(mbr)) => mbr,
268                    _ => continue,
269                };
270
271                for node_j in current_level_nodes.iter().skip(i + 1) {
272                    let node_j_mbr = match node_j.mbr() {
273                        Ok(Some(mbr)) => mbr,
274                        _ => continue,
275                    };
276
277                    // Check if MBRs intersect
278                    if node_i_mbr.intersects(&node_j_mbr)? {
279                        // Calculate intersection area
280                        if let Ok(intersection) = node_i_mbr.intersection(&node_j_mbr) {
281                            total_overlap += intersection.area();
282                        }
283                    }
284                }
285            }
286
287            // Move to the next level
288            let mut next_level_nodes = Vec::new();
289            for node in current_level_nodes {
290                for entry in &node.entries {
291                    if let Entry::NonLeaf { child, .. } = entry {
292                        next_level_nodes.push(&**child);
293                    }
294                }
295            }
296
297            current_level_nodes = next_level_nodes;
298        }
299
300        Ok(total_overlap)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use scirs2_core::ndarray::array;
308
309    #[test]
310    fn test_rtree_optimize() {
311        // Create a new R-tree
312        let mut rtree: RTree<i32> = RTree::new(2, 2, 4).unwrap();
313
314        // Insert some points
315        let points = vec![
316            (array![0.0, 0.0], 0),
317            (array![1.0, 0.0], 1),
318            (array![0.0, 1.0], 2),
319            (array![1.0, 1.0], 3),
320            (array![0.5, 0.5], 4),
321            (array![2.0, 2.0], 5),
322            (array![3.0, 3.0], 6),
323            (array![4.0, 4.0], 7),
324            (array![5.0, 5.0], 8),
325            (array![6.0, 6.0], 9),
326        ];
327
328        for (point, value) in points {
329            rtree.insert(point, value).unwrap();
330        }
331
332        // Optimize the tree
333        rtree.optimize().unwrap();
334
335        // Check that all data is still present
336        assert_eq!(rtree.size(), 10);
337
338        // Try to search for a point
339        let results = rtree
340            .search_range(&array![0.4, 0.4].view(), &array![0.6, 0.6].view())
341            .unwrap();
342
343        assert_eq!(results.len(), 1);
344        assert_eq!(results[0].1, 4);
345    }
346
347    #[test]
348    fn test_rtree_bulk_load() {
349        // Create points
350        let points = vec![
351            (array![0.0, 0.0], 0),
352            (array![1.0, 0.0], 1),
353            (array![0.0, 1.0], 2),
354            (array![1.0, 1.0], 3),
355            (array![0.5, 0.5], 4),
356            (array![2.0, 2.0], 5),
357            (array![3.0, 3.0], 6),
358            (array![4.0, 4.0], 7),
359            (array![5.0, 5.0], 8),
360            (array![6.0, 6.0], 9),
361        ];
362
363        // Bulk load
364        let rtree = RTree::bulk_load(2, 2, 4, points).unwrap();
365
366        // Check that all data is present
367        assert_eq!(rtree.size(), 10);
368
369        // Try to search for a point
370        let results = rtree
371            .search_range(&array![0.4, 0.4].view(), &array![0.6, 0.6].view())
372            .unwrap();
373
374        assert_eq!(results.len(), 1);
375        assert_eq!(results[0].1, 4);
376    }
377}