Skip to main content

phyz_gpu/
gpu_simulator.rs

1//! GPU-accelerated batch simulator.
2//!
3//! Orchestrates compute pipelines for parallel simulation of multiple environments.
4
5use crate::gpu_state::GpuState;
6use crate::shaders::{ABA_SIMPLE_SHADER, INTEGRATE_SHADER};
7use bytemuck::{Pod, Zeroable};
8use std::sync::Arc;
9use phyz_model::{Model, State};
10
11/// Parameters passed to compute shaders.
12#[repr(C)]
13#[derive(Copy, Clone, Pod, Zeroable)]
14struct SimParams {
15    nworld: u32,
16    nv: u32,
17    dt: f32,
18    _padding: u32,
19}
20
21/// Body parameters for simple pendulum dynamics.
22#[repr(C)]
23#[derive(Copy, Clone, Pod, Zeroable)]
24struct BodyParams {
25    mass: f32,
26    inertia: f32,
27    com_y: f32,
28    damping: f32,
29    gravity_y: f32,
30    _padding: [f32; 3],
31}
32
33/// GPU-accelerated batch simulator.
34pub struct GpuSimulator {
35    pub device: Arc<wgpu::Device>,
36    pub queue: Arc<wgpu::Queue>,
37    pub state: GpuState,
38    pub model: Model,
39
40    // Compute pipelines
41    aba_pipeline: wgpu::ComputePipeline,
42    integrate_pipeline: wgpu::ComputePipeline,
43
44    // Bind groups
45    aba_bind_group: wgpu::BindGroup,
46    integrate_bind_group: wgpu::BindGroup,
47
48    // Uniform buffers (kept for potential future updates)
49    _sim_params_buffer: wgpu::Buffer,
50    _body_params_buffer: wgpu::Buffer,
51}
52
53impl GpuSimulator {
54    /// Create a new GPU simulator for batch simulation.
55    pub fn new(model: Model, nworld: usize) -> Result<Self, String> {
56        // Initialize wgpu
57        let instance = wgpu::Instance::new(wgpu::InstanceDescriptor {
58            backends: wgpu::Backends::all(),
59            ..Default::default()
60        });
61
62        let adapter = pollster::block_on(instance.request_adapter(&wgpu::RequestAdapterOptions {
63            power_preference: wgpu::PowerPreference::HighPerformance,
64            compatible_surface: None,
65            force_fallback_adapter: false,
66        }))
67        .ok_or("Failed to find GPU adapter")?;
68
69        let (device, queue) = pollster::block_on(adapter.request_device(
70            &wgpu::DeviceDescriptor {
71                label: Some("phyz-gpu-device"),
72                required_features: wgpu::Features::empty(),
73                required_limits: wgpu::Limits::default(),
74                memory_hints: Default::default(),
75            },
76            None,
77        ))
78        .map_err(|e| format!("Failed to create device: {}", e))?;
79
80        let device = Arc::new(device);
81        let queue = Arc::new(queue);
82
83        // Create GPU state buffers
84        let state = GpuState::new(device.clone(), queue.clone(), &model, nworld);
85
86        // Create uniform buffers
87        let sim_params = SimParams {
88            nworld: nworld as u32,
89            nv: model.nv as u32,
90            dt: model.dt as f32,
91            _padding: 0,
92        };
93
94        let sim_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
95            label: Some("sim_params"),
96            size: std::mem::size_of::<SimParams>() as u64,
97            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
98            mapped_at_creation: false,
99        });
100
101        queue.write_buffer(&sim_params_buffer, 0, bytemuck::bytes_of(&sim_params));
102
103        // Extract body parameters from model (assuming single-body pendulum)
104        let body_params = if !model.bodies.is_empty() {
105            let body = &model.bodies[0];
106            let joint = &model.joints[body.joint_idx];
107
108            // For revolute joint about Z axis: total inertia = I_zz + m*(com.x² + com.y²)
109            let i_zz = body.inertia.inertia[(2, 2)];
110            let com_xy_sq =
111                body.inertia.com.x * body.inertia.com.x + body.inertia.com.y * body.inertia.com.y;
112            let total_inertia = i_zz + body.inertia.mass * com_xy_sq;
113
114            BodyParams {
115                mass: body.inertia.mass as f32,
116                inertia: total_inertia as f32,
117                com_y: body.inertia.com.y as f32,
118                damping: joint.damping as f32,
119                gravity_y: model.gravity.y.abs() as f32,
120                _padding: [0.0; 3],
121            }
122        } else {
123            BodyParams {
124                mass: 1.0,
125                inertia: 1.0,
126                com_y: -0.5,
127                damping: 0.0,
128                gravity_y: 9.81,
129                _padding: [0.0; 3],
130            }
131        };
132
133        let body_params_buffer = device.create_buffer(&wgpu::BufferDescriptor {
134            label: Some("body_params"),
135            size: std::mem::size_of::<BodyParams>() as u64,
136            usage: wgpu::BufferUsages::UNIFORM | wgpu::BufferUsages::COPY_DST,
137            mapped_at_creation: false,
138        });
139
140        queue.write_buffer(&body_params_buffer, 0, bytemuck::bytes_of(&body_params));
141
142        // Create shader modules
143        let aba_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
144            label: Some("aba_shader"),
145            source: wgpu::ShaderSource::Wgsl(ABA_SIMPLE_SHADER.into()),
146        });
147
148        let integrate_module = device.create_shader_module(wgpu::ShaderModuleDescriptor {
149            label: Some("integrate_shader"),
150            source: wgpu::ShaderSource::Wgsl(INTEGRATE_SHADER.into()),
151        });
152
153        // Create bind group layouts
154        let aba_bind_group_layout =
155            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
156                label: Some("aba_bind_group_layout"),
157                entries: &[
158                    wgpu::BindGroupLayoutEntry {
159                        binding: 0,
160                        visibility: wgpu::ShaderStages::COMPUTE,
161                        ty: wgpu::BindingType::Buffer {
162                            ty: wgpu::BufferBindingType::Uniform,
163                            has_dynamic_offset: false,
164                            min_binding_size: None,
165                        },
166                        count: None,
167                    },
168                    wgpu::BindGroupLayoutEntry {
169                        binding: 1,
170                        visibility: wgpu::ShaderStages::COMPUTE,
171                        ty: wgpu::BindingType::Buffer {
172                            ty: wgpu::BufferBindingType::Uniform,
173                            has_dynamic_offset: false,
174                            min_binding_size: None,
175                        },
176                        count: None,
177                    },
178                    wgpu::BindGroupLayoutEntry {
179                        binding: 2,
180                        visibility: wgpu::ShaderStages::COMPUTE,
181                        ty: wgpu::BindingType::Buffer {
182                            ty: wgpu::BufferBindingType::Storage { read_only: true },
183                            has_dynamic_offset: false,
184                            min_binding_size: None,
185                        },
186                        count: None,
187                    },
188                    wgpu::BindGroupLayoutEntry {
189                        binding: 3,
190                        visibility: wgpu::ShaderStages::COMPUTE,
191                        ty: wgpu::BindingType::Buffer {
192                            ty: wgpu::BufferBindingType::Storage { read_only: true },
193                            has_dynamic_offset: false,
194                            min_binding_size: None,
195                        },
196                        count: None,
197                    },
198                    wgpu::BindGroupLayoutEntry {
199                        binding: 4,
200                        visibility: wgpu::ShaderStages::COMPUTE,
201                        ty: wgpu::BindingType::Buffer {
202                            ty: wgpu::BufferBindingType::Storage { read_only: true },
203                            has_dynamic_offset: false,
204                            min_binding_size: None,
205                        },
206                        count: None,
207                    },
208                    wgpu::BindGroupLayoutEntry {
209                        binding: 5,
210                        visibility: wgpu::ShaderStages::COMPUTE,
211                        ty: wgpu::BindingType::Buffer {
212                            ty: wgpu::BufferBindingType::Storage { read_only: false },
213                            has_dynamic_offset: false,
214                            min_binding_size: None,
215                        },
216                        count: None,
217                    },
218                ],
219            });
220
221        let integrate_bind_group_layout =
222            device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
223                label: Some("integrate_bind_group_layout"),
224                entries: &[
225                    wgpu::BindGroupLayoutEntry {
226                        binding: 0,
227                        visibility: wgpu::ShaderStages::COMPUTE,
228                        ty: wgpu::BindingType::Buffer {
229                            ty: wgpu::BufferBindingType::Uniform,
230                            has_dynamic_offset: false,
231                            min_binding_size: None,
232                        },
233                        count: None,
234                    },
235                    wgpu::BindGroupLayoutEntry {
236                        binding: 1,
237                        visibility: wgpu::ShaderStages::COMPUTE,
238                        ty: wgpu::BindingType::Buffer {
239                            ty: wgpu::BufferBindingType::Storage { read_only: false },
240                            has_dynamic_offset: false,
241                            min_binding_size: None,
242                        },
243                        count: None,
244                    },
245                    wgpu::BindGroupLayoutEntry {
246                        binding: 2,
247                        visibility: wgpu::ShaderStages::COMPUTE,
248                        ty: wgpu::BindingType::Buffer {
249                            ty: wgpu::BufferBindingType::Storage { read_only: false },
250                            has_dynamic_offset: false,
251                            min_binding_size: None,
252                        },
253                        count: None,
254                    },
255                    wgpu::BindGroupLayoutEntry {
256                        binding: 3,
257                        visibility: wgpu::ShaderStages::COMPUTE,
258                        ty: wgpu::BindingType::Buffer {
259                            ty: wgpu::BufferBindingType::Storage { read_only: true },
260                            has_dynamic_offset: false,
261                            min_binding_size: None,
262                        },
263                        count: None,
264                    },
265                ],
266            });
267
268        // Create pipelines
269        let aba_pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
270            label: Some("aba_pipeline_layout"),
271            bind_group_layouts: &[&aba_bind_group_layout],
272            push_constant_ranges: &[],
273        });
274
275        let aba_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
276            label: Some("aba_pipeline"),
277            layout: Some(&aba_pipeline_layout),
278            module: &aba_module,
279            entry_point: Some("main"),
280            compilation_options: Default::default(),
281            cache: None,
282        });
283
284        let integrate_pipeline_layout =
285            device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
286                label: Some("integrate_pipeline_layout"),
287                bind_group_layouts: &[&integrate_bind_group_layout],
288                push_constant_ranges: &[],
289            });
290
291        let integrate_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
292            label: Some("integrate_pipeline"),
293            layout: Some(&integrate_pipeline_layout),
294            module: &integrate_module,
295            entry_point: Some("main"),
296            compilation_options: Default::default(),
297            cache: None,
298        });
299
300        // Create bind groups
301        let aba_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
302            label: Some("aba_bind_group"),
303            layout: &aba_bind_group_layout,
304            entries: &[
305                wgpu::BindGroupEntry {
306                    binding: 0,
307                    resource: sim_params_buffer.as_entire_binding(),
308                },
309                wgpu::BindGroupEntry {
310                    binding: 1,
311                    resource: body_params_buffer.as_entire_binding(),
312                },
313                wgpu::BindGroupEntry {
314                    binding: 2,
315                    resource: state.q_buffer.as_entire_binding(),
316                },
317                wgpu::BindGroupEntry {
318                    binding: 3,
319                    resource: state.v_buffer.as_entire_binding(),
320                },
321                wgpu::BindGroupEntry {
322                    binding: 4,
323                    resource: state.ctrl_buffer.as_entire_binding(),
324                },
325                wgpu::BindGroupEntry {
326                    binding: 5,
327                    resource: state.qdd_buffer.as_entire_binding(),
328                },
329            ],
330        });
331
332        let integrate_bind_group = device.create_bind_group(&wgpu::BindGroupDescriptor {
333            label: Some("integrate_bind_group"),
334            layout: &integrate_bind_group_layout,
335            entries: &[
336                wgpu::BindGroupEntry {
337                    binding: 0,
338                    resource: sim_params_buffer.as_entire_binding(),
339                },
340                wgpu::BindGroupEntry {
341                    binding: 1,
342                    resource: state.q_buffer.as_entire_binding(),
343                },
344                wgpu::BindGroupEntry {
345                    binding: 2,
346                    resource: state.v_buffer.as_entire_binding(),
347                },
348                wgpu::BindGroupEntry {
349                    binding: 3,
350                    resource: state.qdd_buffer.as_entire_binding(),
351                },
352            ],
353        });
354
355        Ok(Self {
356            device,
357            queue,
358            state,
359            model,
360            aba_pipeline,
361            integrate_pipeline,
362            aba_bind_group,
363            integrate_bind_group,
364            _sim_params_buffer: sim_params_buffer,
365            _body_params_buffer: body_params_buffer,
366        })
367    }
368
369    /// Upload initial states to GPU.
370    pub fn load_states(&self, states: &[State]) {
371        self.state.upload_states(states);
372    }
373
374    /// Run one simulation step on GPU.
375    pub fn step(&self) {
376        let mut encoder = self
377            .device
378            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
379                label: Some("step_encoder"),
380            });
381
382        // Compute pass 1: ABA (compute accelerations)
383        {
384            let mut aba_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
385                label: Some("aba_pass"),
386                timestamp_writes: None,
387            });
388            aba_pass.set_pipeline(&self.aba_pipeline);
389            aba_pass.set_bind_group(0, &self.aba_bind_group, &[]);
390            let workgroups = (self.state.nworld as u32).div_ceil(256);
391            aba_pass.dispatch_workgroups(workgroups, 1, 1);
392        }
393
394        // Compute pass 2: Integration
395        {
396            let mut integrate_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
397                label: Some("integrate_pass"),
398                timestamp_writes: None,
399            });
400            integrate_pass.set_pipeline(&self.integrate_pipeline);
401            integrate_pass.set_bind_group(0, &self.integrate_bind_group, &[]);
402            let total_dofs = (self.state.nworld * self.state.nv) as u32;
403            let workgroups = total_dofs.div_ceil(256);
404            integrate_pass.dispatch_workgroups(workgroups, 1, 1);
405        }
406
407        self.queue.submit(Some(encoder.finish()));
408    }
409
410    /// Download states from GPU to CPU.
411    pub fn readback_states(&self) -> Vec<State> {
412        let (q_data, v_data) =
413            pollster::block_on(self.state.download_states()).expect("Failed to download states");
414
415        let mut states = Vec::new();
416        for i in 0..self.state.nworld {
417            let mut state = self.model.default_state();
418            for j in 0..self.state.nq {
419                state.q[j] = q_data[i * self.state.nq + j] as f64;
420            }
421            for j in 0..self.state.nv {
422                state.v[j] = v_data[i * self.state.nv + j] as f64;
423            }
424            states.push(state);
425        }
426
427        states
428    }
429}