oxigdal_algorithms/vector/
spatial_join.rs1use crate::error::Result;
6use oxigdal_core::vector::Point;
7use rstar::{AABB, PointDistance, RTree, RTreeObject};
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum SpatialJoinPredicate {
12 Intersects,
14 Contains,
16 Within,
18 Touches,
20 WithinDistance,
22}
23
24#[derive(Debug, Clone)]
26pub struct SpatialJoinOptions {
27 pub predicate: SpatialJoinPredicate,
29 pub distance: f64,
31 pub use_index: bool,
33}
34
35impl Default for SpatialJoinOptions {
36 fn default() -> Self {
37 Self {
38 predicate: SpatialJoinPredicate::Intersects,
39 distance: 0.0,
40 use_index: true,
41 }
42 }
43}
44
45#[derive(Debug, Clone)]
47pub struct SpatialJoinResult {
48 pub matches: Vec<(usize, usize)>,
50 pub num_matches: usize,
52}
53
54#[derive(Debug, Clone)]
56struct IndexedPoint {
57 point: Point,
58 index: usize,
59}
60
61impl RTreeObject for IndexedPoint {
62 type Envelope = AABB<[f64; 2]>;
63
64 fn envelope(&self) -> Self::Envelope {
65 AABB::from_point([self.point.coord.x, self.point.coord.y])
66 }
67}
68
69impl PointDistance for IndexedPoint {
70 fn distance_2(&self, point: &[f64; 2]) -> f64 {
71 let dx = self.point.coord.x - point[0];
72 let dy = self.point.coord.y - point[1];
73 dx * dx + dy * dy
74 }
75}
76
77pub fn spatial_join_points(
119 left_points: &[Point],
120 right_points: &[Point],
121 options: &SpatialJoinOptions,
122) -> Result<SpatialJoinResult> {
123 if left_points.is_empty() || right_points.is_empty() {
124 return Ok(SpatialJoinResult {
125 matches: Vec::new(),
126 num_matches: 0,
127 });
128 }
129
130 let matches = if options.use_index {
131 let indexed_points: Vec<IndexedPoint> = right_points
133 .iter()
134 .enumerate()
135 .map(|(idx, point)| IndexedPoint {
136 point: point.clone(),
137 index: idx,
138 })
139 .collect();
140
141 let rtree = RTree::bulk_load(indexed_points);
142
143 let mut all_matches = Vec::new();
145
146 for (left_idx, left_point) in left_points.iter().enumerate() {
147 let nearby = match options.predicate {
148 SpatialJoinPredicate::WithinDistance => {
149 let envelope = AABB::from_corners(
151 [
152 left_point.coord.x - options.distance,
153 left_point.coord.y - options.distance,
154 ],
155 [
156 left_point.coord.x + options.distance,
157 left_point.coord.y + options.distance,
158 ],
159 );
160
161 rtree
162 .locate_in_envelope(&envelope)
163 .filter(|indexed| {
164 point_distance(left_point, &indexed.point) <= options.distance
165 })
166 .map(|indexed| indexed.index)
167 .collect::<Vec<_>>()
168 }
169 SpatialJoinPredicate::Intersects => {
170 let mut matches = Vec::new();
172 for indexed in rtree.locate_at_point(&[left_point.coord.x, left_point.coord.y])
173 {
174 matches.push(indexed.index);
175 }
176 matches
177 }
178 _ => {
179 Vec::new()
181 }
182 };
183
184 for right_idx in nearby {
185 all_matches.push((left_idx, right_idx));
186 }
187 }
188
189 all_matches
190 } else {
191 let mut all_matches = Vec::new();
193
194 for (left_idx, left_point) in left_points.iter().enumerate() {
195 for (right_idx, right_point) in right_points.iter().enumerate() {
196 if matches_predicate(left_point, right_point, options) {
197 all_matches.push((left_idx, right_idx));
198 }
199 }
200 }
201
202 all_matches
203 };
204
205 Ok(SpatialJoinResult {
206 num_matches: matches.len(),
207 matches,
208 })
209}
210
211fn matches_predicate(left: &Point, right: &Point, options: &SpatialJoinOptions) -> bool {
213 match options.predicate {
214 SpatialJoinPredicate::Intersects => {
215 (left.coord.x - right.coord.x).abs() < 1e-10
216 && (left.coord.y - right.coord.y).abs() < 1e-10
217 }
218 SpatialJoinPredicate::WithinDistance => point_distance(left, right) <= options.distance,
219 _ => false,
220 }
221}
222
223fn point_distance(p1: &Point, p2: &Point) -> f64 {
225 let dx = p1.coord.x - p2.coord.x;
226 let dy = p1.coord.y - p2.coord.y;
227 (dx * dx + dy * dy).sqrt()
228}
229
230pub fn nearest_neighbor(query: &Point, points: &[Point]) -> Option<(usize, f64)> {
232 let indexed_points: Vec<IndexedPoint> = points
233 .iter()
234 .enumerate()
235 .map(|(idx, point)| IndexedPoint {
236 point: point.clone(),
237 index: idx,
238 })
239 .collect();
240
241 if indexed_points.is_empty() {
242 return None;
243 }
244
245 let rtree = RTree::bulk_load(indexed_points);
246 let nearest = rtree.nearest_neighbor(&[query.coord.x, query.coord.y])?;
247
248 let distance = point_distance(query, &nearest.point);
249
250 Some((nearest.index, distance))
251}
252
253pub fn k_nearest_neighbors(query: &Point, points: &[Point], k: usize) -> Vec<(usize, f64)> {
255 let indexed_points: Vec<IndexedPoint> = points
256 .iter()
257 .enumerate()
258 .map(|(idx, point)| IndexedPoint {
259 point: point.clone(),
260 index: idx,
261 })
262 .collect();
263
264 if indexed_points.is_empty() {
265 return Vec::new();
266 }
267
268 let rtree = RTree::bulk_load(indexed_points);
269
270 rtree
271 .nearest_neighbor_iter(&[query.coord.x, query.coord.y])
272 .take(k)
273 .map(|indexed| {
274 let dist = point_distance(query, &indexed.point);
275 (indexed.index, dist)
276 })
277 .collect()
278}
279
280pub fn range_query(query: &Point, points: &[Point], distance: f64) -> Vec<usize> {
282 let options = SpatialJoinOptions {
283 predicate: SpatialJoinPredicate::WithinDistance,
284 distance,
285 use_index: true,
286 };
287
288 let result = spatial_join_points(std::slice::from_ref(query), points, &options);
289
290 result
291 .map(|r| {
292 r.matches
293 .into_iter()
294 .map(|(_, right_idx)| right_idx)
295 .collect()
296 })
297 .unwrap_or_else(|_| Vec::new())
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303
304 #[test]
305 fn test_spatial_join_within_distance() {
306 let left = vec![Point::new(0.0, 0.0), Point::new(10.0, 10.0)];
307
308 let right = vec![Point::new(0.1, 0.1), Point::new(5.0, 5.0)];
309
310 let options = SpatialJoinOptions {
311 predicate: SpatialJoinPredicate::WithinDistance,
312 distance: 0.5,
313 use_index: true,
314 };
315
316 let result = spatial_join_points(&left, &right, &options);
317 assert!(result.is_ok());
318
319 let join_result = result.expect("Join failed");
320 assert!(join_result.num_matches >= 1);
321 }
322
323 #[test]
324 fn test_nearest_neighbor() {
325 let points = vec![
326 Point::new(0.0, 0.0),
327 Point::new(5.0, 5.0),
328 Point::new(10.0, 10.0),
329 ];
330
331 let query = Point::new(0.1, 0.1);
332 let result = nearest_neighbor(&query, &points);
333
334 assert!(result.is_some());
335
336 let (idx, dist) = result.expect("Nearest neighbor failed");
337 assert_eq!(idx, 0);
338 assert!(dist < 0.2);
339 }
340
341 #[test]
342 fn test_k_nearest_neighbors() {
343 let points = vec![
344 Point::new(0.0, 0.0),
345 Point::new(1.0, 1.0),
346 Point::new(2.0, 2.0),
347 Point::new(10.0, 10.0),
348 ];
349
350 let query = Point::new(0.0, 0.0);
351 let result = k_nearest_neighbors(&query, &points, 2);
352
353 assert_eq!(result.len(), 2);
354 assert_eq!(result[0].0, 0); }
356
357 #[test]
358 fn test_range_query() {
359 let points = vec![
360 Point::new(0.0, 0.0),
361 Point::new(0.5, 0.5),
362 Point::new(10.0, 10.0),
363 ];
364
365 let query = Point::new(0.0, 0.0);
366 let result = range_query(&query, &points, 1.0);
367
368 assert!(result.len() >= 2); }
370
371 #[test]
372 fn test_point_distance() {
373 let p1 = Point::new(0.0, 0.0);
374 let p2 = Point::new(3.0, 4.0);
375
376 let dist = point_distance(&p1, &p2);
377 assert!((dist - 5.0).abs() < 1e-6);
378 }
379}