Skip to main content

oxiphysics_gpu/
gpu_collision_detection.rs

1// Copyright 2026 COOLJAPAN OU (Team KitaSan)
2// SPDX-License-Identifier: Apache-2.0
3
4//! GPU-accelerated collision detection (CPU mock backend via Rayon).
5//!
6//! Implements a broadphase AABB sweep-and-prune followed by a narrowphase
7//! sphere-sphere test and a stub GJK dispatcher.  The "GPU" dispatch is mocked
8//! using Rayon parallel iterators so the module compiles and runs on CPU.
9
10use rayon::prelude::*;
11
12// ---------------------------------------------------------------------------
13// Aabb
14// ---------------------------------------------------------------------------
15
16/// Axis-aligned bounding box.
17#[derive(Debug, Clone, Copy, PartialEq)]
18pub struct Aabb {
19    /// Minimum corner (x, y, z).
20    pub min: [f64; 3],
21    /// Maximum corner (x, y, z).
22    pub max: [f64; 3],
23}
24
25impl Aabb {
26    /// Create a new AABB from min/max corners.
27    pub fn new(min: [f64; 3], max: [f64; 3]) -> Self {
28        Self { min, max }
29    }
30
31    /// Create an AABB centred at `centre` with half-extents `half`.
32    pub fn from_center_half(centre: [f64; 3], half: [f64; 3]) -> Self {
33        Self {
34            min: [
35                centre[0] - half[0],
36                centre[1] - half[1],
37                centre[2] - half[2],
38            ],
39            max: [
40                centre[0] + half[0],
41                centre[1] + half[1],
42                centre[2] + half[2],
43            ],
44        }
45    }
46
47    /// Centre of the AABB.
48    pub fn centre(&self) -> [f64; 3] {
49        [
50            (self.min[0] + self.max[0]) * 0.5,
51            (self.min[1] + self.max[1]) * 0.5,
52            (self.min[2] + self.max[2]) * 0.5,
53        ]
54    }
55
56    /// Returns `true` if this AABB overlaps `other`.
57    pub fn overlaps(&self, other: &Aabb) -> bool {
58        self.min[0] <= other.max[0]
59            && self.max[0] >= other.min[0]
60            && self.min[1] <= other.max[1]
61            && self.max[1] >= other.min[1]
62            && self.min[2] <= other.max[2]
63            && self.max[2] >= other.min[2]
64    }
65}
66
67// ---------------------------------------------------------------------------
68// GpuCollisionBuffer
69// ---------------------------------------------------------------------------
70
71/// Buffer holding the AABB list and broadphase potential pairs.
72#[derive(Debug, Clone)]
73pub struct GpuCollisionBuffer {
74    /// Axis-aligned bounding boxes for each object.
75    pub aabbs: Vec<Aabb>,
76    /// Potential collision pairs `(i, j)` with `i < j` from the broadphase.
77    pub potential_pairs: Vec<(usize, usize)>,
78}
79
80impl GpuCollisionBuffer {
81    /// Create an empty collision buffer.
82    pub fn new() -> Self {
83        Self {
84            aabbs: Vec::new(),
85            potential_pairs: Vec::new(),
86        }
87    }
88
89    /// Add an AABB to the buffer, returning its index.
90    pub fn add_aabb(&mut self, aabb: Aabb) -> usize {
91        let idx = self.aabbs.len();
92        self.aabbs.push(aabb);
93        idx
94    }
95
96    /// Number of objects in the buffer.
97    pub fn n_objects(&self) -> usize {
98        self.aabbs.len()
99    }
100}
101
102impl Default for GpuCollisionBuffer {
103    fn default() -> Self {
104        Self::new()
105    }
106}
107
108// ---------------------------------------------------------------------------
109// gpu_broadphase_sort
110// ---------------------------------------------------------------------------
111
112/// Sort AABBs by their minimum x-coordinate (mock GPU sort).
113///
114/// Returns a permutation index array such that `perm[0]` is the index of the
115/// AABB with the smallest `min.x`.  The buffer itself is not reordered.
116///
117/// # Arguments
118/// * `buffer` - The collision buffer to sort.
119pub fn gpu_broadphase_sort(buffer: &GpuCollisionBuffer) -> Vec<usize> {
120    let mut indices: Vec<usize> = (0..buffer.aabbs.len()).collect();
121    indices.sort_unstable_by(|&a, &b| {
122        buffer.aabbs[a].min[0]
123            .partial_cmp(&buffer.aabbs[b].min[0])
124            .unwrap_or(std::cmp::Ordering::Equal)
125    });
126    indices
127}
128
129// ---------------------------------------------------------------------------
130// gpu_aabb_overlap
131// ---------------------------------------------------------------------------
132
133/// AABB-AABB overlap test kernel (mock GPU dispatch).
134///
135/// Tests all pairs `(i, j)` with `i < j` in parallel and returns those that
136/// overlap.
137///
138/// # Arguments
139/// * `buffer` - Collision buffer with the AABB list.
140pub fn gpu_aabb_overlap(buffer: &GpuCollisionBuffer) -> Vec<(usize, usize)> {
141    let n = buffer.aabbs.len();
142    if n < 2 {
143        return Vec::new();
144    }
145
146    // Generate all pairs and test in parallel
147    let pairs: Vec<(usize, usize)> = (0..n)
148        .into_par_iter()
149        .flat_map(|i| {
150            let aabb_i = &buffer.aabbs[i];
151            (i + 1..n)
152                .filter(|&j| aabb_i.overlaps(&buffer.aabbs[j]))
153                .map(|j| (i, j))
154                .collect::<Vec<_>>()
155        })
156        .collect();
157    pairs
158}
159
160// ---------------------------------------------------------------------------
161// gpu_sphere_collision
162// ---------------------------------------------------------------------------
163
164/// Sphere-sphere narrowphase collision result.
165#[derive(Debug, Clone, Copy)]
166pub struct SphereContact {
167    /// Index of the first sphere.
168    pub i: usize,
169    /// Index of the second sphere.
170    pub j: usize,
171    /// Penetration depth (positive = overlapping).
172    pub depth: f64,
173    /// Contact normal pointing from sphere `i` to sphere `j`.
174    pub normal: [f64; 3],
175}
176
177/// Sphere descriptor: centre + radius.
178#[derive(Debug, Clone, Copy)]
179pub struct Sphere {
180    /// Centre position (x, y, z).
181    pub centre: [f64; 3],
182    /// Radius.
183    pub radius: f64,
184}
185
186impl Sphere {
187    /// Create a sphere at `centre` with the given `radius`.
188    pub fn new(centre: [f64; 3], radius: f64) -> Self {
189        Self { centre, radius }
190    }
191}
192
193/// Test sphere-sphere collisions for the given pairs (mock GPU narrowphase).
194///
195/// For each pair `(i, j)` in `pairs`, tests whether spheres `i` and `j`
196/// overlap and, if so, computes a [`SphereContact`].
197///
198/// # Arguments
199/// * `spheres` - Sphere descriptors, indexed by object id.
200/// * `pairs` - Candidate pairs from the broadphase.
201pub fn gpu_sphere_collision(spheres: &[Sphere], pairs: &[(usize, usize)]) -> Vec<SphereContact> {
202    pairs
203        .par_iter()
204        .filter_map(|&(i, j)| {
205            if i >= spheres.len() || j >= spheres.len() {
206                return None;
207            }
208            let a = &spheres[i];
209            let b = &spheres[j];
210            let dx = b.centre[0] - a.centre[0];
211            let dy = b.centre[1] - a.centre[1];
212            let dz = b.centre[2] - a.centre[2];
213            let dist_sq = dx * dx + dy * dy + dz * dz;
214            let r_sum = a.radius + b.radius;
215            if dist_sq >= r_sum * r_sum {
216                return None;
217            }
218            let dist = dist_sq.sqrt().max(1e-15);
219            let depth = r_sum - dist;
220            let normal = [dx / dist, dy / dist, dz / dist];
221            Some(SphereContact {
222                i,
223                j,
224                depth,
225                normal,
226            })
227        })
228        .collect()
229}
230
231// ---------------------------------------------------------------------------
232// build_collision_pairs
233// ---------------------------------------------------------------------------
234
235/// Build the full broadphase collision pair list from a sorted AABB list.
236///
237/// Uses a sweep-and-prune approach along the x-axis: after sorting by `min.x`,
238/// only pairs where the x-extents overlap need to be checked further.
239///
240/// Updates `buffer.potential_pairs` in place and returns the count of pairs.
241///
242/// # Arguments
243/// * `buffer` - Collision buffer (sorted order applied internally).
244pub fn build_collision_pairs(buffer: &mut GpuCollisionBuffer) -> usize {
245    let sorted = gpu_broadphase_sort(buffer);
246    let mut pairs = Vec::new();
247
248    for (si, &i) in sorted.iter().enumerate() {
249        let aabb_i = &buffer.aabbs[i];
250        // Only check objects whose min.x <= max.x of i (sweep prune)
251        for &j in sorted[si + 1..].iter() {
252            let aabb_j = &buffer.aabbs[j];
253            if aabb_j.min[0] > aabb_i.max[0] {
254                break;
255            }
256            if aabb_i.overlaps(aabb_j) {
257                let pair = if i < j { (i, j) } else { (j, i) };
258                pairs.push(pair);
259            }
260        }
261    }
262
263    pairs.sort_unstable();
264    pairs.dedup();
265    buffer.potential_pairs = pairs;
266    buffer.potential_pairs.len()
267}
268
269// ---------------------------------------------------------------------------
270// GjkResult
271// ---------------------------------------------------------------------------
272
273/// Stub result from the GJK narrowphase.
274#[derive(Debug, Clone, Copy, PartialEq)]
275pub enum GjkResult {
276    /// The two shapes are separated; `distance` is the gap.
277    Separated {
278        /// Minimum distance between the two shapes.
279        distance: f64,
280    },
281    /// The two shapes intersect.
282    Intersecting,
283}
284
285/// Dispatch GJK for a batch of collision pairs (stub implementation).
286///
287/// In a real GPU pipeline, this would launch one thread per pair.  Here we
288/// use a heuristic based on AABB centre distance as a placeholder.
289///
290/// # Arguments
291/// * `buffer` - Collision buffer with the AABB list and pair list.
292pub fn batch_gjk_dispatch(buffer: &GpuCollisionBuffer) -> Vec<GjkResult> {
293    buffer
294        .potential_pairs
295        .par_iter()
296        .map(|&(i, j)| {
297            let ci = buffer.aabbs[i].centre();
298            let cj = buffer.aabbs[j].centre();
299            let dx = cj[0] - ci[0];
300            let dy = cj[1] - ci[1];
301            let dz = cj[2] - ci[2];
302            let dist = (dx * dx + dy * dy + dz * dz).sqrt();
303            // Approximate "radius" as half the diagonal of the AABB
304            let diag_i = {
305                let d = [
306                    buffer.aabbs[i].max[0] - buffer.aabbs[i].min[0],
307                    buffer.aabbs[i].max[1] - buffer.aabbs[i].min[1],
308                    buffer.aabbs[i].max[2] - buffer.aabbs[i].min[2],
309                ];
310                (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt() * 0.5
311            };
312            let diag_j = {
313                let d = [
314                    buffer.aabbs[j].max[0] - buffer.aabbs[j].min[0],
315                    buffer.aabbs[j].max[1] - buffer.aabbs[j].min[1],
316                    buffer.aabbs[j].max[2] - buffer.aabbs[j].min[2],
317                ];
318                (d[0] * d[0] + d[1] * d[1] + d[2] * d[2]).sqrt() * 0.5
319            };
320            if dist < diag_i + diag_j {
321                GjkResult::Intersecting
322            } else {
323                GjkResult::Separated {
324                    distance: dist - diag_i - diag_j,
325                }
326            }
327        })
328        .collect()
329}
330
331// ---------------------------------------------------------------------------
332// Tests
333// ---------------------------------------------------------------------------
334
335#[cfg(test)]
336mod tests {
337    use super::*;
338
339    // ── Aabb ──────────────────────────────────────────────────────────────
340
341    #[test]
342    fn test_aabb_new() {
343        let a = Aabb::new([0.0, 0.0, 0.0], [1.0, 1.0, 1.0]);
344        assert_eq!(a.min, [0.0; 3]);
345        assert_eq!(a.max, [1.0; 3]);
346    }
347
348    #[test]
349    fn test_aabb_from_center_half() {
350        let a = Aabb::from_center_half([1.0, 2.0, 3.0], [0.5, 0.5, 0.5]);
351        assert!((a.min[0] - 0.5).abs() < 1e-12);
352        assert!((a.max[0] - 1.5).abs() < 1e-12);
353        assert!((a.min[1] - 1.5).abs() < 1e-12);
354        assert!((a.max[1] - 2.5).abs() < 1e-12);
355    }
356
357    #[test]
358    fn test_aabb_centre() {
359        let a = Aabb::new([0.0, 0.0, 0.0], [2.0, 4.0, 6.0]);
360        let c = a.centre();
361        assert!((c[0] - 1.0).abs() < 1e-12);
362        assert!((c[1] - 2.0).abs() < 1e-12);
363        assert!((c[2] - 3.0).abs() < 1e-12);
364    }
365
366    #[test]
367    fn test_aabb_overlaps_true() {
368        let a = Aabb::new([0.0; 3], [1.0; 3]);
369        let b = Aabb::new([0.5; 3], [1.5; 3]);
370        assert!(a.overlaps(&b));
371    }
372
373    #[test]
374    fn test_aabb_overlaps_false_separated_x() {
375        let a = Aabb::new([0.0; 3], [1.0; 3]);
376        let b = Aabb::new([2.0, 0.0, 0.0], [3.0, 1.0, 1.0]);
377        assert!(!a.overlaps(&b));
378    }
379
380    #[test]
381    fn test_aabb_overlaps_touching_edge() {
382        // Touching at x=1.0 counts as overlap (<=)
383        let a = Aabb::new([0.0; 3], [1.0; 3]);
384        let b = Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]);
385        assert!(a.overlaps(&b));
386    }
387
388    #[test]
389    fn test_aabb_no_overlap_y() {
390        let a = Aabb::new([0.0; 3], [1.0; 3]);
391        let b = Aabb::new([0.0, 2.0, 0.0], [1.0, 3.0, 1.0]);
392        assert!(!a.overlaps(&b));
393    }
394
395    // ── GpuCollisionBuffer ────────────────────────────────────────────────
396
397    #[test]
398    fn test_buffer_new_empty() {
399        let buf = GpuCollisionBuffer::new();
400        assert_eq!(buf.n_objects(), 0);
401        assert!(buf.potential_pairs.is_empty());
402    }
403
404    #[test]
405    fn test_buffer_add_aabb() {
406        let mut buf = GpuCollisionBuffer::new();
407        let idx = buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
408        assert_eq!(idx, 0);
409        assert_eq!(buf.n_objects(), 1);
410    }
411
412    #[test]
413    fn test_buffer_default() {
414        let buf = GpuCollisionBuffer::default();
415        assert_eq!(buf.n_objects(), 0);
416    }
417
418    // ── gpu_broadphase_sort ───────────────────────────────────────────────
419
420    #[test]
421    fn test_broadphase_sort_order() {
422        let mut buf = GpuCollisionBuffer::new();
423        buf.add_aabb(Aabb::new([3.0, 0.0, 0.0], [4.0, 1.0, 1.0]));
424        buf.add_aabb(Aabb::new([1.0, 0.0, 0.0], [2.0, 1.0, 1.0]));
425        buf.add_aabb(Aabb::new([0.0, 0.0, 0.0], [0.5, 1.0, 1.0]));
426        let perm = gpu_broadphase_sort(&buf);
427        assert_eq!(perm, vec![2, 1, 0]);
428    }
429
430    #[test]
431    fn test_broadphase_sort_empty() {
432        let buf = GpuCollisionBuffer::new();
433        let perm = gpu_broadphase_sort(&buf);
434        assert!(perm.is_empty());
435    }
436
437    // ── gpu_aabb_overlap ──────────────────────────────────────────────────
438
439    #[test]
440    fn test_aabb_overlap_two_overlapping() {
441        let mut buf = GpuCollisionBuffer::new();
442        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
443        buf.add_aabb(Aabb::new([0.5; 3], [1.5; 3]));
444        let pairs = gpu_aabb_overlap(&buf);
445        assert_eq!(pairs.len(), 1);
446        assert_eq!(pairs[0], (0, 1));
447    }
448
449    #[test]
450    fn test_aabb_overlap_two_separated() {
451        let mut buf = GpuCollisionBuffer::new();
452        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
453        buf.add_aabb(Aabb::new([2.0; 3], [3.0; 3]));
454        let pairs = gpu_aabb_overlap(&buf);
455        assert!(pairs.is_empty());
456    }
457
458    #[test]
459    fn test_aabb_overlap_three_objects() {
460        let mut buf = GpuCollisionBuffer::new();
461        buf.add_aabb(Aabb::new([0.0; 3], [2.0; 3]));
462        buf.add_aabb(Aabb::new([1.0; 3], [3.0; 3]));
463        buf.add_aabb(Aabb::new([5.0; 3], [6.0; 3]));
464        let pairs = gpu_aabb_overlap(&buf);
465        assert_eq!(pairs.len(), 1);
466        assert_eq!(pairs[0], (0, 1));
467    }
468
469    #[test]
470    fn test_aabb_overlap_single_object_no_pairs() {
471        let mut buf = GpuCollisionBuffer::new();
472        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
473        let pairs = gpu_aabb_overlap(&buf);
474        assert!(pairs.is_empty());
475    }
476
477    // ── gpu_sphere_collision ──────────────────────────────────────────────
478
479    #[test]
480    fn test_sphere_collision_overlapping() {
481        let spheres = vec![
482            Sphere::new([0.0, 0.0, 0.0], 1.0),
483            Sphere::new([1.5, 0.0, 0.0], 1.0),
484        ];
485        let pairs = vec![(0usize, 1usize)];
486        let contacts = gpu_sphere_collision(&spheres, &pairs);
487        assert_eq!(contacts.len(), 1);
488        assert!((contacts[0].depth - 0.5).abs() < 1e-10);
489    }
490
491    #[test]
492    fn test_sphere_collision_separated() {
493        let spheres = vec![
494            Sphere::new([0.0, 0.0, 0.0], 1.0),
495            Sphere::new([5.0, 0.0, 0.0], 1.0),
496        ];
497        let pairs = vec![(0usize, 1usize)];
498        let contacts = gpu_sphere_collision(&spheres, &pairs);
499        assert!(contacts.is_empty());
500    }
501
502    #[test]
503    fn test_sphere_collision_normal_direction() {
504        let spheres = vec![
505            Sphere::new([0.0, 0.0, 0.0], 1.0),
506            Sphere::new([1.0, 0.0, 0.0], 1.0),
507        ];
508        let pairs = vec![(0usize, 1usize)];
509        let contacts = gpu_sphere_collision(&spheres, &pairs);
510        assert_eq!(contacts.len(), 1);
511        assert!((contacts[0].normal[0] - 1.0).abs() < 1e-10);
512        assert!(contacts[0].normal[1].abs() < 1e-10);
513        assert!(contacts[0].normal[2].abs() < 1e-10);
514    }
515
516    #[test]
517    fn test_sphere_collision_out_of_bounds_index() {
518        let spheres = vec![Sphere::new([0.0; 3], 1.0)];
519        let pairs = vec![(0usize, 99usize)];
520        let contacts = gpu_sphere_collision(&spheres, &pairs);
521        assert!(contacts.is_empty());
522    }
523
524    // ── build_collision_pairs ─────────────────────────────────────────────
525
526    #[test]
527    fn test_build_pairs_two_overlapping() {
528        let mut buf = GpuCollisionBuffer::new();
529        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
530        buf.add_aabb(Aabb::new([0.5; 3], [1.5; 3]));
531        let n = build_collision_pairs(&mut buf);
532        assert_eq!(n, 1);
533        assert_eq!(buf.potential_pairs[0], (0, 1));
534    }
535
536    #[test]
537    fn test_build_pairs_no_overlap() {
538        let mut buf = GpuCollisionBuffer::new();
539        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
540        buf.add_aabb(Aabb::new([10.0; 3], [11.0; 3]));
541        let n = build_collision_pairs(&mut buf);
542        assert_eq!(n, 0);
543    }
544
545    #[test]
546    fn test_build_pairs_multiple_objects() {
547        let mut buf = GpuCollisionBuffer::new();
548        // All three overlap with each other
549        buf.add_aabb(Aabb::new([0.0; 3], [2.0; 3]));
550        buf.add_aabb(Aabb::new([1.0; 3], [3.0; 3]));
551        buf.add_aabb(Aabb::new([1.5; 3], [3.5; 3]));
552        let n = build_collision_pairs(&mut buf);
553        assert!(n >= 2); // At minimum (0,1) and (1,2) should be found
554    }
555
556    // ── batch_gjk_dispatch ─────────────────────────────────────────────────
557
558    #[test]
559    fn test_gjk_dispatch_intersecting_pair() {
560        let mut buf = GpuCollisionBuffer::new();
561        buf.add_aabb(Aabb::new([0.0; 3], [1.0; 3]));
562        buf.add_aabb(Aabb::new([0.5; 3], [1.5; 3]));
563        build_collision_pairs(&mut buf);
564        let results = batch_gjk_dispatch(&buf);
565        assert!(!results.is_empty());
566        assert_eq!(results[0], GjkResult::Intersecting);
567    }
568
569    #[test]
570    fn test_gjk_dispatch_empty_pairs() {
571        let buf = GpuCollisionBuffer::new();
572        let results = batch_gjk_dispatch(&buf);
573        assert!(results.is_empty());
574    }
575
576    #[test]
577    fn test_gjk_dispatch_separated_pair() {
578        let mut buf = GpuCollisionBuffer::new();
579        buf.add_aabb(Aabb::new([0.0; 3], [0.1; 3]));
580        buf.add_aabb(Aabb::new([10.0; 3], [10.1; 3]));
581        // Manually add a pair (broadphase won't find this since they don't overlap)
582        buf.potential_pairs = vec![(0, 1)];
583        let results = batch_gjk_dispatch(&buf);
584        assert_eq!(results.len(), 1);
585        matches!(results[0], GjkResult::Separated { .. });
586    }
587}