Skip to main content

oxiphysics_collision/
kdtree_collision.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! KD-tree based spatial queries and collision detection.
5//!
6//! Provides a median-split KD-tree ([`KdTree`]) that supports nearest-neighbour
7//! search, k-nearest neighbours, radius queries, and self-collision pair
8//! detection.  A dynamic wrapper ([`KdTreeCollisionDetector`]) allows
9//! incremental insertions followed by bulk rebuilds.
10
11use std::collections::BinaryHeap;
12
13// ─────────────────────────────────────────────────────────────────────────────
14// Geometry helpers
15// ─────────────────────────────────────────────────────────────────────────────
16
17#[inline]
18fn dist_sq(a: &[f64; 3], b: &[f64; 3]) -> f64 {
19    (a[0] - b[0]).powi(2) + (a[1] - b[1]).powi(2) + (a[2] - b[2]).powi(2)
20}
21
22// ─────────────────────────────────────────────────────────────────────────────
23// Aabb3
24// ─────────────────────────────────────────────────────────────────────────────
25
26/// Axis-aligned bounding box in 3-D.
27#[derive(Debug, Clone, PartialEq)]
28pub struct Aabb3 {
29    /// Minimum corner `[x, y, z]`.
30    pub min: [f64; 3],
31    /// Maximum corner `[x, y, z]`.
32    pub max: [f64; 3],
33}
34
35impl Aabb3 {
36    /// Create a degenerate (point) AABB at the origin.
37    pub fn empty() -> Self {
38        Self {
39            min: [f64::INFINITY; 3],
40            max: [f64::NEG_INFINITY; 3],
41        }
42    }
43
44    /// Create an AABB from explicit min/max corners.
45    pub fn new(min: [f64; 3], max: [f64; 3]) -> Self {
46        Self { min, max }
47    }
48
49    /// Expand this AABB to contain `point`.
50    pub fn expand(&mut self, point: &[f64; 3]) {
51        self.min.iter_mut().zip(point.iter()).for_each(|(m, &p)| {
52            if p < *m {
53                *m = p;
54            }
55        });
56        self.max.iter_mut().zip(point.iter()).for_each(|(m, &p)| {
57            if p > *m {
58                *m = p;
59            }
60        });
61    }
62
63    /// Return `true` if this AABB overlaps `other`.
64    pub fn overlaps(&self, other: &Self) -> bool {
65        self.min
66            .iter()
67            .zip(self.max.iter())
68            .zip(other.min.iter().zip(other.max.iter()))
69            .all(|((&mn, &mx), (&omn, &omx))| mn <= omx && mx >= omn)
70    }
71
72    /// Return `true` if `point` is inside or on the boundary of this AABB.
73    pub fn contains_point(&self, point: &[f64; 3]) -> bool {
74        self.min
75            .iter()
76            .zip(self.max.iter())
77            .zip(point.iter())
78            .all(|((&mn, &mx), &p)| p >= mn && p <= mx)
79    }
80
81    /// Return the squared minimum distance from `point` to this AABB.
82    pub fn min_dist_sq(&self, point: &[f64; 3]) -> f64 {
83        self.min
84            .iter()
85            .zip(self.max.iter())
86            .zip(point.iter())
87            .map(|((&mn, &mx), &p)| {
88                if p < mn {
89                    (mn - p).powi(2)
90                } else if p > mx {
91                    (p - mx).powi(2)
92                } else {
93                    0.0
94                }
95            })
96            .sum()
97    }
98
99    /// Compute the AABB that encloses a slice of points.
100    pub fn from_points(pts: &[[f64; 3]]) -> Self {
101        let mut aabb = Self::empty();
102        for p in pts {
103            aabb.expand(p);
104        }
105        aabb
106    }
107}
108
109// ─────────────────────────────────────────────────────────────────────────────
110// KdPoint trait
111// ─────────────────────────────────────────────────────────────────────────────
112
113/// Any type that can act as a 3-D point in a KD-tree.
114pub trait KdPoint {
115    /// Return the 3-D position of this point.
116    fn position(&self) -> [f64; 3];
117}
118
119impl KdPoint for [f64; 3] {
120    fn position(&self) -> [f64; 3] {
121        *self
122    }
123}
124
125// ─────────────────────────────────────────────────────────────────────────────
126// KdNode
127// ─────────────────────────────────────────────────────────────────────────────
128
129/// A node in a KD-tree.
130#[derive(Debug)]
131pub enum KdNode {
132    /// A leaf holding a small bucket of point indices.
133    Leaf {
134        /// Indices into the owning [`KdTree`]'s points array.
135        indices: Vec<usize>,
136        /// Tight AABB around the leaf points.
137        aabb: Aabb3,
138    },
139    /// An internal split node.
140    Internal {
141        /// Axis along which the split was made (0 = X, 1 = Y, 2 = Z).
142        split_dim: usize,
143        /// The split value (median coordinate along `split_dim`).
144        split_val: f64,
145        /// Left subtree (values ≤ `split_val`).
146        left: Box<KdNode>,
147        /// Right subtree (values > `split_val`).
148        right: Box<KdNode>,
149        /// Tight AABB for all points in this subtree.
150        aabb: Aabb3,
151    },
152}
153
154impl KdNode {
155    /// Return the AABB of this node.
156    pub fn aabb(&self) -> &Aabb3 {
157        match self {
158            KdNode::Leaf { aabb, .. } => aabb,
159            KdNode::Internal { aabb, .. } => aabb,
160        }
161    }
162}
163
164// ─────────────────────────────────────────────────────────────────────────────
165// KdTree
166// ─────────────────────────────────────────────────────────────────────────────
167
168/// Maximum number of points per leaf bucket.
169const LEAF_SIZE: usize = 8;
170
171/// KD-tree for 3-D point sets.
172#[derive(Debug)]
173pub struct KdTree {
174    /// The root node of the tree (or `None` when the tree is empty).
175    pub root: Option<KdNode>,
176    /// The point coordinates stored in the tree.
177    pub points: Vec<[f64; 3]>,
178}
179
180impl KdTree {
181    /// Build a KD-tree from a set of 3-D points using median splits.
182    pub fn build(points: Vec<[f64; 3]>) -> Self {
183        if points.is_empty() {
184            return Self { root: None, points };
185        }
186        let n = points.len();
187        let mut indices: Vec<usize> = (0..n).collect();
188        let root = Some(build_node(&points, &mut indices));
189        Self { root, points }
190    }
191
192    /// Find the single nearest neighbour to `query`.
193    ///
194    /// Returns `Some((index, dist_sq))` or `None` when the tree is empty.
195    pub fn nearest_neighbor(&self, query: &[f64; 3]) -> Option<(usize, f64)> {
196        let root = self.root.as_ref()?;
197        let mut best = (usize::MAX, f64::INFINITY);
198        nn_search(root, query, &self.points, &mut best);
199        if best.0 == usize::MAX {
200            None
201        } else {
202            Some(best)
203        }
204    }
205
206    /// Return the `k` nearest neighbours sorted by ascending squared distance.
207    pub fn k_nearest(&self, query: &[f64; 3], k: usize) -> Vec<(usize, f64)> {
208        if k == 0 {
209            return vec![];
210        }
211        let root = match &self.root {
212            Some(r) => r,
213            None => return vec![],
214        };
215        // Max-heap ordered by dist_sq (negated for a min-heap behaviour).
216        // We use OrderedFloat-like wrapper via std BinaryHeap<(OrdF64, usize)>.
217        let mut heap: BinaryHeap<OrdF64Pair> = BinaryHeap::new();
218        knn_search(root, query, &self.points, k, &mut heap);
219        let mut result: Vec<(usize, f64)> = heap.into_iter().map(|p| (p.idx, p.dist_sq)).collect();
220        result.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
221        result
222    }
223
224    /// Return all point indices whose distance to `center` is ≤ `radius`.
225    pub fn range_query(&self, center: &[f64; 3], radius: f64) -> Vec<usize> {
226        let root = match &self.root {
227            Some(r) => r,
228            None => return vec![],
229        };
230        let r2 = radius * radius;
231        let mut result = Vec::new();
232        range_search(root, center, r2, &self.points, &mut result);
233        result
234    }
235
236    /// Find all pairs `(i, j)` with `i < j` where the distance between
237    /// points `i` and `j` is ≤ `radius`.
238    pub fn self_collision_pairs(&self, radius: f64) -> Vec<(usize, usize)> {
239        let _n = self.points.len();
240        let r2 = radius * radius;
241        let mut pairs = Vec::new();
242        for (i, pt) in self.points.iter().enumerate() {
243            let candidates = self.range_query(pt, radius);
244            for j in candidates {
245                if j > i && dist_sq(pt, &self.points[j]) <= r2 {
246                    pairs.push((i, j));
247                }
248            }
249        }
250        pairs.sort_unstable();
251        pairs
252    }
253}
254
255// ─────────────────────────────────────────────────────────────────────────────
256// Tree-building internals
257// ─────────────────────────────────────────────────────────────────────────────
258
259fn build_node(points: &[[f64; 3]], indices: &mut [usize]) -> KdNode {
260    let aabb = Aabb3::from_points(&indices.iter().map(|&i| points[i]).collect::<Vec<_>>());
261
262    if indices.len() <= LEAF_SIZE {
263        return KdNode::Leaf {
264            indices: indices.to_vec(),
265            aabb,
266        };
267    }
268
269    // Choose split axis: the axis with the widest span (sliding midpoint
270    // strategy).
271    let mut split_dim = 0;
272    let mut max_span = aabb.max[0] - aabb.min[0];
273    for d in 1..3 {
274        let span = aabb.max[d] - aabb.min[d];
275        if span > max_span {
276            max_span = span;
277            split_dim = d;
278        }
279    }
280
281    // Median split.
282    let mid = indices.len() / 2;
283    indices.select_nth_unstable_by(mid, |&a, &b| {
284        points[a][split_dim]
285            .partial_cmp(&points[b][split_dim])
286            .unwrap_or(std::cmp::Ordering::Equal)
287    });
288    let split_val = points[indices[mid]][split_dim];
289
290    let (left_idx, right_idx) = indices.split_at_mut(mid);
291    let left = Box::new(build_node(points, left_idx));
292    let right = Box::new(build_node(points, right_idx));
293
294    KdNode::Internal {
295        split_dim,
296        split_val,
297        left,
298        right,
299        aabb,
300    }
301}
302
303// ─────────────────────────────────────────────────────────────────────────────
304// Search routines
305// ─────────────────────────────────────────────────────────────────────────────
306
307fn nn_search(node: &KdNode, query: &[f64; 3], points: &[[f64; 3]], best: &mut (usize, f64)) {
308    match node {
309        KdNode::Leaf { indices, .. } => {
310            for &i in indices {
311                let d = dist_sq(query, &points[i]);
312                if d < best.1 {
313                    *best = (i, d);
314                }
315            }
316        }
317        KdNode::Internal {
318            split_dim,
319            split_val,
320            left,
321            right,
322            ..
323        } => {
324            let go_left = query[*split_dim] <= *split_val;
325            let (near, far) = if go_left {
326                (left.as_ref(), right.as_ref())
327            } else {
328                (right.as_ref(), left.as_ref())
329            };
330            nn_search(near, query, points, best);
331            // Prune: only visit the far side if its AABB might beat the current best.
332            if far.aabb().min_dist_sq(query) < best.1 {
333                nn_search(far, query, points, best);
334            }
335        }
336    }
337}
338
339/// Pair stored in max-heap for k-NN (largest dist_sq at top).
340struct OrdF64Pair {
341    dist_sq: f64,
342    idx: usize,
343}
344
345impl PartialEq for OrdF64Pair {
346    fn eq(&self, other: &Self) -> bool {
347        self.dist_sq == other.dist_sq
348    }
349}
350impl Eq for OrdF64Pair {}
351impl PartialOrd for OrdF64Pair {
352    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
353        Some(self.cmp(other))
354    }
355}
356impl Ord for OrdF64Pair {
357    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
358        self.dist_sq
359            .partial_cmp(&other.dist_sq)
360            .unwrap_or(std::cmp::Ordering::Equal)
361    }
362}
363
364fn knn_search(
365    node: &KdNode,
366    query: &[f64; 3],
367    points: &[[f64; 3]],
368    k: usize,
369    heap: &mut BinaryHeap<OrdF64Pair>,
370) {
371    // Prune: if the heap is full and the AABB is farther than the worst
372    // current best, skip.
373    let worst = heap.peek().map(|p| p.dist_sq).unwrap_or(f64::INFINITY);
374    if node.aabb().min_dist_sq(query) >= worst && heap.len() >= k {
375        return;
376    }
377    match node {
378        KdNode::Leaf { indices, .. } => {
379            for &i in indices {
380                let d = dist_sq(query, &points[i]);
381                if heap.len() < k {
382                    heap.push(OrdF64Pair { dist_sq: d, idx: i });
383                } else if heap.peek().is_none_or(|top| d < top.dist_sq) {
384                    heap.pop();
385                    heap.push(OrdF64Pair { dist_sq: d, idx: i });
386                }
387            }
388        }
389        KdNode::Internal {
390            split_dim,
391            split_val,
392            left,
393            right,
394            ..
395        } => {
396            let go_left = query[*split_dim] <= *split_val;
397            let (near, far) = if go_left {
398                (left.as_ref(), right.as_ref())
399            } else {
400                (right.as_ref(), left.as_ref())
401            };
402            knn_search(near, query, points, k, heap);
403            knn_search(far, query, points, k, heap);
404        }
405    }
406}
407
408fn range_search(
409    node: &KdNode,
410    center: &[f64; 3],
411    r2: f64,
412    points: &[[f64; 3]],
413    result: &mut Vec<usize>,
414) {
415    if node.aabb().min_dist_sq(center) > r2 {
416        return;
417    }
418    match node {
419        KdNode::Leaf { indices, .. } => {
420            for &i in indices {
421                if dist_sq(center, &points[i]) <= r2 {
422                    result.push(i);
423                }
424            }
425        }
426        KdNode::Internal { left, right, .. } => {
427            range_search(left, center, r2, points, result);
428            range_search(right, center, r2, points, result);
429        }
430    }
431}
432
433// ─────────────────────────────────────────────────────────────────────────────
434// KdTreeCollisionDetector
435// ─────────────────────────────────────────────────────────────────────────────
436
437/// Dynamic KD-tree collision detector that supports incremental insertion and
438/// bulk rebuild.
439#[derive(Debug)]
440pub struct KdTreeCollisionDetector {
441    /// Points collected since the last rebuild.
442    pending: Vec<([f64; 3], usize)>,
443    /// The most recently built KD-tree (may be stale if points were added
444    /// after the last [`Self::rebuild`] call).
445    tree: KdTree,
446    /// User-supplied IDs for each point in the tree.
447    ids: Vec<usize>,
448}
449
450impl KdTreeCollisionDetector {
451    /// Create a new detector with a capacity hint.
452    pub fn new(capacity: usize) -> Self {
453        Self {
454            pending: Vec::with_capacity(capacity),
455            tree: KdTree::build(vec![]),
456            ids: Vec::with_capacity(capacity),
457        }
458    }
459
460    /// Insert a point at `pos` with user-supplied `id`.
461    pub fn insert(&mut self, pos: [f64; 3], id: usize) {
462        self.pending.push((pos, id));
463    }
464
465    /// Rebuild the internal KD-tree from all inserted points.
466    ///
467    /// Returns `&mut Self` for chaining.
468    pub fn rebuild(&mut self) -> &mut Self {
469        let (positions, ids): (Vec<[f64; 3]>, Vec<usize>) = self.pending.iter().cloned().unzip();
470        self.ids = ids;
471        self.tree = KdTree::build(positions);
472        self
473    }
474
475    /// Query all points within `r` of `pos`.
476    ///
477    /// Returns user-supplied IDs.  The tree must have been rebuilt after the
478    /// last insertion for results to be accurate.
479    pub fn query_radius(&self, pos: &[f64; 3], r: f64) -> Vec<usize> {
480        self.tree
481            .range_query(pos, r)
482            .into_iter()
483            .map(|tree_idx| self.ids[tree_idx])
484            .collect()
485    }
486
487    /// Return the number of inserted points.
488    pub fn len(&self) -> usize {
489        self.pending.len()
490    }
491
492    /// Return `true` if no points have been inserted.
493    pub fn is_empty(&self) -> bool {
494        self.pending.is_empty()
495    }
496}
497
498// ─────────────────────────────────────────────────────────────────────────────
499// BVH-style batch leaf grouping
500// ─────────────────────────────────────────────────────────────────────────────
501
502/// A flat list of AABB leaf groups for batch broadphase queries.
503///
504/// Constructed by collecting the leaf AABBs from a [`KdTree`].
505#[derive(Debug, Clone)]
506pub struct BvhLeafGroups {
507    /// Each entry is the AABB of one leaf in the KD-tree.
508    pub groups: Vec<Aabb3>,
509}
510
511impl BvhLeafGroups {
512    /// Extract leaf AABBs from a [`KdTree`].
513    pub fn from_tree(tree: &KdTree) -> Self {
514        let mut groups = Vec::new();
515        if let Some(root) = &tree.root {
516            collect_leaf_aabbs(root, &mut groups);
517        }
518        Self { groups }
519    }
520
521    /// Return the indices of leaf groups whose AABB overlaps the query sphere.
522    pub fn query_sphere(&self, center: &[f64; 3], radius: f64) -> Vec<usize> {
523        let r2 = radius * radius;
524        self.groups
525            .iter()
526            .enumerate()
527            .filter(|(_, aabb)| aabb.min_dist_sq(center) <= r2)
528            .map(|(i, _)| i)
529            .collect()
530    }
531}
532
533fn collect_leaf_aabbs(node: &KdNode, out: &mut Vec<Aabb3>) {
534    match node {
535        KdNode::Leaf { aabb, .. } => out.push(aabb.clone()),
536        KdNode::Internal { left, right, .. } => {
537            collect_leaf_aabbs(left, out);
538            collect_leaf_aabbs(right, out);
539        }
540    }
541}
542
543// ─────────────────────────────────────────────────────────────────────────────
544// Tests
545// ─────────────────────────────────────────────────────────────────────────────
546
547#[cfg(test)]
548mod tests {
549    use super::*;
550
551    // ── helpers ────────────────────────────────────────────────────────────
552
553    fn grid_points(n: usize) -> Vec<[f64; 3]> {
554        let side = (n as f64).cbrt().ceil() as usize;
555        let mut pts = Vec::new();
556        'outer: for x in 0..side {
557            for y in 0..side {
558                for z in 0..side {
559                    pts.push([x as f64, y as f64, z as f64]);
560                    if pts.len() == n {
561                        break 'outer;
562                    }
563                }
564            }
565        }
566        pts
567    }
568
569    // ── Aabb3 ──────────────────────────────────────────────────────────────
570
571    #[test]
572    fn test_aabb_empty_is_inverted() {
573        let aabb = Aabb3::empty();
574        assert!(aabb.min[0] > aabb.max[0]);
575    }
576
577    #[test]
578    fn test_aabb_expand() {
579        let mut aabb = Aabb3::empty();
580        aabb.expand(&[1.0, 2.0, 3.0]);
581        aabb.expand(&[-1.0, 0.0, 5.0]);
582        assert_eq!(aabb.min, [-1.0, 0.0, 3.0]);
583        assert_eq!(aabb.max, [1.0, 2.0, 5.0]);
584    }
585
586    #[test]
587    fn test_aabb_overlaps_true() {
588        let a = Aabb3::new([0.0; 3], [2.0; 3]);
589        let b = Aabb3::new([1.0; 3], [3.0; 3]);
590        assert!(a.overlaps(&b));
591    }
592
593    #[test]
594    fn test_aabb_overlaps_false() {
595        let a = Aabb3::new([0.0; 3], [1.0; 3]);
596        let b = Aabb3::new([2.0; 3], [3.0; 3]);
597        assert!(!a.overlaps(&b));
598    }
599
600    #[test]
601    fn test_aabb_contains_point() {
602        let aabb = Aabb3::new([0.0; 3], [1.0; 3]);
603        assert!(aabb.contains_point(&[0.5, 0.5, 0.5]));
604        assert!(!aabb.contains_point(&[1.5, 0.5, 0.5]));
605    }
606
607    #[test]
608    fn test_aabb_min_dist_sq_inside() {
609        let aabb = Aabb3::new([0.0; 3], [1.0; 3]);
610        assert_eq!(aabb.min_dist_sq(&[0.5, 0.5, 0.5]), 0.0);
611    }
612
613    #[test]
614    fn test_aabb_min_dist_sq_outside() {
615        let aabb = Aabb3::new([0.0; 3], [1.0; 3]);
616        let d = aabb.min_dist_sq(&[2.0, 0.5, 0.5]);
617        assert!((d - 1.0).abs() < 1e-12);
618    }
619
620    #[test]
621    fn test_aabb_from_points() {
622        let pts = vec![[1.0, 2.0, 3.0], [-1.0, 0.0, 5.0], [0.0, 4.0, -1.0]];
623        let aabb = Aabb3::from_points(&pts);
624        assert_eq!(aabb.min, [-1.0, 0.0, -1.0]);
625        assert_eq!(aabb.max, [1.0, 4.0, 5.0]);
626    }
627
628    // ── KdTree build & basics ───────────────────────────────────────────────
629
630    #[test]
631    fn test_kdtree_empty() {
632        let tree = KdTree::build(vec![]);
633        assert!(tree.root.is_none());
634        assert!(tree.nearest_neighbor(&[0.0; 3]).is_none());
635    }
636
637    #[test]
638    fn test_kdtree_single_point() {
639        let tree = KdTree::build(vec![[1.0, 2.0, 3.0]]);
640        let nn = tree.nearest_neighbor(&[0.0; 3]).unwrap();
641        assert_eq!(nn.0, 0);
642    }
643
644    #[test]
645    fn test_kdtree_build_grid() {
646        let pts = grid_points(64);
647        let tree = KdTree::build(pts);
648        assert!(tree.root.is_some());
649    }
650
651    // ── nearest_neighbor ───────────────────────────────────────────────────
652
653    #[test]
654    fn test_nn_exact_match() {
655        let pts = vec![[0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [2.0, 0.0, 0.0]];
656        let tree = KdTree::build(pts);
657        let (idx, d) = tree.nearest_neighbor(&[1.0, 0.0, 0.0]).unwrap();
658        assert_eq!(idx, 1);
659        assert!(d < 1e-12);
660    }
661
662    #[test]
663    fn test_nn_closest_of_many() {
664        let pts = grid_points(27);
665        let tree = KdTree::build(pts.clone());
666        let query = [1.1, 1.1, 1.1];
667        let (idx, _) = tree.nearest_neighbor(&query).unwrap();
668        // Brute-force check
669        let bf_idx = pts
670            .iter()
671            .enumerate()
672            .min_by(|(_, a), (_, b)| dist_sq(a, &query).partial_cmp(&dist_sq(b, &query)).unwrap())
673            .unwrap()
674            .0;
675        assert_eq!(idx, bf_idx);
676    }
677
678    #[test]
679    fn test_nn_large_set() {
680        let pts: Vec<[f64; 3]> = (0..500).map(|i| [i as f64 * 0.1, 0.0, 0.0]).collect();
681        let tree = KdTree::build(pts);
682        let (idx, _) = tree.nearest_neighbor(&[25.05, 0.0, 0.0]).unwrap();
683        // Nearest should be around index 250.
684        assert!((248..=252).contains(&idx));
685    }
686
687    // ── k_nearest ──────────────────────────────────────────────────────────
688
689    #[test]
690    fn test_knn_k_zero() {
691        let pts = grid_points(8);
692        let tree = KdTree::build(pts);
693        assert!(tree.k_nearest(&[0.0; 3], 0).is_empty());
694    }
695
696    #[test]
697    fn test_knn_k_equals_n() {
698        let pts = grid_points(8);
699        let n = pts.len();
700        let tree = KdTree::build(pts);
701        let result = tree.k_nearest(&[0.0; 3], n);
702        assert_eq!(result.len(), n);
703    }
704
705    #[test]
706    fn test_knn_sorted_ascending() {
707        let pts = grid_points(27);
708        let tree = KdTree::build(pts);
709        let result = tree.k_nearest(&[1.5, 1.5, 1.5], 5);
710        for w in result.windows(2) {
711            assert!(w[0].1 <= w[1].1);
712        }
713    }
714
715    #[test]
716    fn test_knn_matches_brute_force() {
717        let pts: Vec<[f64; 3]> = (0..30).map(|i| [i as f64, 0.0, 0.0]).collect();
718        let query = [14.3, 0.0, 0.0];
719        let tree = KdTree::build(pts.clone());
720        let knn = tree.k_nearest(&query, 3);
721        // Brute force top-3
722        let mut bf: Vec<(usize, f64)> = pts
723            .iter()
724            .enumerate()
725            .map(|(i, p)| (i, dist_sq(p, &query)))
726            .collect();
727        bf.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
728        bf.truncate(3);
729        let knn_idxs: Vec<usize> = knn.iter().map(|&(i, _)| i).collect();
730        let bf_idxs: Vec<usize> = bf.iter().map(|&(i, _)| i).collect();
731        assert_eq!(knn_idxs, bf_idxs);
732    }
733
734    // ── range_query ─────────────────────────────────────────────────────────
735
736    #[test]
737    fn test_range_query_empty_tree() {
738        let tree = KdTree::build(vec![]);
739        assert!(tree.range_query(&[0.0; 3], 1.0).is_empty());
740    }
741
742    #[test]
743    fn test_range_query_all_in_radius() {
744        let pts = vec![[0.0, 0.0, 0.0], [0.1, 0.0, 0.0], [0.0, 0.1, 0.0]];
745        let tree = KdTree::build(pts.clone());
746        let mut result = tree.range_query(&[0.05, 0.05, 0.0], 1.0);
747        result.sort_unstable();
748        assert_eq!(result, vec![0, 1, 2]);
749    }
750
751    #[test]
752    fn test_range_query_none_in_radius() {
753        let pts = vec![[10.0, 0.0, 0.0], [20.0, 0.0, 0.0]];
754        let tree = KdTree::build(pts);
755        assert!(tree.range_query(&[0.0; 3], 5.0).is_empty());
756    }
757
758    #[test]
759    fn test_range_query_matches_brute_force() {
760        let pts = grid_points(64);
761        let tree = KdTree::build(pts.clone());
762        let center = [2.5, 2.5, 2.5];
763        let r = 1.8;
764        let r2 = r * r;
765        let mut kd_result = tree.range_query(&center, r);
766        kd_result.sort_unstable();
767        let mut bf: Vec<usize> = pts
768            .iter()
769            .enumerate()
770            .filter(|(_, p)| dist_sq(p, &center) <= r2)
771            .map(|(i, _)| i)
772            .collect();
773        bf.sort_unstable();
774        assert_eq!(kd_result, bf);
775    }
776
777    // ── self_collision_pairs ─────────────────────────────────────────────────
778
779    #[test]
780    fn test_self_collision_no_pairs_far_apart() {
781        let pts = vec![[0.0, 0.0, 0.0], [100.0, 0.0, 0.0], [200.0, 0.0, 0.0]];
782        let tree = KdTree::build(pts);
783        assert!(tree.self_collision_pairs(1.0).is_empty());
784    }
785
786    #[test]
787    fn test_self_collision_all_close() {
788        let pts = vec![[0.0; 3], [0.1, 0.0, 0.0], [0.0, 0.1, 0.0]];
789        let tree = KdTree::build(pts);
790        let pairs = tree.self_collision_pairs(0.2);
791        // Expect 3 pairs: (0,1), (0,2), (1,2)
792        assert_eq!(pairs.len(), 3);
793    }
794
795    #[test]
796    fn test_self_collision_pairs_ordered() {
797        let pts = grid_points(16);
798        let tree = KdTree::build(pts);
799        let pairs = tree.self_collision_pairs(1.5);
800        for &(a, b) in &pairs {
801            assert!(a < b);
802        }
803    }
804
805    // ── KdTreeCollisionDetector ─────────────────────────────────────────────
806
807    #[test]
808    fn test_detector_empty() {
809        let det = KdTreeCollisionDetector::new(10);
810        assert!(det.is_empty());
811        assert_eq!(det.len(), 0);
812    }
813
814    #[test]
815    fn test_detector_insert_and_rebuild() {
816        let mut det = KdTreeCollisionDetector::new(4);
817        det.insert([0.0, 0.0, 0.0], 10);
818        det.insert([1.0, 0.0, 0.0], 20);
819        det.rebuild();
820        assert_eq!(det.len(), 2);
821    }
822
823    #[test]
824    fn test_detector_query_radius() {
825        let mut det = KdTreeCollisionDetector::new(5);
826        det.insert([0.0, 0.0, 0.0], 1);
827        det.insert([0.5, 0.0, 0.0], 2);
828        det.insert([10.0, 0.0, 0.0], 3);
829        det.rebuild();
830        let result = det.query_radius(&[0.0; 3], 1.0);
831        assert!(result.contains(&1));
832        assert!(result.contains(&2));
833        assert!(!result.contains(&3));
834    }
835
836    #[test]
837    fn test_detector_rebuild_chaining() {
838        let mut det = KdTreeCollisionDetector::new(4);
839        det.insert([0.0; 3], 0);
840        det.rebuild();
841        let r = det.query_radius(&[0.0; 3], 0.1);
842        assert_eq!(r, vec![0]);
843    }
844
845    // ── BvhLeafGroups ───────────────────────────────────────────────────────
846
847    #[test]
848    fn test_bvh_leaf_groups_from_empty_tree() {
849        let tree = KdTree::build(vec![]);
850        let groups = BvhLeafGroups::from_tree(&tree);
851        assert!(groups.groups.is_empty());
852    }
853
854    #[test]
855    fn test_bvh_leaf_groups_non_empty() {
856        let pts = grid_points(32);
857        let tree = KdTree::build(pts);
858        let groups = BvhLeafGroups::from_tree(&tree);
859        assert!(!groups.groups.is_empty());
860    }
861
862    #[test]
863    fn test_bvh_leaf_groups_query_sphere() {
864        let pts = grid_points(32);
865        let tree = KdTree::build(pts);
866        let groups = BvhLeafGroups::from_tree(&tree);
867        let hit = groups.query_sphere(&[1.5, 1.5, 1.5], 2.0);
868        assert!(!hit.is_empty());
869    }
870
871    #[test]
872    fn test_bvh_leaf_query_sphere_far_away() {
873        let pts = grid_points(32);
874        let tree = KdTree::build(pts);
875        let groups = BvhLeafGroups::from_tree(&tree);
876        // Query far from all points.
877        let hit = groups.query_sphere(&[1000.0, 1000.0, 1000.0], 0.1);
878        assert!(hit.is_empty());
879    }
880}