Skip to main content

oxiphysics_gpu/
gpu_rigid.rs

1#![allow(clippy::needless_range_loop)]
2// Copyright 2026 COOLJAPAN OU (Team KitaSan)
3// SPDX-License-Identifier: Apache-2.0
4
5//! GPU-accelerated rigid body batch simulation (CPU mock).
6//!
7//! Provides structs and functions for simulating many rigid bodies in batch,
8//! broadphase collision detection (SAP), and constraint solving via sequential
9//! impulse. All computation is done on the CPU as a reference implementation.
10
11// ── Quaternion helpers ────────────────────────────────────────────────────────
12
13/// Rotate vector `v` by unit quaternion `q` (format: \[x, y, z, w\]).
14///
15/// Uses the sandwich product `q * v * q⁻¹` via the Rodrigues formula.
16pub fn quat_rotate(q: [f32; 4], v: [f32; 3]) -> [f32; 3] {
17    let (qx, qy, qz, qw) = (q[0], q[1], q[2], q[3]);
18    // t = 2 * cross(q.xyz, v)
19    let tx = 2.0 * (qy * v[2] - qz * v[1]);
20    let ty = 2.0 * (qz * v[0] - qx * v[2]);
21    let tz = 2.0 * (qx * v[1] - qy * v[0]);
22    // result = v + qw * t + cross(q.xyz, t)
23    [
24        v[0] + qw * tx + (qy * tz - qz * ty),
25        v[1] + qw * ty + (qz * tx - qx * tz),
26        v[2] + qw * tz + (qx * ty - qy * tx),
27    ]
28}
29
30/// Multiply two unit quaternions `a` and `b` (Hamilton product).
31///
32/// Format for both operands and the result: \[x, y, z, w\].
33pub fn quat_mul(a: [f32; 4], b: [f32; 4]) -> [f32; 4] {
34    let (ax, ay, az, aw) = (a[0], a[1], a[2], a[3]);
35    let (bx, by, bz, bw) = (b[0], b[1], b[2], b[3]);
36    [
37        aw * bx + ax * bw + ay * bz - az * by,
38        aw * by - ax * bz + ay * bw + az * bx,
39        aw * bz + ax * by - ay * bx + az * bw,
40        aw * bw - ax * bx - ay * by - az * bz,
41    ]
42}
43
44/// Normalise a quaternion to unit length.
45///
46/// Returns the identity quaternion `[0,0,0,1]` if the norm is nearly zero.
47pub fn quat_normalize(q: [f32; 4]) -> [f32; 4] {
48    let norm = (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt();
49    if norm < 1e-9 {
50        return [0.0, 0.0, 0.0, 1.0];
51    }
52    [q[0] / norm, q[1] / norm, q[2] / norm, q[3] / norm]
53}
54
55/// Integrate a quaternion orientation `q` by angular velocity `omega` (rad/s)
56/// over time step `dt` (seconds).
57///
58/// Uses the first-order approximation `q' = normalize(q + 0.5 * dt * Ω⊗q)`.
59pub fn integrate_orientation(q: [f32; 4], omega: [f32; 3], dt: f32) -> [f32; 4] {
60    // omega as a pure quaternion [ox, oy, oz, 0]
61    let omega_q = [omega[0], omega[1], omega[2], 0.0_f32];
62    let dq = quat_mul(omega_q, q);
63    let q_new = [
64        q[0] + 0.5 * dt * dq[0],
65        q[1] + 0.5 * dt * dq[1],
66        q[2] + 0.5 * dt * dq[2],
67        q[3] + 0.5 * dt * dq[3],
68    ];
69    quat_normalize(q_new)
70}
71
72// ── Core rigid body ───────────────────────────────────────────────────────────
73
74/// A single rigid body stored in GPU-friendly packed arrays of `f32`.
75///
76/// Orientation is stored as a unit quaternion `[x, y, z, w]`.
77/// The inverse inertia tensor is stored in row-major order as a 3×3 matrix.
78#[derive(Debug, Clone)]
79pub struct GpuRigidBody {
80    /// World-space position (metres).
81    pub position: [f32; 3],
82    /// Linear velocity (m/s).
83    pub velocity: [f32; 3],
84    /// Orientation quaternion `[x, y, z, w]`.
85    pub orientation: [f32; 4],
86    /// Angular velocity in world space (rad/s).
87    pub angular_velocity: [f32; 3],
88    /// Mass in kilograms.
89    pub mass: f32,
90    /// Inverse of the body-space inertia tensor, row-major (3×3).
91    pub inv_inertia: [f32; 9],
92}
93
94impl GpuRigidBody {
95    /// Create a new rigid body at rest with the given mass and diagonal inertia.
96    ///
97    /// `ixx`, `iyy`, `izz` are the principal moments of inertia.
98    pub fn new(mass: f32, ixx: f32, iyy: f32, izz: f32) -> Self {
99        let safe_inv = |v: f32| if v.abs() > 1e-12 { 1.0 / v } else { 0.0 };
100        let inv_inertia = [
101            safe_inv(ixx),
102            0.0,
103            0.0,
104            0.0,
105            safe_inv(iyy),
106            0.0,
107            0.0,
108            0.0,
109            safe_inv(izz),
110        ];
111        Self {
112            position: [0.0; 3],
113            velocity: [0.0; 3],
114            orientation: [0.0, 0.0, 0.0, 1.0],
115            angular_velocity: [0.0; 3],
116            mass,
117            inv_inertia,
118        }
119    }
120
121    /// Apply an inverse inertia tensor (row-major 3×3) to a vector.
122    #[allow(dead_code)]
123    fn apply_inv_inertia(&self, v: [f32; 3]) -> [f32; 3] {
124        let i = &self.inv_inertia;
125        [
126            i[0] * v[0] + i[1] * v[1] + i[2] * v[2],
127            i[3] * v[0] + i[4] * v[1] + i[5] * v[2],
128            i[6] * v[0] + i[7] * v[1] + i[8] * v[2],
129        ]
130    }
131}
132
133// ── Batch integration ─────────────────────────────────────────────────────────
134
135/// A batch of GPU rigid bodies supporting bulk integration.
136#[derive(Debug, Clone, Default)]
137pub struct GpuRigidBodyBatch {
138    /// The rigid bodies in this batch.
139    pub bodies: Vec<GpuRigidBody>,
140}
141
142impl GpuRigidBodyBatch {
143    /// Create an empty batch.
144    pub fn new() -> Self {
145        Self::default()
146    }
147
148    /// Add a body to the batch and return its index.
149    pub fn add(&mut self, body: GpuRigidBody) -> usize {
150        let idx = self.bodies.len();
151        self.bodies.push(body);
152        idx
153    }
154
155    /// Integrate all bodies by `dt` seconds under a uniform `gravity` (m/s²).
156    ///
157    /// Uses explicit Euler integration for linear dynamics and a first-order
158    /// quaternion update for rotational dynamics.
159    pub fn integrate_all(&mut self, dt: f32, gravity: [f32; 3]) {
160        for body in &mut self.bodies {
161            // linear: v += g*dt,  p += v*dt
162            body.velocity[0] += gravity[0] * dt;
163            body.velocity[1] += gravity[1] * dt;
164            body.velocity[2] += gravity[2] * dt;
165            body.position[0] += body.velocity[0] * dt;
166            body.position[1] += body.velocity[1] * dt;
167            body.position[2] += body.velocity[2] * dt;
168            // rotational
169            let new_q = integrate_orientation(body.orientation, body.angular_velocity, dt);
170            body.orientation = new_q;
171        }
172    }
173
174    /// Apply a linear impulse `impulse` (N·s) at world-space `point` to body
175    /// at `body_idx`.
176    ///
177    /// Updates both linear and angular velocity.
178    pub fn apply_impulse(&mut self, body_idx: usize, impulse: [f32; 3], point: [f32; 3]) {
179        let body = &mut self.bodies[body_idx];
180        let inv_mass = if body.mass > 1e-12 {
181            1.0 / body.mass
182        } else {
183            0.0
184        };
185        // linear impulse
186        body.velocity[0] += impulse[0] * inv_mass;
187        body.velocity[1] += impulse[1] * inv_mass;
188        body.velocity[2] += impulse[2] * inv_mass;
189        // torque arm: r = point - position
190        let r = [
191            point[0] - body.position[0],
192            point[1] - body.position[1],
193            point[2] - body.position[2],
194        ];
195        // angular impulse: omega += I⁻¹ * (r × impulse)
196        let torque_impulse = cross3f(r, impulse);
197        let delta_omega = apply_mat3(body.inv_inertia, torque_impulse);
198        body.angular_velocity[0] += delta_omega[0];
199        body.angular_velocity[1] += delta_omega[1];
200        body.angular_velocity[2] += delta_omega[2];
201    }
202}
203
204// ── Broadphase SAP ────────────────────────────────────────────────────────────
205
206/// A candidate collision pair from the broadphase, with AABB information.
207#[derive(Debug, Clone)]
208pub struct BroadphasePairGpu {
209    /// Index of the first body.
210    pub body_a: usize,
211    /// Index of the second body.
212    pub body_b: usize,
213    /// AABB centre of body A.
214    pub aabb_a_center: [f32; 3],
215    /// AABB half-extents of body A.
216    pub aabb_a_half: [f32; 3],
217    /// AABB centre of body B.
218    pub aabb_b_center: [f32; 3],
219    /// AABB half-extents of body B.
220    pub aabb_b_half: [f32; 3],
221}
222
223/// Broadphase collision detection for GPU rigid bodies using a sweep-and-prune
224/// (SAP) approach.
225///
226/// Each body is approximated by a sphere; overlapping sphere AABBs are reported
227/// as candidate pairs.
228#[derive(Debug, Clone, Default)]
229pub struct GpuBroadphase {
230    /// Bodies to test.
231    pub bodies: Vec<GpuRigidBody>,
232    /// Bounding sphere radii, one per body.
233    pub radii: Vec<f32>,
234}
235
236impl GpuBroadphase {
237    /// Create an empty broadphase structure.
238    pub fn new() -> Self {
239        Self::default()
240    }
241
242    /// Add a body with its bounding sphere radius.
243    pub fn add_body(&mut self, body: GpuRigidBody, radius: f32) {
244        self.bodies.push(body);
245        self.radii.push(radius);
246    }
247
248    /// Run a sort-and-sweep on the X axis and return overlapping pairs.
249    ///
250    /// Only pairs whose AABBs overlap on all three axes are returned.
251    pub fn compute_pairs_sap(&self) -> Vec<BroadphasePairGpu> {
252        let n = self.bodies.len();
253        let mut pairs = Vec::new();
254
255        // Sort by AABB min-x
256        let mut order: Vec<usize> = (0..n).collect();
257        order.sort_by(|&a, &b| {
258            let ax = self.bodies[a].position[0] - self.radii[a];
259            let bx = self.bodies[b].position[0] - self.radii[b];
260            ax.partial_cmp(&bx).unwrap_or(std::cmp::Ordering::Equal)
261        });
262
263        for i in 0..order.len() {
264            let ai = order[i];
265            let a_max_x = self.bodies[ai].position[0] + self.radii[ai];
266            for j in (i + 1)..order.len() {
267                let bi = order[j];
268                let b_min_x = self.bodies[bi].position[0] - self.radii[bi];
269                if b_min_x > a_max_x {
270                    break; // sorted: no further pair can overlap on X
271                }
272                // Check all three axes
273                if self.aabb_overlap(ai, bi) {
274                    let ra = self.radii[ai];
275                    let rb = self.radii[bi];
276                    pairs.push(BroadphasePairGpu {
277                        body_a: ai,
278                        body_b: bi,
279                        aabb_a_center: self.bodies[ai].position,
280                        aabb_a_half: [ra, ra, ra],
281                        aabb_b_center: self.bodies[bi].position,
282                        aabb_b_half: [rb, rb, rb],
283                    });
284                }
285            }
286        }
287        pairs
288    }
289
290    /// Test whether two bodies' sphere AABBs overlap on all three axes.
291    fn aabb_overlap(&self, a: usize, b: usize) -> bool {
292        for k in 0..3 {
293            let a_min = self.bodies[a].position[k] - self.radii[a];
294            let a_max = self.bodies[a].position[k] + self.radii[a];
295            let b_min = self.bodies[b].position[k] - self.radii[b];
296            let b_max = self.bodies[b].position[k] + self.radii[b];
297            if a_max < b_min || b_max < a_min {
298                return false;
299            }
300        }
301        true
302    }
303}
304
305// ── Contact manifold ──────────────────────────────────────────────────────────
306
307/// Contact manifold between two bodies produced by the narrowphase.
308#[derive(Debug, Clone)]
309pub struct ContactManifoldGpu {
310    /// Index of the first body.
311    pub body_a: usize,
312    /// Index of the second body.
313    pub body_b: usize,
314    /// Contact point positions in world space.
315    pub contact_points: Vec<[f32; 3]>,
316    /// Outward contact normals (pointing from B to A), one per contact point.
317    pub normals: Vec<[f32; 3]>,
318    /// Penetration depths (positive means overlapping), one per contact point.
319    pub penetrations: Vec<f32>,
320}
321
322impl ContactManifoldGpu {
323    /// Create a new, empty contact manifold between two bodies.
324    pub fn new(body_a: usize, body_b: usize) -> Self {
325        Self {
326            body_a,
327            body_b,
328            contact_points: Vec::new(),
329            normals: Vec::new(),
330            penetrations: Vec::new(),
331        }
332    }
333
334    /// Add a single contact point to the manifold.
335    pub fn add_contact(&mut self, point: [f32; 3], normal: [f32; 3], penetration: f32) {
336        self.contact_points.push(point);
337        self.normals.push(normal);
338        self.penetrations.push(penetration);
339    }
340
341    /// Number of contact points in the manifold.
342    pub fn contact_count(&self) -> usize {
343        self.contact_points.len()
344    }
345}
346
347// ── Sequential impulse solver ─────────────────────────────────────────────────
348
349/// Iterative sequential-impulse constraint solver for rigid body contacts.
350#[derive(Debug, Clone, Default)]
351pub struct GpuConstraintSolver {
352    /// Broadphase candidate pairs.
353    pub pairs: Vec<BroadphasePairGpu>,
354    /// Contact manifolds, aligned with `pairs`.
355    pub manifolds: Vec<ContactManifoldGpu>,
356}
357
358impl GpuConstraintSolver {
359    /// Create a new solver with no constraints.
360    pub fn new() -> Self {
361        Self::default()
362    }
363
364    /// Add a manifold (and its corresponding broadphase pair) to the solver.
365    pub fn add_manifold(&mut self, pair: BroadphasePairGpu, manifold: ContactManifoldGpu) {
366        self.pairs.push(pair);
367        self.manifolds.push(manifold);
368    }
369
370    /// Solve all contact constraints via sequential impulses.
371    ///
372    /// Iterates `iterations` times over every contact point and applies a
373    /// non-penetration impulse. A restitution coefficient of 0.3 is used.
374    #[allow(clippy::too_many_arguments)]
375    pub fn solve_sequential_impulse(
376        &self,
377        bodies: &mut [GpuRigidBody],
378        dt: f32,
379        iterations: usize,
380    ) {
381        let _ = dt; // reserved for future bias/baumgarte
382        let restitution = 0.3_f32;
383        for _ in 0..iterations {
384            for manifold in &self.manifolds {
385                let a = manifold.body_a;
386                let b = manifold.body_b;
387                for c in 0..manifold.contact_count() {
388                    let n = manifold.normals[c];
389                    // relative velocity at contact
390                    let va = bodies[a].velocity;
391                    let vb = bodies[b].velocity;
392                    let rv = [va[0] - vb[0], va[1] - vb[1], va[2] - vb[2]];
393                    let vn = dot3f(rv, n);
394                    if vn >= 0.0 {
395                        continue; // separating
396                    }
397                    let inv_ma = if bodies[a].mass > 1e-12 {
398                        1.0 / bodies[a].mass
399                    } else {
400                        0.0
401                    };
402                    let inv_mb = if bodies[b].mass > 1e-12 {
403                        1.0 / bodies[b].mass
404                    } else {
405                        0.0
406                    };
407                    let j = -(1.0 + restitution) * vn / (inv_ma + inv_mb);
408                    // apply
409                    bodies[a].velocity[0] += j * inv_ma * n[0];
410                    bodies[a].velocity[1] += j * inv_ma * n[1];
411                    bodies[a].velocity[2] += j * inv_ma * n[2];
412                    bodies[b].velocity[0] -= j * inv_mb * n[0];
413                    bodies[b].velocity[1] -= j * inv_mb * n[1];
414                    bodies[b].velocity[2] -= j * inv_mb * n[2];
415                }
416            }
417        }
418    }
419}
420
421// ── Internal math helpers ─────────────────────────────────────────────────────
422
423/// Cross product of two `[f32; 3]` vectors.
424fn cross3f(a: [f32; 3], b: [f32; 3]) -> [f32; 3] {
425    [
426        a[1] * b[2] - a[2] * b[1],
427        a[2] * b[0] - a[0] * b[2],
428        a[0] * b[1] - a[1] * b[0],
429    ]
430}
431
432/// Dot product of two `[f32; 3]` vectors.
433fn dot3f(a: [f32; 3], b: [f32; 3]) -> f32 {
434    a[0] * b[0] + a[1] * b[1] + a[2] * b[2]
435}
436
437/// Apply a row-major 3×3 matrix to a vector.
438fn apply_mat3(m: [f32; 9], v: [f32; 3]) -> [f32; 3] {
439    [
440        m[0] * v[0] + m[1] * v[1] + m[2] * v[2],
441        m[3] * v[0] + m[4] * v[1] + m[5] * v[2],
442        m[6] * v[0] + m[7] * v[1] + m[8] * v[2],
443    ]
444}
445
446// ── Tests ─────────────────────────────────────────────────────────────────────
447
448#[cfg(test)]
449mod tests {
450    use super::*;
451
452    const EPS: f32 = 1e-5;
453
454    fn approx_eq3(a: [f32; 3], b: [f32; 3]) -> bool {
455        (a[0] - b[0]).abs() < EPS && (a[1] - b[1]).abs() < EPS && (a[2] - b[2]).abs() < EPS
456    }
457
458    fn approx_eq4(a: [f32; 4], b: [f32; 4]) -> bool {
459        (a[0] - b[0]).abs() < EPS
460            && (a[1] - b[1]).abs() < EPS
461            && (a[2] - b[2]).abs() < EPS
462            && (a[3] - b[3]).abs() < EPS
463    }
464
465    fn quat_norm(q: [f32; 4]) -> f32 {
466        (q[0] * q[0] + q[1] * q[1] + q[2] * q[2] + q[3] * q[3]).sqrt()
467    }
468
469    // ── quat_rotate ────────────────────────────────────────────────────────
470
471    #[test]
472    fn test_quat_rotate_identity() {
473        let q = [0.0, 0.0, 0.0, 1.0_f32];
474        let v = [1.0, 2.0, 3.0_f32];
475        let r = quat_rotate(q, v);
476        assert!(approx_eq3(r, v), "identity quat should not rotate: {r:?}");
477    }
478
479    #[test]
480    fn test_quat_rotate_180_about_z() {
481        // 180° about Z: q = [0, 0, 1, 0]
482        let q = [0.0, 0.0, 1.0_f32, 0.0];
483        let v = [1.0, 0.0, 0.0_f32];
484        let r = quat_rotate(q, v);
485        assert!(approx_eq3(r, [-1.0, 0.0, 0.0]), "180 Z rotate: {r:?}");
486    }
487
488    #[test]
489    fn test_quat_rotate_90_about_y() {
490        // 90° about Y: q = [0, sin45, 0, cos45]
491        let half = std::f32::consts::FRAC_PI_4;
492        let q = [0.0, half.sin(), 0.0, half.cos()];
493        let v = [1.0, 0.0, 0.0_f32];
494        let r = quat_rotate(q, v);
495        // Should become ~[0, 0, -1]
496        assert!((r[0]).abs() < EPS, "x should be ~0: {}", r[0]);
497        assert!((r[1]).abs() < EPS, "y should be ~0: {}", r[1]);
498        assert!((r[2] + 1.0).abs() < EPS, "z should be ~-1: {}", r[2]);
499    }
500
501    #[test]
502    fn test_quat_rotate_preserves_length() {
503        let half = std::f32::consts::FRAC_PI_6;
504        let q = quat_normalize([half.sin(), 0.0, 0.0, half.cos()]);
505        let v = [3.0, 4.0, 0.0_f32];
506        let r = quat_rotate(q, v);
507        let len_v = (v[0] * v[0] + v[1] * v[1] + v[2] * v[2]).sqrt();
508        let len_r = (r[0] * r[0] + r[1] * r[1] + r[2] * r[2]).sqrt();
509        assert!((len_v - len_r).abs() < EPS, "rotation must preserve length");
510    }
511
512    // ── quat_mul ───────────────────────────────────────────────────────────
513
514    #[test]
515    fn test_quat_mul_identity() {
516        let id = [0.0, 0.0, 0.0, 1.0_f32];
517        let q = [0.1, 0.2, 0.3_f32, 0.9];
518        let q = quat_normalize(q);
519        assert!(approx_eq4(quat_mul(id, q), q));
520        assert!(approx_eq4(quat_mul(q, id), q));
521    }
522
523    #[test]
524    fn test_quat_mul_unit_norm() {
525        let a = quat_normalize([1.0, 0.0, 0.0_f32, 1.0]);
526        let b = quat_normalize([0.0, 1.0, 0.0_f32, 1.0]);
527        let c = quat_mul(a, b);
528        assert!(
529            (quat_norm(c) - 1.0).abs() < EPS,
530            "product must be unit: {}",
531            quat_norm(c)
532        );
533    }
534
535    #[test]
536    fn test_quat_mul_double_rotation() {
537        // Two 90° rotations about Z => 180° about Z
538        let half = std::f32::consts::FRAC_PI_4;
539        let q90 = [0.0, 0.0, half.sin(), half.cos()];
540        let q180 = quat_mul(q90, q90);
541        let v = [1.0, 0.0, 0.0_f32];
542        let r = quat_rotate(q180, v);
543        assert!((r[0] + 1.0).abs() < EPS, "should be [-1,0,0]: {r:?}");
544    }
545
546    // ── integrate_orientation ──────────────────────────────────────────────
547
548    #[test]
549    fn test_integrate_orientation_no_rotation() {
550        let q = [0.0, 0.0, 0.0, 1.0_f32];
551        let q2 = integrate_orientation(q, [0.0; 3], 0.01);
552        assert!((quat_norm(q2) - 1.0).abs() < EPS);
553        assert!(approx_eq4(q2, q));
554    }
555
556    #[test]
557    fn test_integrate_orientation_stays_unit() {
558        let q = [0.0, 0.0, 0.0, 1.0_f32];
559        let omega = [0.0, 0.0, 1.0_f32]; // 1 rad/s about Z
560        let mut q_cur = q;
561        for _ in 0..100 {
562            q_cur = integrate_orientation(q_cur, omega, 0.01);
563        }
564        assert!((quat_norm(q_cur) - 1.0).abs() < 1e-4);
565    }
566
567    #[test]
568    fn test_integrate_orientation_direction() {
569        // Small rotation about X: after many steps orientation should shift
570        let q = [0.0, 0.0, 0.0, 1.0_f32];
571        let omega = [1.0, 0.0, 0.0_f32];
572        let q2 = integrate_orientation(q, omega, 0.1);
573        // x component should now be non-zero
574        assert!(
575            q2[0].abs() > 1e-4,
576            "qx should be > 0 after rotation: {}",
577            q2[0]
578        );
579    }
580
581    // ── quat_normalize ─────────────────────────────────────────────────────
582
583    #[test]
584    fn test_quat_normalize_unit() {
585        let q = [1.0_f32, 0.0, 0.0, 0.0];
586        assert!(approx_eq4(quat_normalize(q), q));
587    }
588
589    #[test]
590    fn test_quat_normalize_zero_returns_identity() {
591        let q = quat_normalize([0.0; 4]);
592        assert!(approx_eq4(q, [0.0, 0.0, 0.0, 1.0]));
593    }
594
595    #[test]
596    fn test_quat_normalize_scales() {
597        let q = [2.0_f32, 0.0, 0.0, 0.0];
598        let n = quat_normalize(q);
599        assert!((n[0] - 1.0).abs() < EPS);
600    }
601
602    // ── GpuRigidBody ──────────────────────────────────────────────────────
603
604    #[test]
605    fn test_rigid_body_new_defaults() {
606        let b = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
607        assert_eq!(b.position, [0.0; 3]);
608        assert_eq!(b.velocity, [0.0; 3]);
609        assert_eq!(b.orientation, [0.0, 0.0, 0.0, 1.0]);
610        assert_eq!(b.angular_velocity, [0.0; 3]);
611        assert!((b.mass - 1.0).abs() < EPS);
612    }
613
614    #[test]
615    fn test_rigid_body_inv_inertia_diagonal() {
616        let b = GpuRigidBody::new(1.0, 2.0, 4.0, 8.0);
617        assert!((b.inv_inertia[0] - 0.5).abs() < EPS, "ixx");
618        assert!((b.inv_inertia[4] - 0.25).abs() < EPS, "iyy");
619        assert!((b.inv_inertia[8] - 0.125).abs() < EPS, "izz");
620    }
621
622    // ── GpuRigidBodyBatch::integrate_all ──────────────────────────────────
623
624    #[test]
625    fn test_batch_gravity_integration() {
626        let mut batch = GpuRigidBodyBatch::new();
627        batch.add(GpuRigidBody::new(1.0, 1.0, 1.0, 1.0));
628        batch.integrate_all(1.0, [0.0, -9.81, 0.0]);
629        assert!((batch.bodies[0].velocity[1] + 9.81).abs() < EPS);
630        assert!((batch.bodies[0].position[1] + 9.81).abs() < EPS);
631    }
632
633    #[test]
634    fn test_batch_no_gravity_no_motion() {
635        let mut batch = GpuRigidBodyBatch::new();
636        batch.add(GpuRigidBody::new(1.0, 1.0, 1.0, 1.0));
637        batch.integrate_all(1.0, [0.0; 3]);
638        assert_eq!(batch.bodies[0].position, [0.0; 3]);
639        assert_eq!(batch.bodies[0].velocity, [0.0; 3]);
640    }
641
642    #[test]
643    fn test_batch_multiple_bodies() {
644        let mut batch = GpuRigidBodyBatch::new();
645        for _ in 0..5 {
646            batch.add(GpuRigidBody::new(1.0, 1.0, 1.0, 1.0));
647        }
648        batch.integrate_all(0.5, [0.0, -9.81, 0.0]);
649        for b in &batch.bodies {
650            assert!((b.velocity[1] + 9.81 * 0.5).abs() < EPS);
651        }
652    }
653
654    #[test]
655    fn test_batch_add_returns_index() {
656        let mut batch = GpuRigidBodyBatch::new();
657        let i0 = batch.add(GpuRigidBody::new(1.0, 1.0, 1.0, 1.0));
658        let i1 = batch.add(GpuRigidBody::new(2.0, 1.0, 1.0, 1.0));
659        assert_eq!(i0, 0);
660        assert_eq!(i1, 1);
661    }
662
663    // ── GpuRigidBodyBatch::apply_impulse ──────────────────────────────────
664
665    #[test]
666    fn test_apply_impulse_linear() {
667        let mut batch = GpuRigidBodyBatch::new();
668        batch.add(GpuRigidBody::new(2.0, 1.0, 1.0, 1.0));
669        // impulse [2,0,0] on mass 2 => dv = [1,0,0]
670        batch.apply_impulse(0, [2.0, 0.0, 0.0], [0.0; 3]);
671        assert!((batch.bodies[0].velocity[0] - 1.0).abs() < EPS);
672    }
673
674    #[test]
675    fn test_apply_impulse_angular() {
676        let mut batch = GpuRigidBodyBatch::new();
677        let mut b = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
678        b.position = [0.0; 3];
679        batch.add(b);
680        // Apply Z-axis impulse at offset on X axis → angular velocity change
681        batch.apply_impulse(0, [0.0, 1.0, 0.0], [1.0, 0.0, 0.0]);
682        assert!(batch.bodies[0].angular_velocity[2].abs() > 1e-5);
683    }
684
685    #[test]
686    fn test_apply_impulse_zero_mass() {
687        let mut batch = GpuRigidBodyBatch::new();
688        batch.add(GpuRigidBody::new(0.0, 1.0, 1.0, 1.0));
689        batch.apply_impulse(0, [100.0, 0.0, 0.0], [0.0; 3]);
690        assert_eq!(batch.bodies[0].velocity, [0.0; 3]);
691    }
692
693    // ── GpuBroadphase ─────────────────────────────────────────────────────
694
695    #[test]
696    fn test_broadphase_no_bodies() {
697        let bp = GpuBroadphase::new();
698        assert!(bp.compute_pairs_sap().is_empty());
699    }
700
701    #[test]
702    fn test_broadphase_two_overlapping() {
703        let mut bp = GpuBroadphase::new();
704        let mut b1 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
705        b1.position = [0.0; 3];
706        let mut b2 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
707        b2.position = [0.5, 0.0, 0.0];
708        bp.add_body(b1, 1.0);
709        bp.add_body(b2, 1.0);
710        let pairs = bp.compute_pairs_sap();
711        assert_eq!(pairs.len(), 1);
712    }
713
714    #[test]
715    fn test_broadphase_two_separated() {
716        let mut bp = GpuBroadphase::new();
717        let mut b1 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
718        b1.position = [0.0; 3];
719        let mut b2 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
720        b2.position = [100.0, 0.0, 0.0];
721        bp.add_body(b1, 0.5);
722        bp.add_body(b2, 0.5);
723        let pairs = bp.compute_pairs_sap();
724        assert!(pairs.is_empty());
725    }
726
727    #[test]
728    fn test_broadphase_three_bodies_two_pairs() {
729        let mut bp = GpuBroadphase::new();
730        for x in [0.0_f32, 1.0, 2.0] {
731            let mut b = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
732            b.position = [x, 0.0, 0.0];
733            bp.add_body(b, 0.8);
734        }
735        let pairs = bp.compute_pairs_sap();
736        // 0-1 and 1-2 overlap, 0-2 may or may not
737        assert!(pairs.len() >= 2, "expected >= 2 pairs, got {}", pairs.len());
738    }
739
740    #[test]
741    fn test_broadphase_pair_indices_valid() {
742        let mut bp = GpuBroadphase::new();
743        let mut b1 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
744        b1.position = [0.0; 3];
745        let mut b2 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
746        b2.position = [0.3, 0.0, 0.0];
747        bp.add_body(b1, 0.5);
748        bp.add_body(b2, 0.5);
749        let pairs = bp.compute_pairs_sap();
750        assert_eq!(pairs.len(), 1);
751        let p = &pairs[0];
752        assert!(p.body_a < 2 && p.body_b < 2 && p.body_a != p.body_b);
753    }
754
755    // ── ContactManifoldGpu ────────────────────────────────────────────────
756
757    #[test]
758    fn test_manifold_empty() {
759        let m = ContactManifoldGpu::new(0, 1);
760        assert_eq!(m.contact_count(), 0);
761    }
762
763    #[test]
764    fn test_manifold_add_contact() {
765        let mut m = ContactManifoldGpu::new(0, 1);
766        m.add_contact([0.0; 3], [0.0, 1.0, 0.0], 0.01);
767        assert_eq!(m.contact_count(), 1);
768        assert!((m.penetrations[0] - 0.01).abs() < EPS);
769    }
770
771    #[test]
772    fn test_manifold_multiple_contacts() {
773        let mut m = ContactManifoldGpu::new(0, 1);
774        for i in 0..4 {
775            m.add_contact([i as f32, 0.0, 0.0], [0.0, 1.0, 0.0], 0.01 * i as f32);
776        }
777        assert_eq!(m.contact_count(), 4);
778    }
779
780    // ── GpuConstraintSolver ───────────────────────────────────────────────
781
782    #[test]
783    fn test_solver_no_manifolds() {
784        let solver = GpuConstraintSolver::new();
785        let mut bodies = vec![GpuRigidBody::new(1.0, 1.0, 1.0, 1.0)];
786        solver.solve_sequential_impulse(&mut bodies, 0.01, 10);
787        assert_eq!(bodies[0].velocity, [0.0; 3]);
788    }
789
790    #[test]
791    fn test_solver_separates_colliding_bodies() {
792        let mut b1 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
793        b1.velocity = [1.0, 0.0, 0.0];
794        let mut b2 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
795        b2.velocity = [-1.0, 0.0, 0.0];
796        b2.position = [0.5, 0.0, 0.0];
797
798        let mut solver = GpuConstraintSolver::new();
799        let pair = BroadphasePairGpu {
800            body_a: 0,
801            body_b: 1,
802            aabb_a_center: [0.0; 3],
803            aabb_a_half: [0.5; 3],
804            aabb_b_center: [0.5; 3],
805            aabb_b_half: [0.5; 3],
806        };
807        let mut manifold = ContactManifoldGpu::new(0, 1);
808        manifold.add_contact([0.25, 0.0, 0.0], [1.0, 0.0, 0.0], 0.01);
809        solver.add_manifold(pair, manifold);
810
811        let mut bodies = vec![b1, b2];
812        solver.solve_sequential_impulse(&mut bodies, 0.01, 10);
813
814        // After impulse, relative velocity along normal should be >= 0
815        let rv_x = bodies[0].velocity[0] - bodies[1].velocity[0];
816        assert!(
817            rv_x >= -EPS,
818            "bodies should not penetrate further: rv_x={rv_x}"
819        );
820    }
821
822    #[test]
823    fn test_solver_static_body_kinematic() {
824        // body_b has zero mass => should not move
825        let mut b1 = GpuRigidBody::new(1.0, 1.0, 1.0, 1.0);
826        b1.velocity = [0.0, -1.0, 0.0];
827        let b2 = GpuRigidBody::new(0.0, 1.0, 1.0, 1.0); // static
828
829        let mut solver = GpuConstraintSolver::new();
830        let pair = BroadphasePairGpu {
831            body_a: 0,
832            body_b: 1,
833            aabb_a_center: [0.0; 3],
834            aabb_a_half: [0.5; 3],
835            aabb_b_center: [0.0, -0.9, 0.0],
836            aabb_b_half: [0.5; 3],
837        };
838        let mut manifold = ContactManifoldGpu::new(0, 1);
839        manifold.add_contact([0.0, -0.5, 0.0], [0.0, 1.0, 0.0], 0.1);
840        solver.add_manifold(pair, manifold);
841
842        let mut bodies = vec![b1, b2];
843        solver.solve_sequential_impulse(&mut bodies, 0.01, 10);
844
845        assert_eq!(bodies[1].velocity, [0.0; 3], "static body must not move");
846        assert!(
847            bodies[0].velocity[1] >= -EPS,
848            "dynamic body should bounce up"
849        );
850    }
851
852    // ── Internal helpers ──────────────────────────────────────────────────
853
854    #[test]
855    fn test_cross3f() {
856        let i = [1.0_f32, 0.0, 0.0];
857        let j = [0.0_f32, 1.0, 0.0];
858        let k = cross3f(i, j);
859        assert!(approx_eq3(k, [0.0, 0.0, 1.0]));
860    }
861
862    #[test]
863    fn test_dot3f() {
864        assert!((dot3f([1.0, 2.0, 3.0_f32], [4.0, 5.0, 6.0]) - 32.0).abs() < EPS);
865    }
866
867    #[test]
868    fn test_apply_mat3_identity() {
869        let id = [1.0_f32, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
870        let v = [3.0_f32, 4.0, 5.0];
871        assert!(approx_eq3(apply_mat3(id, v), v));
872    }
873
874    #[test]
875    fn test_apply_mat3_scale() {
876        let m = [2.0_f32, 0.0, 0.0, 0.0, 3.0, 0.0, 0.0, 0.0, 4.0];
877        let v = [1.0_f32, 1.0, 1.0];
878        assert!(approx_eq3(apply_mat3(m, v), [2.0, 3.0, 4.0]));
879    }
880}