use crate::error::Result;
use oxigdal_core::vector::Point;
use rstar::{AABB, PointDistance, RTree, RTreeObject};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SpatialJoinPredicate {
Intersects,
Contains,
Within,
Touches,
WithinDistance,
}
#[derive(Debug, Clone)]
pub struct SpatialJoinOptions {
pub predicate: SpatialJoinPredicate,
pub distance: f64,
pub use_index: bool,
}
impl Default for SpatialJoinOptions {
fn default() -> Self {
Self {
predicate: SpatialJoinPredicate::Intersects,
distance: 0.0,
use_index: true,
}
}
}
#[derive(Debug, Clone)]
pub struct SpatialJoinResult {
pub matches: Vec<(usize, usize)>,
pub num_matches: usize,
}
#[derive(Debug, Clone)]
struct IndexedPoint {
point: Point,
index: usize,
}
impl RTreeObject for IndexedPoint {
type Envelope = AABB<[f64; 2]>;
fn envelope(&self) -> Self::Envelope {
AABB::from_point([self.point.coord.x, self.point.coord.y])
}
}
impl PointDistance for IndexedPoint {
fn distance_2(&self, point: &[f64; 2]) -> f64 {
let dx = self.point.coord.x - point[0];
let dy = self.point.coord.y - point[1];
dx * dx + dy * dy
}
}
pub fn spatial_join_points(
left_points: &[Point],
right_points: &[Point],
options: &SpatialJoinOptions,
) -> Result<SpatialJoinResult> {
if left_points.is_empty() || right_points.is_empty() {
return Ok(SpatialJoinResult {
matches: Vec::new(),
num_matches: 0,
});
}
let matches = if options.use_index {
let indexed_points: Vec<IndexedPoint> = right_points
.iter()
.enumerate()
.map(|(idx, point)| IndexedPoint {
point: point.clone(),
index: idx,
})
.collect();
let rtree = RTree::bulk_load(indexed_points);
let mut all_matches = Vec::new();
for (left_idx, left_point) in left_points.iter().enumerate() {
let nearby = match options.predicate {
SpatialJoinPredicate::WithinDistance => {
let envelope = AABB::from_corners(
[
left_point.coord.x - options.distance,
left_point.coord.y - options.distance,
],
[
left_point.coord.x + options.distance,
left_point.coord.y + options.distance,
],
);
rtree
.locate_in_envelope(&envelope)
.filter(|indexed| {
point_distance(left_point, &indexed.point) <= options.distance
})
.map(|indexed| indexed.index)
.collect::<Vec<_>>()
}
SpatialJoinPredicate::Intersects => {
let mut matches = Vec::new();
for indexed in rtree.locate_at_point(&[left_point.coord.x, left_point.coord.y])
{
matches.push(indexed.index);
}
matches
}
_ => {
Vec::new()
}
};
for right_idx in nearby {
all_matches.push((left_idx, right_idx));
}
}
all_matches
} else {
let mut all_matches = Vec::new();
for (left_idx, left_point) in left_points.iter().enumerate() {
for (right_idx, right_point) in right_points.iter().enumerate() {
if matches_predicate(left_point, right_point, options) {
all_matches.push((left_idx, right_idx));
}
}
}
all_matches
};
Ok(SpatialJoinResult {
num_matches: matches.len(),
matches,
})
}
fn matches_predicate(left: &Point, right: &Point, options: &SpatialJoinOptions) -> bool {
match options.predicate {
SpatialJoinPredicate::Intersects => {
(left.coord.x - right.coord.x).abs() < 1e-10
&& (left.coord.y - right.coord.y).abs() < 1e-10
}
SpatialJoinPredicate::WithinDistance => point_distance(left, right) <= options.distance,
_ => false,
}
}
fn point_distance(p1: &Point, p2: &Point) -> f64 {
let dx = p1.coord.x - p2.coord.x;
let dy = p1.coord.y - p2.coord.y;
(dx * dx + dy * dy).sqrt()
}
pub fn nearest_neighbor(query: &Point, points: &[Point]) -> Option<(usize, f64)> {
let indexed_points: Vec<IndexedPoint> = points
.iter()
.enumerate()
.map(|(idx, point)| IndexedPoint {
point: point.clone(),
index: idx,
})
.collect();
if indexed_points.is_empty() {
return None;
}
let rtree = RTree::bulk_load(indexed_points);
let nearest = rtree.nearest_neighbor(&[query.coord.x, query.coord.y])?;
let distance = point_distance(query, &nearest.point);
Some((nearest.index, distance))
}
pub fn k_nearest_neighbors(query: &Point, points: &[Point], k: usize) -> Vec<(usize, f64)> {
let indexed_points: Vec<IndexedPoint> = points
.iter()
.enumerate()
.map(|(idx, point)| IndexedPoint {
point: point.clone(),
index: idx,
})
.collect();
if indexed_points.is_empty() {
return Vec::new();
}
let rtree = RTree::bulk_load(indexed_points);
rtree
.nearest_neighbor_iter(&[query.coord.x, query.coord.y])
.take(k)
.map(|indexed| {
let dist = point_distance(query, &indexed.point);
(indexed.index, dist)
})
.collect()
}
pub fn range_query(query: &Point, points: &[Point], distance: f64) -> Vec<usize> {
let options = SpatialJoinOptions {
predicate: SpatialJoinPredicate::WithinDistance,
distance,
use_index: true,
};
let result = spatial_join_points(std::slice::from_ref(query), points, &options);
result
.map(|r| {
r.matches
.into_iter()
.map(|(_, right_idx)| right_idx)
.collect()
})
.unwrap_or_else(|_| Vec::new())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spatial_join_within_distance() {
let left = vec![Point::new(0.0, 0.0), Point::new(10.0, 10.0)];
let right = vec![Point::new(0.1, 0.1), Point::new(5.0, 5.0)];
let options = SpatialJoinOptions {
predicate: SpatialJoinPredicate::WithinDistance,
distance: 0.5,
use_index: true,
};
let result = spatial_join_points(&left, &right, &options);
assert!(result.is_ok());
let join_result = result.expect("Join failed");
assert!(join_result.num_matches >= 1);
}
#[test]
fn test_nearest_neighbor() {
let points = vec![
Point::new(0.0, 0.0),
Point::new(5.0, 5.0),
Point::new(10.0, 10.0),
];
let query = Point::new(0.1, 0.1);
let result = nearest_neighbor(&query, &points);
assert!(result.is_some());
let (idx, dist) = result.expect("Nearest neighbor failed");
assert_eq!(idx, 0);
assert!(dist < 0.2);
}
#[test]
fn test_k_nearest_neighbors() {
let points = vec![
Point::new(0.0, 0.0),
Point::new(1.0, 1.0),
Point::new(2.0, 2.0),
Point::new(10.0, 10.0),
];
let query = Point::new(0.0, 0.0);
let result = k_nearest_neighbors(&query, &points, 2);
assert_eq!(result.len(), 2);
assert_eq!(result[0].0, 0); }
#[test]
fn test_range_query() {
let points = vec![
Point::new(0.0, 0.0),
Point::new(0.5, 0.5),
Point::new(10.0, 10.0),
];
let query = Point::new(0.0, 0.0);
let result = range_query(&query, &points, 1.0);
assert!(result.len() >= 2); }
#[test]
fn test_point_distance() {
let p1 = Point::new(0.0, 0.0);
let p2 = Point::new(3.0, 4.0);
let dist = point_distance(&p1, &p2);
assert!((dist - 5.0).abs() < 1e-6);
}
}