1pub const INTEGRATE_SHADER: &str = r#"
7struct SimParams {
8 nworld: u32,
9 nv: u32,
10 dt: f32,
11 _padding: u32,
12}
13
14@group(0) @binding(0) var<uniform> params: SimParams;
15@group(0) @binding(1) var<storage, read_write> q: array<f32>;
16@group(0) @binding(2) var<storage, read_write> v: array<f32>;
17@group(0) @binding(3) var<storage, read> qdd: array<f32>;
18
19@compute @workgroup_size(256)
20fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
21 let idx = gid.x;
22 let total_dofs = params.nworld * params.nv;
23
24 if (idx >= total_dofs) {
25 return;
26 }
27
28 let dt = params.dt;
29
30 // Semi-implicit Euler: v' = v + dt * qdd, q' = q + dt * v'
31 let v_old = v[idx];
32 let qdd_val = qdd[idx];
33 let v_new = v_old + dt * qdd_val;
34 let q_old = q[idx];
35 let q_new = q_old + dt * v_new;
36
37 v[idx] = v_new;
38 q[idx] = q_new;
39}
40"#;
41
42pub const ABA_SIMPLE_SHADER: &str = r#"
48struct SimParams {
49 nworld: u32,
50 nv: u32,
51 dt: f32,
52 _padding: u32,
53}
54
55struct BodyParams {
56 mass: f32,
57 inertia: f32,
58 com_y: f32,
59 damping: f32,
60 gravity_y: f32,
61 _padding0: f32,
62 _padding1: f32,
63 _padding2: f32,
64}
65
66@group(0) @binding(0) var<uniform> params: SimParams;
67@group(0) @binding(1) var<uniform> body: BodyParams;
68@group(0) @binding(2) var<storage, read> q: array<f32>;
69@group(0) @binding(3) var<storage, read> v: array<f32>;
70@group(0) @binding(4) var<storage, read> ctrl: array<f32>;
71@group(0) @binding(5) var<storage, read_write> qdd: array<f32>;
72
73@compute @workgroup_size(256)
74fn main(@builtin(global_invocation_id) gid: vec3<u32>) {
75 let world_idx = gid.x;
76
77 if (world_idx >= params.nworld) {
78 return;
79 }
80
81 // For single revolute joint: qdd = (tau - damping*v - m*g*L*sin(q)) / I
82 let idx = world_idx;
83 let q_val = q[idx];
84 let v_val = v[idx];
85 let tau = ctrl[idx];
86
87 // Gravity torque: m * g * L * sin(q)
88 // Note: gravity_y is magnitude (positive), com_y is typically negative
89 // ABA uses base acceleration trick, so we need positive sign here
90 let gravity_torque = body.mass * body.gravity_y * body.com_y * sin(q_val);
91
92 // Total torque: applied torque + gravity torque - damping torque
93 let total_torque = tau + gravity_torque - body.damping * v_val;
94
95 // Total inertia for pendulum: I = m*L²/3 (parallel axis theorem)
96 // For simplicity, we pass the computed inertia from CPU
97 let total_inertia = body.inertia;
98
99 // qdd = torque / inertia
100 qdd[idx] = total_torque / total_inertia;
101}
102"#;