Skip to main content

nodedb_spatial/
spatial_join.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! Spatial join strategy: R-tree probe join.
4//!
5//! When two collections are joined on a spatial predicate
6//! (`ST_Intersects(a.geom, b.geom)`), the naive approach is nested-loop
7//! O(N*M). With an R-tree index on one side, we get O(N * log M):
8//!
9//! 1. Build R-tree on the smaller collection (or use existing index)
10//! 2. For each geometry in the larger collection, R-tree range search
11//! 3. For each candidate pair, apply exact predicate
12//!
13//! This module provides the join logic; the planner decides which side
14//! to index based on collection cardinality.
15
16use nodedb_types::geometry::Geometry;
17use nodedb_types::{BoundingBox, geometry_bbox};
18
19use crate::predicates;
20use crate::rtree::{RTree, RTreeEntry};
21
22/// Result of a spatial join between two collections.
23pub struct SpatialJoinResult {
24    /// Matched pairs: (left_entry_id, right_entry_id).
25    pub pairs: Vec<(u64, u64)>,
26    /// Number of R-tree probes performed.
27    pub probes: usize,
28    /// Number of exact predicate evaluations (after R-tree filter).
29    pub exact_evals: usize,
30}
31
32/// Execute a spatial join using R-tree probe.
33///
34/// `indexed_side`: R-tree built on one collection.
35/// `probe_side`: entries from the other collection to probe against.
36/// `get_geometry`: callback to retrieve the full geometry for an entry ID
37///   (needed for exact predicate evaluation after R-tree bbox filter).
38/// `predicate`: which spatial predicate to apply (intersects, contains, etc.).
39pub fn spatial_join(
40    indexed_side: &RTree,
41    probe_entries: &[(u64, BoundingBox)],
42    get_indexed_geom: &dyn Fn(u64) -> Option<Geometry>,
43    get_probe_geom: &dyn Fn(u64) -> Option<Geometry>,
44    predicate: SpatialJoinPredicate,
45) -> SpatialJoinResult {
46    let mut pairs = Vec::new();
47    let mut probes = 0;
48    let mut exact_evals = 0;
49
50    for &(probe_id, ref probe_bbox) in probe_entries {
51        // R-tree range search: find indexed entries whose bbox intersects probe bbox.
52        let candidates = indexed_side.search(probe_bbox);
53        probes += 1;
54
55        for candidate in &candidates {
56            // Exact predicate evaluation.
57            let Some(indexed_geom) = get_indexed_geom(candidate.id) else {
58                continue;
59            };
60            let Some(probe_geom) = get_probe_geom(probe_id) else {
61                continue;
62            };
63            exact_evals += 1;
64
65            let matches = match predicate {
66                SpatialJoinPredicate::Intersects => {
67                    predicates::st_intersects(&probe_geom, &indexed_geom)
68                }
69                SpatialJoinPredicate::Contains => {
70                    predicates::st_contains(&probe_geom, &indexed_geom)
71                }
72                SpatialJoinPredicate::Within => predicates::st_within(&probe_geom, &indexed_geom),
73                SpatialJoinPredicate::DWithin(dist) => {
74                    predicates::st_dwithin(&probe_geom, &indexed_geom, dist)
75                }
76            };
77
78            if matches {
79                pairs.push((probe_id, candidate.id));
80            }
81        }
82    }
83
84    SpatialJoinResult {
85        pairs,
86        probes,
87        exact_evals,
88    }
89}
90
91/// Build an R-tree from a list of (entry_id, geometry) pairs for join.
92pub fn build_join_index(entries: &[(u64, Geometry)]) -> RTree {
93    let rtree_entries: Vec<RTreeEntry> = entries
94        .iter()
95        .map(|(id, geom)| RTreeEntry {
96            id: *id,
97            bbox: geometry_bbox(geom),
98        })
99        .collect();
100    RTree::bulk_load(rtree_entries)
101}
102
103/// Which spatial predicate to use for the join.
104#[derive(Debug, Clone, Copy)]
105#[non_exhaustive]
106pub enum SpatialJoinPredicate {
107    Intersects,
108    Contains,
109    Within,
110    DWithin(f64),
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116
117    #[test]
118    fn join_overlapping_squares() {
119        // Left side: 5 squares at positions (0,0), (5,5), (10,10), (15,15), (20,20).
120        let left: Vec<(u64, Geometry)> = (0..5)
121            .map(|i| {
122                let base = (i * 5) as f64;
123                (
124                    i as u64,
125                    Geometry::polygon(vec![vec![
126                        [base, base],
127                        [base + 4.0, base],
128                        [base + 4.0, base + 4.0],
129                        [base, base + 4.0],
130                        [base, base],
131                    ]]),
132                )
133            })
134            .collect();
135
136        // Right side: 5 squares offset by 2.
137        let right: Vec<(u64, Geometry)> = (0..5)
138            .map(|i| {
139                let base = (i * 5 + 2) as f64;
140                (
141                    100 + i as u64,
142                    Geometry::polygon(vec![vec![
143                        [base, base],
144                        [base + 4.0, base],
145                        [base + 4.0, base + 4.0],
146                        [base, base + 4.0],
147                        [base, base],
148                    ]]),
149                )
150            })
151            .collect();
152
153        // Build index on left side.
154        let index = build_join_index(&left);
155
156        // Probe with right side.
157        let probe_entries: Vec<(u64, BoundingBox)> = right
158            .iter()
159            .map(|(id, geom)| (*id, geometry_bbox(geom)))
160            .collect();
161
162        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
163        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
164
165        let result = spatial_join(
166            &index,
167            &probe_entries,
168            &|id| left_map.get(&id).cloned(),
169            &|id| right_map.get(&id).cloned(),
170            SpatialJoinPredicate::Intersects,
171        );
172
173        // Adjacent overlapping squares should produce matches.
174        assert!(!result.pairs.is_empty(), "expected some join matches");
175        assert!(result.probes == 5); // One probe per right entry.
176    }
177
178    #[test]
179    fn join_no_overlap() {
180        let left = vec![(
181            1u64,
182            Geometry::polygon(vec![vec![
183                [0.0, 0.0],
184                [1.0, 0.0],
185                [1.0, 1.0],
186                [0.0, 1.0],
187                [0.0, 0.0],
188            ]]),
189        )];
190        let right = vec![(
191            100u64,
192            Geometry::polygon(vec![vec![
193                [50.0, 50.0],
194                [51.0, 50.0],
195                [51.0, 51.0],
196                [50.0, 51.0],
197                [50.0, 50.0],
198            ]]),
199        )];
200
201        let index = build_join_index(&left);
202        let probes: Vec<(u64, BoundingBox)> = right
203            .iter()
204            .map(|(id, g)| (*id, geometry_bbox(g)))
205            .collect();
206
207        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
208        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
209
210        let result = spatial_join(
211            &index,
212            &probes,
213            &|id| left_map.get(&id).cloned(),
214            &|id| right_map.get(&id).cloned(),
215            SpatialJoinPredicate::Intersects,
216        );
217
218        assert!(result.pairs.is_empty());
219    }
220
221    #[test]
222    fn join_with_dwithin() {
223        let left = vec![(1u64, Geometry::point(0.0, 0.0))];
224        let right = vec![(100u64, Geometry::point(0.001, 0.0))]; // ~111m away
225
226        let index = build_join_index(&left);
227        let probes: Vec<(u64, BoundingBox)> = right
228            .iter()
229            .map(|(id, g)| {
230                // Expand bbox by distance for R-tree search.
231                (*id, geometry_bbox(g).expand_meters(500.0))
232            })
233            .collect();
234
235        let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
236        let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
237
238        let result = spatial_join(
239            &index,
240            &probes,
241            &|id| left_map.get(&id).cloned(),
242            &|id| right_map.get(&id).cloned(),
243            SpatialJoinPredicate::DWithin(500.0), // 500m
244        );
245
246        assert_eq!(result.pairs.len(), 1);
247    }
248}