1use bytemuck::{Pod, Zeroable};
6use std::sync::Arc;
7use phyz_model::{Model, State};
8
9pub struct GpuState {
13 pub device: Arc<wgpu::Device>,
14 pub queue: Arc<wgpu::Queue>,
15 pub nworld: usize,
16 pub nq: usize,
17 pub nv: usize,
18
19 pub q_buffer: wgpu::Buffer,
21 pub v_buffer: wgpu::Buffer,
22 pub ctrl_buffer: wgpu::Buffer,
23
24 pub qdd_buffer: wgpu::Buffer,
26
27 pub q_staging: wgpu::Buffer,
29 pub v_staging: wgpu::Buffer,
30}
31
32#[repr(C)]
34#[derive(Copy, Clone, Pod, Zeroable)]
35struct PackedState {
36 q: [f32; 16], v: [f32; 16],
38 ctrl: [f32; 16],
39}
40
41impl GpuState {
42 pub fn new(
44 device: Arc<wgpu::Device>,
45 queue: Arc<wgpu::Queue>,
46 model: &Model,
47 nworld: usize,
48 ) -> Self {
49 let nq = model.nq;
50 let nv = model.nv;
51
52 let q_buffer = device.create_buffer(&wgpu::BufferDescriptor {
54 label: Some("q_buffer"),
55 size: (nworld * nq * std::mem::size_of::<f32>()) as u64,
56 usage: wgpu::BufferUsages::STORAGE
57 | wgpu::BufferUsages::COPY_DST
58 | wgpu::BufferUsages::COPY_SRC,
59 mapped_at_creation: false,
60 });
61
62 let v_buffer = device.create_buffer(&wgpu::BufferDescriptor {
63 label: Some("v_buffer"),
64 size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
65 usage: wgpu::BufferUsages::STORAGE
66 | wgpu::BufferUsages::COPY_DST
67 | wgpu::BufferUsages::COPY_SRC,
68 mapped_at_creation: false,
69 });
70
71 let ctrl_buffer = device.create_buffer(&wgpu::BufferDescriptor {
72 label: Some("ctrl_buffer"),
73 size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
74 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_DST,
75 mapped_at_creation: false,
76 });
77
78 let qdd_buffer = device.create_buffer(&wgpu::BufferDescriptor {
79 label: Some("qdd_buffer"),
80 size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
81 usage: wgpu::BufferUsages::STORAGE,
82 mapped_at_creation: false,
83 });
84
85 let q_staging = device.create_buffer(&wgpu::BufferDescriptor {
87 label: Some("q_staging"),
88 size: (nworld * nq * std::mem::size_of::<f32>()) as u64,
89 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
90 mapped_at_creation: false,
91 });
92
93 let v_staging = device.create_buffer(&wgpu::BufferDescriptor {
94 label: Some("v_staging"),
95 size: (nworld * nv * std::mem::size_of::<f32>()) as u64,
96 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
97 mapped_at_creation: false,
98 });
99
100 Self {
101 device,
102 queue,
103 nworld,
104 nq,
105 nv,
106 q_buffer,
107 v_buffer,
108 ctrl_buffer,
109 qdd_buffer,
110 q_staging,
111 v_staging,
112 }
113 }
114
115 pub fn upload_states(&self, states: &[State]) {
117 assert_eq!(states.len(), self.nworld);
118
119 let mut q_data = vec![0.0f32; self.nworld * self.nq];
121 let mut v_data = vec![0.0f32; self.nworld * self.nv];
122 let mut ctrl_data = vec![0.0f32; self.nworld * self.nv];
123
124 for (i, state) in states.iter().enumerate() {
125 for j in 0..self.nq {
126 q_data[i * self.nq + j] = state.q[j] as f32;
127 }
128 for j in 0..self.nv {
129 v_data[i * self.nv + j] = state.v[j] as f32;
130 ctrl_data[i * self.nv + j] = state.ctrl[j] as f32;
131 }
132 }
133
134 self.queue
136 .write_buffer(&self.q_buffer, 0, bytemuck::cast_slice(&q_data));
137 self.queue
138 .write_buffer(&self.v_buffer, 0, bytemuck::cast_slice(&v_data));
139 self.queue
140 .write_buffer(&self.ctrl_buffer, 0, bytemuck::cast_slice(&ctrl_data));
141 }
142
143 pub async fn download_states(&self) -> Result<(Vec<f32>, Vec<f32>), String> {
145 let mut encoder = self
147 .device
148 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
149 label: Some("download_encoder"),
150 });
151
152 encoder.copy_buffer_to_buffer(
153 &self.q_buffer,
154 0,
155 &self.q_staging,
156 0,
157 (self.nworld * self.nq * std::mem::size_of::<f32>()) as u64,
158 );
159
160 encoder.copy_buffer_to_buffer(
161 &self.v_buffer,
162 0,
163 &self.v_staging,
164 0,
165 (self.nworld * self.nv * std::mem::size_of::<f32>()) as u64,
166 );
167
168 self.queue.submit(Some(encoder.finish()));
169
170 let q_slice = self.q_staging.slice(..);
172 let v_slice = self.v_staging.slice(..);
173
174 let (q_tx, q_rx) = futures_intrusive::channel::shared::oneshot_channel();
175 let (v_tx, v_rx) = futures_intrusive::channel::shared::oneshot_channel();
176
177 q_slice.map_async(wgpu::MapMode::Read, move |result| {
178 q_tx.send(result).ok();
179 });
180
181 v_slice.map_async(wgpu::MapMode::Read, move |result| {
182 v_tx.send(result).ok();
183 });
184
185 self.device.poll(wgpu::Maintain::Wait);
186
187 q_rx.receive()
188 .await
189 .ok_or("Failed to map q buffer")?
190 .map_err(|e| format!("GPU buffer mapping failed: {:?}", e))?;
191 v_rx.receive()
192 .await
193 .ok_or("Failed to map v buffer")?
194 .map_err(|e| format!("GPU buffer mapping failed: {:?}", e))?;
195
196 let q_data = q_slice.get_mapped_range();
197 let v_data = v_slice.get_mapped_range();
198
199 let q_vec: Vec<f32> = bytemuck::cast_slice(&q_data).to_vec();
200 let v_vec: Vec<f32> = bytemuck::cast_slice(&v_data).to_vec();
201
202 drop(q_data);
203 drop(v_data);
204 self.q_staging.unmap();
205 self.v_staging.unmap();
206
207 Ok((q_vec, v_vec))
208 }
209}