scirs2_spatial/rtree/
query.rs

1use crate::error::SpatialResult;
2use crate::rtree::node::{Entry, EntryWithDistance, Node, RTree, Rectangle};
3use scirs2_core::ndarray::ArrayView1;
4use std::cmp::Ordering;
5use std::collections::BinaryHeap;
6
7impl<T: Clone> RTree<T> {
8    /// Search for data points within a range
9    ///
10    /// # Arguments
11    ///
12    /// * `min` - Minimum coordinates of the search range
13    /// * `max` - Maximum coordinates of the search range
14    ///
15    /// # Returns
16    ///
17    /// A `SpatialResult` containing a vector of (index, data) pairs for data points within the range,
18    /// or an error if the range has invalid dimensions
19    pub fn search_range(
20        &self,
21        min: &ArrayView1<f64>,
22        max: &ArrayView1<f64>,
23    ) -> SpatialResult<Vec<(usize, T)>> {
24        if min.len() != self.ndim() || max.len() != self.ndim() {
25            return Err(crate::error::SpatialError::DimensionError(format!(
26                "Search range dimensions ({}, {}) do not match RTree dimension {}",
27                min.len(),
28                max.len(),
29                self.ndim()
30            )));
31        }
32
33        // Create a search rectangle
34        let rect = Rectangle::new(min.to_owned(), max.to_owned())?;
35
36        // Perform the search
37        let mut results = Vec::new();
38        self.search_range_internal(&rect, &self.root, &mut results)?;
39
40        Ok(results)
41    }
42
43    /// Recursively search for points within a range
44    #[allow(clippy::only_used_in_recursion)]
45    fn search_range_internal(
46        &self,
47        rect: &Rectangle,
48        node: &Node<T>,
49        results: &mut Vec<(usize, T)>,
50    ) -> SpatialResult<()> {
51        // Process each entry in the node
52        for entry in &node.entries {
53            // Check if this entry's MBR intersects with the search rectangle
54            if entry.mbr().intersects(rect)? {
55                match entry {
56                    // If this is a leaf entry, add the data to the results
57                    Entry::Leaf { data, index, .. } => {
58                        results.push((*index, data.clone()));
59                    }
60                    // If this is a non-leaf entry, recursively search its child
61                    Entry::NonLeaf { child, .. } => {
62                        self.search_range_internal(rect, child, results)?;
63                    }
64                }
65            }
66        }
67
68        Ok(())
69    }
70
71    /// Find the k nearest neighbors to a query point
72    ///
73    /// # Arguments
74    ///
75    /// * `point` - The query point
76    /// * `k` - The number of nearest neighbors to find
77    ///
78    /// # Returns
79    ///
80    /// A `SpatialResult` containing a vector of (index, data, distance) tuples for the k nearest data points,
81    /// sorted by distance (closest first), or an error if the point has invalid dimensions
82    pub fn nearest(
83        &self,
84        point: &ArrayView1<f64>,
85        k: usize,
86    ) -> SpatialResult<Vec<(usize, T, f64)>> {
87        if point.len() != self.ndim() {
88            return Err(crate::error::SpatialError::DimensionError(format!(
89                "Point dimension {} does not match RTree dimension {}",
90                point.len(),
91                self.ndim()
92            )));
93        }
94
95        if k == 0 || self.is_empty() {
96            return Ok(Vec::new());
97        }
98
99        // Use a priority queue to keep track of nodes to visit
100        let mut pq = BinaryHeap::new();
101        let mut results = Vec::new();
102
103        // Initialize with root node
104        if let Ok(Some(root_mbr)) = self.root.mbr() {
105            let _distance = root_mbr.min_distance_to_point(point)?;
106
107            // Add all entries from the root
108            for entry in &self.root.entries {
109                let entry_distance = entry.mbr().min_distance_to_point(point)?;
110                pq.push(EntryWithDistance {
111                    entry: entry.clone(),
112                    distance: entry_distance,
113                });
114            }
115        }
116
117        // Current maximum distance in the result set
118        let mut max_distance = f64::MAX;
119
120        // Process the priority queue
121        while let Some(item) = pq.pop() {
122            // If the minimum distance is greater than our current maximum, we can stop
123            if item.distance > max_distance && results.len() >= k {
124                break;
125            }
126
127            match item.entry {
128                // If this is a leaf entry, add it to the results
129                Entry::Leaf { data, index, .. } => {
130                    results.push((index, data, item.distance));
131
132                    // Update max_distance if we have enough results
133                    if results.len() >= k {
134                        // Sort results by distance
135                        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
136
137                        // Keep only the k closest
138                        results.truncate(k);
139
140                        // Update max_distance
141                        if let Some((_, _, dist)) = results.last() {
142                            max_distance = *dist;
143                        }
144                    }
145                }
146                // If this is a non-leaf entry, add its children to the queue
147                Entry::NonLeaf { child, .. } => {
148                    for entry in &child.entries {
149                        let entry_distance = entry.mbr().min_distance_to_point(point)?;
150
151                        // Only add entries that could be closer than our current maximum
152                        if entry_distance <= max_distance || results.len() < k {
153                            pq.push(EntryWithDistance {
154                                entry: entry.clone(),
155                                distance: entry_distance,
156                            });
157                        }
158                    }
159                }
160            }
161        }
162
163        // Sort final results by distance
164        results.sort_by(|a, b| a.2.partial_cmp(&b.2).unwrap_or(Ordering::Equal));
165
166        // Truncate to k results
167        results.truncate(k);
168
169        Ok(results)
170    }
171
172    /// Perform a spatial join between this R-tree and another
173    ///
174    /// # Arguments
175    ///
176    /// * `other` - The other R-tree to join with
177    /// * `predicate` - A function that takes MBRs from both trees and returns true
178    ///   if they should be joined, e.g., for an intersection join: `|mbr1, mbr2| mbr1.intersects(mbr2)`
179    ///
180    /// # Returns
181    ///
182    /// A `SpatialResult` containing a vector of pairs of data from both trees that satisfy the predicate,
183    /// or an error if the R-trees have different dimensions
184    pub fn spatial_join<U, P>(&self, other: &RTree<U>, predicate: P) -> SpatialResult<Vec<(T, U)>>
185    where
186        U: Clone,
187        P: Fn(&Rectangle, &Rectangle) -> SpatialResult<bool>,
188    {
189        if self.ndim() != other.ndim() {
190            return Err(crate::error::SpatialError::DimensionError(format!(
191                "RTrees have different dimensions: {} and {}",
192                self.ndim(),
193                other.ndim()
194            )));
195        }
196
197        let mut results = Vec::new();
198
199        // If either tree is empty, return an empty result
200        if self.is_empty() || other.is_empty() {
201            return Ok(results);
202        }
203
204        // Perform the join
205        self.spatial_join_internal(&self.root, &other.root, &predicate, &mut results)?;
206
207        Ok(results)
208    }
209
210    /// Recursively perform a spatial join between two nodes
211    #[allow(clippy::only_used_in_recursion)]
212    fn spatial_join_internal<U, P>(
213        &self,
214        node1: &Node<T>,
215        node2: &Node<U>,
216        predicate: &P,
217        results: &mut Vec<(T, U)>,
218    ) -> SpatialResult<()>
219    where
220        U: Clone,
221        P: Fn(&Rectangle, &Rectangle) -> SpatialResult<bool>,
222    {
223        // Process each pair of entries
224        for entry1 in &node1.entries {
225            for entry2 in &node2.entries {
226                // Check if the entries satisfy the predicate
227                if predicate(entry1.mbr(), entry2.mbr())? {
228                    match (entry1, entry2) {
229                        // If both are leaf entries, add to results
230                        (Entry::Leaf { data: data1, .. }, Entry::Leaf { data: data2, .. }) => {
231                            results.push((data1.clone(), data2.clone()));
232                        }
233                        // If entry1 is a non-leaf, recurse with its children
234                        (Entry::NonLeaf { child: child1, .. }, Entry::Leaf { .. }) => {
235                            self.spatial_join_internal(
236                                child1,
237                                &Node {
238                                    entries: vec![entry2.clone()],
239                                    _isleaf: true,
240                                    level: 0,
241                                },
242                                predicate,
243                                results,
244                            )?;
245                        }
246                        // If entry2 is a non-leaf, recurse with its children
247                        (Entry::Leaf { .. }, Entry::NonLeaf { child: child2, .. }) => {
248                            self.spatial_join_internal(
249                                &Node {
250                                    entries: vec![entry1.clone()],
251                                    _isleaf: true,
252                                    level: 0,
253                                },
254                                child2,
255                                predicate,
256                                results,
257                            )?;
258                        }
259                        // If both are non-leaf entries, recurse with both children
260                        (
261                            Entry::NonLeaf { child: child1, .. },
262                            Entry::NonLeaf { child: child2, .. },
263                        ) => {
264                            self.spatial_join_internal(child1, child2, predicate, results)?;
265                        }
266                    }
267                }
268            }
269        }
270
271        Ok(())
272    }
273}
274
275#[cfg(test)]
276mod tests {
277    use super::*;
278    use approx::assert_relative_eq;
279    use scirs2_core::ndarray::array;
280
281    #[test]
282    fn test_rtree_nearest_neighbors() {
283        // Create a new R-tree
284        let mut rtree: RTree<i32> = RTree::new(2, 2, 4).unwrap();
285
286        // Insert some points
287        let points = vec![
288            (array![0.0, 0.0], 0),
289            (array![1.0, 0.0], 1),
290            (array![0.0, 1.0], 2),
291            (array![1.0, 1.0], 3),
292            (array![0.5, 0.5], 4),
293            (array![2.0, 2.0], 5),
294            (array![3.0, 3.0], 6),
295            (array![4.0, 4.0], 7),
296            (array![5.0, 5.0], 8),
297            (array![6.0, 6.0], 9),
298        ];
299
300        for (point, value) in points {
301            rtree.insert(point, value).unwrap();
302        }
303
304        // Find the nearest neighbor to (0.6, 0.6)
305        let nn_results = rtree.nearest(&array![0.6, 0.6].view(), 1).unwrap();
306
307        // Should be (0.5, 0.5)
308        assert_eq!(nn_results.len(), 1);
309        assert_eq!(nn_results[0].1, 4);
310
311        // Find the 3 nearest neighbors to (0.0, 0.0)
312        let nn_results = rtree.nearest(&array![0.0, 0.0].view(), 3).unwrap();
313
314        // Should be (0.0, 0.0), (1.0, 0.0), and (0.0, 1.0)
315        assert_eq!(nn_results.len(), 3);
316
317        // The results should be sorted by distance
318        assert_eq!(nn_results[0].1, 0); // (0.0, 0.0) - distance 0
319        assert_eq!(nn_results[1].1, 4); // (0.5, 0.5) - distance ~0.707
320
321        // The third one could be either (1.0, 0.0) or (0.0, 1.0) - distance 1.0
322        assert!(nn_results[2].1 == 1 || nn_results[2].1 == 2);
323
324        // Check distances
325        assert_relative_eq!(nn_results[0].2, 0.0);
326        assert_relative_eq!(
327            nn_results[1].2,
328            (0.5_f64.powi(2) + 0.5_f64.powi(2)).sqrt(),
329            epsilon = 1e-10
330        );
331        assert_relative_eq!(nn_results[2].2, 1.0);
332
333        // Test k=0
334        let nn_empty = rtree.nearest(&array![0.0, 0.0].view(), 0).unwrap();
335        assert_eq!(nn_empty.len(), 0);
336
337        // Test k > size
338        let nn_all = rtree.nearest(&array![0.0, 0.0].view(), 20).unwrap();
339        assert_eq!(nn_all.len(), 10); // Should return all points
340    }
341
342    #[test]
343    fn test_rtree_spatial_join() {
344        // Create two R-trees
345        let mut rtree1: RTree<i32> = RTree::new(2, 2, 4).unwrap();
346        let mut rtree2: RTree<char> = RTree::new(2, 2, 4).unwrap();
347
348        // Insert rectangles into the first R-tree
349        let rectangles1 = vec![
350            (array![0.0, 0.0], array![0.6, 0.6], 0),
351            (array![0.4, 0.0], array![1.0, 0.6], 1),
352            (array![0.0, 0.4], array![0.6, 1.0], 2),
353            (array![0.4, 0.4], array![1.0, 1.0], 3),
354        ];
355
356        for (min_corner, max_corner, value) in rectangles1 {
357            rtree1
358                .insert_rectangle(min_corner, max_corner, value)
359                .unwrap();
360        }
361
362        // Insert rectangles into the second R-tree
363        let rectangles2 = vec![
364            (array![0.3, 0.3], array![0.7, 0.7], 'A'),
365            (array![0.8, 0.3], array![1.2, 0.7], 'B'),
366            (array![0.3, 0.8], array![0.7, 1.2], 'C'),
367            (array![0.8, 0.8], array![1.2, 1.2], 'D'),
368        ];
369
370        for (min_corner, max_corner, value) in rectangles2 {
371            rtree2
372                .insert_rectangle(min_corner, max_corner, value)
373                .unwrap();
374        }
375
376        // Perform a spatial join with an intersection predicate
377        let join_results = rtree1
378            .spatial_join(&rtree2, |mbr1, mbr2| mbr1.intersects(mbr2))
379            .unwrap();
380
381        // There should be multiple pairs since several rectangles intersect
382        assert!(
383            !join_results.is_empty(),
384            "Expected spatial join to find intersecting rectangles"
385        );
386
387        // With the given rectangles:
388        // Rectangle A [0.3,0.3]x[0.7,0.7] intersects with all 4 rectangles (0,1,2,3)
389        // Rectangle B [0.8,0.3]x[1.2,0.7] intersects with rectangles 1 and 3
390        // Rectangle C [0.3,0.8]x[0.7,1.2] intersects with rectangles 2 and 3
391        // Rectangle D [0.8,0.8]x[1.2,1.2] intersects with rectangle 3
392        // Total expected intersections: 4 + 2 + 2 + 1 = 9
393        assert_eq!(
394            join_results.len(),
395            9,
396            "Expected 9 intersections, found {}",
397            join_results.len()
398        );
399
400        // Test a more restrictive join predicate (contains)
401        let strict_join_results = rtree1
402            .spatial_join(&rtree2, |mbr1, mbr2| mbr1.contains_rectangle(mbr2))
403            .unwrap();
404
405        // Should be fewer results than with just intersection
406        assert!(strict_join_results.len() <= join_results.len());
407    }
408}