Skip to main content

phyz_gpu/
shaders.rs

1//! WGSL compute shaders for GPU-accelerated physics.
2
3/// WGSL shader for semi-implicit Euler integration.
4///
5/// Each work item processes one DOF across all worlds in parallel.
6pub const INTEGRATE_SHADER: &str = r#"
7struct SimParams {
8    nworld: u32,
9    nv: u32,
10    dt: f32,
11    _padding: u32,
12}
13
14@group(0) @binding(0) var<uniform> params: SimParams;
15@group(0) @binding(1) var<storage, read_write> q: array<f32>;
16@group(0) @binding(2) var<storage, read_write> v: array<f32>;
17@group(0) @binding(3) var<storage, read> qdd: array<f32>;
18
19@compute @workgroup_size(256)
20fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21    let idx = gid.x;
22    let total_dofs = params.nworld * params.nv;
23
24    if (idx >= total_dofs) {
25        return;
26    }
27
28    let dt = params.dt;
29
30    // Semi-implicit Euler: v' = v + dt * qdd, q' = q + dt * v'
31    let v_old = v[idx];
32    let qdd_val = qdd[idx];
33    let v_new = v_old + dt * qdd_val;
34    let q_old = q[idx];
35    let q_new = q_old + dt * v_new;
36
37    v[idx] = v_new;
38    q[idx] = q_new;
39}
40"#;
41
42/// WGSL shader for simplified ABA (single revolute joint systems).
43///
44/// This is a simplified version that handles pendulum-like systems
45/// with single revolute joints. For multi-body systems, we'd need
46/// a more complex shader with tree traversal.
47pub const ABA_SIMPLE_SHADER: &str = r#"
48struct SimParams {
49    nworld: u32,
50    nv: u32,
51    dt: f32,
52    _padding: u32,
53}
54
55struct BodyParams {
56    mass: f32,
57    inertia: f32,
58    com_y: f32,
59    damping: f32,
60    gravity_y: f32,
61    _padding0: f32,
62    _padding1: f32,
63    _padding2: f32,
64}
65
66@group(0) @binding(0) var<uniform> params: SimParams;
67@group(0) @binding(1) var<uniform> body: BodyParams;
68@group(0) @binding(2) var<storage, read> q: array<f32>;
69@group(0) @binding(3) var<storage, read> v: array<f32>;
70@group(0) @binding(4) var<storage, read> ctrl: array<f32>;
71@group(0) @binding(5) var<storage, read_write> qdd: array<f32>;
72
73@compute @workgroup_size(256)
74fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
75    let world_idx = gid.x;
76
77    if (world_idx >= params.nworld) {
78        return;
79    }
80
81    // For single revolute joint: qdd = (tau - damping*v - m*g*L*sin(q)) / I
82    let idx = world_idx;
83    let q_val = q[idx];
84    let v_val = v[idx];
85    let tau = ctrl[idx];
86
87    // Gravity torque: m * g * L * sin(q)
88    // Note: gravity_y is magnitude (positive), com_y is typically negative
89    // ABA uses base acceleration trick, so we need positive sign here
90    let gravity_torque = body.mass * body.gravity_y * body.com_y * sin(q_val);
91
92    // Total torque: applied torque + gravity torque - damping torque
93    let total_torque = tau + gravity_torque - body.damping * v_val;
94
95    // Total inertia for pendulum: I = m*L²/3 (parallel axis theorem)
96    // For simplicity, we pass the computed inertia from CPU
97    let total_inertia = body.inertia;
98
99    // qdd = torque / inertia
100    qdd[idx] = total_torque / total_inertia;
101}
102"#;