ringkernel_wavesim/simulation/
cuda_compute.rs

1//! CUDA backend implementation for tile-based FDTD simulation.
2//!
3//! This module implements the `TileGpuBackend` trait for NVIDIA CUDA GPUs,
4//! providing GPU-resident tile buffers with minimal host transfers.
5//!
6//! ## Features
7//!
8//! - PTX compilation via NVRTC
9//! - Persistent GPU buffers (pressure stays on GPU)
10//! - Halo-only transfers (64 bytes per edge per step)
11//! - Full grid readback only when rendering
12//!
13//! ## Kernel Source
14//!
15//! When the `cuda-codegen` feature is enabled, kernels are generated from Rust DSL.
16//! Otherwise, handwritten CUDA source from `shaders/fdtd_tile.cu` is used.
17
18use ringkernel_core::error::{Result, RingKernelError};
19use ringkernel_core::memory::GpuBuffer;
20use ringkernel_cuda::{CudaBuffer, CudaDevice};
21
22use super::gpu_backend::{Edge, FdtdParams, TileGpuBackend, TileGpuBuffers};
23
24/// CUDA kernel source for tile FDTD.
25///
26/// Uses generated DSL kernels when `cuda-codegen` feature is enabled,
27/// otherwise falls back to handwritten CUDA source.
28#[cfg(feature = "cuda-codegen")]
29fn get_cuda_source() -> String {
30    super::kernels::generate_tile_kernels()
31}
32
33#[cfg(not(feature = "cuda-codegen"))]
34fn get_cuda_source() -> String {
35    include_str!("../shaders/fdtd_tile.cu").to_string()
36}
37
38/// Module name for loaded kernels.
39const MODULE_NAME: &str = "fdtd_tile";
40
41/// Kernel function names.
42const FN_FDTD_STEP: &str = "fdtd_tile_step";
43const FN_EXTRACT_HALO: &str = "extract_halo";
44const FN_INJECT_HALO: &str = "inject_halo";
45const FN_READ_INTERIOR: &str = "read_interior";
46
47/// CUDA tile GPU backend.
48///
49/// Manages CUDA device, compiled kernels, and provides the `TileGpuBackend` implementation.
50pub struct CudaTileBackend {
51    /// CUDA device.
52    device: CudaDevice,
53    /// Tile size (interior cells per side).
54    tile_size: u32,
55}
56
57impl CudaTileBackend {
58    /// Create a new CUDA tile backend.
59    ///
60    /// Compiles the FDTD kernels using NVRTC.
61    pub fn new(tile_size: u32) -> Result<Self> {
62        Self::with_device(0, tile_size)
63    }
64
65    /// Create a CUDA tile backend on a specific device.
66    pub fn with_device(ordinal: usize, tile_size: u32) -> Result<Self> {
67        let device = CudaDevice::new(ordinal)?;
68
69        tracing::info!(
70            "CUDA tile backend: {} (CC {}.{})",
71            device.name(),
72            device.compute_capability().0,
73            device.compute_capability().1
74        );
75
76        // Compile CUDA source to PTX and load module
77        let cuda_source = get_cuda_source();
78        let ptx = cudarc::nvrtc::compile_ptx(&cuda_source).map_err(|e| {
79            RingKernelError::BackendError(format!("NVRTC compilation failed: {}", e))
80        })?;
81
82        device
83            .inner()
84            .load_ptx(
85                ptx,
86                MODULE_NAME,
87                &[
88                    FN_FDTD_STEP,
89                    FN_EXTRACT_HALO,
90                    FN_INJECT_HALO,
91                    FN_READ_INTERIOR,
92                ],
93            )
94            .map_err(|e| {
95                RingKernelError::BackendError(format!("Failed to load PTX module: {}", e))
96            })?;
97
98        Ok(Self { device, tile_size })
99    }
100
101    /// Get buffer width (tile_size + 2 for halos).
102    pub fn buffer_width(&self) -> u32 {
103        self.tile_size + 2
104    }
105
106    /// Get buffer size in bytes (18x18 floats for 16x16 tile).
107    pub fn buffer_size_bytes(&self) -> usize {
108        let bw = self.buffer_width() as usize;
109        bw * bw * std::mem::size_of::<f32>()
110    }
111
112    /// Get halo buffer size in bytes (16 floats).
113    pub fn halo_size_bytes(&self) -> usize {
114        self.tile_size as usize * std::mem::size_of::<f32>()
115    }
116
117    /// Get interior buffer size in bytes (16x16 floats).
118    pub fn interior_size_bytes(&self) -> usize {
119        (self.tile_size * self.tile_size) as usize * std::mem::size_of::<f32>()
120    }
121
122    /// Get device reference.
123    pub fn device(&self) -> &CudaDevice {
124        &self.device
125    }
126}
127
128impl TileGpuBackend for CudaTileBackend {
129    type Buffer = CudaBuffer;
130
131    fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>> {
132        let buffer_width = tile_size + 2;
133        let buffer_size = (buffer_width * buffer_width) as usize * std::mem::size_of::<f32>();
134        let halo_size = tile_size as usize * std::mem::size_of::<f32>();
135
136        // Allocate pressure buffers (18x18 for 16x16 tile)
137        let pressure_a = CudaBuffer::new(&self.device, buffer_size)?;
138        let pressure_b = CudaBuffer::new(&self.device, buffer_size)?;
139
140        // Allocate halo staging buffers (16 floats each)
141        let halo_north = CudaBuffer::new(&self.device, halo_size)?;
142        let halo_south = CudaBuffer::new(&self.device, halo_size)?;
143        let halo_west = CudaBuffer::new(&self.device, halo_size)?;
144        let halo_east = CudaBuffer::new(&self.device, halo_size)?;
145
146        // Initialize buffers to zero
147        let zeros_buffer = vec![0u8; buffer_size];
148        let zeros_halo = vec![0u8; halo_size];
149
150        pressure_a.copy_from_host(&zeros_buffer)?;
151        pressure_b.copy_from_host(&zeros_buffer)?;
152        halo_north.copy_from_host(&zeros_halo)?;
153        halo_south.copy_from_host(&zeros_halo)?;
154        halo_west.copy_from_host(&zeros_halo)?;
155        halo_east.copy_from_host(&zeros_halo)?;
156
157        Ok(TileGpuBuffers {
158            pressure_a,
159            pressure_b,
160            halo_north,
161            halo_south,
162            halo_west,
163            halo_east,
164            current_is_a: true,
165            tile_size,
166            buffer_width,
167        })
168    }
169
170    fn upload_initial_state(
171        &self,
172        buffers: &TileGpuBuffers<Self::Buffer>,
173        pressure: &[f32],
174        pressure_prev: &[f32],
175    ) -> Result<()> {
176        // Convert f32 slices to byte slices
177        let pressure_bytes: &[u8] = bytemuck::cast_slice(pressure);
178        let pressure_prev_bytes: &[u8] = bytemuck::cast_slice(pressure_prev);
179
180        buffers.pressure_a.copy_from_host(pressure_bytes)?;
181        buffers.pressure_b.copy_from_host(pressure_prev_bytes)?;
182
183        Ok(())
184    }
185
186    fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()> {
187        use cudarc::driver::LaunchAsync;
188
189        // Ping-pong: read from current, write to previous
190        let (current, prev) = if buffers.current_is_a {
191            (&buffers.pressure_a, &buffers.pressure_b)
192        } else {
193            (&buffers.pressure_b, &buffers.pressure_a)
194        };
195
196        // Get kernel function
197        let kernel_fn = self
198            .device
199            .inner()
200            .get_func(MODULE_NAME, FN_FDTD_STEP)
201            .ok_or_else(|| RingKernelError::BackendError("FDTD kernel not found".to_string()))?;
202
203        let cfg = cudarc::driver::LaunchConfig {
204            grid_dim: (1, 1, 1),
205            block_dim: (16, 16, 1),
206            shared_mem_bytes: 0,
207        };
208
209        // Get device pointers
210        let current_ptr = current.device_ptr();
211        let prev_ptr = prev.device_ptr();
212
213        // Launch kernel
214        unsafe { kernel_fn.launch(cfg, (current_ptr, prev_ptr, params.c2, params.damping)) }
215            .map_err(|e| {
216                RingKernelError::BackendError(format!("FDTD kernel launch failed: {}", e))
217            })?;
218
219        Ok(())
220    }
221
222    fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>> {
223        use cudarc::driver::LaunchAsync;
224
225        // Get current pressure buffer
226        let current = if buffers.current_is_a {
227            &buffers.pressure_a
228        } else {
229            &buffers.pressure_b
230        };
231
232        // Get staging buffer for this edge
233        let staging = buffers.halo_buffer(edge);
234
235        // Get kernel function
236        let kernel_fn = self
237            .device
238            .inner()
239            .get_func(MODULE_NAME, FN_EXTRACT_HALO)
240            .ok_or_else(|| {
241                RingKernelError::BackendError("Extract halo kernel not found".to_string())
242            })?;
243
244        let cfg = cudarc::driver::LaunchConfig {
245            grid_dim: (1, 1, 1),
246            block_dim: (16, 1, 1),
247            shared_mem_bytes: 0,
248        };
249
250        let current_ptr = current.device_ptr();
251        let staging_ptr = staging.device_ptr();
252        let edge_val = edge as i32;
253
254        // Launch kernel
255        unsafe { kernel_fn.launch(cfg, (current_ptr, staging_ptr, edge_val)) }.map_err(|e| {
256            RingKernelError::BackendError(format!("Extract halo kernel launch failed: {}", e))
257        })?;
258
259        // Synchronize and read back
260        self.device.synchronize()?;
261
262        let mut halo = vec![0u8; self.halo_size_bytes()];
263        staging.copy_to_host(&mut halo)?;
264
265        // Convert bytes to f32
266        Ok(bytemuck::cast_slice(&halo).to_vec())
267    }
268
269    fn inject_halo(
270        &self,
271        buffers: &TileGpuBuffers<Self::Buffer>,
272        edge: Edge,
273        data: &[f32],
274    ) -> Result<()> {
275        use cudarc::driver::LaunchAsync;
276
277        // Get current pressure buffer
278        let current = if buffers.current_is_a {
279            &buffers.pressure_a
280        } else {
281            &buffers.pressure_b
282        };
283
284        // Get staging buffer for this edge
285        let staging = buffers.halo_buffer(edge);
286
287        // Upload halo data to staging buffer
288        let data_bytes: &[u8] = bytemuck::cast_slice(data);
289        staging.copy_from_host(data_bytes)?;
290
291        // Get kernel function
292        let kernel_fn = self
293            .device
294            .inner()
295            .get_func(MODULE_NAME, FN_INJECT_HALO)
296            .ok_or_else(|| {
297                RingKernelError::BackendError("Inject halo kernel not found".to_string())
298            })?;
299
300        let cfg = cudarc::driver::LaunchConfig {
301            grid_dim: (1, 1, 1),
302            block_dim: (16, 1, 1),
303            shared_mem_bytes: 0,
304        };
305
306        let current_ptr = current.device_ptr();
307        let staging_ptr = staging.device_ptr();
308        let edge_val = edge as i32;
309
310        // Launch kernel
311        unsafe { kernel_fn.launch(cfg, (current_ptr, staging_ptr, edge_val)) }.map_err(|e| {
312            RingKernelError::BackendError(format!("Inject halo kernel launch failed: {}", e))
313        })?;
314
315        Ok(())
316    }
317
318    fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>) {
319        buffers.current_is_a = !buffers.current_is_a;
320    }
321
322    fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>> {
323        use cudarc::driver::LaunchAsync;
324
325        // Get current pressure buffer
326        let current = if buffers.current_is_a {
327            &buffers.pressure_a
328        } else {
329            &buffers.pressure_b
330        };
331
332        // Create temporary buffer for readback (16x16 = 256 floats)
333        let output_buffer = CudaBuffer::new(&self.device, self.interior_size_bytes())?;
334
335        // Get kernel function
336        let kernel_fn = self
337            .device
338            .inner()
339            .get_func(MODULE_NAME, FN_READ_INTERIOR)
340            .ok_or_else(|| {
341                RingKernelError::BackendError("Read interior kernel not found".to_string())
342            })?;
343
344        let cfg = cudarc::driver::LaunchConfig {
345            grid_dim: (1, 1, 1),
346            block_dim: (16, 16, 1),
347            shared_mem_bytes: 0,
348        };
349
350        let current_ptr = current.device_ptr();
351        let output_ptr = output_buffer.device_ptr();
352
353        // Launch kernel
354        unsafe { kernel_fn.launch(cfg, (current_ptr, output_ptr)) }.map_err(|e| {
355            RingKernelError::BackendError(format!("Read interior kernel launch failed: {}", e))
356        })?;
357
358        // Synchronize and read back
359        self.device.synchronize()?;
360
361        let mut output = vec![0u8; self.interior_size_bytes()];
362        output_buffer.copy_to_host(&mut output)?;
363
364        // Convert bytes to f32
365        Ok(bytemuck::cast_slice(&output).to_vec())
366    }
367
368    fn synchronize(&self) -> Result<()> {
369        self.device.synchronize()
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    #[ignore] // Requires CUDA hardware
379    fn test_cuda_backend_creation() {
380        let backend = CudaTileBackend::new(16).unwrap();
381        assert_eq!(backend.tile_size, 16);
382        assert_eq!(backend.buffer_width(), 18);
383    }
384
385    #[test]
386    #[ignore] // Requires CUDA hardware
387    fn test_cuda_buffer_creation() {
388        let backend = CudaTileBackend::new(16).unwrap();
389        let buffers = backend.create_tile_buffers(16).unwrap();
390
391        assert_eq!(buffers.tile_size, 16);
392        assert_eq!(buffers.buffer_width, 18);
393        assert!(buffers.current_is_a);
394    }
395
396    #[test]
397    #[ignore] // Requires CUDA hardware
398    fn test_cuda_fdtd_step() {
399        let backend = CudaTileBackend::new(16).unwrap();
400        let mut buffers = backend.create_tile_buffers(16).unwrap();
401
402        // Create test data with impulse at center
403        let buffer_size = 18 * 18;
404        let mut pressure = vec![0.0f32; buffer_size];
405        let pressure_prev = vec![0.0f32; buffer_size];
406
407        // Set center cell to 1.0 (interior center is at buffer position 9,9)
408        let center_idx = 9 * 18 + 9;
409        pressure[center_idx] = 1.0;
410
411        // Upload initial state
412        backend
413            .upload_initial_state(&buffers, &pressure, &pressure_prev)
414            .unwrap();
415
416        // Run FDTD step
417        let params = FdtdParams::new(16, 0.25, 0.99);
418        backend.fdtd_step(&buffers, &params).unwrap();
419        backend.swap_buffers(&mut buffers);
420
421        // Read back result
422        let result = backend.read_interior_pressure(&buffers).unwrap();
423
424        // Center should have decreased, neighbors should have increased
425        let center_interior = 8 * 16 + 8; // Interior center
426        assert!(
427            result[center_interior].abs() < 1.0,
428            "Center should have decreased"
429        );
430    }
431
432    #[test]
433    #[ignore] // Requires CUDA hardware
434    fn test_cuda_halo_exchange() {
435        let backend = CudaTileBackend::new(16).unwrap();
436        let buffers = backend.create_tile_buffers(16).unwrap();
437
438        // Create test data with gradient
439        let buffer_size = 18 * 18;
440        let mut pressure = vec![0.0f32; buffer_size];
441
442        // Fill first interior row with values 1..16
443        for x in 0..16 {
444            let idx = 18 + (x + 1);
445            pressure[idx] = (x + 1) as f32;
446        }
447
448        // Upload
449        backend
450            .upload_initial_state(&buffers, &pressure, &vec![0.0f32; buffer_size])
451            .unwrap();
452
453        // Extract north edge
454        let halo = backend.extract_halo(&buffers, Edge::North).unwrap();
455
456        assert_eq!(halo.len(), 16);
457        for (i, &v) in halo.iter().enumerate() {
458            assert_eq!(v, (i + 1) as f32, "Halo mismatch at {}", i);
459        }
460    }
461}