Skip to main content

oxiphysics_gpu/kernels/
broadphase.rs

1#![allow(clippy::ptr_arg)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! Broadphase AABB kernels for parallel overlap detection.
6
7#![allow(dead_code)]
8
9use crate::compute::ComputeKernel;
10use std::collections::HashMap;
11
12// ---------------------------------------------------------------------------
13// Legacy f64 kernels (keep existing API)
14// ---------------------------------------------------------------------------
15
16/// Kernel that sorts AABBs along the X-axis and outputs sorted indices.
17///
18/// **Input layout** (flat f64 array, 6 values per object):
19///   `[min_x, max_x, min_y, max_y, min_z, max_z, ...]`
20///
21/// **Output\[0\]**: sorted indices (as f64, cast back to usize by caller).
22pub struct AabbSortKernel;
23
24impl ComputeKernel for AabbSortKernel {
25    fn name(&self) -> &str {
26        "AabbSortKernel"
27    }
28
29    fn execute(&self, inputs: &[&[f64]], outputs: &mut [Vec<f64>], _work_size: usize) {
30        if inputs.is_empty() || outputs.is_empty() {
31            return;
32        }
33        let aabbs = inputs[0];
34        let n = aabbs.len() / 6;
35        // Build (index, min_x) pairs and sort by min_x.
36        let mut indices: Vec<usize> = (0..n).collect();
37        indices.sort_by(|&a, &b| {
38            let ax = aabbs[a * 6];
39            let bx = aabbs[b * 6];
40            ax.partial_cmp(&bx).unwrap_or(std::cmp::Ordering::Equal)
41        });
42        outputs[0] = indices.iter().map(|&i| i as f64).collect();
43    }
44}
45
46/// Kernel that detects overlapping AABB pairs.
47///
48/// **Input layout** (flat f64 array, 6 values per object):
49///   `[min_x, max_x, min_y, max_y, min_z, max_z, ...]`
50///
51/// **Output\[0\]**: flat pairs `[i, j, i, j, ...]` of overlapping indices (as f64).
52pub struct AabbOverlapKernel;
53
54impl AabbOverlapKernel {
55    #[inline]
56    fn overlaps(a: &[f64], b: &[f64]) -> bool {
57        // a, b are slices of length 6: [min_x, max_x, min_y, max_y, min_z, max_z]
58        a[0] <= b[1] && a[1] >= b[0] // x overlap
59        && a[2] <= b[3] && a[3] >= b[2] // y overlap
60        && a[4] <= b[5] && a[5] >= b[4] // z overlap
61    }
62}
63
64#[allow(clippy::needless_range_loop)]
65impl ComputeKernel for AabbOverlapKernel {
66    fn name(&self) -> &str {
67        "AabbOverlapKernel"
68    }
69
70    fn execute(&self, inputs: &[&[f64]], outputs: &mut [Vec<f64>], _work_size: usize) {
71        if inputs.is_empty() || outputs.is_empty() {
72            return;
73        }
74        let aabbs = inputs[0];
75        let n = aabbs.len() / 6;
76        let mut pairs = Vec::new();
77        for i in 0..n {
78            for j in (i + 1)..n {
79                if Self::overlaps(&aabbs[i * 6..(i + 1) * 6], &aabbs[j * 6..(j + 1) * 6]) {
80                    pairs.push(i as f64);
81                    pairs.push(j as f64);
82                }
83            }
84        }
85        outputs[0] = pairs;
86    }
87}
88
89// ---------------------------------------------------------------------------
90// GPU-compatible f32 AABB
91// ---------------------------------------------------------------------------
92
93/// Axis-aligned bounding box using f32 coordinates for GPU compatibility.
94#[derive(Debug, Clone, Copy, PartialEq)]
95pub struct AabbGpu {
96    /// Minimum corner `[min_x, min_y, min_z]`.
97    pub min: [f32; 3],
98    /// Maximum corner `[max_x, max_y, max_z]`.
99    pub max: [f32; 3],
100    /// Index of the owning rigid body.
101    pub body_id: u32,
102}
103
104impl AabbGpu {
105    /// Create a new GPU AABB.
106    pub fn new(min: [f32; 3], max: [f32; 3], body_id: u32) -> Self {
107        Self { min, max, body_id }
108    }
109
110    /// Test whether two AABBs overlap in all three axes.
111    #[inline]
112    pub fn overlaps(&self, other: &AabbGpu) -> bool {
113        self.min[0] <= other.max[0]
114            && self.max[0] >= other.min[0]
115            && self.min[1] <= other.max[1]
116            && self.max[1] >= other.min[1]
117            && self.min[2] <= other.max[2]
118            && self.max[2] >= other.min[2]
119    }
120}
121
122// ---------------------------------------------------------------------------
123// Sort-and-Sweep (GPU)
124// ---------------------------------------------------------------------------
125
126/// Sort-and-sweep broadphase operating on `AabbGpu` values.
127///
128/// Bodies are sorted by their minimum X coordinate, then swept to find all
129/// overlapping pairs in O(n log n + k) time where k is the number of pairs.
130pub struct SortAndSweepGpu;
131
132impl SortAndSweepGpu {
133    /// Detect all overlapping AABB pairs.
134    ///
135    /// Returns a `Vec` of `(body_id_a, body_id_b)` pairs where the two AABBs
136    /// overlap in all three axes.  Each pair is reported exactly once with
137    /// `body_id_a < body_id_b`.
138    pub fn detect_pairs(aabbs: &[AabbGpu]) -> Vec<(u32, u32)> {
139        if aabbs.is_empty() {
140            return Vec::new();
141        }
142
143        // Sort by min_x.
144        let mut sorted: Vec<&AabbGpu> = aabbs.iter().collect();
145        sorted.sort_by(|a, b| {
146            a.min[0]
147                .partial_cmp(&b.min[0])
148                .unwrap_or(std::cmp::Ordering::Equal)
149        });
150
151        let n = sorted.len();
152        let mut pairs = Vec::new();
153
154        for i in 0..n {
155            for j in (i + 1)..n {
156                // Early-out: if the next AABB's min_x exceeds current max_x, no further overlaps
157                if sorted[j].min[0] > sorted[i].max[0] {
158                    break;
159                }
160                if sorted[i].overlaps(sorted[j]) {
161                    let a = sorted[i].body_id;
162                    let b = sorted[j].body_id;
163                    let pair = if a < b { (a, b) } else { (b, a) };
164                    pairs.push(pair);
165                }
166            }
167        }
168
169        pairs
170    }
171}
172
173// ---------------------------------------------------------------------------
174// Uniform Grid (GPU)
175// ---------------------------------------------------------------------------
176
177/// Uniform 3-D grid for GPU broadphase.
178#[derive(Debug, Clone)]
179pub struct UniformGridGpu {
180    /// Cell size (same in all dimensions).
181    pub cell_size: f32,
182    /// World-space origin of the grid `[ox, oy, oz]`.
183    pub origin: [f32; 3],
184    /// Number of cells in each dimension `[nx, ny, nz]`.
185    pub dims: [u32; 3],
186}
187
188impl UniformGridGpu {
189    /// Create a new uniform grid.
190    pub fn new(cell_size: f32, origin: [f32; 3], dims: [u32; 3]) -> Self {
191        Self {
192            cell_size,
193            origin,
194            dims,
195        }
196    }
197
198    /// Compute the cell index `[ix, iy, iz]` for a world-space position.
199    pub fn cell_of(&self, pos: [f32; 3]) -> [i32; 3] {
200        [
201            ((pos[0] - self.origin[0]) / self.cell_size).floor() as i32,
202            ((pos[1] - self.origin[1]) / self.cell_size).floor() as i32,
203            ((pos[2] - self.origin[2]) / self.cell_size).floor() as i32,
204        ]
205    }
206
207    /// Pack cell indices into a `u64` key for use in a HashMap.
208    fn cell_key(ix: i32, iy: i32, iz: i32) -> u64 {
209        // Pack three i16-range integers into a u64.
210        let x = (ix as i16) as u64 & 0xFFFF;
211        let y = (iy as i16) as u64 & 0xFFFF;
212        let z = (iz as i16) as u64 & 0xFFFF;
213        (z << 32) | (y << 16) | x
214    }
215
216    /// Insert all AABBs into the uniform grid.
217    ///
218    /// Each AABB is inserted into every cell it overlaps.  Returns a map from
219    /// packed cell key to the list of `body_id`s in that cell.
220    pub fn insert_aabbs(&self, aabbs: &[AabbGpu]) -> HashMap<u64, Vec<u32>> {
221        let mut map: HashMap<u64, Vec<u32>> = HashMap::new();
222
223        for aabb in aabbs {
224            let min_cell = self.cell_of(aabb.min);
225            let max_cell = self.cell_of(aabb.max);
226
227            for iz in min_cell[2]..=max_cell[2] {
228                for iy in min_cell[1]..=max_cell[1] {
229                    for ix in min_cell[0]..=max_cell[0] {
230                        let key = Self::cell_key(ix, iy, iz);
231                        map.entry(key).or_default().push(aabb.body_id);
232                    }
233                }
234            }
235        }
236
237        map
238    }
239
240    /// Query all overlapping AABB pairs using the uniform grid.
241    ///
242    /// Returns deduplicated pairs `(a, b)` with `a < b`.
243    pub fn query_pairs(&self, aabbs: &[AabbGpu]) -> Vec<(u32, u32)> {
244        let map = self.insert_aabbs(aabbs);
245
246        let mut seen = std::collections::HashSet::new();
247        let mut pairs = Vec::new();
248
249        for body_list in map.values() {
250            let n = body_list.len();
251            for i in 0..n {
252                for j in (i + 1)..n {
253                    let a = body_list[i];
254                    let b = body_list[j];
255                    let pair = if a < b { (a, b) } else { (b, a) };
256                    if seen.insert(pair) {
257                        // Verify actual AABB overlap
258                        if let (Some(aa), Some(bb)) = (
259                            aabbs.iter().find(|x| x.body_id == pair.0),
260                            aabbs.iter().find(|x| x.body_id == pair.1),
261                        ) && aa.overlaps(bb)
262                        {
263                            pairs.push(pair);
264                        }
265                    }
266                }
267            }
268        }
269
270        pairs.sort();
271        pairs
272    }
273}
274
275// ---------------------------------------------------------------------------
276// Morton code (Z-curve)
277// ---------------------------------------------------------------------------
278
279/// Compute the 3-D Morton code (Z-curve) for integer coordinates.
280///
281/// Interleaves the bits of `x`, `y`, and `z` to produce a single `u64`
282/// space-filling key.  Each coordinate may use up to 21 bits.
283pub fn morton_code(x: u32, y: u32, z: u32) -> u64 {
284    spread_bits(x as u64) | (spread_bits(y as u64) << 1) | (spread_bits(z as u64) << 2)
285}
286
287/// Spread the bits of a 21-bit integer, inserting two zero bits between each bit.
288#[inline]
289fn spread_bits(mut v: u64) -> u64 {
290    v &= 0x1fffff; // keep lower 21 bits
291    v = (v | (v << 32)) & 0x1f00000000ffff;
292    v = (v | (v << 16)) & 0x1f0000ff0000ff;
293    v = (v | (v << 8)) & 0x100f00f00f00f00f;
294    v = (v | (v << 4)) & 0x10c30c30c30c30c3;
295    v = (v | (v << 2)) & 0x1249249249249249;
296    v
297}
298
299// ---------------------------------------------------------------------------
300// Sort-and-Sweep (CPU-side mock for GPU-style parallel data)
301// ---------------------------------------------------------------------------
302
303/// Sort-and-sweep on a flat f64 AABB array (GPU-style parallel data layout).
304///
305/// Input layout per object: `[min_x, max_x, min_y, max_y, min_z, max_z]`.
306/// Returns pairs of indices `(i, j)` (sorted, deduped) that overlap.
307#[allow(clippy::needless_range_loop)]
308pub fn sort_and_sweep_flat(aabbs: &[f64]) -> Vec<(usize, usize)> {
309    let n = aabbs.len() / 6;
310    if n == 0 {
311        return Vec::new();
312    }
313
314    // Sort by min_x
315    let mut order: Vec<usize> = (0..n).collect();
316    order.sort_by(|&a, &b| {
317        let ax = aabbs[a * 6];
318        let bx = aabbs[b * 6];
319        ax.partial_cmp(&bx).unwrap_or(std::cmp::Ordering::Equal)
320    });
321
322    let mut pairs = Vec::new();
323    for i in 0..n {
324        let si = order[i];
325        let max_x_i = aabbs[si * 6 + 1];
326        for j in (i + 1)..n {
327            let sj = order[j];
328            if aabbs[sj * 6] > max_x_i {
329                break; // Early out: rest cannot overlap in X
330            }
331            // Full 3-D overlap check
332            let ai = &aabbs[si * 6..(si + 1) * 6];
333            let aj = &aabbs[sj * 6..(sj + 1) * 6];
334            if ai[0] <= aj[1]
335                && ai[1] >= aj[0]
336                && ai[2] <= aj[3]
337                && ai[3] >= aj[2]
338                && ai[4] <= aj[5]
339                && ai[5] >= aj[4]
340            {
341                let pair = if si < sj { (si, sj) } else { (sj, si) };
342                pairs.push(pair);
343            }
344        }
345    }
346    pairs.sort();
347    pairs.dedup();
348    pairs
349}
350
351// ---------------------------------------------------------------------------
352// Uniform Grid (CPU-side parallel data helper)
353// ---------------------------------------------------------------------------
354
355/// Assigns each AABB in a flat array to uniform grid cells.
356///
357/// Returns a `Vec<(cell_key, body_idx)>` (unsorted, one entry per cell overlap).
358pub fn assign_to_grid_cells(aabbs: &[f64], cell_size: f64, origin: [f64; 3]) -> Vec<(u64, usize)> {
359    let n = aabbs.len() / 6;
360    let mut result = Vec::new();
361    let pack = |ix: i64, iy: i64, iz: i64| -> u64 {
362        let x = (ix as i16) as u64 & 0xFFFF;
363        let y = (iy as i16) as u64 & 0xFFFF;
364        let z = (iz as i16) as u64 & 0xFFFF;
365        (z << 32) | (y << 16) | x
366    };
367    for i in 0..n {
368        let a = &aabbs[i * 6..(i + 1) * 6];
369        let ix0 = ((a[0] - origin[0]) / cell_size).floor() as i64;
370        let ix1 = ((a[1] - origin[0]) / cell_size).floor() as i64;
371        let iy0 = ((a[2] - origin[1]) / cell_size).floor() as i64;
372        let iy1 = ((a[3] - origin[1]) / cell_size).floor() as i64;
373        let iz0 = ((a[4] - origin[2]) / cell_size).floor() as i64;
374        let iz1 = ((a[5] - origin[2]) / cell_size).floor() as i64;
375        for iz in iz0..=iz1 {
376            for iy in iy0..=iy1 {
377                for ix in ix0..=ix1 {
378                    result.push((pack(ix, iy, iz), i));
379                }
380            }
381        }
382    }
383    result
384}
385
386/// Extract candidate pairs from a cell-assignment list.
387///
388/// Groups entries by cell key and emits pairs for bodies sharing a cell.
389pub fn pairs_from_grid_assignments(assignments: &[(u64, usize)]) -> Vec<(usize, usize)> {
390    let mut by_cell: HashMap<u64, Vec<usize>> = HashMap::new();
391    for &(key, idx) in assignments {
392        by_cell.entry(key).or_default().push(idx);
393    }
394    let mut seen = std::collections::HashSet::new();
395    let mut pairs = Vec::new();
396    for bodies in by_cell.values() {
397        let n = bodies.len();
398        for i in 0..n {
399            for j in (i + 1)..n {
400                let a = bodies[i];
401                let b = bodies[j];
402                let p = if a < b { (a, b) } else { (b, a) };
403                if seen.insert(p) {
404                    pairs.push(p);
405                }
406            }
407        }
408    }
409    pairs.sort();
410    pairs
411}
412
413// ---------------------------------------------------------------------------
414// Morton code sort utilities
415// ---------------------------------------------------------------------------
416
417/// Compute a Morton key for a world-space AABB centroid.
418///
419/// Quantises each coordinate to a 21-bit integer using `cell_size` and `origin`.
420pub fn morton_key_for_aabb(aabb: &AabbGpu, cell_size: f32, origin: [f32; 3]) -> u64 {
421    let cx = ((aabb.min[0] + aabb.max[0]) * 0.5 - origin[0]) / cell_size;
422    let cy = ((aabb.min[1] + aabb.max[1]) * 0.5 - origin[1]) / cell_size;
423    let cz = ((aabb.min[2] + aabb.max[2]) * 0.5 - origin[2]) / cell_size;
424
425    let ix = (cx.max(0.0) as u32).min(0x1F_FFFF);
426    let iy = (cy.max(0.0) as u32).min(0x1F_FFFF);
427    let iz = (cz.max(0.0) as u32).min(0x1F_FFFF);
428    morton_code(ix, iy, iz)
429}
430
431/// Sort a slice of `AabbGpu` by Morton code (Z-order curve).
432///
433/// Returns a new `Vec`AabbGpu` sorted by `morton_key_for_aabb`.
434pub fn morton_sort(aabbs: &[AabbGpu], cell_size: f32, origin: [f32; 3]) -> Vec<AabbGpu> {
435    let mut keyed: Vec<(u64, AabbGpu)> = aabbs
436        .iter()
437        .map(|a| (morton_key_for_aabb(a, cell_size, origin), *a))
438        .collect();
439    keyed.sort_by_key(|&(k, _)| k);
440    keyed.into_iter().map(|(_, a)| a).collect()
441}
442
443// ---------------------------------------------------------------------------
444// Compact pairs list
445// ---------------------------------------------------------------------------
446
447/// A deduplicated, sorted list of overlapping body-ID pairs.
448#[derive(Debug, Clone, Default)]
449pub struct CompactPairList {
450    pairs: Vec<(u32, u32)>,
451}
452
453impl CompactPairList {
454    /// Create an empty list.
455    pub fn new() -> Self {
456        Self::default()
457    }
458
459    /// Insert a pair (order-normalised to a < b).
460    pub fn insert(&mut self, a: u32, b: u32) {
461        let pair = if a < b { (a, b) } else { (b, a) };
462        // Simple insert-if-not-present (small lists)
463        if !self.pairs.contains(&pair) {
464            self.pairs.push(pair);
465        }
466    }
467
468    /// Insert all pairs from a `Vec`.
469    pub fn insert_all(&mut self, pairs: &[(u32, u32)]) {
470        for &(a, b) in pairs {
471            self.insert(a, b);
472        }
473    }
474
475    /// Sort the pair list.
476    pub fn sort(&mut self) {
477        self.pairs.sort();
478    }
479
480    /// Return a reference to all pairs.
481    pub fn pairs(&self) -> &[(u32, u32)] {
482        &self.pairs
483    }
484
485    /// Number of pairs.
486    pub fn len(&self) -> usize {
487        self.pairs.len()
488    }
489
490    /// Returns true if the list is empty.
491    pub fn is_empty(&self) -> bool {
492        self.pairs.is_empty()
493    }
494
495    /// Remove pairs involving a specific body (e.g., after removal from simulation).
496    pub fn remove_body(&mut self, body_id: u32) {
497        self.pairs.retain(|&(a, b)| a != body_id && b != body_id);
498    }
499
500    /// Returns `true` if the given pair exists.
501    pub fn contains(&self, a: u32, b: u32) -> bool {
502        let p = if a < b { (a, b) } else { (b, a) };
503        self.pairs.contains(&p)
504    }
505}
506
507// ---------------------------------------------------------------------------
508// BVH (GPU)
509// ---------------------------------------------------------------------------
510
511/// A single node of a GPU-oriented BVH tree.
512#[derive(Debug, Clone, Copy)]
513pub struct BvhGpuNode {
514    /// Bounding box of this node.
515    pub aabb: AabbGpu,
516    /// Index of the left child node, or `< 0` for a leaf (then `-body_id - 1`).
517    pub left: i32,
518    /// Index of the right child node, or `< 0` for a leaf (then `-body_id - 1`).
519    pub right: i32,
520}
521
522impl BvhGpuNode {
523    /// Create an internal node.
524    pub fn internal(aabb: AabbGpu, left: i32, right: i32) -> Self {
525        Self { aabb, left, right }
526    }
527
528    /// Create a leaf node for a given body.
529    pub fn leaf(aabb: AabbGpu) -> Self {
530        let id = aabb.body_id as i32;
531        Self {
532            aabb,
533            left: -(id + 1),
534            right: -(id + 1),
535        }
536    }
537
538    /// Return true if this is a leaf node.
539    pub fn is_leaf(&self) -> bool {
540        self.left < 0
541    }
542}
543
544/// Build a BVH using a simple midpoint-split heuristic.
545///
546/// Returns a flat `Vec`BvhGpuNode` with node 0 as the root.
547/// Internal nodes reference their children by index into this vec.
548pub fn build_bvh(aabbs: &[AabbGpu]) -> Vec<BvhGpuNode> {
549    let mut nodes = Vec::new();
550    if aabbs.is_empty() {
551        return nodes;
552    }
553    let mut indices: Vec<usize> = (0..aabbs.len()).collect();
554    build_bvh_recursive(aabbs, &mut indices, &mut nodes);
555    nodes
556}
557
558fn merge_aabbs(aabbs: &[AabbGpu], indices: &[usize]) -> AabbGpu {
559    let mut min = aabbs[indices[0]].min;
560    let mut max = aabbs[indices[0]].max;
561    for &idx in &indices[1..] {
562        let a = &aabbs[idx];
563        for k in 0..3 {
564            if a.min[k] < min[k] {
565                min[k] = a.min[k];
566            }
567            if a.max[k] > max[k] {
568                max[k] = a.max[k];
569            }
570        }
571    }
572    AabbGpu {
573        min,
574        max,
575        body_id: 0,
576    }
577}
578
579fn build_bvh_recursive(
580    aabbs: &[AabbGpu],
581    indices: &mut [usize],
582    nodes: &mut Vec<BvhGpuNode>,
583) -> i32 {
584    let n = indices.len();
585    let merged = merge_aabbs(aabbs, indices);
586
587    if n == 1 {
588        let idx = nodes.len() as i32;
589        nodes.push(BvhGpuNode::leaf(aabbs[indices[0]]));
590        return idx;
591    }
592
593    // Find the longest axis of the merged AABB.
594    let extents = [
595        merged.max[0] - merged.min[0],
596        merged.max[1] - merged.min[1],
597        merged.max[2] - merged.min[2],
598    ];
599    let axis = if extents[0] >= extents[1] && extents[0] >= extents[2] {
600        0
601    } else if extents[1] >= extents[2] {
602        1
603    } else {
604        2
605    };
606
607    // Sort by centroid along longest axis.
608    indices.sort_by(|&a, &b| {
609        let ca = (aabbs[a].min[axis] + aabbs[a].max[axis]) * 0.5;
610        let cb = (aabbs[b].min[axis] + aabbs[b].max[axis]) * 0.5;
611        ca.partial_cmp(&cb).unwrap_or(std::cmp::Ordering::Equal)
612    });
613
614    let mid = n / 2;
615    let (left_idx, right_idx) = indices.split_at_mut(mid);
616
617    // Reserve a slot for the internal node before recursing.
618    let node_idx = nodes.len() as i32;
619    nodes.push(BvhGpuNode {
620        aabb: merged,
621        left: 0,
622        right: 0,
623    }); // placeholder
624
625    let mut left_indices = left_idx.to_vec();
626    let mut right_indices = right_idx.to_vec();
627
628    let left = build_bvh_recursive(aabbs, &mut left_indices, nodes);
629    let right = build_bvh_recursive(aabbs, &mut right_indices, nodes);
630
631    nodes[node_idx as usize].left = left;
632    nodes[node_idx as usize].right = right;
633
634    node_idx
635}
636
637// ---------------------------------------------------------------------------
638// LBVH — Linear BVH traversal
639// ---------------------------------------------------------------------------
640
641/// Query overlapping leaf pairs in an LBVH by traversal.
642///
643/// Traverses the BVH from the root to find all pairs of leaf nodes whose
644/// AABBs overlap.  Returns `Vec<(u32, u32)>` with `a < b` and deduplication.
645pub fn lbvh_query_pairs(nodes: &[BvhGpuNode]) -> Vec<(u32, u32)> {
646    if nodes.is_empty() {
647        return Vec::new();
648    }
649    let mut pairs = Vec::new();
650    // Stack-based traversal: (nodeA, nodeB)
651    let mut stack: Vec<(usize, usize)> = Vec::new();
652    // Self-collision query: test root against itself
653    stack.push((0, 0));
654
655    while let Some((a_idx, b_idx)) = stack.pop() {
656        let na = &nodes[a_idx];
657        let nb = &nodes[b_idx];
658
659        if !na.aabb.overlaps(&nb.aabb) {
660            continue;
661        }
662
663        if na.is_leaf() && nb.is_leaf() {
664            if a_idx != b_idx {
665                let id_a = na.aabb.body_id;
666                let id_b = nb.aabb.body_id;
667                let pair = if id_a < id_b {
668                    (id_a, id_b)
669                } else {
670                    (id_b, id_a)
671                };
672                pairs.push(pair);
673            }
674            continue;
675        }
676
677        // Descend into larger node first
678        if na.is_leaf() {
679            // descend b
680            if nb.left >= 0 {
681                stack.push((a_idx, nb.left as usize));
682            }
683            if nb.right >= 0 {
684                stack.push((a_idx, nb.right as usize));
685            }
686        } else if nb.is_leaf() {
687            // descend a
688            if na.left >= 0 {
689                stack.push((na.left as usize, b_idx));
690            }
691            if na.right >= 0 {
692                stack.push((na.right as usize, b_idx));
693            }
694        } else {
695            // Both internal — descend a
696            if na.left >= 0 {
697                stack.push((na.left as usize, b_idx));
698            }
699            if na.right >= 0 {
700                stack.push((na.right as usize, b_idx));
701            }
702        }
703    }
704    pairs.sort();
705    pairs.dedup();
706    pairs
707}
708
709// ---------------------------------------------------------------------------
710// BVH refitting
711// ---------------------------------------------------------------------------
712
713/// Refit BVH bounding boxes bottom-up after leaf AABBs have changed.
714///
715/// For each internal node, recomputes its AABB as the union of its children.
716/// Assumes a flat node array where children always have higher indices
717/// (which is guaranteed by `build_bvh_recursive`).
718pub fn refit_bvh(nodes: &mut Vec<BvhGpuNode>) {
719    // Process nodes in reverse order (children before parents)
720    let n = nodes.len();
721    for i in (0..n).rev() {
722        if nodes[i].is_leaf() {
723            continue; // Leaves are driven by actual AABB data — no refit needed here
724        }
725        let left_idx = nodes[i].left;
726        let right_idx = nodes[i].right;
727        if left_idx < 0 || right_idx < 0 {
728            continue;
729        }
730        let l = &nodes[left_idx as usize];
731        let r = &nodes[right_idx as usize];
732        let mut min = l.aabb.min;
733        let mut max = l.aabb.max;
734        for k in 0..3 {
735            if r.aabb.min[k] < min[k] {
736                min[k] = r.aabb.min[k];
737            }
738            if r.aabb.max[k] > max[k] {
739                max[k] = r.aabb.max[k];
740            }
741        }
742        nodes[i].aabb = AabbGpu {
743            min,
744            max,
745            body_id: 0,
746        };
747    }
748}
749
750// ---------------------------------------------------------------------------
751// BVH quality metrics
752// ---------------------------------------------------------------------------
753
754/// Compute the surface area of an `AabbGpu`.
755#[allow(dead_code)]
756pub fn aabb_surface_area(aabb: &AabbGpu) -> f32 {
757    let dx = aabb.max[0] - aabb.min[0];
758    let dy = aabb.max[1] - aabb.min[1];
759    let dz = aabb.max[2] - aabb.min[2];
760    2.0 * (dx * dy + dy * dz + dz * dx)
761}
762
763/// Surface Area Heuristic (SAH) cost of a BVH tree.
764///
765/// `SAH = sum_{internal nodes} SA(node) / SA(root) * cost_traversal`
766///       `+ sum_{leaf nodes} SA(leaf) / SA(root) * num_primitives`
767///
768/// A lower SAH cost indicates a better-quality BVH.
769#[allow(dead_code)]
770pub fn bvh_sah_cost(nodes: &[BvhGpuNode], cost_traversal: f32, cost_primitive: f32) -> f32 {
771    if nodes.is_empty() {
772        return 0.0;
773    }
774    let root_sa = aabb_surface_area(&nodes[0].aabb);
775    if root_sa < 1e-20 {
776        return 0.0;
777    }
778    let mut cost = 0.0f32;
779    for node in nodes {
780        let sa = aabb_surface_area(&node.aabb);
781        if node.is_leaf() {
782            cost += sa / root_sa * cost_primitive;
783        } else {
784            cost += sa / root_sa * cost_traversal;
785        }
786    }
787    cost
788}
789
790/// Compute the depth of the BVH tree.
791///
792/// Returns the maximum node depth from root (depth 0) to the deepest leaf.
793#[allow(dead_code)]
794pub fn bvh_depth(nodes: &[BvhGpuNode]) -> usize {
795    if nodes.is_empty() {
796        return 0;
797    }
798    let mut max_depth = 0usize;
799    // Stack: (node_idx, current_depth)
800    let mut stack: Vec<(usize, usize)> = vec![(0, 0)];
801    while let Some((idx, depth)) = stack.pop() {
802        if depth > max_depth {
803            max_depth = depth;
804        }
805        let node = &nodes[idx];
806        if !node.is_leaf() {
807            if node.left >= 0 {
808                stack.push((node.left as usize, depth + 1));
809            }
810            if node.right >= 0 {
811                stack.push((node.right as usize, depth + 1));
812            }
813        }
814    }
815    max_depth
816}
817
818/// Count the number of leaf nodes in the BVH.
819#[allow(dead_code)]
820pub fn bvh_leaf_count(nodes: &[BvhGpuNode]) -> usize {
821    nodes.iter().filter(|n| n.is_leaf()).count()
822}
823
824// ---------------------------------------------------------------------------
825// Parallel SAP update (incremental update for moved bodies)
826// ---------------------------------------------------------------------------
827
828/// Update the sort-and-sweep pair list after a set of bodies have moved.
829///
830/// This is an incremental update: only re-run SAP for the bodies that moved
831/// and merge with the existing pair list.  All existing pairs involving moved
832/// bodies are removed and recomputed.
833///
834/// `moved_ids`: set of `body_id`s that have moved.
835/// `aabbs`: updated AABB array (all bodies).
836///
837/// Returns the updated pair list.
838#[allow(dead_code)]
839pub fn sap_incremental_update(
840    existing: &CompactPairList,
841    aabbs: &[AabbGpu],
842    moved_ids: &[u32],
843) -> CompactPairList {
844    let mut new_list = existing.clone();
845
846    // Remove all pairs involving moved bodies
847    for &id in moved_ids {
848        new_list.remove_body(id);
849    }
850
851    // Re-detect pairs for moved bodies against all others
852    for &moved_id in moved_ids {
853        if let Some(moved_aabb) = aabbs.iter().find(|a| a.body_id == moved_id) {
854            for other in aabbs {
855                if other.body_id == moved_id {
856                    continue;
857                }
858                if moved_aabb.overlaps(other) {
859                    new_list.insert(moved_id, other.body_id);
860                }
861            }
862        }
863    }
864
865    new_list.sort();
866    new_list
867}
868
869// ---------------------------------------------------------------------------
870// SAH split heuristic
871// ---------------------------------------------------------------------------
872
873/// Surface Area Heuristic (SAH) split for BVH construction.
874///
875/// Evaluates `num_bins` candidate split planes along the given axis and
876/// returns the bin index (from 0 to `num_bins-1`) that minimises the SAH cost.
877///
878/// Returns `None` if all primitives have the same centroid along the axis.
879#[allow(dead_code)]
880pub fn sah_best_split(
881    aabbs: &[AabbGpu],
882    indices: &[usize],
883    axis: usize,
884    num_bins: usize,
885) -> Option<usize> {
886    if indices.len() < 2 || num_bins < 2 {
887        return None;
888    }
889
890    // Centroid range
891    let min_c = indices
892        .iter()
893        .map(|&i| 0.5 * (aabbs[i].min[axis] + aabbs[i].max[axis]))
894        .fold(f32::INFINITY, f32::min);
895    let max_c = indices
896        .iter()
897        .map(|&i| 0.5 * (aabbs[i].min[axis] + aabbs[i].max[axis]))
898        .fold(f32::NEG_INFINITY, f32::max);
899
900    if (max_c - min_c).abs() < 1e-10 {
901        return None;
902    }
903
904    let bin_width = (max_c - min_c) / num_bins as f32;
905    let mut bin_counts = vec![0usize; num_bins];
906    let mut bin_aabbs: Vec<Option<AabbGpu>> = vec![None; num_bins];
907
908    for &i in indices {
909        let c = 0.5 * (aabbs[i].min[axis] + aabbs[i].max[axis]);
910        let bin = ((c - min_c) / bin_width).floor() as usize;
911        let bin = bin.min(num_bins - 1);
912        bin_counts[bin] += 1;
913        bin_aabbs[bin] = Some(match &bin_aabbs[bin] {
914            None => aabbs[i],
915            Some(prev) => {
916                let mut merged = *prev;
917                for k in 0..3 {
918                    if aabbs[i].min[k] < merged.min[k] {
919                        merged.min[k] = aabbs[i].min[k];
920                    }
921                    if aabbs[i].max[k] > merged.max[k] {
922                        merged.max[k] = aabbs[i].max[k];
923                    }
924                }
925                merged
926            }
927        });
928    }
929
930    let mut best_cost = f32::INFINITY;
931    let mut best_split = None;
932
933    for split in 1..num_bins {
934        let left_count: usize = bin_counts[..split].iter().sum();
935        let right_count: usize = bin_counts[split..].iter().sum();
936        if left_count == 0 || right_count == 0 {
937            continue;
938        }
939
940        // Compute left and right bounding boxes
941        let left_sa = bin_aabbs[..split]
942            .iter()
943            .flatten()
944            .fold(None::<AabbGpu>, |acc, a| {
945                Some(match acc {
946                    None => *a,
947                    Some(prev) => {
948                        let mut m = prev;
949                        for k in 0..3 {
950                            if a.min[k] < m.min[k] {
951                                m.min[k] = a.min[k];
952                            }
953                            if a.max[k] > m.max[k] {
954                                m.max[k] = a.max[k];
955                            }
956                        }
957                        m
958                    }
959                })
960            })
961            .map(|a| aabb_surface_area(&a))
962            .unwrap_or(0.0);
963
964        let right_sa = bin_aabbs[split..]
965            .iter()
966            .flatten()
967            .fold(None::<AabbGpu>, |acc, a| {
968                Some(match acc {
969                    None => *a,
970                    Some(prev) => {
971                        let mut m = prev;
972                        for k in 0..3 {
973                            if a.min[k] < m.min[k] {
974                                m.min[k] = a.min[k];
975                            }
976                            if a.max[k] > m.max[k] {
977                                m.max[k] = a.max[k];
978                            }
979                        }
980                        m
981                    }
982                })
983            })
984            .map(|a| aabb_surface_area(&a))
985            .unwrap_or(0.0);
986
987        let cost = left_sa * left_count as f32 + right_sa * right_count as f32;
988        if cost < best_cost {
989            best_cost = cost;
990            best_split = Some(split);
991        }
992    }
993
994    best_split
995}
996
997// ---------------------------------------------------------------------------
998// LBVH build using Morton codes
999// ---------------------------------------------------------------------------
1000
1001/// Build a Linear BVH (LBVH) by sorting AABBs on their Morton codes.
1002///
1003/// This is a GPU-style construction algorithm:
1004/// 1. Compute Morton code for each AABB centroid.
1005/// 2. Sort AABBs by Morton code.
1006/// 3. Recursively build BVH on sorted order (binary radix tree).
1007///
1008/// Returns a flat node array with node 0 as the root.
1009#[allow(dead_code)]
1010pub fn build_lbvh(aabbs: &[AabbGpu], cell_size: f32, origin: [f32; 3]) -> Vec<BvhGpuNode> {
1011    if aabbs.is_empty() {
1012        return Vec::new();
1013    }
1014    // Sort by Morton code
1015    let sorted = morton_sort(aabbs, cell_size, origin);
1016    // Build BVH on sorted order
1017    build_bvh(&sorted)
1018}
1019
1020// ---------------------------------------------------------------------------
1021// Tests
1022// ---------------------------------------------------------------------------
1023
1024#[cfg(test)]
1025mod tests {
1026    use super::*;
1027
1028    #[test]
1029    fn aabb_overlap_detects_overlapping_boxes() {
1030        // Two boxes that clearly overlap
1031        #[rustfmt::skip]
1032        let aabbs: Vec<f64> = vec![
1033            0.0, 2.0, 0.0, 2.0, 0.0, 2.0, // box 0
1034            1.0, 3.0, 1.0, 3.0, 1.0, 3.0, // box 1 (overlaps box 0)
1035        ];
1036        let mut outputs = vec![Vec::new()];
1037        AabbOverlapKernel.execute(&[&aabbs], &mut outputs, 2);
1038        assert_eq!(outputs[0], vec![0.0, 1.0]);
1039    }
1040
1041    #[test]
1042    fn aabb_overlap_rejects_non_overlapping_boxes() {
1043        #[rustfmt::skip]
1044        let aabbs: Vec<f64> = vec![
1045            0.0, 1.0, 0.0, 1.0, 0.0, 1.0, // box 0
1046            5.0, 6.0, 5.0, 6.0, 5.0, 6.0, // box 1 (far away)
1047        ];
1048        let mut outputs = vec![Vec::new()];
1049        AabbOverlapKernel.execute(&[&aabbs], &mut outputs, 2);
1050        assert!(outputs[0].is_empty());
1051    }
1052
1053    #[test]
1054    fn test_broadphase_gpu_matches_cpu() {
1055        // Three AABBs: 0 overlaps 1, 1 overlaps 2, but 0 does NOT overlap 2.
1056        #[rustfmt::skip]
1057        let aabbs: Vec<f64> = vec![
1058            0.0, 1.5, 0.0, 1.5, 0.0, 1.5, // box 0
1059            1.0, 2.5, 0.0, 1.5, 0.0, 1.5, // box 1 (overlaps 0 in x)
1060            3.0, 4.0, 0.0, 1.5, 0.0, 1.5, // box 2 (overlaps 1 in x, not 0)
1061        ];
1062
1063        // GPU kernel output
1064        let mut gpu_outputs = vec![Vec::new()];
1065        AabbOverlapKernel.execute(&[&aabbs], &mut gpu_outputs, 3);
1066
1067        // CPU brute-force reference
1068        let n = aabbs.len() / 6;
1069        let mut cpu_pairs: Vec<(usize, usize)> = Vec::new();
1070        for i in 0..n {
1071            for j in (i + 1)..n {
1072                let a = &aabbs[i * 6..(i + 1) * 6];
1073                let b = &aabbs[j * 6..(j + 1) * 6];
1074                let overlaps = a[0] <= b[1]
1075                    && a[1] >= b[0]
1076                    && a[2] <= b[3]
1077                    && a[3] >= b[2]
1078                    && a[4] <= b[5]
1079                    && a[5] >= b[4];
1080                if overlaps {
1081                    cpu_pairs.push((i, j));
1082                }
1083            }
1084        }
1085
1086        // Convert GPU output to (usize, usize) pairs
1087        let raw = &gpu_outputs[0];
1088        assert_eq!(raw.len() % 2, 0, "GPU output length must be even");
1089        let gpu_pairs: Vec<(usize, usize)> = raw
1090            .chunks(2)
1091            .map(|c| (c[0] as usize, c[1] as usize))
1092            .collect();
1093
1094        assert_eq!(
1095            gpu_pairs, cpu_pairs,
1096            "GPU broadphase pairs do not match CPU brute-force pairs"
1097        );
1098    }
1099
1100    /// Morton code: (0,0,0) → 0, and each axis occupies its own bit lanes.
1101    #[test]
1102    fn test_morton_code_correctness() {
1103        assert_eq!(morton_code(0, 0, 0), 0);
1104        // x=1: only bit 0 of x set → Morton bit 0 (axis 0)
1105        assert_eq!(morton_code(1, 0, 0), 1);
1106        // y=1: Morton bit 1
1107        assert_eq!(morton_code(0, 1, 0), 2);
1108        // z=1: Morton bit 2
1109        assert_eq!(morton_code(0, 0, 1), 4);
1110        // x=1, y=1, z=1 → bits 0,1,2 all set
1111        assert_eq!(morton_code(1, 1, 1), 7);
1112    }
1113
1114    /// SAP detects overlapping GPU AABBs.
1115    #[test]
1116    fn test_sap_finds_overlap() {
1117        let a = AabbGpu::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0], 0);
1118        let b = AabbGpu::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0], 1);
1119        let c = AabbGpu::new([10.0, 10.0, 10.0], [12.0, 12.0, 12.0], 2);
1120
1121        let pairs = SortAndSweepGpu::detect_pairs(&[a, b, c]);
1122        assert!(
1123            pairs.contains(&(0, 1)),
1124            "SAP should find pair (0,1), got {pairs:?}"
1125        );
1126        assert!(!pairs.contains(&(0, 2)), "SAP should not find pair (0,2)");
1127        assert!(!pairs.contains(&(1, 2)), "SAP should not find pair (1,2)");
1128    }
1129
1130    /// SAP returns empty for non-overlapping AABBs.
1131    #[test]
1132    fn test_sap_no_overlap() {
1133        let a = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0);
1134        let b = AabbGpu::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0], 1);
1135        let pairs = SortAndSweepGpu::detect_pairs(&[a, b]);
1136        assert!(
1137            pairs.is_empty(),
1138            "Should be no overlapping pairs, got {pairs:?}"
1139        );
1140    }
1141
1142    /// Uniform grid detects the same pairs as brute-force for 3 AABBs.
1143    #[test]
1144    fn test_uniform_grid_pair_detection() {
1145        let a = AabbGpu::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0], 0);
1146        let b = AabbGpu::new([1.5, 1.5, 1.5], [3.5, 3.5, 3.5], 1);
1147        let c = AabbGpu::new([10.0, 10.0, 10.0], [12.0, 12.0, 12.0], 2);
1148
1149        let grid = UniformGridGpu::new(5.0, [0.0, 0.0, 0.0], [10, 10, 10]);
1150        let pairs = grid.query_pairs(&[a, b, c]);
1151
1152        // Only a and b overlap.
1153        assert!(
1154            pairs.contains(&(0, 1)),
1155            "Grid should find pair (0,1), got {pairs:?}"
1156        );
1157        assert!(
1158            !pairs
1159                .iter()
1160                .any(|&(x, y)| x == 0 && y == 2 || x == 2 && y == 0)
1161        );
1162    }
1163
1164    /// BVH builds without panic and covers all body AABBs.
1165    #[test]
1166    fn test_bvh_depth() {
1167        let aabbs: Vec<AabbGpu> = (0..8)
1168            .map(|i| {
1169                let x = (i * 3) as f32;
1170                AabbGpu::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0], i)
1171            })
1172            .collect();
1173
1174        let nodes = build_bvh(&aabbs);
1175
1176        // There should be at least 1 node, and at most 2*n-1 nodes.
1177        assert!(!nodes.is_empty(), "BVH should have at least one node");
1178        assert!(
1179            nodes.len() <= 2 * aabbs.len(),
1180            "BVH node count unexpected: {}",
1181            nodes.len()
1182        );
1183
1184        // The root node (index 0) should cover the entire range.
1185        let root = &nodes[0];
1186        assert!(root.aabb.min[0] <= 0.0 + 1e-5, "Root min_x too large");
1187        assert!(
1188            root.aabb.max[0] >= 22.0 - 1e-5,
1189            "Root max_x too small: {}",
1190            root.aabb.max[0]
1191        );
1192    }
1193
1194    /// AabbGpu::overlaps is symmetric.
1195    #[test]
1196    fn test_aabb_gpu_overlap_symmetric() {
1197        let a = AabbGpu::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0], 0);
1198        let b = AabbGpu::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0], 1);
1199        assert_eq!(a.overlaps(&b), b.overlaps(&a));
1200
1201        let c = AabbGpu::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0], 2);
1202        assert_eq!(a.overlaps(&c), c.overlaps(&a));
1203        assert!(!a.overlaps(&c));
1204    }
1205
1206    // ── New expanded tests ──
1207
1208    #[allow(clippy::nonminimal_bool)]
1209    #[test]
1210    fn test_sort_and_sweep_flat_finds_pair() {
1211        #[rustfmt::skip]
1212        let aabbs: Vec<f64> = vec![
1213            0.0, 2.0, 0.0, 2.0, 0.0, 2.0, // box 0
1214            1.0, 3.0, 1.0, 3.0, 1.0, 3.0, // box 1 overlaps 0
1215            5.0, 6.0, 5.0, 6.0, 5.0, 6.0, // box 2 no overlap
1216        ];
1217        let pairs = sort_and_sweep_flat(&aabbs);
1218        assert!(pairs.contains(&(0, 1)), "should find (0,1)");
1219        assert!(
1220            !pairs
1221                .iter()
1222                .any(|&(a, b)| a == 0 && b == 2 || a == 1 && b == 2)
1223        );
1224    }
1225
1226    #[test]
1227    fn test_sort_and_sweep_flat_empty() {
1228        let pairs = sort_and_sweep_flat(&[]);
1229        assert!(pairs.is_empty());
1230    }
1231
1232    #[test]
1233    fn test_sort_and_sweep_flat_no_overlap() {
1234        #[rustfmt::skip]
1235        let aabbs: Vec<f64> = vec![
1236            0.0, 1.0, 0.0, 1.0, 0.0, 1.0,
1237            2.0, 3.0, 2.0, 3.0, 2.0, 3.0,
1238        ];
1239        let pairs = sort_and_sweep_flat(&aabbs);
1240        assert!(pairs.is_empty());
1241    }
1242
1243    #[test]
1244    fn test_assign_to_grid_cells() {
1245        #[rustfmt::skip]
1246        let aabbs: Vec<f64> = vec![
1247            0.0, 1.0, 0.0, 1.0, 0.0, 1.0, // fits in cell (0,0,0)
1248        ];
1249        let cells = assign_to_grid_cells(&aabbs, 2.0, [0.0, 0.0, 0.0]);
1250        assert!(!cells.is_empty());
1251        // All entries should have body index 0
1252        assert!(cells.iter().all(|&(_, idx)| idx == 0));
1253    }
1254
1255    #[test]
1256    fn test_pairs_from_grid_assignments() {
1257        // Two bodies both assigned to the same cell
1258        let assignments = vec![(0u64, 0usize), (0u64, 1usize)];
1259        let pairs = pairs_from_grid_assignments(&assignments);
1260        assert!(pairs.contains(&(0, 1)));
1261    }
1262
1263    #[test]
1264    fn test_pairs_from_grid_assignments_no_dup() {
1265        // Same pair in multiple cells
1266        let assignments = vec![
1267            (0u64, 0usize),
1268            (0u64, 1usize),
1269            (1u64, 0usize),
1270            (1u64, 1usize),
1271        ];
1272        let pairs = pairs_from_grid_assignments(&assignments);
1273        // Pair (0,1) should appear exactly once
1274        let count = pairs.iter().filter(|&&p| p == (0, 1)).count();
1275        assert_eq!(count, 1);
1276    }
1277
1278    #[test]
1279    fn test_morton_sort_orders_aabbs() {
1280        let aabbs: Vec<AabbGpu> = vec![
1281            AabbGpu::new([4.0, 0.0, 0.0], [5.0, 1.0, 1.0], 0),
1282            AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 1),
1283            AabbGpu::new([2.0, 2.0, 2.0], [3.0, 3.0, 3.0], 2),
1284        ];
1285        let sorted = morton_sort(&aabbs, 1.0, [0.0, 0.0, 0.0]);
1286        // Body 1 (centroid near origin) should come first in Z-order
1287        assert_eq!(
1288            sorted[0].body_id, 1,
1289            "body at origin should be first in Morton order"
1290        );
1291    }
1292
1293    #[test]
1294    fn test_morton_key_reproducible() {
1295        let a = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0);
1296        let k1 = morton_key_for_aabb(&a, 1.0, [0.0, 0.0, 0.0]);
1297        let k2 = morton_key_for_aabb(&a, 1.0, [0.0, 0.0, 0.0]);
1298        assert_eq!(k1, k2);
1299    }
1300
1301    #[test]
1302    fn test_compact_pair_list_insert_dedup() {
1303        let mut list = CompactPairList::new();
1304        list.insert(0, 1);
1305        list.insert(1, 0); // same pair, reversed
1306        list.insert(0, 1); // duplicate
1307        assert_eq!(list.len(), 1);
1308    }
1309
1310    #[test]
1311    fn test_compact_pair_list_contains() {
1312        let mut list = CompactPairList::new();
1313        list.insert(2, 5);
1314        assert!(list.contains(2, 5));
1315        assert!(list.contains(5, 2));
1316        assert!(!list.contains(0, 1));
1317    }
1318
1319    #[test]
1320    fn test_compact_pair_list_remove_body() {
1321        let mut list = CompactPairList::new();
1322        list.insert(0, 1);
1323        list.insert(0, 2);
1324        list.insert(1, 2);
1325        list.remove_body(0);
1326        assert!(!list.contains(0, 1));
1327        assert!(!list.contains(0, 2));
1328        assert!(list.contains(1, 2));
1329    }
1330
1331    #[test]
1332    fn test_compact_pair_list_insert_all() {
1333        let mut list = CompactPairList::new();
1334        list.insert_all(&[(0, 1), (1, 2), (0, 2)]);
1335        assert_eq!(list.len(), 3);
1336    }
1337
1338    #[test]
1339    fn test_compact_pair_list_sort() {
1340        let mut list = CompactPairList::new();
1341        list.insert(3, 4);
1342        list.insert(0, 1);
1343        list.insert(1, 2);
1344        list.sort();
1345        assert_eq!(list.pairs()[0], (0, 1));
1346    }
1347
1348    // ── LBVH traversal tests ─────────────────────────────────────────────────
1349
1350    #[test]
1351    fn test_lbvh_query_pairs_finds_overlap() {
1352        // Build BVH for two overlapping boxes
1353        let a = AabbGpu::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0], 0);
1354        let b = AabbGpu::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0], 1);
1355        let nodes = build_bvh(&[a, b]);
1356        let pairs = lbvh_query_pairs(&nodes);
1357        assert!(
1358            pairs.contains(&(0, 1)),
1359            "LBVH traversal should find (0,1): {pairs:?}"
1360        );
1361    }
1362
1363    #[test]
1364    fn test_lbvh_query_pairs_no_overlap() {
1365        let a = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0);
1366        let b = AabbGpu::new([5.0, 5.0, 5.0], [6.0, 6.0, 6.0], 1);
1367        let nodes = build_bvh(&[a, b]);
1368        let pairs = lbvh_query_pairs(&nodes);
1369        assert!(
1370            pairs.is_empty(),
1371            "Non-overlapping: should have no pairs: {pairs:?}"
1372        );
1373    }
1374
1375    #[test]
1376    fn test_lbvh_query_empty_bvh() {
1377        let pairs = lbvh_query_pairs(&[]);
1378        assert!(pairs.is_empty());
1379    }
1380
1381    // ── BVH refitting tests ──────────────────────────────────────────────────
1382
1383    #[test]
1384    fn test_refit_bvh_no_panic() {
1385        let aabbs: Vec<AabbGpu> = (0..4)
1386            .map(|i| {
1387                let x = (i * 2) as f32;
1388                AabbGpu::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0], i)
1389            })
1390            .collect();
1391        let mut nodes = build_bvh(&aabbs);
1392        // Simulate moving a leaf
1393        for node in nodes.iter_mut() {
1394            if node.is_leaf() {
1395                node.aabb.max[0] += 0.5;
1396            }
1397        }
1398        refit_bvh(&mut nodes);
1399        // After refitting, root should still cover all children
1400        assert!(!nodes.is_empty());
1401    }
1402
1403    #[test]
1404    fn test_refit_bvh_root_encompasses_leaves() {
1405        let aabbs: Vec<AabbGpu> = vec![
1406            AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0),
1407            AabbGpu::new([10.0, 0.0, 0.0], [11.0, 1.0, 1.0], 1),
1408        ];
1409        let mut nodes = build_bvh(&aabbs);
1410        refit_bvh(&mut nodes);
1411        // Root AABB should span at least [0..11]
1412        assert!(nodes[0].aabb.min[0] <= 0.0 + 1e-5);
1413        assert!(nodes[0].aabb.max[0] >= 11.0 - 1e-5);
1414    }
1415
1416    // ── BVH quality metrics tests ────────────────────────────────────────────
1417
1418    #[test]
1419    fn test_aabb_surface_area() {
1420        let a = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 2.0, 3.0], 0);
1421        let sa = aabb_surface_area(&a);
1422        // SA = 2*(1*2 + 2*3 + 3*1) = 2*(2+6+3) = 22
1423        assert!((sa - 22.0).abs() < 1e-5, "SA = {sa}");
1424    }
1425
1426    #[test]
1427    fn test_aabb_surface_area_unit_cube() {
1428        let a = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0);
1429        let sa = aabb_surface_area(&a);
1430        assert!((sa - 6.0).abs() < 1e-5, "unit cube SA = {sa}");
1431    }
1432
1433    #[test]
1434    fn test_bvh_sah_cost_positive() {
1435        let aabbs: Vec<AabbGpu> = (0..4)
1436            .map(|i| {
1437                let x = (i * 3) as f32;
1438                AabbGpu::new([x, 0.0, 0.0], [x + 2.0, 2.0, 2.0], i)
1439            })
1440            .collect();
1441        let nodes = build_bvh(&aabbs);
1442        let cost = bvh_sah_cost(&nodes, 1.0, 1.0);
1443        assert!(cost > 0.0, "SAH cost should be positive: {cost}");
1444        assert!(cost.is_finite(), "SAH cost should be finite");
1445    }
1446
1447    #[test]
1448    fn test_bvh_depth_single_leaf() {
1449        let aabbs = vec![AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0)];
1450        let nodes = build_bvh(&aabbs);
1451        let d = bvh_depth(&nodes);
1452        assert_eq!(d, 0, "single leaf depth = {d}");
1453    }
1454
1455    #[test]
1456    fn test_bvh_depth_two_leaves() {
1457        let aabbs = vec![
1458            AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0),
1459            AabbGpu::new([2.0, 0.0, 0.0], [3.0, 1.0, 1.0], 1),
1460        ];
1461        let nodes = build_bvh(&aabbs);
1462        let d = bvh_depth(&nodes);
1463        assert!(d >= 1, "two-leaf BVH depth >= 1, got {d}");
1464    }
1465
1466    #[test]
1467    fn test_bvh_leaf_count() {
1468        let aabbs: Vec<AabbGpu> = (0..8)
1469            .map(|i| {
1470                let x = (i * 2) as f32;
1471                AabbGpu::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0], i)
1472            })
1473            .collect();
1474        let nodes = build_bvh(&aabbs);
1475        let leaves = bvh_leaf_count(&nodes);
1476        assert_eq!(leaves, 8, "Expected 8 leaves, got {leaves}");
1477    }
1478
1479    // ── SAP incremental update tests ─────────────────────────────────────────
1480
1481    #[test]
1482    fn test_sap_incremental_removes_moved() {
1483        let a = AabbGpu::new([0.0, 0.0, 0.0], [2.0, 2.0, 2.0], 0);
1484        let b = AabbGpu::new([1.0, 1.0, 1.0], [3.0, 3.0, 3.0], 1);
1485        let c = AabbGpu::new([10.0, 10.0, 10.0], [11.0, 11.0, 11.0], 2);
1486        let initial_pairs = SortAndSweepGpu::detect_pairs(&[a, b, c]);
1487        let mut existing = CompactPairList::new();
1488        existing.insert_all(&initial_pairs);
1489
1490        // Move body 0 far away
1491        let a_moved = AabbGpu::new([20.0, 20.0, 20.0], [21.0, 21.0, 21.0], 0);
1492        let updated = sap_incremental_update(&existing, &[a_moved, b, c], &[0]);
1493
1494        // After move, body 0 overlaps nothing
1495        assert!(
1496            !updated.contains(0, 1),
1497            "pair (0,1) should be removed after move"
1498        );
1499    }
1500
1501    #[test]
1502    fn test_sap_incremental_adds_new_overlap() {
1503        let _a = AabbGpu::new([10.0, 0.0, 0.0], [11.0, 1.0, 1.0], 0);
1504        let b = AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 1);
1505        let existing = CompactPairList::new(); // no initial pairs
1506
1507        // Move body 0 to overlap body 1
1508        let a_moved = AabbGpu::new([0.5, 0.5, 0.5], [1.5, 1.5, 1.5], 0);
1509        let updated = sap_incremental_update(&existing, &[a_moved, b], &[0]);
1510        assert!(
1511            updated.contains(0, 1),
1512            "pair (0,1) should be added after move: {:?}",
1513            updated.pairs()
1514        );
1515    }
1516
1517    // ── SAH split tests ──────────────────────────────────────────────────────
1518
1519    #[test]
1520    fn test_sah_best_split_basic() {
1521        let aabbs: Vec<AabbGpu> = (0..8)
1522            .map(|i| {
1523                let x = (i * 2) as f32;
1524                AabbGpu::new([x, 0.0, 0.0], [x + 1.5, 1.0, 1.0], i)
1525            })
1526            .collect();
1527        let indices: Vec<usize> = (0..8).collect();
1528        let result = sah_best_split(&aabbs, &indices, 0, 8);
1529        assert!(result.is_some(), "SAH split should find a valid split");
1530    }
1531
1532    #[test]
1533    fn test_sah_best_split_single_element() {
1534        let aabbs = vec![AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0)];
1535        let result = sah_best_split(&aabbs, &[0], 0, 4);
1536        assert!(result.is_none(), "Single element: no split possible");
1537    }
1538
1539    // ── LBVH build tests ─────────────────────────────────────────────────────
1540
1541    #[test]
1542    fn test_build_lbvh_nonempty() {
1543        let aabbs: Vec<AabbGpu> = (0..6)
1544            .map(|i| {
1545                let x = (i * 2) as f32;
1546                AabbGpu::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0], i)
1547            })
1548            .collect();
1549        let nodes = build_lbvh(&aabbs, 1.0, [0.0, 0.0, 0.0]);
1550        assert!(!nodes.is_empty(), "LBVH should build non-empty node list");
1551        assert!(nodes.len() <= 2 * aabbs.len());
1552    }
1553
1554    #[test]
1555    fn test_build_lbvh_empty() {
1556        let nodes = build_lbvh(&[], 1.0, [0.0, 0.0, 0.0]);
1557        assert!(nodes.is_empty());
1558    }
1559
1560    #[test]
1561    fn test_build_lbvh_single() {
1562        let aabbs = vec![AabbGpu::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0], 0)];
1563        let nodes = build_lbvh(&aabbs, 1.0, [0.0, 0.0, 0.0]);
1564        assert_eq!(nodes.len(), 1);
1565        assert!(nodes[0].is_leaf());
1566    }
1567
1568    #[test]
1569    fn test_lbvh_vs_brute_force_pairs() {
1570        // LBVH traversal should find same pairs as brute-force overlap check
1571        let aabbs: Vec<AabbGpu> = (0..6)
1572            .map(|i| {
1573                let x = (i as f32) * 0.8;
1574                AabbGpu::new([x, 0.0, 0.0], [x + 1.0, 1.0, 1.0], i as u32)
1575            })
1576            .collect();
1577
1578        let lbvh_nodes = build_lbvh(&aabbs, 0.5, [0.0, 0.0, 0.0]);
1579        let mut lbvh_pairs = lbvh_query_pairs(&lbvh_nodes);
1580        lbvh_pairs.sort();
1581        lbvh_pairs.dedup();
1582
1583        let mut brute: Vec<(u32, u32)> = Vec::new();
1584        for i in 0..aabbs.len() {
1585            for j in (i + 1)..aabbs.len() {
1586                if aabbs[i].overlaps(&aabbs[j]) {
1587                    brute.push((aabbs[i].body_id, aabbs[j].body_id));
1588                }
1589            }
1590        }
1591        brute.sort();
1592
1593        assert_eq!(
1594            lbvh_pairs, brute,
1595            "LBVH pairs {lbvh_pairs:?} != brute force {brute:?}"
1596        );
1597    }
1598}