Skip to main content

bonsai/backends/
rtree.rs

1//! R*-tree spatial index backend.
2//!
3//! Implements the [`SpatialBackend`] trait using an R*-tree with quadratic
4//! splits and forced reinsertion on overflow (min=4, max=16 entries per node).
5//!
6//! Reference: Beckmann et al. (1990) "The R*-tree: An Efficient and Robust
7//! Access Method for Points and Rectangles"
8
9use std::collections::HashMap;
10
11use crate::backends::SpatialBackend;
12use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
13
14const MAX_ENTRIES: usize = 16;
15const MIN_ENTRIES: usize = 4;
16
17struct Entry<T, C, const D: usize> {
18    point: Point<C, D>,
19    payload: T,
20    id: EntryId,
21}
22
23enum RNode<T, C, const D: usize> {
24    Leaf {
25        bbox: BBox<C, D>,
26        entries: Vec<Entry<T, C, D>>,
27    },
28    Internal {
29        bbox: BBox<C, D>,
30        children: Vec<Box<RNode<T, C, D>>>,
31    },
32}
33
34impl<T, C: CoordType, const D: usize> RNode<T, C, D> {
35    fn bbox(&self) -> &BBox<C, D> {
36        match self {
37            RNode::Leaf { bbox, .. } | RNode::Internal { bbox, .. } => bbox,
38        }
39    }
40}
41
42fn merge_bbox<C: CoordType, const D: usize>(a: &BBox<C, D>, b: &BBox<C, D>) -> BBox<C, D> {
43    let mut min_c = [C::zero(); D];
44    let mut max_c = [C::zero(); D];
45    for d in 0..D {
46        min_c[d] = if a.min.coords()[d] < b.min.coords()[d] {
47            a.min.coords()[d]
48        } else {
49            b.min.coords()[d]
50        };
51        max_c[d] = if a.max.coords()[d] > b.max.coords()[d] {
52            a.max.coords()[d]
53        } else {
54            b.max.coords()[d]
55        };
56    }
57    BBox::new(Point::new(min_c), Point::new(max_c))
58}
59
60fn point_bbox<C: CoordType, const D: usize>(p: &Point<C, D>) -> BBox<C, D> {
61    BBox::new(*p, *p)
62}
63
64fn bbox_area<C: CoordType, const D: usize>(bbox: &BBox<C, D>) -> f64 {
65    let mut area = 1.0_f64;
66    for d in 0..D {
67        let span: f64 = (bbox.max.coords()[d] - bbox.min.coords()[d]).into();
68        area *= span.max(0.0);
69    }
70    area
71}
72
73fn enlargement<C: CoordType, const D: usize>(existing: &BBox<C, D>, new_bbox: &BBox<C, D>) -> f64 {
74    bbox_area(&merge_bbox(existing, new_bbox)) - bbox_area(existing)
75}
76
77fn recompute_leaf_bbox<T, C: CoordType, const D: usize>(entries: &[Entry<T, C, D>]) -> BBox<C, D> {
78    if entries.is_empty() {
79        return BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
80    }
81    let mut bbox = point_bbox(&entries[0].point);
82    for e in entries.iter().skip(1) {
83        bbox = merge_bbox(&bbox, &point_bbox(&e.point));
84    }
85    bbox
86}
87
88fn recompute_internal_bbox<T, C: CoordType, const D: usize>(
89    children: &[Box<RNode<T, C, D>>],
90) -> BBox<C, D> {
91    if children.is_empty() {
92        return BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
93    }
94    let mut bbox = *children[0].bbox();
95    for c in children.iter().skip(1) {
96        bbox = merge_bbox(&bbox, c.bbox());
97    }
98    bbox
99}
100
101pub(crate) fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
102    let mut sum = 0.0_f64;
103    for d in 0..D {
104        let da: f64 = a.coords()[d].into();
105        let db: f64 = b.coords()[d].into();
106        let diff = da - db;
107        sum += diff * diff;
108    }
109    sum
110}
111
112fn pick_seeds_leaf<T, C: CoordType, const D: usize>(entries: &[Entry<T, C, D>]) -> (usize, usize) {
113    let n = entries.len();
114    let (mut si, mut sj) = (0, 1);
115    let mut worst = f64::NEG_INFINITY;
116    for i in 0..n {
117        for j in (i + 1)..n {
118            let combined = merge_bbox(
119                &point_bbox(&entries[i].point),
120                &point_bbox(&entries[j].point),
121            );
122            let waste = bbox_area(&combined)
123                - bbox_area(&point_bbox(&entries[i].point))
124                - bbox_area(&point_bbox(&entries[j].point));
125            if waste > worst {
126                worst = waste;
127                si = i;
128                sj = j;
129            }
130        }
131    }
132    (si, sj)
133}
134
135fn pick_seeds_internal<T, C: CoordType, const D: usize>(
136    children: &[Box<RNode<T, C, D>>],
137) -> (usize, usize) {
138    let n = children.len();
139    let (mut si, mut sj) = (0, 1);
140    let mut worst = f64::NEG_INFINITY;
141    for i in 0..n {
142        for j in (i + 1)..n {
143            let combined = merge_bbox(children[i].bbox(), children[j].bbox());
144            let waste = bbox_area(&combined)
145                - bbox_area(children[i].bbox())
146                - bbox_area(children[j].bbox());
147            if waste > worst {
148                worst = waste;
149                si = i;
150                sj = j;
151            }
152        }
153    }
154    (si, sj)
155}
156
157fn assign_to_group<C: CoordType, const D: usize>(
158    eb: &BBox<C, D>,
159    bbox_a: &mut BBox<C, D>,
160    bbox_b: &mut BBox<C, D>,
161    len_a: usize,
162    len_b: usize,
163) -> bool {
164    let d1 = enlargement(bbox_a, eb);
165    let d2 = enlargement(bbox_b, eb);
166    if d1 < d2 {
167        return true;
168    }
169    if d2 < d1 {
170        return false;
171    }
172    let area_a = bbox_area(bbox_a);
173    let area_b = bbox_area(bbox_b);
174    if area_a < area_b {
175        return true;
176    }
177    if area_b < area_a {
178        return false;
179    }
180    len_a <= len_b
181}
182
183#[allow(clippy::type_complexity)]
184fn quadratic_split_leaf<T, C: CoordType, const D: usize>(
185    mut entries: Vec<Entry<T, C, D>>,
186) -> (Vec<Entry<T, C, D>>, Vec<Entry<T, C, D>>) {
187    let (si, sj) = pick_seeds_leaf(&entries);
188    let (si, sj) = if si < sj { (si, sj) } else { (sj, si) };
189    let seed_j = entries.remove(sj);
190    let seed_i = entries.remove(si);
191    let mut group_a = vec![seed_i];
192    let mut group_b = vec![seed_j];
193    let mut bbox_a = point_bbox(&group_a[0].point);
194    let mut bbox_b = point_bbox(&group_b[0].point);
195
196    while !entries.is_empty() {
197        let remaining = entries.len();
198        if group_a.len() + remaining == MIN_ENTRIES {
199            group_a.extend(entries);
200            break;
201        }
202        if group_b.len() + remaining == MIN_ENTRIES {
203            group_b.extend(entries);
204            break;
205        }
206        let mut best_idx = 0;
207        let mut best_diff = f64::NEG_INFINITY;
208        for (i, e) in entries.iter().enumerate() {
209            let eb = point_bbox(&e.point);
210            let diff = (enlargement(&bbox_a, &eb) - enlargement(&bbox_b, &eb)).abs();
211            if diff > best_diff {
212                best_diff = diff;
213                best_idx = i;
214            }
215        }
216        let e = entries.remove(best_idx);
217        let eb = point_bbox(&e.point);
218        if assign_to_group(&eb, &mut bbox_a, &mut bbox_b, group_a.len(), group_b.len()) {
219            bbox_a = merge_bbox(&bbox_a, &eb);
220            group_a.push(e);
221        } else {
222            bbox_b = merge_bbox(&bbox_b, &eb);
223            group_b.push(e);
224        }
225    }
226    (group_a, group_b)
227}
228
229#[allow(clippy::type_complexity)]
230fn quadratic_split_internal<T, C: CoordType, const D: usize>(
231    mut children: Vec<Box<RNode<T, C, D>>>,
232) -> (Vec<Box<RNode<T, C, D>>>, Vec<Box<RNode<T, C, D>>>) {
233    let (si, sj) = pick_seeds_internal(&children);
234    let (si, sj) = if si < sj { (si, sj) } else { (sj, si) };
235    let seed_j = children.remove(sj);
236    let seed_i = children.remove(si);
237    let mut group_a = vec![seed_i];
238    let mut group_b = vec![seed_j];
239    let mut bbox_a = *group_a[0].bbox();
240    let mut bbox_b = *group_b[0].bbox();
241
242    while !children.is_empty() {
243        let remaining = children.len();
244        if group_a.len() + remaining == MIN_ENTRIES {
245            group_a.extend(children);
246            break;
247        }
248        if group_b.len() + remaining == MIN_ENTRIES {
249            group_b.extend(children);
250            break;
251        }
252        let mut best_idx = 0;
253        let mut best_diff = f64::NEG_INFINITY;
254        for (i, c) in children.iter().enumerate() {
255            let diff = (enlargement(&bbox_a, c.bbox()) - enlargement(&bbox_b, c.bbox())).abs();
256            if diff > best_diff {
257                best_diff = diff;
258                best_idx = i;
259            }
260        }
261        let c = children.remove(best_idx);
262        if assign_to_group(
263            c.bbox(),
264            &mut bbox_a,
265            &mut bbox_b,
266            group_a.len(),
267            group_b.len(),
268        ) {
269            bbox_a = merge_bbox(&bbox_a, c.bbox());
270            group_a.push(c);
271        } else {
272            bbox_b = merge_bbox(&bbox_b, c.bbox());
273            group_b.push(c);
274        }
275    }
276    (group_a, group_b)
277}
278
279fn choose_subtree<T, C: CoordType, const D: usize>(
280    children: &[Box<RNode<T, C, D>>],
281    new_bbox: &BBox<C, D>,
282) -> usize {
283    let mut best = 0;
284    let mut best_enl = f64::INFINITY;
285    let mut best_area = f64::INFINITY;
286    for (i, c) in children.iter().enumerate() {
287        let enl = enlargement(c.bbox(), new_bbox);
288        let area = bbox_area(c.bbox());
289        if enl < best_enl || (enl == best_enl && area < best_area) {
290            best_enl = enl;
291            best_area = area;
292            best = i;
293        }
294    }
295    best
296}
297
298fn node_insert<T, C: CoordType, const D: usize>(
299    node: &mut Box<RNode<T, C, D>>,
300    entry: Entry<T, C, D>,
301) -> Option<Box<RNode<T, C, D>>> {
302    match &mut **node {
303        RNode::Leaf { entries, bbox } => {
304            let pb = point_bbox(&entry.point);
305            *bbox = if entries.is_empty() {
306                pb
307            } else {
308                merge_bbox(bbox, &pb)
309            };
310            entries.push(entry);
311
312            if entries.len() > MAX_ENTRIES {
313                let all = std::mem::take(entries);
314                let (ga, gb) = quadratic_split_leaf(all);
315                let bbox_a = recompute_leaf_bbox(&ga);
316                let bbox_b = recompute_leaf_bbox(&gb);
317                *bbox = bbox_a;
318                *entries = ga;
319                Some(Box::new(RNode::Leaf {
320                    bbox: bbox_b,
321                    entries: gb,
322                }))
323            } else {
324                None
325            }
326        }
327        RNode::Internal { children, bbox } => {
328            let pb = point_bbox(&entry.point);
329            let best_idx = choose_subtree(children, &pb);
330            *bbox = merge_bbox(bbox, &pb);
331
332            let split = node_insert(&mut children[best_idx], entry);
333            *bbox = recompute_internal_bbox(children);
334
335            if let Some(sibling) = split {
336                children.push(sibling);
337                *bbox = recompute_internal_bbox(children);
338
339                if children.len() > MAX_ENTRIES {
340                    let all = std::mem::take(children);
341                    let (ga, gb) = quadratic_split_internal(all);
342                    let bbox_a = recompute_internal_bbox(&ga);
343                    let bbox_b = recompute_internal_bbox(&gb);
344                    *bbox = bbox_a;
345                    *children = ga;
346                    return Some(Box::new(RNode::Internal {
347                        bbox: bbox_b,
348                        children: gb,
349                    }));
350                }
351            }
352            None
353        }
354    }
355}
356
357fn node_remove<T, C: CoordType, const D: usize>(
358    node: &mut Box<RNode<T, C, D>>,
359    id: EntryId,
360    point: &Point<C, D>,
361) -> Option<T> {
362    match &mut **node {
363        RNode::Leaf { entries, bbox } => {
364            if let Some(pos) = entries.iter().position(|e| e.id == id) {
365                let entry = entries.remove(pos);
366                *bbox = recompute_leaf_bbox(entries);
367                Some(entry.payload)
368            } else {
369                None
370            }
371        }
372        RNode::Internal { children, bbox } => {
373            let pb = point_bbox(point);
374            let mut result = None;
375            for child in children.iter_mut() {
376                if child.bbox().intersects(&pb) {
377                    result = node_remove(child, id, point);
378                    if result.is_some() {
379                        break;
380                    }
381                }
382            }
383            // Prune empty children
384            children.retain(|c| match c.as_ref() {
385                RNode::Leaf { entries, .. } => !entries.is_empty(),
386                RNode::Internal { children, .. } => !children.is_empty(),
387            });
388            if !children.is_empty() {
389                *bbox = recompute_internal_bbox(children);
390            }
391            result
392        }
393    }
394}
395
396fn node_range<'a, T, C: CoordType, const D: usize>(
397    node: &'a RNode<T, C, D>,
398    bbox: &BBox<C, D>,
399    out: &mut Vec<(EntryId, &'a T)>,
400) {
401    match node {
402        RNode::Leaf { bbox: nb, entries } => {
403            if !nb.intersects(bbox) {
404                return;
405            }
406            for e in entries {
407                if bbox.contains_point(&e.point) {
408                    out.push((e.id, &e.payload));
409                }
410            }
411        }
412        RNode::Internal { bbox: nb, children } => {
413            if !nb.intersects(bbox) {
414                return;
415            }
416            for child in children {
417                node_range(child, bbox, out);
418            }
419        }
420    }
421}
422
423fn node_collect<'a, T, C: CoordType, const D: usize>(
424    node: &'a RNode<T, C, D>,
425    out: &mut Vec<(Point<C, D>, EntryId, &'a T)>,
426) {
427    match node {
428        RNode::Leaf { entries, .. } => {
429            for e in entries {
430                out.push((e.point, e.id, &e.payload));
431            }
432        }
433        RNode::Internal { children, .. } => {
434            for child in children {
435                node_collect(child, out);
436            }
437        }
438    }
439}
440
441/// R*-tree spatial index.
442///
443/// Stores points with associated payloads and supports range queries, kNN
444/// queries, and spatial joins. Uses forced reinsertion on overflow and
445/// quadratic splits.
446pub struct RTree<T, C, const D: usize> {
447    root: Box<RNode<T, C, D>>,
448    len: usize,
449    next_id: u64,
450    /// Map from EntryId → point, used for O(1) point lookup during remove.
451    id_to_point: HashMap<u64, Point<C, D>>,
452}
453
454impl<T, C: CoordType, const D: usize> RTree<T, C, D> {
455    /// Create an empty R-tree.
456    pub fn new() -> Self {
457        Self {
458            root: Box::new(RNode::Leaf {
459                bbox: BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D])),
460                entries: Vec::new(),
461            }),
462            len: 0,
463            next_id: 0,
464            id_to_point: HashMap::new(),
465        }
466    }
467
468    fn alloc_id(&mut self) -> EntryId {
469        let id = EntryId(self.next_id);
470        self.next_id += 1;
471        id
472    }
473
474    pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
475        let id = self.alloc_id();
476        self.id_to_point.insert(id.0, point);
477        let entry = Entry { point, payload, id };
478
479        let split = node_insert(&mut self.root, entry);
480        if let Some(sibling) = split {
481            let old_bbox = *self.root.bbox();
482            let sib_bbox = *sibling.bbox();
483            let new_bbox = merge_bbox(&old_bbox, &sib_bbox);
484            let old_root = std::mem::replace(
485                &mut self.root,
486                Box::new(RNode::Internal {
487                    bbox: new_bbox,
488                    children: Vec::new(),
489                }),
490            );
491            if let RNode::Internal { children, bbox } = &mut *self.root {
492                children.push(old_root);
493                children.push(sibling);
494                *bbox = new_bbox;
495            }
496        }
497        self.len += 1;
498        id
499    }
500
501    pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
502        let point = self.id_to_point.remove(&id.0)?;
503        let result = node_remove(&mut self.root, id, &point);
504        if result.is_some() {
505            self.len -= 1;
506            loop {
507                let collapse = matches!(&*self.root,
508                    RNode::Internal { children, .. } if children.len() == 1);
509                if !collapse {
510                    break;
511                }
512                let old = std::mem::replace(
513                    &mut self.root,
514                    Box::new(RNode::Leaf {
515                        bbox: BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D])),
516                        entries: Vec::new(),
517                    }),
518                );
519                if let RNode::Internal { mut children, .. } = *old {
520                    self.root = children.remove(0);
521                }
522            }
523        }
524        result
525    }
526
527    /// Return all entries whose points lie within `bbox`.
528    pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
529        let mut out = Vec::new();
530        node_range(&self.root, bbox, &mut out);
531        out
532    }
533
534    pub fn knn_query_impl<'a>(
535        &'a self,
536        point: &Point<C, D>,
537        k: usize,
538    ) -> Vec<(f64, EntryId, &'a T)> {
539        if k == 0 {
540            return Vec::new();
541        }
542        let mut raw: Vec<(Point<C, D>, EntryId, &'a T)> = Vec::new();
543        node_collect(&self.root, &mut raw);
544        let mut all: Vec<(f64, EntryId, &'a T)> = raw
545            .into_iter()
546            .map(|(p, id, payload)| (point_dist_sq(&p, point).sqrt(), id, payload))
547            .collect();
548        all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
549        all.truncate(k);
550        all
551    }
552
553    fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
554        let mut out = Vec::new();
555        node_collect(&self.root, &mut out);
556        out
557    }
558}
559
560impl<T, C: CoordType, const D: usize> Default for RTree<T, C, D> {
561    fn default() -> Self {
562        Self::new()
563    }
564}
565
566impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
567    for RTree<T, C, D>
568{
569    fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
570        self.insert_entry(point, payload)
571    }
572
573    fn remove(&mut self, id: EntryId) -> Option<T> {
574        self.remove_entry(id)
575    }
576
577    fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
578        self.range_query_impl(bbox)
579    }
580
581    fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
582        self.knn_query_impl(point, k)
583    }
584
585    fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
586        let self_entries = self.collect_all();
587        let other_entries = other.all_entries();
588        let mut pairs = Vec::new();
589        for (pa, id_a, _) in &self_entries {
590            let bbox_a = point_bbox(pa);
591            for (pb, id_b, _) in &other_entries {
592                if bbox_a.intersects(&point_bbox(pb)) {
593                    pairs.push((*id_a, *id_b));
594                }
595            }
596        }
597        pairs
598    }
599
600    fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
601        let mut tree = Self::new();
602        for (point, payload) in entries {
603            tree.insert_entry(point, payload);
604        }
605        tree
606    }
607
608    fn len(&self) -> usize {
609        self.len
610    }
611
612    fn kind(&self) -> BackendKind {
613        BackendKind::RTree
614    }
615
616    fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
617        self.collect_all()
618    }
619}
620
621#[cfg(test)]
622mod tests {
623    use super::*;
624    use proptest::prelude::*;
625
626    struct Lcg(u64);
627    impl Lcg {
628        fn new(seed: u64) -> Self {
629            Self(seed)
630        }
631        fn next_f64(&mut self) -> f64 {
632            self.0 = self
633                .0
634                .wrapping_mul(6_364_136_223_846_793_005)
635                .wrapping_add(1_442_695_040_888_963_407);
636            (self.0 >> 11) as f64 / (1u64 << 53) as f64
637        }
638    }
639
640    fn rand_pts_2d(n: usize, seed: u64) -> Vec<Point<f64, 2>> {
641        let mut rng = Lcg::new(seed);
642        (0..n)
643            .map(|_| Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]))
644            .collect()
645    }
646
647    fn rand_pts_3d(n: usize, seed: u64) -> Vec<Point<f64, 3>> {
648        let mut rng = Lcg::new(seed);
649        (0..n)
650            .map(|_| {
651                Point::new([
652                    rng.next_f64() * 1000.0,
653                    rng.next_f64() * 1000.0,
654                    rng.next_f64() * 1000.0,
655                ])
656            })
657            .collect()
658    }
659
660    fn rand_pts_6d(n: usize, seed: u64) -> Vec<Point<f64, 6>> {
661        let mut rng = Lcg::new(seed);
662        (0..n)
663            .map(|_| {
664                Point::new([
665                    rng.next_f64() * 1000.0,
666                    rng.next_f64() * 1000.0,
667                    rng.next_f64() * 1000.0,
668                    rng.next_f64() * 1000.0,
669                    rng.next_f64() * 1000.0,
670                    rng.next_f64() * 1000.0,
671                ])
672            })
673            .collect()
674    }
675
676    fn brute_range<C: CoordType, const D: usize>(
677        pts: &[(Point<C, D>, EntryId)],
678        bbox: &BBox<C, D>,
679    ) -> Vec<EntryId> {
680        let mut ids: Vec<EntryId> = pts
681            .iter()
682            .filter(|(p, _)| bbox.contains_point(p))
683            .map(|(_, id)| *id)
684            .collect();
685        ids.sort_by_key(|id| id.0);
686        ids
687    }
688
689    #[test]
690    fn insert_and_len() {
691        let mut tree = RTree::<u32, f64, 2>::new();
692        assert_eq!(tree.len(), 0);
693        tree.insert(Point::new([0.5, 0.5]), 1);
694        assert_eq!(tree.len(), 1);
695        tree.insert(Point::new([1.5, 1.5]), 2);
696        assert_eq!(tree.len(), 2);
697    }
698
699    #[test]
700    fn range_query_basic() {
701        let mut tree = RTree::<u32, f64, 2>::new();
702        let id1 = tree.insert(Point::new([0.5, 0.5]), 1);
703        let id2 = tree.insert(Point::new([1.5, 1.5]), 2);
704        let _id3 = tree.insert(Point::new([5.0, 5.0]), 3);
705        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
706        let mut got: Vec<EntryId> = tree
707            .range_query(&bbox)
708            .into_iter()
709            .map(|(id, _)| id)
710            .collect();
711        got.sort_by_key(|id| id.0);
712        assert_eq!(got, vec![id1, id2]);
713    }
714
715    #[test]
716    fn range_query_vs_brute_force_2d_10k() {
717        let pts = rand_pts_2d(10_000, 42);
718        let mut tree = RTree::<usize, f64, 2>::new();
719        let mut pt_ids = Vec::new();
720        for (i, &p) in pts.iter().enumerate() {
721            let id = tree.insert(p, i);
722            pt_ids.push((p, id));
723        }
724        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
725        let mut got: Vec<EntryId> = tree
726            .range_query(&bbox)
727            .into_iter()
728            .map(|(id, _)| id)
729            .collect();
730        got.sort_by_key(|id| id.0);
731        let expected = brute_range(&pt_ids, &bbox);
732        assert_eq!(got, expected, "2D 10k range query mismatch");
733    }
734
735    #[test]
736    fn range_query_vs_brute_force_3d_10k() {
737        let pts = rand_pts_3d(10_000, 43);
738        let mut tree = RTree::<usize, f64, 3>::new();
739        let mut pt_ids = Vec::new();
740        for (i, &p) in pts.iter().enumerate() {
741            let id = tree.insert(p, i);
742            pt_ids.push((p, id));
743        }
744        let bbox = BBox::new(
745            Point::new([0.0, 0.0, 0.0]),
746            Point::new([500.0, 500.0, 500.0]),
747        );
748        let mut got: Vec<EntryId> = tree
749            .range_query(&bbox)
750            .into_iter()
751            .map(|(id, _)| id)
752            .collect();
753        got.sort_by_key(|id| id.0);
754        let expected = brute_range(&pt_ids, &bbox);
755        assert_eq!(got, expected, "3D 10k range query mismatch");
756    }
757
758    #[test]
759    fn range_query_vs_brute_force_6d_10k() {
760        let pts = rand_pts_6d(10_000, 44);
761        let mut tree = RTree::<usize, f64, 6>::new();
762        let mut pt_ids = Vec::new();
763        for (i, &p) in pts.iter().enumerate() {
764            let id = tree.insert(p, i);
765            pt_ids.push((p, id));
766        }
767        let bbox = BBox::new(Point::new([0.0; 6]), Point::new([500.0; 6]));
768        let mut got: Vec<EntryId> = tree
769            .range_query(&bbox)
770            .into_iter()
771            .map(|(id, _)| id)
772            .collect();
773        got.sort_by_key(|id| id.0);
774        let expected = brute_range(&pt_ids, &bbox);
775        assert_eq!(got, expected, "6D 10k range query mismatch");
776    }
777
778    #[test]
779    fn knn_correctness_2d() {
780        let pts = rand_pts_2d(1000, 99);
781        let mut tree = RTree::<usize, f64, 2>::new();
782        let mut all: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
783        for (i, &p) in pts.iter().enumerate() {
784            let id = tree.insert(p, i);
785            all.push((p, id));
786        }
787        let query = Point::new([500.0, 500.0]);
788        let k = 5;
789        let knn = tree.knn_query(&query, k);
790        assert_eq!(knn.len(), k);
791        for w in knn.windows(2) {
792            assert!(w[0].0 <= w[1].0, "kNN not sorted");
793        }
794        let mut bf: Vec<(f64, EntryId)> = all
795            .iter()
796            .map(|(p, id)| (point_dist_sq(p, &query).sqrt(), *id))
797            .collect();
798        bf.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
799        let bf_ids: Vec<EntryId> = bf[..k].iter().map(|(_, id)| *id).collect();
800        let knn_ids: Vec<EntryId> = knn.iter().map(|(_, id, _)| *id).collect();
801        assert_eq!(knn_ids, bf_ids, "kNN ids don't match brute force");
802    }
803
804    #[test]
805    fn remove_works() {
806        let mut tree = RTree::<&str, f64, 2>::new();
807        let id1 = tree.insert(Point::new([1.0, 1.0]), "a");
808        let id2 = tree.insert(Point::new([2.0, 2.0]), "b");
809        assert_eq!(tree.len(), 2);
810        assert_eq!(tree.remove(id1), Some("a"));
811        assert_eq!(tree.len(), 1);
812        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
813        let ids: Vec<EntryId> = tree
814            .range_query(&bbox)
815            .into_iter()
816            .map(|(id, _)| id)
817            .collect();
818        assert!(!ids.contains(&id1));
819        assert!(ids.contains(&id2));
820    }
821
822    #[test]
823    fn kind_is_rtree() {
824        assert_eq!(RTree::<u32, f64, 2>::new().kind(), BackendKind::RTree);
825    }
826
827    #[test]
828    fn bulk_load_correctness() {
829        let pts = rand_pts_2d(500, 77);
830        let entries: Vec<(Point<f64, 2>, usize)> =
831            pts.iter().enumerate().map(|(i, &p)| (p, i)).collect();
832        let tree = RTree::<usize, f64, 2>::bulk_load(entries);
833        assert_eq!(tree.len(), 500);
834        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([1000.0, 1000.0]));
835        assert_eq!(tree.range_query(&bbox).len(), 500);
836    }
837
838    fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
839        (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
840    }
841
842    fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
843        (
844            0.0_f64..900.0,
845            0.0_f64..900.0,
846            10.0_f64..200.0,
847            10.0_f64..200.0,
848        )
849            .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
850    }
851
852    fn pt3d() -> impl Strategy<Value = Point<f64, 3>> {
853        (0.0_f64..1000.0, 0.0_f64..1000.0, 0.0_f64..1000.0)
854            .prop_map(|(x, y, z)| Point::new([x, y, z]))
855    }
856
857    fn bbox3d() -> impl Strategy<Value = BBox<f64, 3>> {
858        (
859            0.0_f64..800.0,
860            0.0_f64..800.0,
861            0.0_f64..800.0,
862            10.0_f64..200.0,
863            10.0_f64..200.0,
864            10.0_f64..200.0,
865        )
866            .prop_map(|(x, y, z, w, h, d)| {
867                BBox::new(Point::new([x, y, z]), Point::new([x + w, y + h, z + d]))
868            })
869    }
870
871    // For any dataset and bbox, range_query result must equal brute-force linear scan.
872    proptest! {
873        #![proptest_config(proptest::test_runner::Config {
874            cases: 100,
875            ..Default::default()
876        })]
877
878        #[test]
879        fn prop_range_query_oracle_2d(
880            pts in prop::collection::vec(pt2d(), 1..100),
881            bbox in bbox2d(),
882        ) {
883            let mut tree = RTree::<usize, f64, 2>::new();
884            let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
885            for (i, p) in pts.iter().enumerate() {
886                let id = tree.insert(*p, i);
887                pt_ids.push((*p, id));
888            }
889            let mut got: Vec<EntryId> =
890                tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
891            got.sort_by_key(|id| id.0);
892            let expected = brute_range(&pt_ids, &bbox);
893            prop_assert_eq!(got, expected);
894        }
895    }
896
897    proptest! {
898        #![proptest_config(proptest::test_runner::Config {
899            cases: 100,
900            ..Default::default()
901        })]
902
903        #[test]
904        fn prop_range_query_oracle_3d(
905            pts in prop::collection::vec(pt3d(), 1..80),
906            bbox in bbox3d(),
907        ) {
908            let mut tree = RTree::<usize, f64, 3>::new();
909            let mut pt_ids: Vec<(Point<f64, 3>, EntryId)> = Vec::new();
910            for (i, p) in pts.iter().enumerate() {
911                let id = tree.insert(*p, i);
912                pt_ids.push((*p, id));
913            }
914            let mut got: Vec<EntryId> =
915                tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
916            got.sort_by_key(|id| id.0);
917            let expected = brute_range(&pt_ids, &bbox);
918            prop_assert_eq!(got, expected);
919        }
920    }
921
922    // knn_query(k) must return globally nearest k points in ascending distance order.
923    proptest! {
924        #![proptest_config(proptest::test_runner::Config {
925            cases: 100,
926            ..Default::default()
927        })]
928
929        #[test]
930        fn prop_knn_correctness(
931            pts in prop::collection::vec(pt2d(), 5..80),
932            qx in 0.0_f64..1000.0,
933            qy in 0.0_f64..1000.0,
934            k in 1usize..5,
935        ) {
936            let mut tree = RTree::<usize, f64, 2>::new();
937            let mut all: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
938            for (i, p) in pts.iter().enumerate() {
939                let id = tree.insert(*p, i);
940                all.push((*p, id));
941            }
942            let query = Point::new([qx, qy]);
943            let actual_k = k.min(all.len());
944            let knn = tree.knn_query(&query, actual_k);
945            prop_assert_eq!(knn.len(), actual_k);
946            for w in knn.windows(2) {
947                prop_assert!(w[0].0 <= w[1].0, "kNN not sorted");
948            }
949            let mut bf: Vec<(f64, EntryId)> = all
950                .iter()
951                .map(|(p, id)| (point_dist_sq(p, &query).sqrt(), *id))
952                .collect();
953            bf.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
954            let bf_ids: Vec<EntryId> = bf[..actual_k].iter().map(|(_, id)| *id).collect();
955            let knn_ids: Vec<EntryId> = knn.iter().map(|(_, id, _)| *id).collect();
956            prop_assert_eq!(knn_ids, bf_ids);
957        }
958    }
959
960    // Insert-Remove Round Trip
961    proptest! {
962        #![proptest_config(proptest::test_runner::Config {
963            cases: 100,
964            ..Default::default()
965        })]
966
967        #[test]
968        fn prop_insert_remove_round_trip_rtree(
969            pts in prop::collection::vec(pt2d(), 1..50),
970            remove_indices in prop::collection::vec(0usize..50, 0..25),
971        ) {
972            let mut tree = RTree::<usize, f64, 2>::new();
973            let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
974            for (i, &p) in pts.iter().enumerate() {
975                let id = tree.insert(p, i);
976                inserted.push((p, id));
977            }
978            // Remove a subset of entries
979            let mut removed_ids: Vec<EntryId> = Vec::new();
980            for &ri in &remove_indices {
981                let idx = ri % inserted.len();
982                let (_, id) = inserted[idx];
983                if !removed_ids.contains(&id) {
984                    let result = tree.remove(id);
985                    prop_assert!(result.is_some(), "remove returned None for inserted id");
986                    removed_ids.push(id);
987                }
988            }
989            // Verify removed entries no longer appear in range queries
990            let full_bbox = BBox::new(Point::new([-1.0e9, -1.0e9]), Point::new([1.0e9, 1.0e9]));
991            let remaining_ids: Vec<EntryId> = tree.range_query(&full_bbox)
992                .into_iter()
993                .map(|(id, _)| id)
994                .collect();
995            for &removed_id in &removed_ids {
996                prop_assert!(
997                    !remaining_ids.contains(&removed_id),
998                    "removed entry {:?} still appears in range query",
999                    removed_id
1000                );
1001            }
1002            // Verify len is consistent
1003            let expected_len = inserted.len() - removed_ids.len();
1004            prop_assert_eq!(tree.len(), expected_len);
1005        }
1006    }
1007
1008    // spatial_join must return exactly the pairs where bboxes intersect.
1009    proptest! {
1010        #![proptest_config(proptest::test_runner::Config {
1011            cases: 100,
1012            ..Default::default()
1013        })]
1014
1015        #[test]
1016        fn prop_spatial_join_completeness(
1017            pts_a in prop::collection::vec(pt2d(), 1..30),
1018            pts_b in prop::collection::vec(pt2d(), 1..30),
1019        ) {
1020            let mut tree_a = RTree::<usize, f64, 2>::new();
1021            let mut tree_b = RTree::<usize, f64, 2>::new();
1022            let mut ids_a: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
1023            let mut ids_b: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
1024            for (i, p) in pts_a.iter().enumerate() {
1025                let id = tree_a.insert(*p, i);
1026                ids_a.push((*p, id));
1027            }
1028            for (i, p) in pts_b.iter().enumerate() {
1029                let id = tree_b.insert(*p, i);
1030                ids_b.push((*p, id));
1031            }
1032            let mut got = tree_a.spatial_join(&tree_b);
1033            got.sort();
1034            let mut expected: Vec<(EntryId, EntryId)> = Vec::new();
1035            for (pa, id_a) in &ids_a {
1036                for (pb, id_b) in &ids_b {
1037                    if point_bbox(pa).intersects(&point_bbox(pb)) {
1038                        expected.push((*id_a, *id_b));
1039                    }
1040                }
1041            }
1042            expected.sort();
1043            prop_assert_eq!(got, expected);
1044        }
1045    }
1046}