1use crate::gpu_state::GpuState;
6use crate::shaders::{ABA_SIMPLE_SHADER, INTEGRATE_SHADER};
7use bytemuck::{Pod, Zeroable};
8use std::sync::Arc;
9use phyz_model::{Model, State};
10
11#[repr(C)]
13#[derive(Copy, Clone, Pod, Zeroable)]
14struct SimParams {
15 nworld: u32,
16 nv: u32,
17 dt: f32,
18 _padding: u32,
19}
20
21#[repr(C)]
23#[derive(Copy, Clone, Pod, Zeroable)]
24struct BodyParams {
25 mass: f32,
26 inertia: f32,
27 com_y: f32,
28 damping: f32,
29 gravity_y: f32,
30 _padding: [f32; 3],
31}
32
33pub struct GpuSimulator {
35 pub device: Arc<wgpu::Device>,
36 pub queue: Arc<wgpu::Queue>,
37 pub state: GpuState,
38 pub model: Model,
39
40 aba_pipeline: wgpu::ComputePipeline,
42 integrate_pipeline: wgpu::ComputePipeline,
43
44 aba_bind_group: wgpu::BindGroup,
46 integrate_bind_group: wgpu::BindGroup,
47
48 _sim_params_buffer: wgpu::Buffer,
50 _body_params_buffer: wgpu::Buffer,
51}
52
53impl GpuSimulator {
54 pub fn new(model: Model, nworld: usize) -> Result<Self, String> {
56 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
58 backends: wgpu::Backends::all(),
59 ..Default::default()
60 });
61
62 let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
63 power_preference: wgpu::PowerPreference::HighPerformance,
64 compatible_surface: None,
65 force_fallback_adapter: false,
66 }))
67 .ok_or("Failed to find GPU adapter")?;
68
69 let (device, queue) = pollster::block_on(adapter.request_device(
70 &wgpu::DeviceDescriptor {
71 label: Some("phyz-gpu-device"),
72 required_features: wgpu::Features::empty(),
73 required_limits: wgpu::Limits::default(),
74 memory_hints: Default::default(),
75 },
76 None,
77 ))
78 .map_err(|e| format!("Failed to create device: {}", e))?;
79
80 let device = Arc::new(device);
81 let queue = Arc::new(queue);
82
83 let state = GpuState::new(device.clone(), queue.clone(), &model, nworld);
85
86 let sim_params = SimParams {
88 nworld: nworld as u32,
89 nv: model.nv as u32,
90 dt: model.dt as f32,
91 _padding: 0,
92 };
93
94 let sim_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
95 label: Some("sim_params"),
96 size: std::mem::size_of::<SimParams>() as u64,
97 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
98 mapped_at_creation: false,
99 });
100
101 queue.write_buffer(&sim_params_buffer, 0, bytemuck::bytes_of(&sim_params));
102
103 let body_params = if !model.bodies.is_empty() {
105 let body = &model.bodies[0];
106 let joint = &model.joints[body.joint_idx];
107
108 let i_zz = body.inertia.inertia[(2, 2)];
110 let com_xy_sq =
111 body.inertia.com.x * body.inertia.com.x + body.inertia.com.y * body.inertia.com.y;
112 let total_inertia = i_zz + body.inertia.mass * com_xy_sq;
113
114 BodyParams {
115 mass: body.inertia.mass as f32,
116 inertia: total_inertia as f32,
117 com_y: body.inertia.com.y as f32,
118 damping: joint.damping as f32,
119 gravity_y: model.gravity.y.abs() as f32,
120 _padding: [0.0; 3],
121 }
122 } else {
123 BodyParams {
124 mass: 1.0,
125 inertia: 1.0,
126 com_y: -0.5,
127 damping: 0.0,
128 gravity_y: 9.81,
129 _padding: [0.0; 3],
130 }
131 };
132
133 let body_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
134 label: Some("body_params"),
135 size: std::mem::size_of::<BodyParams>() as u64,
136 usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
137 mapped_at_creation: false,
138 });
139
140 queue.write_buffer(&body_params_buffer, 0, bytemuck::bytes_of(&body_params));
141
142 let aba_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
144 label: Some("aba_shader"),
145 source: wgpu::ShaderSource::Wgsl(ABA_SIMPLE_SHADER.into()),
146 });
147
148 let integrate_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
149 label: Some("integrate_shader"),
150 source: wgpu::ShaderSource::Wgsl(INTEGRATE_SHADER.into()),
151 });
152
153 let aba_bind_group_layout =
155 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
156 label: Some("aba_bind_group_layout"),
157 entries: &[
158 wgpu::BindGroupLayoutEntry {
159 binding: 0,
160 visibility: wgpu::ShaderStages::COMPUTE,
161 ty: wgpu::BindingType::Buffer {
162 ty: wgpu::BufferBindingType::Uniform,
163 has_dynamic_offset: false,
164 min_binding_size: None,
165 },
166 count: None,
167 },
168 wgpu::BindGroupLayoutEntry {
169 binding: 1,
170 visibility: wgpu::ShaderStages::COMPUTE,
171 ty: wgpu::BindingType::Buffer {
172 ty: wgpu::BufferBindingType::Uniform,
173 has_dynamic_offset: false,
174 min_binding_size: None,
175 },
176 count: None,
177 },
178 wgpu::BindGroupLayoutEntry {
179 binding: 2,
180 visibility: wgpu::ShaderStages::COMPUTE,
181 ty: wgpu::BindingType::Buffer {
182 ty: wgpu::BufferBindingType::Storage { read_only: true },
183 has_dynamic_offset: false,
184 min_binding_size: None,
185 },
186 count: None,
187 },
188 wgpu::BindGroupLayoutEntry {
189 binding: 3,
190 visibility: wgpu::ShaderStages::COMPUTE,
191 ty: wgpu::BindingType::Buffer {
192 ty: wgpu::BufferBindingType::Storage { read_only: true },
193 has_dynamic_offset: false,
194 min_binding_size: None,
195 },
196 count: None,
197 },
198 wgpu::BindGroupLayoutEntry {
199 binding: 4,
200 visibility: wgpu::ShaderStages::COMPUTE,
201 ty: wgpu::BindingType::Buffer {
202 ty: wgpu::BufferBindingType::Storage { read_only: true },
203 has_dynamic_offset: false,
204 min_binding_size: None,
205 },
206 count: None,
207 },
208 wgpu::BindGroupLayoutEntry {
209 binding: 5,
210 visibility: wgpu::ShaderStages::COMPUTE,
211 ty: wgpu::BindingType::Buffer {
212 ty: wgpu::BufferBindingType::Storage { read_only: false },
213 has_dynamic_offset: false,
214 min_binding_size: None,
215 },
216 count: None,
217 },
218 ],
219 });
220
221 let integrate_bind_group_layout =
222 device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
223 label: Some("integrate_bind_group_layout"),
224 entries: &[
225 wgpu::BindGroupLayoutEntry {
226 binding: 0,
227 visibility: wgpu::ShaderStages::COMPUTE,
228 ty: wgpu::BindingType::Buffer {
229 ty: wgpu::BufferBindingType::Uniform,
230 has_dynamic_offset: false,
231 min_binding_size: None,
232 },
233 count: None,
234 },
235 wgpu::BindGroupLayoutEntry {
236 binding: 1,
237 visibility: wgpu::ShaderStages::COMPUTE,
238 ty: wgpu::BindingType::Buffer {
239 ty: wgpu::BufferBindingType::Storage { read_only: false },
240 has_dynamic_offset: false,
241 min_binding_size: None,
242 },
243 count: None,
244 },
245 wgpu::BindGroupLayoutEntry {
246 binding: 2,
247 visibility: wgpu::ShaderStages::COMPUTE,
248 ty: wgpu::BindingType::Buffer {
249 ty: wgpu::BufferBindingType::Storage { read_only: false },
250 has_dynamic_offset: false,
251 min_binding_size: None,
252 },
253 count: None,
254 },
255 wgpu::BindGroupLayoutEntry {
256 binding: 3,
257 visibility: wgpu::ShaderStages::COMPUTE,
258 ty: wgpu::BindingType::Buffer {
259 ty: wgpu::BufferBindingType::Storage { read_only: true },
260 has_dynamic_offset: false,
261 min_binding_size: None,
262 },
263 count: None,
264 },
265 ],
266 });
267
268 let aba_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
270 label: Some("aba_pipeline_layout"),
271 bind_group_layouts: &[&aba_bind_group_layout],
272 push_constant_ranges: &[],
273 });
274
275 let aba_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
276 label: Some("aba_pipeline"),
277 layout: Some(&aba_pipeline_layout),
278 module: &aba_module,
279 entry_point: Some("main"),
280 compilation_options: Default::default(),
281 cache: None,
282 });
283
284 let integrate_pipeline_layout =
285 device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
286 label: Some("integrate_pipeline_layout"),
287 bind_group_layouts: &[&integrate_bind_group_layout],
288 push_constant_ranges: &[],
289 });
290
291 let integrate_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
292 label: Some("integrate_pipeline"),
293 layout: Some(&integrate_pipeline_layout),
294 module: &integrate_module,
295 entry_point: Some("main"),
296 compilation_options: Default::default(),
297 cache: None,
298 });
299
300 let aba_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
302 label: Some("aba_bind_group"),
303 layout: &aba_bind_group_layout,
304 entries: &[
305 wgpu::BindGroupEntry {
306 binding: 0,
307 resource: sim_params_buffer.as_entire_binding(),
308 },
309 wgpu::BindGroupEntry {
310 binding: 1,
311 resource: body_params_buffer.as_entire_binding(),
312 },
313 wgpu::BindGroupEntry {
314 binding: 2,
315 resource: state.q_buffer.as_entire_binding(),
316 },
317 wgpu::BindGroupEntry {
318 binding: 3,
319 resource: state.v_buffer.as_entire_binding(),
320 },
321 wgpu::BindGroupEntry {
322 binding: 4,
323 resource: state.ctrl_buffer.as_entire_binding(),
324 },
325 wgpu::BindGroupEntry {
326 binding: 5,
327 resource: state.qdd_buffer.as_entire_binding(),
328 },
329 ],
330 });
331
332 let integrate_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
333 label: Some("integrate_bind_group"),
334 layout: &integrate_bind_group_layout,
335 entries: &[
336 wgpu::BindGroupEntry {
337 binding: 0,
338 resource: sim_params_buffer.as_entire_binding(),
339 },
340 wgpu::BindGroupEntry {
341 binding: 1,
342 resource: state.q_buffer.as_entire_binding(),
343 },
344 wgpu::BindGroupEntry {
345 binding: 2,
346 resource: state.v_buffer.as_entire_binding(),
347 },
348 wgpu::BindGroupEntry {
349 binding: 3,
350 resource: state.qdd_buffer.as_entire_binding(),
351 },
352 ],
353 });
354
355 Ok(Self {
356 device,
357 queue,
358 state,
359 model,
360 aba_pipeline,
361 integrate_pipeline,
362 aba_bind_group,
363 integrate_bind_group,
364 _sim_params_buffer: sim_params_buffer,
365 _body_params_buffer: body_params_buffer,
366 })
367 }
368
369 pub fn load_states(&self, states: &[State]) {
371 self.state.upload_states(states);
372 }
373
374 pub fn step(&self) {
376 let mut encoder = self
377 .device
378 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
379 label: Some("step_encoder"),
380 });
381
382 {
384 let mut aba_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
385 label: Some("aba_pass"),
386 timestamp_writes: None,
387 });
388 aba_pass.set_pipeline(&self.aba_pipeline);
389 aba_pass.set_bind_group(0, &self.aba_bind_group, &[]);
390 let workgroups = (self.state.nworld as u32).div_ceil(256);
391 aba_pass.dispatch_workgroups(workgroups, 1, 1);
392 }
393
394 {
396 let mut integrate_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
397 label: Some("integrate_pass"),
398 timestamp_writes: None,
399 });
400 integrate_pass.set_pipeline(&self.integrate_pipeline);
401 integrate_pass.set_bind_group(0, &self.integrate_bind_group, &[]);
402 let total_dofs = (self.state.nworld * self.state.nv) as u32;
403 let workgroups = total_dofs.div_ceil(256);
404 integrate_pass.dispatch_workgroups(workgroups, 1, 1);
405 }
406
407 self.queue.submit(Some(encoder.finish()));
408 }
409
410 pub fn readback_states(&self) -> Vec<State> {
412 let (q_data, v_data) =
413 pollster::block_on(self.state.download_states()).expect("Failed to download states");
414
415 let mut states = Vec::new();
416 for i in 0..self.state.nworld {
417 let mut state = self.model.default_state();
418 for j in 0..self.state.nq {
419 state.q[j] = q_data[i * self.state.nq + j] as f64;
420 }
421 for j in 0..self.state.nv {
422 state.v[j] = v_data[i * self.state.nv + j] as f64;
423 }
424 states.push(state);
425 }
426
427 states
428 }
429}