1use std::collections::HashMap;
4use std::sync::Mutex;
5
6use crate::backends::SpatialBackend;
7use crate::types::{BBox, BackendKind, CoordType, EntryId, Point};
8
9const LEAF_CAPACITY: usize = 8;
10
11enum KDNode<C, const D: usize> {
12 Leaf {
13 entries: Vec<(Point<C, D>, EntryId)>,
14 },
15 Split {
16 axis: usize,
17 split_val: C,
18 left: Box<KDNode<C, D>>,
19 right: Box<KDNode<C, D>>,
20 },
21}
22
23fn entries_bbox<C: CoordType, const D: usize>(entries: &[(Point<C, D>, EntryId)]) -> BBox<C, D> {
24 debug_assert!(!entries.is_empty());
25 let mut min_c = *entries[0].0.coords();
26 let mut max_c = *entries[0].0.coords();
27 for (p, _) in entries.iter().skip(1) {
28 for d in 0..D {
29 let v = p.coords()[d];
30 if v < min_c[d] {
31 min_c[d] = v;
32 }
33 if v > max_c[d] {
34 max_c[d] = v;
35 }
36 }
37 }
38 BBox::new(Point::new(min_c), Point::new(max_c))
39}
40
41fn build_node<C: CoordType, const D: usize>(
42 mut entries: Vec<(Point<C, D>, EntryId)>,
43 bbox: BBox<C, D>,
44) -> KDNode<C, D> {
45 if entries.len() <= LEAF_CAPACITY {
46 return KDNode::Leaf { entries };
47 }
48
49 let mut best_axis = 0;
51 let mut best_span = C::zero();
52 for d in 0..D {
53 let span = bbox.max.coords()[d] - bbox.min.coords()[d];
54 if span > best_span {
55 best_span = span;
56 best_axis = d;
57 }
58 }
59
60 if best_span == C::zero() {
61 return KDNode::Leaf { entries };
62 }
63
64 entries.sort_by(|(a, _), (b, _)| {
66 let av: f64 = a.coords()[best_axis].into();
67 let bv: f64 = b.coords()[best_axis].into();
68 av.partial_cmp(&bv).unwrap_or(std::cmp::Ordering::Equal)
69 });
70
71 let mid_idx = entries.len() / 2;
72 let right_entries = entries.split_off(mid_idx);
73 let left_entries = entries;
74
75 let actual_split_f64 = left_entries
76 .iter()
77 .map(|(p, _)| {
78 let v: f64 = p.coords()[best_axis].into();
79 v
80 })
81 .fold(f64::NEG_INFINITY, f64::max);
82
83 let split_val: C = C::from(actual_split_f64 as f32);
84
85 let mut left_bbox = bbox;
86 left_bbox.max.coords_mut()[best_axis] = split_val;
87
88 let mut right_bbox = bbox;
89 right_bbox.min.coords_mut()[best_axis] = split_val;
90
91 let left = Box::new(build_node(left_entries, left_bbox));
92 let right = Box::new(build_node(right_entries, right_bbox));
93
94 KDNode::Split {
95 axis: best_axis,
96 split_val,
97 left,
98 right,
99 }
100}
101
102fn node_range<C: CoordType, const D: usize>(
103 node: &KDNode<C, D>,
104 bbox: &BBox<C, D>,
105 node_bbox: &BBox<C, D>,
106 out: &mut Vec<EntryId>,
107) {
108 if !node_bbox.intersects(bbox) {
109 return;
110 }
111 match node {
112 KDNode::Leaf { entries } => {
113 for (p, id) in entries {
114 if bbox.contains_point(p) {
115 out.push(*id);
116 }
117 }
118 }
119 KDNode::Split {
120 axis,
121 split_val,
122 left,
123 right,
124 } => {
125 let mut left_bbox = *node_bbox;
126 left_bbox.max.coords_mut()[*axis] = *split_val;
127 let mut right_bbox = *node_bbox;
128 right_bbox.min.coords_mut()[*axis] = *split_val;
129
130 node_range(left, bbox, &left_bbox, out);
131 node_range(right, bbox, &right_bbox, out);
132 }
133 }
134}
135
136fn node_collect<C: CoordType, const D: usize>(
137 node: &KDNode<C, D>,
138 out: &mut Vec<(Point<C, D>, EntryId)>,
139) {
140 match node {
141 KDNode::Leaf { entries } => {
142 out.extend_from_slice(entries);
143 }
144 KDNode::Split { left, right, .. } => {
145 node_collect(left, out);
146 node_collect(right, out);
147 }
148 }
149}
150
151fn node_depth<C, const D: usize>(node: &KDNode<C, D>) -> usize {
152 match node {
153 KDNode::Leaf { .. } => 1,
154 KDNode::Split { left, right, .. } => 1 + node_depth(left).max(node_depth(right)),
155 }
156}
157
158fn point_dist_sq<C: CoordType, const D: usize>(a: &Point<C, D>, b: &Point<C, D>) -> f64 {
159 let mut sum = 0.0_f64;
160 for d in 0..D {
161 let da: f64 = a.coords()[d].into();
162 let db: f64 = b.coords()[d].into();
163 let diff = da - db;
164 sum += diff * diff;
165 }
166 sum
167}
168
169pub struct KDTree<T, C, const D: usize> {
173 tree: Mutex<(KDNode<C, D>, BBox<C, D>)>,
174 flat: Vec<(Point<C, D>, EntryId)>,
175 id_to_payload: HashMap<u64, T>,
176 id_to_point: HashMap<u64, Point<C, D>>,
177 len: usize,
178 next_id: u64,
179 dirty: Mutex<bool>,
180}
181
182impl<T, C: CoordType, const D: usize> KDTree<T, C, D> {
183 pub fn new() -> Self {
184 let empty_bbox = BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
185 Self {
186 tree: Mutex::new((
187 KDNode::Leaf {
188 entries: Vec::new(),
189 },
190 empty_bbox,
191 )),
192 flat: Vec::new(),
193 id_to_payload: HashMap::new(),
194 id_to_point: HashMap::new(),
195 len: 0,
196 next_id: 0,
197 dirty: Mutex::new(false),
198 }
199 }
200
201 fn alloc_id(&mut self) -> EntryId {
202 let id = EntryId(self.next_id);
203 self.next_id += 1;
204 id
205 }
206
207 fn rebuild(&self) {
208 let empty_bbox = BBox::new(Point::new([C::zero(); D]), Point::new([C::zero(); D]));
209 if self.flat.is_empty() {
210 *self.tree.lock().unwrap() = (
211 KDNode::Leaf {
212 entries: Vec::new(),
213 },
214 empty_bbox,
215 );
216 *self.dirty.lock().unwrap() = false;
217 return;
218 }
219 let bbox = entries_bbox(&self.flat);
220 let root = build_node(self.flat.clone(), bbox);
221 *self.tree.lock().unwrap() = (root, bbox);
222 *self.dirty.lock().unwrap() = false;
223 }
224
225 fn ensure_built(&self) {
226 if *self.dirty.lock().unwrap() {
227 self.rebuild();
228 }
229 }
230
231 pub fn depth(&self) -> usize {
233 self.ensure_built();
234 node_depth(&self.tree.lock().unwrap().0)
235 }
236
237 pub fn insert_entry(&mut self, point: Point<C, D>, payload: T) -> EntryId {
238 let id = self.alloc_id();
239 self.id_to_point.insert(id.0, point);
240 self.id_to_payload.insert(id.0, payload);
241 self.flat.push((point, id));
242 self.len += 1;
243 *self.dirty.lock().unwrap() = true;
244 id
245 }
246
247 pub fn remove_entry(&mut self, id: EntryId) -> Option<T> {
248 self.id_to_point.remove(&id.0)?;
249 let payload = self.id_to_payload.remove(&id.0)?;
250 if let Some(pos) = self.flat.iter().position(|(_, eid)| *eid == id) {
251 self.flat.swap_remove(pos);
252 }
253 self.len -= 1;
254 *self.dirty.lock().unwrap() = true;
255 Some(payload)
256 }
257
258 pub fn range_query_impl<'a>(&'a self, bbox: &BBox<C, D>) -> Vec<(EntryId, &'a T)> {
259 self.ensure_built();
260 if self.len == 0 {
261 return Vec::new();
262 }
263 let tree = self.tree.lock().unwrap();
264 let (root, root_bbox) = &*tree;
265 let mut ids = Vec::new();
266 node_range(root, bbox, root_bbox, &mut ids);
267 drop(tree);
268 ids.into_iter()
269 .filter_map(|id| self.id_to_payload.get(&id.0).map(|p| (id, p)))
270 .collect()
271 }
272
273 pub fn knn_query_impl<'a>(
274 &'a self,
275 point: &Point<C, D>,
276 k: usize,
277 ) -> Vec<(f64, EntryId, &'a T)> {
278 self.ensure_built();
279 if k == 0 {
280 return Vec::new();
281 }
282 let tree = self.tree.lock().unwrap();
283 let (root, _) = &*tree;
284 let mut raw: Vec<(Point<C, D>, EntryId)> = Vec::new();
285 node_collect(root, &mut raw);
286 drop(tree);
287 let mut all: Vec<(f64, EntryId, &'a T)> = raw
288 .into_iter()
289 .filter_map(|(p, id)| {
290 self.id_to_payload
291 .get(&id.0)
292 .map(|payload| (point_dist_sq(&p, point).sqrt(), id, payload))
293 })
294 .collect();
295 all.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
296 all.truncate(k);
297 all
298 }
299
300 fn collect_all(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
301 self.ensure_built();
302 let tree = self.tree.lock().unwrap();
303 let (root, _) = &*tree;
304 let mut raw: Vec<(Point<C, D>, EntryId)> = Vec::new();
305 node_collect(root, &mut raw);
306 drop(tree);
307 raw.into_iter()
308 .filter_map(|(p, id)| {
309 self.id_to_payload
310 .get(&id.0)
311 .map(|payload| (p, id, payload))
312 })
313 .collect()
314 }
315}
316
317impl<T, C: CoordType, const D: usize> Default for KDTree<T, C, D> {
318 fn default() -> Self {
319 Self::new()
320 }
321}
322
323impl<T: Send + Sync + 'static, C: CoordType, const D: usize> SpatialBackend<T, C, D>
324 for KDTree<T, C, D>
325{
326 fn insert(&mut self, point: Point<C, D>, payload: T) -> EntryId {
327 self.insert_entry(point, payload)
328 }
329
330 fn remove(&mut self, id: EntryId) -> Option<T> {
331 self.remove_entry(id)
332 }
333
334 fn range_query(&self, bbox: &BBox<C, D>) -> Vec<(EntryId, &T)> {
335 self.range_query_impl(bbox)
336 }
337
338 fn knn_query(&self, point: &Point<C, D>, k: usize) -> Vec<(f64, EntryId, &T)> {
339 self.knn_query_impl(point, k)
340 }
341
342 fn spatial_join(&self, other: &dyn SpatialBackend<T, C, D>) -> Vec<(EntryId, EntryId)> {
343 let self_entries = self.collect_all();
344 let other_entries = other.all_entries();
345 let mut pairs = Vec::new();
346 for (pa, id_a, _) in &self_entries {
347 let bbox_a = BBox::new(*pa, *pa);
348 for (pb, id_b, _) in &other_entries {
349 if bbox_a.intersects(&BBox::new(*pb, *pb)) {
350 pairs.push((*id_a, *id_b));
351 }
352 }
353 }
354 pairs
355 }
356
357 fn bulk_load(entries: Vec<(Point<C, D>, T)>) -> Self {
358 let mut tree = Self::new();
359 for (point, payload) in entries {
360 tree.insert_entry(point, payload);
361 }
362 tree
363 }
364
365 fn len(&self) -> usize {
366 self.len
367 }
368
369 fn kind(&self) -> BackendKind {
370 BackendKind::KDTree
371 }
372
373 fn all_entries(&self) -> Vec<(Point<C, D>, EntryId, &T)> {
374 self.collect_all()
375 }
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use proptest::prelude::*;
382
383 struct Lcg(u64);
384 impl Lcg {
385 fn new(seed: u64) -> Self {
386 Self(seed)
387 }
388 fn next_f64(&mut self) -> f64 {
389 self.0 = self
390 .0
391 .wrapping_mul(6_364_136_223_846_793_005)
392 .wrapping_add(1_442_695_040_888_963_407);
393 (self.0 >> 11) as f64 / (1u64 << 53) as f64
394 }
395 fn next_normal(&mut self) -> f64 {
396 let u1 = self.next_f64().max(1e-15);
397 let u2 = self.next_f64();
398 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()
399 }
400 }
401
402 fn rand_pts_2d(n: usize, seed: u64) -> Vec<Point<f64, 2>> {
403 let mut rng = Lcg::new(seed);
404 (0..n)
405 .map(|_| Point::new([rng.next_f64() * 1000.0, rng.next_f64() * 1000.0]))
406 .collect()
407 }
408
409 fn clustered_pts_2d(n: usize, clusters: usize, seed: u64) -> Vec<Point<f64, 2>> {
410 let mut rng = Lcg::new(seed);
411 let centres: Vec<[f64; 2]> = (0..clusters)
412 .map(|_| {
413 [
414 rng.next_f64() * 800.0 + 100.0,
415 rng.next_f64() * 800.0 + 100.0,
416 ]
417 })
418 .collect();
419 (0..n)
420 .map(|i| {
421 let c = ¢res[i % clusters];
422 let x = c[0] + rng.next_normal() * 20.0;
423 let y = c[1] + rng.next_normal() * 20.0;
424 Point::new([x.clamp(0.0, 1000.0), y.clamp(0.0, 1000.0)])
425 })
426 .collect()
427 }
428
429 fn brute_range<C: CoordType, const D: usize>(
430 pts: &[(Point<C, D>, EntryId)],
431 bbox: &BBox<C, D>,
432 ) -> Vec<EntryId> {
433 let mut ids: Vec<EntryId> = pts
434 .iter()
435 .filter(|(p, _)| bbox.contains_point(p))
436 .map(|(_, id)| *id)
437 .collect();
438 ids.sort_by_key(|id| id.0);
439 ids
440 }
441
442 #[test]
443 fn insert_and_len() {
444 let mut tree = KDTree::<u32, f64, 2>::new();
445 assert_eq!(tree.len(), 0);
446 tree.insert(Point::new([0.5, 0.5]), 1u32);
447 assert_eq!(tree.len(), 1);
448 tree.insert(Point::new([1.5, 1.5]), 2u32);
449 assert_eq!(tree.len(), 2);
450 }
451
452 #[test]
453 fn range_query_basic() {
454 let mut tree = KDTree::<u32, f64, 2>::new();
455 let id1 = tree.insert(Point::new([0.5, 0.5]), 1u32);
456 let id2 = tree.insert(Point::new([1.5, 1.5]), 2u32);
457 let _id3 = tree.insert(Point::new([5.0, 5.0]), 3u32);
458 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([2.0, 2.0]));
459 let mut got: Vec<EntryId> = tree
460 .range_query(&bbox)
461 .into_iter()
462 .map(|(id, _)| id)
463 .collect();
464 got.sort_by_key(|id| id.0);
465 assert_eq!(got, vec![id1, id2]);
466 }
467
468 #[test]
469 fn remove_works() {
470 let mut tree = KDTree::<u32, f64, 2>::new();
471 let id1 = tree.insert(Point::new([1.0, 1.0]), 10u32);
472 let id2 = tree.insert(Point::new([2.0, 2.0]), 20u32);
473 assert_eq!(tree.len(), 2);
474 assert_eq!(tree.remove(id1), Some(10u32));
475 assert_eq!(tree.len(), 1);
476 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([3.0, 3.0]));
477 let ids: Vec<EntryId> = tree
478 .range_query(&bbox)
479 .into_iter()
480 .map(|(id, _)| id)
481 .collect();
482 assert!(!ids.contains(&id1));
483 assert!(ids.contains(&id2));
484 }
485
486 #[test]
487 fn kind_is_kdtree() {
488 assert_eq!(KDTree::<u32, f64, 2>::new().kind(), BackendKind::KDTree);
489 }
490
491 #[test]
492 fn range_query_vs_brute_force_2d_10k() {
493 let pts = rand_pts_2d(10_000, 42);
494 let mut tree = KDTree::<usize, f64, 2>::new();
495 let mut pt_ids = Vec::new();
496 for (i, &p) in pts.iter().enumerate() {
497 let id = tree.insert(p, i);
498 pt_ids.push((p, id));
499 }
500 let bbox = BBox::new(Point::new([0.0, 0.0]), Point::new([500.0, 500.0]));
501 let mut got: Vec<EntryId> = tree
502 .range_query(&bbox)
503 .into_iter()
504 .map(|(id, _)| id)
505 .collect();
506 got.sort_by_key(|id| id.0);
507 let expected = brute_range(&pt_ids, &bbox);
508 assert_eq!(got, expected);
509 }
510
511 fn check_depth_bound<const D: usize>(n: usize, seed: u64)
512 where
513 [(); D]:,
514 {
515 let mut rng = Lcg::new(seed);
516 let mut tree = KDTree::<usize, f64, D>::new();
517 for i in 0..n {
518 let coords: [f64; D] = std::array::from_fn(|_| rng.next_f64() * 1000.0);
519 tree.insert(Point::new(coords), i);
520 }
521 let depth = tree.depth();
522 let bound = (n as f64).log2().ceil() as usize + 1;
523 assert!(
524 depth <= bound,
525 "D={D}: depth {depth} exceeds bound {bound} for n={n}"
526 );
527 }
528
529 #[test]
530 fn depth_bound_d2() {
531 check_depth_bound::<2>(1000, 1);
532 }
533 #[test]
534 fn depth_bound_d3() {
535 check_depth_bound::<3>(1000, 2);
536 }
537 #[test]
538 fn depth_bound_d4() {
539 check_depth_bound::<4>(1000, 3);
540 }
541 #[test]
542 fn depth_bound_d5() {
543 check_depth_bound::<5>(1000, 4);
544 }
545 #[test]
546 fn depth_bound_d6() {
547 check_depth_bound::<6>(1000, 5);
548 }
549 #[test]
550 fn depth_bound_d7() {
551 check_depth_bound::<7>(1000, 6);
552 }
553 #[test]
554 fn depth_bound_d8() {
555 check_depth_bound::<8>(1000, 7);
556 }
557
558 #[test]
559 fn depth_bound_clustered_2d() {
560 let pts = clustered_pts_2d(5000, 20, 99);
561 let mut tree = KDTree::<usize, f64, 2>::new();
562 for (i, &p) in pts.iter().enumerate() {
563 tree.insert(p, i);
564 }
565 let n = pts.len();
566 let depth = tree.depth();
567 let bound = (n as f64).log2().ceil() as usize + 1;
568 assert!(
569 depth <= bound,
570 "clustered: depth {depth} exceeds bound {bound} for n={n}"
571 );
572 }
573
574 fn pt2d() -> impl Strategy<Value = Point<f64, 2>> {
575 (0.0_f64..1000.0, 0.0_f64..1000.0).prop_map(|(x, y)| Point::new([x, y]))
576 }
577
578 fn bbox2d() -> impl Strategy<Value = BBox<f64, 2>> {
579 (
580 0.0_f64..900.0,
581 0.0_f64..900.0,
582 10.0_f64..200.0,
583 10.0_f64..200.0,
584 )
585 .prop_map(|(x, y, w, h)| BBox::new(Point::new([x, y]), Point::new([x + w, y + h])))
586 }
587
588 proptest! {
590 #![proptest_config(proptest::test_runner::Config {
591 cases: 100,
592 ..Default::default()
593 })]
594
595 #[test]
596 fn prop_kdtree_depth_bound(pts in prop::collection::vec(pt2d(), 2..200)) {
597 let mut tree = KDTree::<usize, f64, 2>::new();
598 for (i, p) in pts.iter().enumerate() {
599 tree.insert(*p, i);
600 }
601 let n = pts.len();
602 let depth = tree.depth();
603 let bound = (n as f64).log2().ceil() as usize + 1;
604 prop_assert!(depth <= bound, "depth {} exceeds bound {} for n={}", depth, bound, n);
605 }
606 }
607
608 proptest! {
610 #![proptest_config(proptest::test_runner::Config {
611 cases: 100,
612 ..Default::default()
613 })]
614
615 #[test]
616 fn prop_insert_remove_round_trip_kdtree(
617 pts in prop::collection::vec(pt2d(), 1..50),
618 remove_indices in prop::collection::vec(0usize..50, 0..25),
619 ) {
620 let mut tree = KDTree::<usize, f64, 2>::new();
621 let mut inserted: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
622 for (i, &p) in pts.iter().enumerate() {
623 let id = tree.insert(p, i);
624 inserted.push((p, id));
625 }
626 let mut removed_ids: Vec<EntryId> = Vec::new();
627 for &ri in &remove_indices {
628 let idx = ri % inserted.len();
629 let (_, id) = inserted[idx];
630 if !removed_ids.contains(&id) {
631 let result = tree.remove(id);
632 prop_assert!(result.is_some(), "remove returned None for inserted id");
633 removed_ids.push(id);
634 }
635 }
636 let full_bbox = BBox::new(Point::new([-1.0e9, -1.0e9]), Point::new([1.0e9, 1.0e9]));
637 let remaining_ids: Vec<EntryId> = tree.range_query(&full_bbox)
638 .into_iter()
639 .map(|(id, _)| id)
640 .collect();
641 for &removed_id in &removed_ids {
642 prop_assert!(
643 !remaining_ids.contains(&removed_id),
644 "removed entry {:?} still appears in range query",
645 removed_id
646 );
647 }
648 let expected_len = inserted.len() - removed_ids.len();
649 prop_assert_eq!(tree.len(), expected_len);
650 }
651 }
652
653 proptest! {
655 #![proptest_config(proptest::test_runner::Config {
656 cases: 100,
657 ..Default::default()
658 })]
659
660 #[test]
661 fn prop_range_query_oracle_kdtree(
662 pts in prop::collection::vec(pt2d(), 1..100),
663 bbox in bbox2d(),
664 ) {
665 let mut tree = KDTree::<usize, f64, 2>::new();
666 let mut pt_ids: Vec<(Point<f64, 2>, EntryId)> = Vec::new();
667 for (i, p) in pts.iter().enumerate() {
668 let id = tree.insert(*p, i);
669 pt_ids.push((*p, id));
670 }
671 let mut got: Vec<EntryId> =
672 tree.range_query(&bbox).into_iter().map(|(id, _)| id).collect();
673 got.sort_by_key(|id| id.0);
674 let expected = brute_range(&pt_ids, &bbox);
675 prop_assert_eq!(got, expected);
676 }
677 }
678}