Skip to main content

nodedb_spatial/
spatial_join.rs

1//! Spatial join strategy: R-tree probe join.
2//!
3//! When two collections are joined on a spatial predicate
4//! (`ST_Intersects(a.geom, b.geom)`), the naive approach is nested-loop
5//! O(N*M). With an R-tree index on one side, we get O(N * log M):
6//!
7//! 1. Build R-tree on the smaller collection (or use existing index)
8//! 2. For each geometry in the larger collection, R-tree range search
9//! 3. For each candidate pair, apply exact predicate
10//!
11//! This module provides the join logic; the planner decides which side
12//! to index based on collection cardinality.
13
14use nodedb_types::geometry::Geometry;
15use nodedb_types::{BoundingBox, geometry_bbox};
16
17use crate::predicates;
18use crate::rtree::{RTree, RTreeEntry};
19
20/// Result of a spatial join between two collections.
21pub struct SpatialJoinResult {
22    /// Matched pairs: (left_entry_id, right_entry_id).
23    pub pairs: Vec<(u64, u64)>,
24    /// Number of R-tree probes performed.
25    pub probes: usize,
26    /// Number of exact predicate evaluations (after R-tree filter).
27    pub exact_evals: usize,
28}
29
30/// Execute a spatial join using R-tree probe.
31///
32/// `indexed_side`: R-tree built on one collection.
33/// `probe_side`: entries from the other collection to probe against.
34/// `get_geometry`: callback to retrieve the full geometry for an entry ID
35///   (needed for exact predicate evaluation after R-tree bbox filter).
36/// `predicate`: which spatial predicate to apply (intersects, contains, etc.).
37pub fn spatial_join(
38    indexed_side: &RTree,
39    probe_entries: &[(u64, BoundingBox)],
40    get_indexed_geom: &dyn Fn(u64) -> Option<Geometry>,
41    get_probe_geom: &dyn Fn(u64) -> Option<Geometry>,
42    predicate: SpatialJoinPredicate,
43) -> SpatialJoinResult {
44    let mut pairs = Vec::new();
45    let mut probes = 0;
46    let mut exact_evals = 0;
47
48    for &(probe_id, ref probe_bbox) in probe_entries {
49        // R-tree range search: find indexed entries whose bbox intersects probe bbox.
50        let candidates = indexed_side.search(probe_bbox);
51        probes += 1;
52
53        for candidate in &candidates {
54            // Exact predicate evaluation.
55            let Some(indexed_geom) = get_indexed_geom(candidate.id) else {
56                continue;
57            };
58            let Some(probe_geom) = get_probe_geom(probe_id) else {
59                continue;
60            };
61            exact_evals += 1;
62
63            let matches = match predicate {
64                SpatialJoinPredicate::Intersects => {
65                    predicates::st_intersects(&probe_geom, &indexed_geom)
66                }
67                SpatialJoinPredicate::Contains => {
68                    predicates::st_contains(&probe_geom, &indexed_geom)
69                }
70                SpatialJoinPredicate::Within => predicates::st_within(&probe_geom, &indexed_geom),
71                SpatialJoinPredicate::DWithin(dist) => {
72                    predicates::st_dwithin(&probe_geom, &indexed_geom, dist)
73                }
74            };
75
76            if matches {
77                pairs.push((probe_id, candidate.id));
78            }
79        }
80    }
81
82    SpatialJoinResult {
83        pairs,
84        probes,
85        exact_evals,
86    }
87}
88
89/// Build an R-tree from a list of (entry_id, geometry) pairs for join.
90pub fn build_join_index(entries: &[(u64, Geometry)]) -> RTree {
91    let rtree_entries: Vec<RTreeEntry> = entries
92        .iter()
93        .map(|(id, geom)| RTreeEntry {
94            id: *id,
95            bbox: geometry_bbox(geom),
96        })
97        .collect();
98    RTree::bulk_load(rtree_entries)
99}
100
101/// Which spatial predicate to use for the join.
102#[derive(Debug, Clone, Copy)]
103pub enum SpatialJoinPredicate {
104    Intersects,
105    Contains,
106    Within,
107    DWithin(f64),
108}
109
110#[cfg(test)]
111mod tests {
112    use super::*;
113
114    #[test]
115    fn join_overlapping_squares() {
116        // Left side: 5 squares at positions (0,0), (5,5), (10,10), (15,15), (20,20).
117        let left: Vec<(u64, Geometry)> = (0..5)
118            .map(|i| {
119                let base = (i * 5) as f64;
120                (
121                    i as u64,
122                    Geometry::polygon(vec![vec![
123                        [base, base],
124                        [base + 4.0, base],
125                        [base + 4.0, base + 4.0],
126                        [base, base + 4.0],
127                        [base, base],
128                    ]]),
129                )
130            })
131            .collect();
132
133        // Right side: 5 squares offset by 2.
134        let right: Vec<(u64, Geometry)> = (0..5)
135            .map(|i| {
136                let base = (i * 5 + 2) as f64;
137                (
138                    100 + i as u64,
139                    Geometry::polygon(vec![vec![
140                        [base, base],
141                        [base + 4.0, base],
142                        [base + 4.0, base + 4.0],
143                        [base, base + 4.0],
144                        [base, base],
145                    ]]),
146                )
147            })
148            .collect();
149
150        // Build index on left side.
151        let index = build_join_index(&left);
152
153        // Probe with right side.
154        let probe_entries: Vec<(u64, BoundingBox)> = right
155            .iter()
156            .map(|(id, geom)| (*id, geometry_bbox(geom)))
157            .collect();
158
159        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
160        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
161
162        let result = spatial_join(
163            &index,
164            &probe_entries,
165            &|id| left_map.get(&id).cloned(),
166            &|id| right_map.get(&id).cloned(),
167            SpatialJoinPredicate::Intersects,
168        );
169
170        // Adjacent overlapping squares should produce matches.
171        assert!(!result.pairs.is_empty(), "expected some join matches");
172        assert!(result.probes == 5); // One probe per right entry.
173    }
174
175    #[test]
176    fn join_no_overlap() {
177        let left = vec![(
178            1u64,
179            Geometry::polygon(vec![vec![
180                [0.0, 0.0],
181                [1.0, 0.0],
182                [1.0, 1.0],
183                [0.0, 1.0],
184                [0.0, 0.0],
185            ]]),
186        )];
187        let right = vec![(
188            100u64,
189            Geometry::polygon(vec![vec![
190                [50.0, 50.0],
191                [51.0, 50.0],
192                [51.0, 51.0],
193                [50.0, 51.0],
194                [50.0, 50.0],
195            ]]),
196        )];
197
198        let index = build_join_index(&left);
199        let probes: Vec<(u64, BoundingBox)> = right
200            .iter()
201            .map(|(id, g)| (*id, geometry_bbox(g)))
202            .collect();
203
204        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
205        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
206
207        let result = spatial_join(
208            &index,
209            &probes,
210            &|id| left_map.get(&id).cloned(),
211            &|id| right_map.get(&id).cloned(),
212            SpatialJoinPredicate::Intersects,
213        );
214
215        assert!(result.pairs.is_empty());
216    }
217
218    #[test]
219    fn join_with_dwithin() {
220        let left = vec![(1u64, Geometry::point(0.0, 0.0))];
221        let right = vec![(100u64, Geometry::point(0.001, 0.0))]; // ~111m away
222
223        let index = build_join_index(&left);
224        let probes: Vec<(u64, BoundingBox)> = right
225            .iter()
226            .map(|(id, g)| {
227                // Expand bbox by distance for R-tree search.
228                (*id, geometry_bbox(g).expand_meters(500.0))
229            })
230            .collect();
231
232        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
233        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
234
235        let result = spatial_join(
236            &index,
237            &probes,
238            &|id| left_map.get(&id).cloned(),
239            &|id| right_map.get(&id).cloned(),
240            SpatialJoinPredicate::DWithin(500.0), // 500m
241        );
242
243        assert_eq!(result.pairs.len(), 1);
244    }
245}