1use std::collections::HashMap;
10
11use crate::backends::SpatialBackend;
12use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
13
14const MAX_ENTRIES: usize = 16;
15const MIN_ENTRIES: usize = 4;
16
17struct Entry<T, C, const D: usize> {
18 point: Point<C, D>,
19 payload: T,
20 id: EntryId,
21}
22
23enum RNode<T, C, const D: usize> {
24 Leaf {
25 bbox: BBox<C, D>,
26 entries: Vec<Entry<T, C, D>>,
27 },
28 Internal {
29 bbox: BBox<C, D>,
30 children: Vec<Box<RNode<T, C, D>>>,
31 },
32}
33
34impl<T, C: CoordType, const D: usize> RNode<T, C, D> {
35 fn bbox(&self) -> &BBox<C, D> {
36 match self {
37 RNode::Leaf { bbox, .. } | RNode::Internal { bbox, .. } => bbox,
38 }
39 }
40}
41
42fn merge_bbox<C: CoordType, const D: usize>(a: &BBox<C, D>, b: &BBox<C, D>) -> BBox<C, D> {
43 let mut min_c = [C::zero(); D];
44 let mut max_c = [C::zero(); D];
45 for d in 0..D {
46 min_c[d] = if a.min.coords()[d] < b.min.coords()[d] {
47 a.min.coords()[d]
48 } else {
49 b.min.coords()[d]
50 };
51 max_c[d] = if a.max.coords()[d] > b.max.coords()[d] {
52 a.max.coords()[d]
53 } else {
54 b.max.coords()[d]
55 };
56 }
57 BBox::new(Point::new(min_c), Point::new(max_c))
58}
59
60fn point_bbox<C: CoordType, const D: usize>(p: &Point<C, D>) -> BBox<C, D> {
61 BBox::new(*p, *p)
62}
63
64fn bbox_area<C: CoordType, const D: usize>(bbox: &BBox<C, D>) -> f64 {
65 let mut area = 1.0_f64;
66 for d in 0..D {
67 let span: f64 = (bbox.max.coords()[d] - bbox.min.coords()[d]).into();
68 area *= span.max(0.0);
69 }
70 area
71}
72
73fn enlargement<C: CoordType, const D: usize>(existing: &BBox<C, D>, new_bbox: &BBox<C, D>) -> f64 {
74 bbox_area(&merge_bbox(existing, new_bbox)) - bbox_area(existing)
75}
76
77fn recompute_leaf_bbox<T, C: CoordType, const D: usize>(entries: &[Entry<T, C, D>]) -> BBox<C, D> {
78 if entries.is_empty() {
79 return BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
80 }
81 let mut bbox = point_bbox(&entries[0].point);
82 for e in entries.iter().skip(1) {
83 bbox = merge_bbox(&bbox, &point_bbox(&e.point));
84 }
85 bbox
86}
87
88fn recompute_internal_bbox<T, C: CoordType, const D: usize>(
89 children: &[Box<RNode<T, C, D>>],
90) -> BBox<C, D> {
91 if children.is_empty() {
92 return BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
93 }
94 let mut bbox = *children[0].bbox();
95 for c in children.iter().skip(1) {
96 bbox = merge_bbox(&bbox, c.bbox());
97 }
98 bbox
99}
100
101pub(crate) fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
102 let mut sum = 0.0_f64;
103 for d in 0..D {
104 let da: f64 = a.coords()[d].into();
105 let db: f64 = b.coords()[d].into();
106 let diff = da - db;
107 sum += diff * diff;
108 }
109 sum
110}
111
112fn pick_seeds_leaf<T, C: CoordType, const D: usize>(entries: &[Entry<T, C, D>]) -> (usize, usize) {
113 let n = entries.len();
114 let (mut si, mut sj) = (0, 1);
115 let mut worst = f64::NEG_INFINITY;
116 for i in 0..n {
117 for j in (i + 1)..n {
118 let combined = merge_bbox(
119 &point_bbox(&entries[i].point),
120 &point_bbox(&entries[j].point),
121 );
122 let waste = bbox_area(&combined)
123 - bbox_area(&point_bbox(&entries[i].point))
124 - bbox_area(&point_bbox(&entries[j].point));
125 if waste > worst {
126 worst = waste;
127 si = i;
128 sj = j;
129 }
130 }
131 }
132 (si, sj)
133}
134
135fn pick_seeds_internal<T, C: CoordType, const D: usize>(
136 children: &[Box<RNode<T, C, D>>],
137) -> (usize, usize) {
138 let n = children.len();
139 let (mut si, mut sj) = (0, 1);
140 let mut worst = f64::NEG_INFINITY;
141 for i in 0..n {
142 for j in (i + 1)..n {
143 let combined = merge_bbox(children[i].bbox(), children[j].bbox());
144 let waste = bbox_area(&combined)
145 - bbox_area(children[i].bbox())
146 - bbox_area(children[j].bbox());
147 if waste > worst {
148 worst = waste;
149 si = i;
150 sj = j;
151 }
152 }
153 }
154 (si, sj)
155}
156
157fn assign_to_group<C: CoordType, const D: usize>(
158 eb: &BBox<C, D>,
159 bbox_a: &mut BBox<C, D>,
160 bbox_b: &mut BBox<C, D>,
161 len_a: usize,
162 len_b: usize,
163) -> bool {
164 let d1 = enlargement(bbox_a, eb);
165 let d2 = enlargement(bbox_b, eb);
166 if d1 < d2 {
167 return true;
168 }
169 if d2 < d1 {
170 return false;
171 }
172 let area_a = bbox_area(bbox_a);
173 let area_b = bbox_area(bbox_b);
174 if area_a < area_b {
175 return true;
176 }
177 if area_b < area_a {
178 return false;
179 }
180 len_a <= len_b
181}
182
183#[allow(clippy::type_complexity)]
184fn quadratic_split_leaf<T, C: CoordType, const D: usize>(
185 mut entries: Vec<Entry<T, C, D>>,
186) -> (Vec<Entry<T, C, D>>, Vec<Entry<T, C, D>>) {
187 let (si, sj) = pick_seeds_leaf(&entries);
188 let (si, sj) = if si < sj { (si, sj) } else { (sj, si) };
189 let seed_j = entries.remove(sj);
190 let seed_i = entries.remove(si);
191 let mut group_a = vec![seed_i];
192 let mut group_b = vec![seed_j];
193 let mut bbox_a = point_bbox(&group_a[0].point);
194 let mut bbox_b = point_bbox(&group_b[0].point);
195
196 while !entries.is_empty() {
197 let remaining = entries.len();
198 if group_a.len() + remaining == MIN_ENTRIES {
199 group_a.extend(entries);
200 break;
201 }
202 if group_b.len() + remaining == MIN_ENTRIES {
203 group_b.extend(entries);
204 break;
205 }
206 let mut best_idx = 0;
207 let mut best_diff = f64::NEG_INFINITY;
208 for (i, e) in entries.iter().enumerate() {
209 let eb = point_bbox(&e.point);
210 let diff = (enlargement(&bbox_a, &eb) - enlargement(&bbox_b, &eb)).abs();
211 if diff > best_diff {
212 best_diff = diff;
213 best_idx = i;
214 }
215 }
216 let e = entries.remove(best_idx);
217 let eb = point_bbox(&e.point);
218 if assign_to_group(&eb, &mut bbox_a, &mut bbox_b, group_a.len(), group_b.len()) {
219 bbox_a = merge_bbox(&bbox_a, &eb);
220 group_a.push(e);
221 } else {
222 bbox_b = merge_bbox(&bbox_b, &eb);
223 group_b.push(e);
224 }
225 }
226 (group_a, group_b)
227}
228
229#[allow(clippy::type_complexity)]
230fn quadratic_split_internal<T, C: CoordType, const D: usize>(
231 mut children: Vec<Box<RNode<T, C, D>>>,
232) -> (Vec<Box<RNode<T, C, D>>>, Vec<Box<RNode<T, C, D>>>) {
233 let (si, sj) = pick_seeds_internal(&children);
234 let (si, sj) = if si < sj { (si, sj) } else { (sj, si) };
235 let seed_j = children.remove(sj);
236 let seed_i = children.remove(si);
237 let mut group_a = vec![seed_i];
238 let mut group_b = vec![seed_j];
239 let mut bbox_a = *group_a[0].bbox();
240 let mut bbox_b = *group_b[0].bbox();
241
242 while !children.is_empty() {
243 let remaining = children.len();
244 if group_a.len() + remaining == MIN_ENTRIES {
245 group_a.extend(children);
246 break;
247 }
248 if group_b.len() + remaining == MIN_ENTRIES {
249 group_b.extend(children);
250 break;
251 }
252 let mut best_idx = 0;
253 let mut best_diff = f64::NEG_INFINITY;
254 for (i, c) in children.iter().enumerate() {
255 let diff = (enlargement(&bbox_a, c.bbox()) - enlargement(&bbox_b, c.bbox())).abs();
256 if diff > best_diff {
257 best_diff = diff;
258 best_idx = i;
259 }
260 }
261 let c = children.remove(best_idx);
262 if assign_to_group(
263 c.bbox(),
264 &mut bbox_a,
265 &mut bbox_b,
266 group_a.len(),
267 group_b.len(),
268 ) {
269 bbox_a = merge_bbox(&bbox_a, c.bbox());
270 group_a.push(c);
271 } else {
272 bbox_b = merge_bbox(&bbox_b, c.bbox());
273 group_b.push(c);
274 }
275 }
276 (group_a, group_b)
277}
278
279fn choose_subtree<T, C: CoordType, const D: usize>(
280 children: &[Box<RNode<T, C, D>>],
281 new_bbox: &BBox<C, D>,
282) -> usize {
283 let mut best = 0;
284 let mut best_enl = f64::INFINITY;
285 let mut best_area = f64::INFINITY;
286 for (i, c) in children.iter().enumerate() {
287 let enl = enlargement(c.bbox(), new_bbox);
288 let area = bbox_area(c.bbox());
289 if enl < best_enl || (enl == best_enl && area < best_area) {
290 best_enl = enl;
291 best_area = area;
292 best = i;
293 }
294 }
295 best
296}
297
298fn node_insert<T, C: CoordType, const D: usize>(
299 node: &mut Box<RNode<T, C, D>>,
300 entry: Entry<T, C, D>,
301) -> Option<Box<RNode<T, C, D>>> {
302 match &mut **node {
303 RNode::Leaf { entries, bbox } => {
304 let pb = point_bbox(&entry.point);
305 *bbox = if entries.is_empty() {
306 pb
307 } else {
308 merge_bbox(bbox, &pb)
309 };
310 entries.push(entry);
311
312 if entries.len() > MAX_ENTRIES {
313 let all = std::mem::take(entries);
314 let (ga, gb) = quadratic_split_leaf(all);
315 let bbox_a = recompute_leaf_bbox(&ga);
316 let bbox_b = recompute_leaf_bbox(&gb);
317 *bbox = bbox_a;
318 *entries = ga;
319 Some(Box::new(RNode::Leaf {
320 bbox: bbox_b,
321 entries: gb,
322 }))
323 } else {
324 None
325 }
326 }
327 RNode::Internal { children, bbox } => {
328 let pb = point_bbox(&entry.point);
329 let best_idx = choose_subtree(children, &pb);
330 *bbox = merge_bbox(bbox, &pb);
331
332 let split = node_insert(&mut children[best_idx], entry);
333 *bbox = recompute_internal_bbox(children);
334
335 if let Some(sibling) = split {
336 children.push(sibling);
337 *bbox = recompute_internal_bbox(children);
338
339 if children.len() > MAX_ENTRIES {
340 let all = std::mem::take(children);
341 let (ga, gb) = quadratic_split_internal(all);
342 let bbox_a = recompute_internal_bbox(&ga);
343 let bbox_b = recompute_internal_bbox(&gb);
344 *bbox = bbox_a;
345 *children = ga;
346 return Some(Box::new(RNode::Internal {
347 bbox: bbox_b,
348 children: gb,
349 }));
350 }
351 }
352 None
353 }
354 }
355}
356
357fn node_remove<T, C: CoordType, const D: usize>(
358 node: &mut Box<RNode<T, C, D>>,
359 id: EntryId,
360 point: &Point<C, D>,
361) -> Option<T> {
362 match &mut **node {
363 RNode::Leaf { entries, bbox } => {
364 if let Some(pos) = entries.iter().position(|e| e.id == id) {
365 let entry = entries.remove(pos);
366 *bbox = recompute_leaf_bbox(entries);
367 Some(entry.payload)
368 } else {
369 None
370 }
371 }
372 RNode::Internal { children, bbox } => {
373 let pb = point_bbox(point);
374 let mut result = None;
375 for child in children.iter_mut() {
376 if child.bbox().intersects(&pb) {
377 result = node_remove(child, id, point);
378 if result.is_some() {
379 break;
380 }
381 }
382 }
383 children.retain(|c| match c.as_ref() {
385 RNode::Leaf { entries, .. } => !entries.is_empty(),
386 RNode::Internal { children, .. } => !children.is_empty(),
387 });
388 if !children.is_empty() {
389 *bbox = recompute_internal_bbox(children);
390 }
391 result
392 }
393 }
394}
395
396fn node_range<'a, T, C: CoordType, const D: usize>(
397 node: &'a RNode<T, C, D>,
398 bbox: &BBox<C, D>,
399 out: &mut Vec<(EntryId, &'a T)>,
400) {
401 match node {
402 RNode::Leaf { bbox: nb, entries } => {
403 if !nb.intersects(bbox) {
404 return;
405 }
406 for e in entries {
407 if bbox.contains_point(&e.point) {
408 out.push((e.id, &e.payload));
409 }
410 }
411 }
412 RNode::Internal { bbox: nb, children } => {
413 if !nb.intersects(bbox) {
414 return;
415 }
416 for child in children {
417 node_range(child, bbox, out);
418 }
419 }
420 }
421}
422
423fn node_collect<'a, T, C: CoordType, const D: usize>(
424 node: &'a RNode<T, C, D>,
425 out: &mut Vec<(Point<C, D>, EntryId, &'a T)>,
426) {
427 match node {
428 RNode::Leaf { entries, .. } => {
429 for e in entries {
430 out.push((e.point, e.id, &e.payload));
431 }
432 }
433 RNode::Internal { children, .. } => {
434 for child in children {
435 node_collect(child, out);
436 }
437 }
438 }
439}
440
441pub struct RTree<T, C, const D: usize> {
447 root: Box<RNode<T, C, D>>,
448 len: usize,
449 next_id: u64,
450 id_to_point: HashMap<u64, Point<C, D>>,
452}
453
454impl<T, C: CoordType, const D: usize> RTree<T, C, D> {
455 pub fn new() -> Self {
457 Self {
458 root: Box::new(RNode::Leaf {
459 bbox: BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D])),
460 entries: Vec::new(),
461 }),
462 len: 0,
463 next_id: 0,
464 id_to_point: HashMap::new(),
465 }
466 }
467
468 fn alloc_id(&mut self) -> EntryId {
469 let id = EntryId(self.next_id);
470 self.next_id += 1;
471 id
472 }
473
474 pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
475 let id = self.alloc_id();
476 self.id_to_point.insert(id.0, point);
477 let entry = Entry { point, payload, id };
478
479 let split = node_insert(&mut self.root, entry);
480 if let Some(sibling) = split {
481 let old_bbox = *self.root.bbox();
482 let sib_bbox = *sibling.bbox();
483 let new_bbox = merge_bbox(&old_bbox, &sib_bbox);
484 let old_root = std::mem::replace(
485 &mut self.root,
486 Box::new(RNode::Internal {
487 bbox: new_bbox,
488 children: Vec::new(),
489 }),
490 );
491 if let RNode::Internal { children, bbox } = &mut *self.root {
492 children.push(old_root);
493 children.push(sibling);
494 *bbox = new_bbox;
495 }
496 }
497 self.len += 1;
498 id
499 }
500
501 pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
502 let point = self.id_to_point.remove(&id.0)?;
503 let result = node_remove(&mut self.root, id, &point);
504 if result.is_some() {
505 self.len -= 1;
506 loop {
507 let collapse = matches!(&*self.root,
508 RNode::Internal { children, .. } if children.len() == 1);
509 if !collapse {
510 break;
511 }
512 let old = std::mem::replace(
513 &mut self.root,
514 Box::new(RNode::Leaf {
515 bbox: BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D])),
516 entries: Vec::new(),
517 }),
518 );
519 if let RNode::Internal { mut children, .. } = *old {
520 self.root = children.remove(0);
521 }
522 }
523 }
524 result
525 }
526
527 pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
529 let mut out = Vec::new();
530 node_range(&self.root, bbox, &mut out);
531 out
532 }
533
534 pub fn knn_query_impl<'a>(
535 &'a self,
536 point: &Point<C, D>,
537 k: usize,
538 ) -> Vec<(f64, EntryId, &'a T)> {
539 if k == 0 {
540 return Vec::new();
541 }
542 let mut raw: Vec<(Point<C, D>, EntryId, &'a T)> = Vec::new();
543 node_collect(&self.root, &mut raw);
544 let mut all: Vec<(f64, EntryId, &'a T)> = raw
545 .into_iter()
546 .map(|(p, id, payload)| (point_dist_sq(&p, point).sqrt(), id, payload))
547 .collect();
548 all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
549 all.truncate(k);
550 all
551 }
552
553 fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
554 let mut out = Vec::new();
555 node_collect(&self.root, &mut out);
556 out
557 }
558}
559
560impl<T, C: CoordType, const D: usize> Default for RTree<T, C, D> {
561 fn default() -> Self {
562 Self::new()
563 }
564}
565
566impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
567 for RTree<T, C, D>
568{
569 fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
570 self.insert_entry(point, payload)
571 }
572
573 fn remove(&mut self, id: EntryId) -> Option<T> {
574 self.remove_entry(id)
575 }
576
577 fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
578 self.range_query_impl(bbox)
579 }
580
581 fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
582 self.knn_query_impl(point, k)
583 }
584
585 fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
586 let self_entries = self.collect_all();
587 let other_entries = other.all_entries();
588 let mut pairs = Vec::new();
589 for (pa, id_a, _) in &self_entries {
590 let bbox_a = point_bbox(pa);
591 for (pb, id_b, _) in &other_entries {
592 if bbox_a.intersects(&point_bbox(pb)) {
593 pairs.push((*id_a, *id_b));
594 }
595 }
596 }
597 pairs
598 }
599
600 fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
601 let mut tree = Self::new();
602 for (point, payload) in entries {
603 tree.insert_entry(point, payload);
604 }
605 tree
606 }
607
608 fn len(&self) -> usize {
609 self.len
610 }
611
612 fn kind(&self) -> BackendKind {
613 BackendKind::RTree
614 }
615
616 fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
617 self.collect_all()
618 }
619}
620
621#[cfg(test)]
622mod tests {
623 use super::*;
624 use proptest::prelude::*;
625
626 struct Lcg(u64);
627 impl Lcg {
628 fn new(seed: u64) -> Self {
629 Self(seed)
630 }
631 fn next_f64(&mut self) -> f64 {
632 self.0 = self
633 .0
634 .wrapping_mul(6_364_136_223_846_793_005)
635 .wrapping_add(1_442_695_040_888_963_407);
636 (self.0 >> 11) as f64 / (1u64 << 53) as f64
637 }
638 }
639
640 fn rand_pts_2d(n: usize, seed: u64) -> Vec<Point<f64, 2>> {
641 let mut rng = Lcg::new(seed);
642 (0..n)
643 .map(|_| Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]))
644 .collect()
645 }
646
647 fn rand_pts_3d(n: usize, seed: u64) -> Vec<Point<f64, 3>> {
648 let mut rng = Lcg::new(seed);
649 (0..n)
650 .map(|_| {
651 Point::new([
652 rng.next_f64() * 1000.0,
653 rng.next_f64() * 1000.0,
654 rng.next_f64() * 1000.0,
655 ])
656 })
657 .collect()
658 }
659
660 fn rand_pts_6d(n: usize, seed: u64) -> Vec<Point<f64, 6>> {
661 let mut rng = Lcg::new(seed);
662 (0..n)
663 .map(|_| {
664 Point::new([
665 rng.next_f64() * 1000.0,
666 rng.next_f64() * 1000.0,
667 rng.next_f64() * 1000.0,
668 rng.next_f64() * 1000.0,
669 rng.next_f64() * 1000.0,
670 rng.next_f64() * 1000.0,
671 ])
672 })
673 .collect()
674 }
675
676 fn brute_range<C: CoordType, const D: usize>(
677 pts: &[(Point<C, D>, EntryId)],
678 bbox: &BBox<C, D>,
679 ) -> Vec<EntryId> {
680 let mut ids: Vec<EntryId> = pts
681 .iter()
682 .filter(|(p, _)| bbox.contains_point(p))
683 .map(|(_, id)| *id)
684 .collect();
685 ids.sort_by_key(|id| id.0);
686 ids
687 }
688
689 #[test]
690 fn insert_and_len() {
691 let mut tree = RTree::<u32, f64, 2>::new();
692 assert_eq!(tree.len(), 0);
693 tree.insert(Point::new([0.5, 0.5]), 1);
694 assert_eq!(tree.len(), 1);
695 tree.insert(Point::new([1.5, 1.5]), 2);
696 assert_eq!(tree.len(), 2);
697 }
698
699 #[test]
700 fn range_query_basic() {
701 let mut tree = RTree::<u32, f64, 2>::new();
702 let id1 = tree.insert(Point::new([0.5, 0.5]), 1);
703 let id2 = tree.insert(Point::new([1.5, 1.5]), 2);
704 let _id3 = tree.insert(Point::new([5.0, 5.0]), 3);
705 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
706 let mut got: Vec<EntryId> = tree
707 .range_query(&bbox)
708 .into_iter()
709 .map(|(id, _)| id)
710 .collect();
711 got.sort_by_key(|id| id.0);
712 assert_eq!(got, vec![id1, id2]);
713 }
714
715 #[test]
716 fn range_query_vs_brute_force_2d_10k() {
717 let pts = rand_pts_2d(10_000, 42);
718 let mut tree = RTree::<usize, f64, 2>::new();
719 let mut pt_ids = Vec::new();
720 for (i, &p) in pts.iter().enumerate() {
721 let id = tree.insert(p, i);
722 pt_ids.push((p, id));
723 }
724 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
725 let mut got: Vec<EntryId> = tree
726 .range_query(&bbox)
727 .into_iter()
728 .map(|(id, _)| id)
729 .collect();
730 got.sort_by_key(|id| id.0);
731 let expected = brute_range(&pt_ids, &bbox);
732 assert_eq!(got, expected, "2D 10k range query mismatch");
733 }
734
735 #[test]
736 fn range_query_vs_brute_force_3d_10k() {
737 let pts = rand_pts_3d(10_000, 43);
738 let mut tree = RTree::<usize, f64, 3>::new();
739 let mut pt_ids = Vec::new();
740 for (i, &p) in pts.iter().enumerate() {
741 let id = tree.insert(p, i);
742 pt_ids.push((p, id));
743 }
744 let bbox = BBox::new(
745 Point::new([0.0, 0.0, 0.0]),
746 Point::new([500.0, 500.0, 500.0]),
747 );
748 let mut got: Vec<EntryId> = tree
749 .range_query(&bbox)
750 .into_iter()
751 .map(|(id, _)| id)
752 .collect();
753 got.sort_by_key(|id| id.0);
754 let expected = brute_range(&pt_ids, &bbox);
755 assert_eq!(got, expected, "3D 10k range query mismatch");
756 }
757
758 #[test]
759 fn range_query_vs_brute_force_6d_10k() {
760 let pts = rand_pts_6d(10_000, 44);
761 let mut tree = RTree::<usize, f64, 6>::new();
762 let mut pt_ids = Vec::new();
763 for (i, &p) in pts.iter().enumerate() {
764 let id = tree.insert(p, i);
765 pt_ids.push((p, id));
766 }
767 let bbox = BBox::new(Point::new([0.0; 6]), Point::new([500.0; 6]));
768 let mut got: Vec<EntryId> = tree
769 .range_query(&bbox)
770 .into_iter()
771 .map(|(id, _)| id)
772 .collect();
773 got.sort_by_key(|id| id.0);
774 let expected = brute_range(&pt_ids, &bbox);
775 assert_eq!(got, expected, "6D 10k range query mismatch");
776 }
777
778 #[test]
779 fn knn_correctness_2d() {
780 let pts = rand_pts_2d(1000, 99);
781 let mut tree = RTree::<usize, f64, 2>::new();
782 let mut all: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
783 for (i, &p) in pts.iter().enumerate() {
784 let id = tree.insert(p, i);
785 all.push((p, id));
786 }
787 let query = Point::new([500.0, 500.0]);
788 let k = 5;
789 let knn = tree.knn_query(&query, k);
790 assert_eq!(knn.len(), k);
791 for w in knn.windows(2) {
792 assert!(w[0].0 <= w[1].0, "kNN not sorted");
793 }
794 let mut bf: Vec<(f64, EntryId)> = all
795 .iter()
796 .map(|(p, id)| (point_dist_sq(p, &query).sqrt(), *id))
797 .collect();
798 bf.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
799 let bf_ids: Vec<EntryId> = bf[..k].iter().map(|(_, id)| *id).collect();
800 let knn_ids: Vec<EntryId> = knn.iter().map(|(_, id, _)| *id).collect();
801 assert_eq!(knn_ids, bf_ids, "kNN ids don't match brute force");
802 }
803
804 #[test]
805 fn remove_works() {
806 let mut tree = RTree::<&str, f64, 2>::new();
807 let id1 = tree.insert(Point::new([1.0, 1.0]), "a");
808 let id2 = tree.insert(Point::new([2.0, 2.0]), "b");
809 assert_eq!(tree.len(), 2);
810 assert_eq!(tree.remove(id1), Some("a"));
811 assert_eq!(tree.len(), 1);
812 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
813 let ids: Vec<EntryId> = tree
814 .range_query(&bbox)
815 .into_iter()
816 .map(|(id, _)| id)
817 .collect();
818 assert!(!ids.contains(&id1));
819 assert!(ids.contains(&id2));
820 }
821
822 #[test]
823 fn kind_is_rtree() {
824 assert_eq!(RTree::<u32, f64, 2>::new().kind(), BackendKind::RTree);
825 }
826
827 #[test]
828 fn bulk_load_correctness() {
829 let pts = rand_pts_2d(500, 77);
830 let entries: Vec<(Point<f64, 2>, usize)> =
831 pts.iter().enumerate().map(|(i, &p)| (p, i)).collect();
832 let tree = RTree::<usize, f64, 2>::bulk_load(entries);
833 assert_eq!(tree.len(), 500);
834 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1000.0, 1000.0]));
835 assert_eq!(tree.range_query(&bbox).len(), 500);
836 }
837
838 fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
839 (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
840 }
841
842 fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
843 (
844 0.0_f64..900.0,
845 0.0_f64..900.0,
846 10.0_f64..200.0,
847 10.0_f64..200.0,
848 )
849 .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
850 }
851
852 fn pt3d() -> impl Strategy<Value = Point<f64, 3>> {
853 (0.0_f64..1000.0, 0.0_f64..1000.0, 0.0_f64..1000.0)
854 .prop_map(|(x, y, z)| Point::new([x, y, z]))
855 }
856
857 fn bbox3d() -> impl Strategy<Value = BBox<f64, 3>> {
858 (
859 0.0_f64..800.0,
860 0.0_f64..800.0,
861 0.0_f64..800.0,
862 10.0_f64..200.0,
863 10.0_f64..200.0,
864 10.0_f64..200.0,
865 )
866 .prop_map(|(x, y, z, w, h, d)| {
867 BBox::new(Point::new([x, y, z]), Point::new([x + w, y + h, z + d]))
868 })
869 }
870
871 proptest! {
873 #![proptest_config(proptest::test_runner::Config {
874 cases: 100,
875 ..Default::default()
876 })]
877
878 #[test]
879 fn prop_range_query_oracle_2d(
880 pts in prop::collection::vec(pt2d(), 1..100),
881 bbox in bbox2d(),
882 ) {
883 let mut tree = RTree::<usize, f64, 2>::new();
884 let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
885 for (i, p) in pts.iter().enumerate() {
886 let id = tree.insert(*p, i);
887 pt_ids.push((*p, id));
888 }
889 let mut got: Vec<EntryId> =
890 tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
891 got.sort_by_key(|id| id.0);
892 let expected = brute_range(&pt_ids, &bbox);
893 prop_assert_eq!(got, expected);
894 }
895 }
896
897 proptest! {
898 #![proptest_config(proptest::test_runner::Config {
899 cases: 100,
900 ..Default::default()
901 })]
902
903 #[test]
904 fn prop_range_query_oracle_3d(
905 pts in prop::collection::vec(pt3d(), 1..80),
906 bbox in bbox3d(),
907 ) {
908 let mut tree = RTree::<usize, f64, 3>::new();
909 let mut pt_ids: Vec<(Point<f64, 3>, EntryId)> = Vec::new();
910 for (i, p) in pts.iter().enumerate() {
911 let id = tree.insert(*p, i);
912 pt_ids.push((*p, id));
913 }
914 let mut got: Vec<EntryId> =
915 tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
916 got.sort_by_key(|id| id.0);
917 let expected = brute_range(&pt_ids, &bbox);
918 prop_assert_eq!(got, expected);
919 }
920 }
921
922 proptest! {
924 #![proptest_config(proptest::test_runner::Config {
925 cases: 100,
926 ..Default::default()
927 })]
928
929 #[test]
930 fn prop_knn_correctness(
931 pts in prop::collection::vec(pt2d(), 5..80),
932 qx in 0.0_f64..1000.0,
933 qy in 0.0_f64..1000.0,
934 k in 1usize..5,
935 ) {
936 let mut tree = RTree::<usize, f64, 2>::new();
937 let mut all: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
938 for (i, p) in pts.iter().enumerate() {
939 let id = tree.insert(*p, i);
940 all.push((*p, id));
941 }
942 let query = Point::new([qx, qy]);
943 let actual_k = k.min(all.len());
944 let knn = tree.knn_query(&query, actual_k);
945 prop_assert_eq!(knn.len(), actual_k);
946 for w in knn.windows(2) {
947 prop_assert!(w[0].0 <= w[1].0, "kNN not sorted");
948 }
949 let mut bf: Vec<(f64, EntryId)> = all
950 .iter()
951 .map(|(p, id)| (point_dist_sq(p, &query).sqrt(), *id))
952 .collect();
953 bf.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
954 let bf_ids: Vec<EntryId> = bf[..actual_k].iter().map(|(_, id)| *id).collect();
955 let knn_ids: Vec<EntryId> = knn.iter().map(|(_, id, _)| *id).collect();
956 prop_assert_eq!(knn_ids, bf_ids);
957 }
958 }
959
960 proptest! {
962 #![proptest_config(proptest::test_runner::Config {
963 cases: 100,
964 ..Default::default()
965 })]
966
967 #[test]
968 fn prop_insert_remove_round_trip_rtree(
969 pts in prop::collection::vec(pt2d(), 1..50),
970 remove_indices in prop::collection::vec(0usize..50, 0..25),
971 ) {
972 let mut tree = RTree::<usize, f64, 2>::new();
973 let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
974 for (i, &p) in pts.iter().enumerate() {
975 let id = tree.insert(p, i);
976 inserted.push((p, id));
977 }
978 let mut removed_ids: Vec<EntryId> = Vec::new();
980 for &ri in &remove_indices {
981 let idx = ri % inserted.len();
982 let (_, id) = inserted[idx];
983 if !removed_ids.contains(&id) {
984 let result = tree.remove(id);
985 prop_assert!(result.is_some(), "remove returned None for inserted id");
986 removed_ids.push(id);
987 }
988 }
989 let full_bbox = BBox::new(Point::new([-1.0e9, -1.0e9]), Point::new([1.0e9, 1.0e9]));
991 let remaining_ids: Vec<EntryId> = tree.range_query(&full_bbox)
992 .into_iter()
993 .map(|(id, _)| id)
994 .collect();
995 for &removed_id in &removed_ids {
996 prop_assert!(
997 !remaining_ids.contains(&removed_id),
998 "removed entry {:?} still appears in range query",
999 removed_id
1000 );
1001 }
1002 let expected_len = inserted.len() - removed_ids.len();
1004 prop_assert_eq!(tree.len(), expected_len);
1005 }
1006 }
1007
1008 proptest! {
1010 #![proptest_config(proptest::test_runner::Config {
1011 cases: 100,
1012 ..Default::default()
1013 })]
1014
1015 #[test]
1016 fn prop_spatial_join_completeness(
1017 pts_a in prop::collection::vec(pt2d(), 1..30),
1018 pts_b in prop::collection::vec(pt2d(), 1..30),
1019 ) {
1020 let mut tree_a = RTree::<usize, f64, 2>::new();
1021 let mut tree_b = RTree::<usize, f64, 2>::new();
1022 let mut ids_a: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
1023 let mut ids_b: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
1024 for (i, p) in pts_a.iter().enumerate() {
1025 let id = tree_a.insert(*p, i);
1026 ids_a.push((*p, id));
1027 }
1028 for (i, p) in pts_b.iter().enumerate() {
1029 let id = tree_b.insert(*p, i);
1030 ids_b.push((*p, id));
1031 }
1032 let mut got = tree_a.spatial_join(&tree_b);
1033 got.sort();
1034 let mut expected: Vec<(EntryId, EntryId)> = Vec::new();
1035 for (pa, id_a) in &ids_a {
1036 for (pb, id_b) in &ids_b {
1037 if point_bbox(pa).intersects(&point_bbox(pb)) {
1038 expected.push((*id_a, *id_b));
1039 }
1040 }
1041 }
1042 expected.sort();
1043 prop_assert_eq!(got, expected);
1044 }
1045 }
1046}