ringkernel_wavesim/simulation/
cuda_compute.rs1use ringkernel_core::error::{Result, RingKernelError};
19use ringkernel_core::memory::GpuBuffer;
20use ringkernel_cuda::{CudaBuffer, CudaDevice};
21
22use super::gpu_backend::{Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
23
24#[cfg(feature = "cuda-codegen")]
29fn get_cuda_source() -> String {
30 super::kernels::generate_tile_kernels()
31}
32
33#[cfg(not(feature = "cuda-codegen"))]
34fn get_cuda_source() -> String {
35 include_str!("../shaders/fdtd_tile.cu").to_string()
36}
37
38const MODULE_NAME: &str = "fdtd_tile";
40
41const FN_FDTD_STEP: &str = "fdtd_tile_step";
43const FN_EXTRACT_HALO: &str = "extract_halo";
44const FN_INJECT_HALO: &str = "inject_halo";
45const FN_READ_INTERIOR: &str = "read_interior";
46
47pub struct CudaTileBackend {
51 device: CudaDevice,
53 tile_size: u32,
55}
56
57impl CudaTileBackend {
58 pub fn new(tile_size: u32) -> Result<Self> {
62 Self::with_device(0, tile_size)
63 }
64
65 pub fn with_device(ordinal: usize, tile_size: u32) -> Result<Self> {
67 let device = CudaDevice::new(ordinal)?;
68
69 tracing::info!(
70 "CUDA tile backend: {} (CC {}.{})",
71 device.name(),
72 device.compute_capability().0,
73 device.compute_capability().1
74 );
75
76 let cuda_source = get_cuda_source();
78 let ptx = cudarc::nvrtc::compile_ptx(&cuda_source).map_err(|e| {
79 RingKernelError::BackendError(format!("NVRTC compilation failed: {}", e))
80 })?;
81
82 device
83 .inner()
84 .load_ptx(
85 ptx,
86 MODULE_NAME,
87 &[
88 FN_FDTD_STEP,
89 FN_EXTRACT_HALO,
90 FN_INJECT_HALO,
91 FN_READ_INTERIOR,
92 ],
93 )
94 .map_err(|e| {
95 RingKernelError::BackendError(format!("Failed to load PTX module: {}", e))
96 })?;
97
98 Ok(Self { device, tile_size })
99 }
100
101 pub fn buffer_width(&self) -> u32 {
103 self.tile_size + 2
104 }
105
106 pub fn buffer_size_bytes(&self) -> usize {
108 let bw = self.buffer_width() as usize;
109 bw * bw * std::mem::size_of::<f32>()
110 }
111
112 pub fn halo_size_bytes(&self) -> usize {
114 self.tile_size as usize * std::mem::size_of::<f32>()
115 }
116
117 pub fn interior_size_bytes(&self) -> usize {
119 (self.tile_size * self.tile_size) as usize * std::mem::size_of::<f32>()
120 }
121
122 pub fn device(&self) -> &CudaDevice {
124 &self.device
125 }
126}
127
128impl TileGpuBackend for CudaTileBackend {
129 type Buffer = CudaBuffer;
130
131 fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>> {
132 let buffer_width = tile_size + 2;
133 let buffer_size = (buffer_width * buffer_width) as usize * std::mem::size_of::<f32>();
134 let halo_size = tile_size as usize * std::mem::size_of::<f32>();
135
136 let pressure_a = CudaBuffer::new(&self.device, buffer_size)?;
138 let pressure_b = CudaBuffer::new(&self.device, buffer_size)?;
139
140 let halo_north = CudaBuffer::new(&self.device, halo_size)?;
142 let halo_south = CudaBuffer::new(&self.device, halo_size)?;
143 let halo_west = CudaBuffer::new(&self.device, halo_size)?;
144 let halo_east = CudaBuffer::new(&self.device, halo_size)?;
145
146 let zeros_buffer = vec![0u8; buffer_size];
148 let zeros_halo = vec![0u8; halo_size];
149
150 pressure_a.copy_from_host(&zeros_buffer)?;
151 pressure_b.copy_from_host(&zeros_buffer)?;
152 halo_north.copy_from_host(&zeros_halo)?;
153 halo_south.copy_from_host(&zeros_halo)?;
154 halo_west.copy_from_host(&zeros_halo)?;
155 halo_east.copy_from_host(&zeros_halo)?;
156
157 Ok(TileGpuBuffers {
158 pressure_a,
159 pressure_b,
160 halo_north,
161 halo_south,
162 halo_west,
163 halo_east,
164 current_is_a: true,
165 tile_size,
166 buffer_width,
167 })
168 }
169
170 fn upload_initial_state(
171 &self,
172 buffers: &TileGpuBuffers<Self::Buffer>,
173 pressure: &[f32],
174 pressure_prev: &[f32],
175 ) -> Result<()> {
176 let pressure_bytes: &[u8] = bytemuck::cast_slice(pressure);
178 let pressure_prev_bytes: &[u8] = bytemuck::cast_slice(pressure_prev);
179
180 buffers.pressure_a.copy_from_host(pressure_bytes)?;
181 buffers.pressure_b.copy_from_host(pressure_prev_bytes)?;
182
183 Ok(())
184 }
185
186 fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()> {
187 use cudarc::driver::LaunchAsync;
188
189 let (current, prev) = if buffers.current_is_a {
191 (&buffers.pressure_a, &buffers.pressure_b)
192 } else {
193 (&buffers.pressure_b, &buffers.pressure_a)
194 };
195
196 let kernel_fn = self
198 .device
199 .inner()
200 .get_func(MODULE_NAME, FN_FDTD_STEP)
201 .ok_or_else(|| RingKernelError::BackendError("FDTD kernel not found".to_string()))?;
202
203 let cfg = cudarc::driver::LaunchConfig {
204 grid_dim: (1, 1, 1),
205 block_dim: (16, 16, 1),
206 shared_mem_bytes: 0,
207 };
208
209 let current_ptr = current.device_ptr();
211 let prev_ptr = prev.device_ptr();
212
213 unsafe { kernel_fn.launch(cfg, (current_ptr, prev_ptr, params.c2, params.damping)) }
215 .map_err(|e| {
216 RingKernelError::BackendError(format!("FDTD kernel launch failed: {}", e))
217 })?;
218
219 Ok(())
220 }
221
222 fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>> {
223 use cudarc::driver::LaunchAsync;
224
225 let current = if buffers.current_is_a {
227 &buffers.pressure_a
228 } else {
229 &buffers.pressure_b
230 };
231
232 let staging = buffers.halo_buffer(edge);
234
235 let kernel_fn = self
237 .device
238 .inner()
239 .get_func(MODULE_NAME, FN_EXTRACT_HALO)
240 .ok_or_else(|| {
241 RingKernelError::BackendError("Extract halo kernel not found".to_string())
242 })?;
243
244 let cfg = cudarc::driver::LaunchConfig {
245 grid_dim: (1, 1, 1),
246 block_dim: (16, 1, 1),
247 shared_mem_bytes: 0,
248 };
249
250 let current_ptr = current.device_ptr();
251 let staging_ptr = staging.device_ptr();
252 let edge_val = edge as i32;
253
254 unsafe { kernel_fn.launch(cfg, (current_ptr, staging_ptr, edge_val)) }.map_err(|e| {
256 RingKernelError::BackendError(format!("Extract halo kernel launch failed: {}", e))
257 })?;
258
259 self.device.synchronize()?;
261
262 let mut halo = vec![0u8; self.halo_size_bytes()];
263 staging.copy_to_host(&mut halo)?;
264
265 Ok(bytemuck::cast_slice(&halo).to_vec())
267 }
268
269 fn inject_halo(
270 &self,
271 buffers: &TileGpuBuffers<Self::Buffer>,
272 edge: Edge,
273 data: &[f32],
274 ) -> Result<()> {
275 use cudarc::driver::LaunchAsync;
276
277 let current = if buffers.current_is_a {
279 &buffers.pressure_a
280 } else {
281 &buffers.pressure_b
282 };
283
284 let staging = buffers.halo_buffer(edge);
286
287 let data_bytes: &[u8] = bytemuck::cast_slice(data);
289 staging.copy_from_host(data_bytes)?;
290
291 let kernel_fn = self
293 .device
294 .inner()
295 .get_func(MODULE_NAME, FN_INJECT_HALO)
296 .ok_or_else(|| {
297 RingKernelError::BackendError("Inject halo kernel not found".to_string())
298 })?;
299
300 let cfg = cudarc::driver::LaunchConfig {
301 grid_dim: (1, 1, 1),
302 block_dim: (16, 1, 1),
303 shared_mem_bytes: 0,
304 };
305
306 let current_ptr = current.device_ptr();
307 let staging_ptr = staging.device_ptr();
308 let edge_val = edge as i32;
309
310 unsafe { kernel_fn.launch(cfg, (current_ptr, staging_ptr, edge_val)) }.map_err(|e| {
312 RingKernelError::BackendError(format!("Inject halo kernel launch failed: {}", e))
313 })?;
314
315 Ok(())
316 }
317
318 fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>) {
319 buffers.current_is_a = !buffers.current_is_a;
320 }
321
322 fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>> {
323 use cudarc::driver::LaunchAsync;
324
325 let current = if buffers.current_is_a {
327 &buffers.pressure_a
328 } else {
329 &buffers.pressure_b
330 };
331
332 let output_buffer = CudaBuffer::new(&self.device, self.interior_size_bytes())?;
334
335 let kernel_fn = self
337 .device
338 .inner()
339 .get_func(MODULE_NAME, FN_READ_INTERIOR)
340 .ok_or_else(|| {
341 RingKernelError::BackendError("Read interior kernel not found".to_string())
342 })?;
343
344 let cfg = cudarc::driver::LaunchConfig {
345 grid_dim: (1, 1, 1),
346 block_dim: (16, 16, 1),
347 shared_mem_bytes: 0,
348 };
349
350 let current_ptr = current.device_ptr();
351 let output_ptr = output_buffer.device_ptr();
352
353 unsafe { kernel_fn.launch(cfg, (current_ptr, output_ptr)) }.map_err(|e| {
355 RingKernelError::BackendError(format!("Read interior kernel launch failed: {}", e))
356 })?;
357
358 self.device.synchronize()?;
360
361 let mut output = vec![0u8; self.interior_size_bytes()];
362 output_buffer.copy_to_host(&mut output)?;
363
364 Ok(bytemuck::cast_slice(&output).to_vec())
366 }
367
368 fn synchronize(&self) -> Result<()> {
369 self.device.synchronize()
370 }
371}
372
373#[cfg(test)]
374mod tests {
375 use super::*;
376
377 #[test]
378 #[ignore] fn test_cuda_backend_creation() {
380 let backend = CudaTileBackend::new(16).unwrap();
381 assert_eq!(backend.tile_size, 16);
382 assert_eq!(backend.buffer_width(), 18);
383 }
384
385 #[test]
386 #[ignore] fn test_cuda_buffer_creation() {
388 let backend = CudaTileBackend::new(16).unwrap();
389 let buffers = backend.create_tile_buffers(16).unwrap();
390
391 assert_eq!(buffers.tile_size, 16);
392 assert_eq!(buffers.buffer_width, 18);
393 assert!(buffers.current_is_a);
394 }
395
396 #[test]
397 #[ignore] fn test_cuda_fdtd_step() {
399 let backend = CudaTileBackend::new(16).unwrap();
400 let mut buffers = backend.create_tile_buffers(16).unwrap();
401
402 let buffer_size = 18 * 18;
404 let mut pressure = vec![0.0f32; buffer_size];
405 let pressure_prev = vec![0.0f32; buffer_size];
406
407 let center_idx = 9 * 18 + 9;
409 pressure[center_idx] = 1.0;
410
411 backend
413 .upload_initial_state(&buffers, &pressure, &pressure_prev)
414 .unwrap();
415
416 let params = FdtdParams::new(16, 0.25, 0.99);
418 backend.fdtd_step(&buffers, ¶ms).unwrap();
419 backend.swap_buffers(&mut buffers);
420
421 let result = backend.read_interior_pressure(&buffers).unwrap();
423
424 let center_interior = 8 * 16 + 8; assert!(
427 result[center_interior].abs() < 1.0,
428 "Center should have decreased"
429 );
430 }
431
432 #[test]
433 #[ignore] fn test_cuda_halo_exchange() {
435 let backend = CudaTileBackend::new(16).unwrap();
436 let buffers = backend.create_tile_buffers(16).unwrap();
437
438 let buffer_size = 18 * 18;
440 let mut pressure = vec![0.0f32; buffer_size];
441
442 for x in 0..16 {
444 let idx = 18 + (x + 1);
445 pressure[idx] = (x + 1) as f32;
446 }
447
448 backend
450 .upload_initial_state(&buffers, &pressure, &vec![0.0f32; buffer_size])
451 .unwrap();
452
453 let halo = backend.extract_halo(&buffers, Edge::North).unwrap();
455
456 assert_eq!(halo.len(), 16);
457 for (i, &v) in halo.iter().enumerate() {
458 assert_eq!(v, (i + 1) as f32, "Halo mismatch at {}", i);
459 }
460 }
461}