Skip to main content

bonsai/backends/
kdtree.rs

1//! KD-tree spatial index backend with sliding midpoint splits.
2
3use std::collections::HashMap;
4use std::sync::Mutex;
5
6use crate::backends::SpatialBackend;
7use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
8
9const LEAF_CAPACITY: usize = 8;
10
11enum KDNode<C, const D: usize> {
12    Leaf {
13        entries: Vec<(Point<C, D>, EntryId)>,
14    },
15    Split {
16        axis: usize,
17        split_val: C,
18        left: Box<KDNode<C, D>>,
19        right: Box<KDNode<C, D>>,
20    },
21}
22
23fn entries_bbox<C: CoordType, const D: usize>(entries: &[(Point<C, D>, EntryId)]) -> BBox<C, D> {
24    debug_assert!(!entries.is_empty());
25    let mut min_c = *entries[0].0.coords();
26    let mut max_c = *entries[0].0.coords();
27    for (p, _) in entries.iter().skip(1) {
28        for d in 0..D {
29            let v = p.coords()[d];
30            if v < min_c[d] {
31                min_c[d] = v;
32            }
33            if v > max_c[d] {
34                max_c[d] = v;
35            }
36        }
37    }
38    BBox::new(Point::new(min_c), Point::new(max_c))
39}
40
41fn build_node<C: CoordType, const D: usize>(
42    mut entries: Vec<(Point<C, D>, EntryId)>,
43    bbox: BBox<C, D>,
44) -> KDNode<C, D> {
45    if entries.len() <= LEAF_CAPACITY {
46        return KDNode::Leaf { entries };
47    }
48
49    // Choose the axis with the largest spread.
50    let mut best_axis = 0;
51    let mut best_span = C::zero();
52    for d in 0..D {
53        let span = bbox.max.coords()[d] - bbox.min.coords()[d];
54        if span > best_span {
55            best_span = span;
56            best_axis = d;
57        }
58    }
59
60    if best_span == C::zero() {
61        return KDNode::Leaf { entries };
62    }
63
64    // Median split guarantees ceil(log2(n)) + 1 depth.
65    entries.sort_by(|(a, _), (b, _)| {
66        let av: f64 = a.coords()[best_axis].into();
67        let bv: f64 = b.coords()[best_axis].into();
68        av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
69    });
70
71    let mid_idx = entries.len() / 2;
72    let right_entries = entries.split_off(mid_idx);
73    let left_entries = entries;
74
75    let actual_split_f64 = left_entries
76        .iter()
77        .map(|(p, _)| {
78            let v: f64 = p.coords()[best_axis].into();
79            v
80        })
81        .fold(f64::NEG_INFINITY, f64::max);
82
83    let split_val: C = C::from(actual_split_f64 as f32);
84
85    let mut left_bbox = bbox;
86    left_bbox.max.coords_mut()[best_axis] = split_val;
87
88    let mut right_bbox = bbox;
89    right_bbox.min.coords_mut()[best_axis] = split_val;
90
91    let left = Box::new(build_node(left_entries, left_bbox));
92    let right = Box::new(build_node(right_entries, right_bbox));
93
94    KDNode::Split {
95        axis: best_axis,
96        split_val,
97        left,
98        right,
99    }
100}
101
102fn node_range<C: CoordType, const D: usize>(
103    node: &KDNode<C, D>,
104    bbox: &BBox<C, D>,
105    node_bbox: &BBox<C, D>,
106    out: &mut Vec<EntryId>,
107) {
108    if !node_bbox.intersects(bbox) {
109        return;
110    }
111    match node {
112        KDNode::Leaf { entries } => {
113            for (p, id) in entries {
114                if bbox.contains_point(p) {
115                    out.push(*id);
116                }
117            }
118        }
119        KDNode::Split {
120            axis,
121            split_val,
122            left,
123            right,
124        } => {
125            let mut left_bbox = *node_bbox;
126            left_bbox.max.coords_mut()[*axis] = *split_val;
127            let mut right_bbox = *node_bbox;
128            right_bbox.min.coords_mut()[*axis] = *split_val;
129
130            node_range(left, bbox, &left_bbox, out);
131            node_range(right, bbox, &right_bbox, out);
132        }
133    }
134}
135
136fn node_collect<C: CoordType, const D: usize>(
137    node: &KDNode<C, D>,
138    out: &mut Vec<(Point<C, D>, EntryId)>,
139) {
140    match node {
141        KDNode::Leaf { entries } => {
142            out.extend_from_slice(entries);
143        }
144        KDNode::Split { left, right, .. } => {
145            node_collect(left, out);
146            node_collect(right, out);
147        }
148    }
149}
150
151fn node_depth<C, const D: usize>(node: &KDNode<C, D>) -> usize {
152    match node {
153        KDNode::Leaf { .. } => 1,
154        KDNode::Split { left, right, .. } => 1 + node_depth(left).max(node_depth(right)),
155    }
156}
157
158fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
159    let mut sum = 0.0_f64;
160    for d in 0..D {
161        let da: f64 = a.coords()[d].into();
162        let db: f64 = b.coords()[d].into();
163        let diff = da - db;
164        sum += diff * diff;
165    }
166    sum
167}
168
169/// KD-tree spatial index with sliding midpoint splits.
170///
171/// Guarantees O(log n) depth for any data distribution via median splits.
172pub struct KDTree<T, C, const D: usize> {
173    tree: Mutex<(KDNode<C, D>, BBox<C, D>)>,
174    flat: Vec<(Point<C, D>, EntryId)>,
175    id_to_payload: HashMap<u64, T>,
176    id_to_point: HashMap<u64, Point<C, D>>,
177    len: usize,
178    next_id: u64,
179    dirty: Mutex<bool>,
180}
181
182impl<T, C: CoordType, const D: usize> KDTree<T, C, D> {
183    pub fn new() -> Self {
184        let empty_bbox = BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
185        Self {
186            tree: Mutex::new((
187                KDNode::Leaf {
188                    entries: Vec::new(),
189                },
190                empty_bbox,
191            )),
192            flat: Vec::new(),
193            id_to_payload: HashMap::new(),
194            id_to_point: HashMap::new(),
195            len: 0,
196            next_id: 0,
197            dirty: Mutex::new(false),
198        }
199    }
200
201    fn alloc_id(&mut self) -> EntryId {
202        let id = EntryId(self.next_id);
203        self.next_id += 1;
204        id
205    }
206
207    fn rebuild(&self) {
208        let empty_bbox = BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
209        if self.flat.is_empty() {
210            *self.tree.lock().unwrap() = (
211                KDNode::Leaf {
212                    entries: Vec::new(),
213                },
214                empty_bbox,
215            );
216            *self.dirty.lock().unwrap() = false;
217            return;
218        }
219        let bbox = entries_bbox(&self.flat);
220        let root = build_node(self.flat.clone(), bbox);
221        *self.tree.lock().unwrap() = (root, bbox);
222        *self.dirty.lock().unwrap() = false;
223    }
224
225    fn ensure_built(&self) {
226        if *self.dirty.lock().unwrap() {
227            self.rebuild();
228        }
229    }
230
231    /// Return the current depth of the tree.
232    pub fn depth(&self) -> usize {
233        self.ensure_built();
234        node_depth(&self.tree.lock().unwrap().0)
235    }
236
237    pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
238        let id = self.alloc_id();
239        self.id_to_point.insert(id.0, point);
240        self.id_to_payload.insert(id.0, payload);
241        self.flat.push((point, id));
242        self.len += 1;
243        *self.dirty.lock().unwrap() = true;
244        id
245    }
246
247    pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
248        self.id_to_point.remove(&id.0)?;
249        let payload = self.id_to_payload.remove(&id.0)?;
250        if let Some(pos) = self.flat.iter().position(|(_, eid)| *eid == id) {
251            self.flat.swap_remove(pos);
252        }
253        self.len -= 1;
254        *self.dirty.lock().unwrap() = true;
255        Some(payload)
256    }
257
258    pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
259        self.ensure_built();
260        if self.len == 0 {
261            return Vec::new();
262        }
263        let tree = self.tree.lock().unwrap();
264        let (root, root_bbox) = &*tree;
265        let mut ids = Vec::new();
266        node_range(root, bbox, root_bbox, &mut ids);
267        drop(tree);
268        ids.into_iter()
269            .filter_map(|id| self.id_to_payload.get(&id.0).map(|p| (id, p)))
270            .collect()
271    }
272
273    pub fn knn_query_impl<'a>(
274        &'a self,
275        point: &Point<C, D>,
276        k: usize,
277    ) -> Vec<(f64, EntryId, &'a T)> {
278        self.ensure_built();
279        if k == 0 {
280            return Vec::new();
281        }
282        let tree = self.tree.lock().unwrap();
283        let (root, _) = &*tree;
284        let mut raw: Vec<(Point<C, D>, EntryId)> = Vec::new();
285        node_collect(root, &mut raw);
286        drop(tree);
287        let mut all: Vec<(f64, EntryId, &'a T)> = raw
288            .into_iter()
289            .filter_map(|(p, id)| {
290                self.id_to_payload
291                    .get(&id.0)
292                    .map(|payload| (point_dist_sq(&p, point).sqrt(), id, payload))
293            })
294            .collect();
295        all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
296        all.truncate(k);
297        all
298    }
299
300    fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
301        self.ensure_built();
302        let tree = self.tree.lock().unwrap();
303        let (root, _) = &*tree;
304        let mut raw: Vec<(Point<C, D>, EntryId)> = Vec::new();
305        node_collect(root, &mut raw);
306        drop(tree);
307        raw.into_iter()
308            .filter_map(|(p, id)| {
309                self.id_to_payload
310                    .get(&id.0)
311                    .map(|payload| (p, id, payload))
312            })
313            .collect()
314    }
315}
316
317impl<T, C: CoordType, const D: usize> Default for KDTree<T, C, D> {
318    fn default() -> Self {
319        Self::new()
320    }
321}
322
323impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
324    for KDTree<T, C, D>
325{
326    fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
327        self.insert_entry(point, payload)
328    }
329
330    fn remove(&mut self, id: EntryId) -> Option<T> {
331        self.remove_entry(id)
332    }
333
334    fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
335        self.range_query_impl(bbox)
336    }
337
338    fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
339        self.knn_query_impl(point, k)
340    }
341
342    fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
343        let self_entries = self.collect_all();
344        let other_entries = other.all_entries();
345        let mut pairs = Vec::new();
346        for (pa, id_a, _) in &self_entries {
347            let bbox_a = BBox::new(*pa, *pa);
348            for (pb, id_b, _) in &other_entries {
349                if bbox_a.intersects(&BBox::new(*pb, *pb)) {
350                    pairs.push((*id_a, *id_b));
351                }
352            }
353        }
354        pairs
355    }
356
357    fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
358        let mut tree = Self::new();
359        for (point, payload) in entries {
360            tree.insert_entry(point, payload);
361        }
362        tree
363    }
364
365    fn len(&self) -> usize {
366        self.len
367    }
368
369    fn kind(&self) -> BackendKind {
370        BackendKind::KDTree
371    }
372
373    fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
374        self.collect_all()
375    }
376}
377
378#[cfg(test)]
379mod tests {
380    use super::*;
381    use proptest::prelude::*;
382
383    struct Lcg(u64);
384    impl Lcg {
385        fn new(seed: u64) -> Self {
386            Self(seed)
387        }
388        fn next_f64(&mut self) -> f64 {
389            self.0 = self
390                .0
391                .wrapping_mul(6_364_136_223_846_793_005)
392                .wrapping_add(1_442_695_040_888_963_407);
393            (self.0 >> 11) as f64 / (1u64 << 53) as f64
394        }
395        fn next_normal(&mut self) -> f64 {
396            let u1 = self.next_f64().max(1e-15);
397            let u2 = self.next_f64();
398            (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
399        }
400    }
401
402    fn rand_pts_2d(n: usize, seed: u64) -> Vec<Point<f64, 2>> {
403        let mut rng = Lcg::new(seed);
404        (0..n)
405            .map(|_| Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]))
406            .collect()
407    }
408
409    fn clustered_pts_2d(n: usize, clusters: usize, seed: u64) -> Vec<Point<f64, 2>> {
410        let mut rng = Lcg::new(seed);
411        let centres: Vec<[f64; 2]> = (0..clusters)
412            .map(|_| {
413                [
414                    rng.next_f64() * 800.0 + 100.0,
415                    rng.next_f64() * 800.0 + 100.0,
416                ]
417            })
418            .collect();
419        (0..n)
420            .map(|i| {
421                let c = &centres[i % clusters];
422                let x = c[0] + rng.next_normal() * 20.0;
423                let y = c[1] + rng.next_normal() * 20.0;
424                Point::new([x.clamp(0.0, 1000.0), y.clamp(0.0, 1000.0)])
425            })
426            .collect()
427    }
428
429    fn brute_range<C: CoordType, const D: usize>(
430        pts: &[(Point<C, D>, EntryId)],
431        bbox: &BBox<C, D>,
432    ) -> Vec<EntryId> {
433        let mut ids: Vec<EntryId> = pts
434            .iter()
435            .filter(|(p, _)| bbox.contains_point(p))
436            .map(|(_, id)| *id)
437            .collect();
438        ids.sort_by_key(|id| id.0);
439        ids
440    }
441
442    #[test]
443    fn insert_and_len() {
444        let mut tree = KDTree::<u32, f64, 2>::new();
445        assert_eq!(tree.len(), 0);
446        tree.insert(Point::new([0.5, 0.5]), 1u32);
447        assert_eq!(tree.len(), 1);
448        tree.insert(Point::new([1.5, 1.5]), 2u32);
449        assert_eq!(tree.len(), 2);
450    }
451
452    #[test]
453    fn range_query_basic() {
454        let mut tree = KDTree::<u32, f64, 2>::new();
455        let id1 = tree.insert(Point::new([0.5, 0.5]), 1u32);
456        let id2 = tree.insert(Point::new([1.5, 1.5]), 2u32);
457        let _id3 = tree.insert(Point::new([5.0, 5.0]), 3u32);
458        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
459        let mut got: Vec<EntryId> = tree
460            .range_query(&bbox)
461            .into_iter()
462            .map(|(id, _)| id)
463            .collect();
464        got.sort_by_key(|id| id.0);
465        assert_eq!(got, vec![id1, id2]);
466    }
467
468    #[test]
469    fn remove_works() {
470        let mut tree = KDTree::<u32, f64, 2>::new();
471        let id1 = tree.insert(Point::new([1.0, 1.0]), 10u32);
472        let id2 = tree.insert(Point::new([2.0, 2.0]), 20u32);
473        assert_eq!(tree.len(), 2);
474        assert_eq!(tree.remove(id1), Some(10u32));
475        assert_eq!(tree.len(), 1);
476        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
477        let ids: Vec<EntryId> = tree
478            .range_query(&bbox)
479            .into_iter()
480            .map(|(id, _)| id)
481            .collect();
482        assert!(!ids.contains(&id1));
483        assert!(ids.contains(&id2));
484    }
485
486    #[test]
487    fn kind_is_kdtree() {
488        assert_eq!(KDTree::<u32, f64, 2>::new().kind(), BackendKind::KDTree);
489    }
490
491    #[test]
492    fn range_query_vs_brute_force_2d_10k() {
493        let pts = rand_pts_2d(10_000, 42);
494        let mut tree = KDTree::<usize, f64, 2>::new();
495        let mut pt_ids = Vec::new();
496        for (i, &p) in pts.iter().enumerate() {
497            let id = tree.insert(p, i);
498            pt_ids.push((p, id));
499        }
500        let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
501        let mut got: Vec<EntryId> = tree
502            .range_query(&bbox)
503            .into_iter()
504            .map(|(id, _)| id)
505            .collect();
506        got.sort_by_key(|id| id.0);
507        let expected = brute_range(&pt_ids, &bbox);
508        assert_eq!(got, expected);
509    }
510
511    fn check_depth_bound<const D: usize>(n: usize, seed: u64)
512    where
513        [(); D]:,
514    {
515        let mut rng = Lcg::new(seed);
516        let mut tree = KDTree::<usize, f64, D>::new();
517        for i in 0..n {
518            let coords: [f64; D] = std::array::from_fn(|_| rng.next_f64() * 1000.0);
519            tree.insert(Point::new(coords), i);
520        }
521        let depth = tree.depth();
522        let bound = (n as f64).log2().ceil() as usize + 1;
523        assert!(
524            depth <= bound,
525            "D={D}: depth {depth} exceeds bound {bound} for n={n}"
526        );
527    }
528
529    #[test]
530    fn depth_bound_d2() {
531        check_depth_bound::<2>(1000, 1);
532    }
533    #[test]
534    fn depth_bound_d3() {
535        check_depth_bound::<3>(1000, 2);
536    }
537    #[test]
538    fn depth_bound_d4() {
539        check_depth_bound::<4>(1000, 3);
540    }
541    #[test]
542    fn depth_bound_d5() {
543        check_depth_bound::<5>(1000, 4);
544    }
545    #[test]
546    fn depth_bound_d6() {
547        check_depth_bound::<6>(1000, 5);
548    }
549    #[test]
550    fn depth_bound_d7() {
551        check_depth_bound::<7>(1000, 6);
552    }
553    #[test]
554    fn depth_bound_d8() {
555        check_depth_bound::<8>(1000, 7);
556    }
557
558    #[test]
559    fn depth_bound_clustered_2d() {
560        let pts = clustered_pts_2d(5000, 20, 99);
561        let mut tree = KDTree::<usize, f64, 2>::new();
562        for (i, &p) in pts.iter().enumerate() {
563            tree.insert(p, i);
564        }
565        let n = pts.len();
566        let depth = tree.depth();
567        let bound = (n as f64).log2().ceil() as usize + 1;
568        assert!(
569            depth <= bound,
570            "clustered: depth {depth} exceeds bound {bound} for n={n}"
571        );
572    }
573
574    fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
575        (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
576    }
577
578    fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
579        (
580            0.0_f64..900.0,
581            0.0_f64..900.0,
582            10.0_f64..200.0,
583            10.0_f64..200.0,
584        )
585            .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
586    }
587
588    // KD-tree depth must not exceed ceil(log2(n)) + 1 for any dataset.
589    proptest! {
590        #![proptest_config(proptest::test_runner::Config {
591            cases: 100,
592            ..Default::default()
593        })]
594
595        #[test]
596        fn prop_kdtree_depth_bound(pts in prop::collection::vec(pt2d(), 2..200)) {
597            let mut tree = KDTree::<usize, f64, 2>::new();
598            for (i, p) in pts.iter().enumerate() {
599                tree.insert(*p, i);
600            }
601            let n = pts.len();
602            let depth = tree.depth();
603            let bound = (n as f64).log2().ceil() as usize + 1;
604            prop_assert!(depth <= bound, "depth {} exceeds bound {} for n={}", depth, bound, n);
605        }
606    }
607
608    // Insert-Remove Round Trip
609    proptest! {
610        #![proptest_config(proptest::test_runner::Config {
611            cases: 100,
612            ..Default::default()
613        })]
614
615        #[test]
616        fn prop_insert_remove_round_trip_kdtree(
617            pts in prop::collection::vec(pt2d(), 1..50),
618            remove_indices in prop::collection::vec(0usize..50, 0..25),
619        ) {
620            let mut tree = KDTree::<usize, f64, 2>::new();
621            let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
622            for (i, &p) in pts.iter().enumerate() {
623                let id = tree.insert(p, i);
624                inserted.push((p, id));
625            }
626            let mut removed_ids: Vec<EntryId> = Vec::new();
627            for &ri in &remove_indices {
628                let idx = ri % inserted.len();
629                let (_, id) = inserted[idx];
630                if !removed_ids.contains(&id) {
631                    let result = tree.remove(id);
632                    prop_assert!(result.is_some(), "remove returned None for inserted id");
633                    removed_ids.push(id);
634                }
635            }
636            let full_bbox = BBox::new(Point::new([-1.0e9, -1.0e9]), Point::new([1.0e9, 1.0e9]));
637            let remaining_ids: Vec<EntryId> = tree.range_query(&full_bbox)
638                .into_iter()
639                .map(|(id, _)| id)
640                .collect();
641            for &removed_id in &removed_ids {
642                prop_assert!(
643                    !remaining_ids.contains(&removed_id),
644                    "removed entry {:?} still appears in range query",
645                    removed_id
646                );
647            }
648            let expected_len = inserted.len() - removed_ids.len();
649            prop_assert_eq!(tree.len(), expected_len);
650        }
651    }
652
653    // Range query result must equal brute-force linear scan for any dataset and bbox.
654    proptest! {
655        #![proptest_config(proptest::test_runner::Config {
656            cases: 100,
657            ..Default::default()
658        })]
659
660        #[test]
661        fn prop_range_query_oracle_kdtree(
662            pts in prop::collection::vec(pt2d(), 1..100),
663            bbox in bbox2d(),
664        ) {
665            let mut tree = KDTree::<usize, f64, 2>::new();
666            let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
667            for (i, p) in pts.iter().enumerate() {
668                let id = tree.insert(*p, i);
669                pt_ids.push((*p, id));
670            }
671            let mut got: Vec<EntryId> =
672                tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
673            got.sort_by_key(|id| id.0);
674            let expected = brute_range(&pt_ids, &bbox);
675            prop_assert_eq!(got, expected);
676        }
677    }
678}