use nodedb_types::geometry::Geometry;
use nodedb_types::{BoundingBox, geometry_bbox};
use crate::predicates;
use crate::rtree::{RTree, RTreeEntry};
pub struct SpatialJoinResult {
pub pairs: Vec<(u64, u64)>,
pub probes: usize,
pub exact_evals: usize,
}
pub fn spatial_join(
indexed_side: &RTree,
probe_entries: &[(u64, BoundingBox)],
get_indexed_geom: &dyn Fn(u64) -> Option<Geometry>,
get_probe_geom: &dyn Fn(u64) -> Option<Geometry>,
predicate: SpatialJoinPredicate,
) -> SpatialJoinResult {
let mut pairs = Vec::new();
let mut probes = 0;
let mut exact_evals = 0;
for &(probe_id, ref probe_bbox) in probe_entries {
let candidates = indexed_side.search(probe_bbox);
probes += 1;
for candidate in &candidates {
let Some(indexed_geom) = get_indexed_geom(candidate.id) else {
continue;
};
let Some(probe_geom) = get_probe_geom(probe_id) else {
continue;
};
exact_evals += 1;
let matches = match predicate {
SpatialJoinPredicate::Intersects => {
predicates::st_intersects(&probe_geom, &indexed_geom)
}
SpatialJoinPredicate::Contains => {
predicates::st_contains(&probe_geom, &indexed_geom)
}
SpatialJoinPredicate::Within => predicates::st_within(&probe_geom, &indexed_geom),
SpatialJoinPredicate::DWithin(dist) => {
predicates::st_dwithin(&probe_geom, &indexed_geom, dist)
}
};
if matches {
pairs.push((probe_id, candidate.id));
}
}
}
SpatialJoinResult {
pairs,
probes,
exact_evals,
}
}
pub fn build_join_index(entries: &[(u64, Geometry)]) -> RTree {
let rtree_entries: Vec<RTreeEntry> = entries
.iter()
.map(|(id, geom)| RTreeEntry {
id: *id,
bbox: geometry_bbox(geom),
})
.collect();
RTree::bulk_load(rtree_entries)
}
#[derive(Debug, Clone, Copy)]
pub enum SpatialJoinPredicate {
Intersects,
Contains,
Within,
DWithin(f64),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn join_overlapping_squares() {
let left: Vec<(u64, Geometry)> = (0..5)
.map(|i| {
let base = (i * 5) as f64;
(
i as u64,
Geometry::polygon(vec![vec![
[base, base],
[base + 4.0, base],
[base + 4.0, base + 4.0],
[base, base + 4.0],
[base, base],
]]),
)
})
.collect();
let right: Vec<(u64, Geometry)> = (0..5)
.map(|i| {
let base = (i * 5 + 2) as f64;
(
100 + i as u64,
Geometry::polygon(vec![vec![
[base, base],
[base + 4.0, base],
[base + 4.0, base + 4.0],
[base, base + 4.0],
[base, base],
]]),
)
})
.collect();
let index = build_join_index(&left);
let probe_entries: Vec<(u64, BoundingBox)> = right
.iter()
.map(|(id, geom)| (*id, geometry_bbox(geom)))
.collect();
let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
let result = spatial_join(
&index,
&probe_entries,
&|id| left_map.get(&id).cloned(),
&|id| right_map.get(&id).cloned(),
SpatialJoinPredicate::Intersects,
);
assert!(!result.pairs.is_empty(), "expected some join matches");
assert!(result.probes == 5); }
#[test]
fn join_no_overlap() {
let left = vec![(
1u64,
Geometry::polygon(vec![vec![
[0.0, 0.0],
[1.0, 0.0],
[1.0, 1.0],
[0.0, 1.0],
[0.0, 0.0],
]]),
)];
let right = vec![(
100u64,
Geometry::polygon(vec![vec![
[50.0, 50.0],
[51.0, 50.0],
[51.0, 51.0],
[50.0, 51.0],
[50.0, 50.0],
]]),
)];
let index = build_join_index(&left);
let probes: Vec<(u64, BoundingBox)> = right
.iter()
.map(|(id, g)| (*id, geometry_bbox(g)))
.collect();
let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
let result = spatial_join(
&index,
&probes,
&|id| left_map.get(&id).cloned(),
&|id| right_map.get(&id).cloned(),
SpatialJoinPredicate::Intersects,
);
assert!(result.pairs.is_empty());
}
#[test]
fn join_with_dwithin() {
let left = vec![(1u64, Geometry::point(0.0, 0.0))];
let right = vec![(100u64, Geometry::point(0.001, 0.0))];
let index = build_join_index(&left);
let probes: Vec<(u64, BoundingBox)> = right
.iter()
.map(|(id, g)| {
(*id, geometry_bbox(g).expand_meters(500.0))
})
.collect();
let left_map: std::collections::HashMap<u64, Geometry> = left.into_iter().collect();
let right_map: std::collections::HashMap<u64, Geometry> = right.into_iter().collect();
let result = spatial_join(
&index,
&probes,
&|id| left_map.get(&id).cloned(),
&|id| right_map.get(&id).cloned(),
SpatialJoinPredicate::DWithin(500.0), );
assert_eq!(result.pairs.len(), 1);
}
}