Skip to main content

phyz_gpu/
gpu_state.rs

1//! GPU state buffer management.
2//!
3//! Handles allocation and synchronization of GPU buffers for batch simulation.
4
5use bytemuck::{Pod, Zeroable};
6use std::sync::Arc;
7use phyz_model::{Model, State};
8
9/// GPU-backed batch simulation state.
10///
11/// Stores state for `nworld` parallel environments on GPU memory.
12pub struct GpuState {
13    pub device: Arc<wgpu::Device>,
14    pub queue: Arc<wgpu::Queue>,
15    pub nworld: usize,
16    pub nq: usize,
17    pub nv: usize,
18
19    // State buffers (nworld × ndof)
20    pub q_buffer: wgpu::Buffer,
21    pub v_buffer: wgpu::Buffer,
22    pub ctrl_buffer: wgpu::Buffer,
23
24    // Scratch buffers for computation
25    pub qdd_buffer: wgpu::Buffer,
26
27    // Staging buffers for CPU ↔ GPU transfer
28    pub q_staging: wgpu::Buffer,
29    pub v_staging: wgpu::Buffer,
30}
31
32/// GPU-friendly packed state data for a single environment.
33#[repr(C)]
34#[derive(Copy, Clone, Pod, Zeroable)]
35struct PackedState {
36    q: [f32; 16], // max 16 DOF for simplicity
37    v: [f32; 16],
38    ctrl: [f32; 16],
39}
40
41impl GpuState {
42    /// Create GPU buffers for batch simulation.
43    pub fn new(
44        device: Arc<wgpu::Device>,
45        queue: Arc<wgpu::Queue>,
46        model: &Model,
47        nworld: usize,
48    ) -> Self {
49        let nq = model.nq;
50        let nv = model.nv;
51
52        // Create buffers with STORAGE usage for compute shaders
53        let q_buffer = device.create_buffer(&wgpu::BufferDescriptor {
54            label: Some("q_buffer"),
55            size: (nworld * nq * std::mem::size_of::<f32>()) as u64,
56            usage: wgpu::BufferUsages::STORAGE
57                | wgpu::BufferUsages::COPY_DST
58                | wgpu::BufferUsages::COPY_SRC,
59            mapped_at_creation: false,
60        });
61
62        let v_buffer = device.create_buffer(&wgpu::BufferDescriptor {
63            label: Some("v_buffer"),
64            size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
65            usage: wgpu::BufferUsages::STORAGE
66                | wgpu::BufferUsages::COPY_DST
67                | wgpu::BufferUsages::COPY_SRC,
68            mapped_at_creation: false,
69        });
70
71        let ctrl_buffer = device.create_buffer(&wgpu::BufferDescriptor {
72            label: Some("ctrl_buffer"),
73            size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
74            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
75            mapped_at_creation: false,
76        });
77
78        let qdd_buffer = device.create_buffer(&wgpu::BufferDescriptor {
79            label: Some("qdd_buffer"),
80            size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
81            usage: wgpu::BufferUsages::STORAGE,
82            mapped_at_creation: false,
83        });
84
85        // Staging buffers for readback
86        let q_staging = device.create_buffer(&wgpu::BufferDescriptor {
87            label: Some("q_staging"),
88            size: (nworld * nq * std::mem::size_of::<f32>()) as u64,
89            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
90            mapped_at_creation: false,
91        });
92
93        let v_staging = device.create_buffer(&wgpu::BufferDescriptor {
94            label: Some("v_staging"),
95            size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
96            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
97            mapped_at_creation: false,
98        });
99
100        Self {
101            device,
102            queue,
103            nworld,
104            nq,
105            nv,
106            q_buffer,
107            v_buffer,
108            ctrl_buffer,
109            qdd_buffer,
110            q_staging,
111            v_staging,
112        }
113    }
114
115    /// Upload states from CPU to GPU.
116    pub fn upload_states(&self, states: &[State]) {
117        assert_eq!(states.len(), self.nworld);
118
119        // Pack states into flat f32 arrays
120        let mut q_data = vec![0.0f32; self.nworld * self.nq];
121        let mut v_data = vec![0.0f32; self.nworld * self.nv];
122        let mut ctrl_data = vec![0.0f32; self.nworld * self.nv];
123
124        for (i, state) in states.iter().enumerate() {
125            for j in 0..self.nq {
126                q_data[i * self.nq + j] = state.q[j] as f32;
127            }
128            for j in 0..self.nv {
129                v_data[i * self.nv + j] = state.v[j] as f32;
130                ctrl_data[i * self.nv + j] = state.ctrl[j] as f32;
131            }
132        }
133
134        // Upload to GPU
135        self.queue
136            .write_buffer(&self.q_buffer, 0, bytemuck::cast_slice(&q_data));
137        self.queue
138            .write_buffer(&self.v_buffer, 0, bytemuck::cast_slice(&v_data));
139        self.queue
140            .write_buffer(&self.ctrl_buffer, 0, bytemuck::cast_slice(&ctrl_data));
141    }
142
143    /// Download states from GPU to CPU.
144    pub async fn download_states(&self) -> Result<(Vec<f32>, Vec<f32>), String> {
145        // Copy from storage buffers to staging buffers
146        let mut encoder = self
147            .device
148            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
149                label: Some("download_encoder"),
150            });
151
152        encoder.copy_buffer_to_buffer(
153            &self.q_buffer,
154            0,
155            &self.q_staging,
156            0,
157            (self.nworld * self.nq * std::mem::size_of::<f32>()) as u64,
158        );
159
160        encoder.copy_buffer_to_buffer(
161            &self.v_buffer,
162            0,
163            &self.v_staging,
164            0,
165            (self.nworld * self.nv * std::mem::size_of::<f32>()) as u64,
166        );
167
168        self.queue.submit(Some(encoder.finish()));
169
170        // Map and read staging buffers
171        let q_slice = self.q_staging.slice(..);
172        let v_slice = self.v_staging.slice(..);
173
174        let (q_tx, q_rx) = futures_intrusive::channel::shared::oneshot_channel();
175        let (v_tx, v_rx) = futures_intrusive::channel::shared::oneshot_channel();
176
177        q_slice.map_async(wgpu::MapMode::Read, move |result| {
178            q_tx.send(result).ok();
179        });
180
181        v_slice.map_async(wgpu::MapMode::Read, move |result| {
182            v_tx.send(result).ok();
183        });
184
185        self.device.poll(wgpu::Maintain::Wait);
186
187        q_rx.receive()
188            .await
189            .ok_or("Failed to map q buffer")?
190            .map_err(|e| format!("GPU buffer mapping failed: {:?}", e))?;
191        v_rx.receive()
192            .await
193            .ok_or("Failed to map v buffer")?
194            .map_err(|e| format!("GPU buffer mapping failed: {:?}", e))?;
195
196        let q_data = q_slice.get_mapped_range();
197        let v_data = v_slice.get_mapped_range();
198
199        let q_vec: Vec<f32> = bytemuck::cast_slice(&q_data).to_vec();
200        let v_vec: Vec<f32> = bytemuck::cast_slice(&v_data).to_vec();
201
202        drop(q_data);
203        drop(v_data);
204        self.q_staging.unmap();
205        self.v_staging.unmap();
206
207        Ok((q_vec, v_vec))
208    }
209}