ringkernel_wavesim/simulation/gpu_backend.rs
1//! Unified GPU backend trait for tile-based FDTD simulation.
2//!
3//! This module provides an abstraction layer over CUDA and WGPU backends,
4//! enabling GPU-resident state with minimal host transfers.
5//!
6//! ## Design
7//!
8//! The key optimization is keeping simulation state on the GPU:
9//! - Pressure buffers persist on GPU across steps (no per-step upload/download)
10//! - Only halo data (64 bytes per tile edge) transfers for K2K messaging
11//! - Full grid readback only when GUI needs to render
12//!
13//! ## Buffer Layout
14//!
15//! Each tile uses an 18x18 buffer (324 floats = 1,296 bytes):
16//!
17//! ```text
18//! +---+----------------+---+
19//! | NW| North Halo |NE | <- Row 0 (from neighbor)
20//! +---+----------------+---+
21//! | | | |
22//! | W | 16x16 Tile | E | <- Rows 1-16 (owned)
23//! | | Interior | |
24//! +---+----------------+---+
25//! | SW| South Halo |SE | <- Row 17 (from neighbor)
26//! +---+----------------+---+
27//! ```
28
29use ringkernel_core::error::Result;
30
31/// Edge direction for halo exchange.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33#[repr(u8)]
34pub enum Edge {
35 /// North edge (top row of interior)
36 North = 0,
37 /// South edge (bottom row of interior)
38 South = 1,
39 /// West edge (left column of interior)
40 West = 2,
41 /// East edge (right column of interior)
42 East = 3,
43}
44
45impl Edge {
46 /// Get the opposite edge (for K2K routing).
47 pub fn opposite(self) -> Self {
48 match self {
49 Edge::North => Edge::South,
50 Edge::South => Edge::North,
51 Edge::West => Edge::East,
52 Edge::East => Edge::West,
53 }
54 }
55
56 /// All edges in order.
57 pub const ALL: [Edge; 4] = [Edge::North, Edge::South, Edge::West, Edge::East];
58}
59
60/// Parameters for FDTD computation.
61#[derive(Debug, Clone, Copy)]
62pub struct FdtdParams {
63 /// Tile interior size (e.g., 16 for 16x16 tiles).
64 pub tile_size: u32,
65 /// Courant number squared (c² = (speed * dt / dx)²).
66 pub c2: f32,
67 /// Damping factor (1.0 - damping_coefficient).
68 pub damping: f32,
69}
70
71impl FdtdParams {
72 /// Create new FDTD parameters.
73 pub fn new(tile_size: u32, c2: f32, damping: f32) -> Self {
74 Self {
75 tile_size,
76 c2,
77 damping,
78 }
79 }
80}
81
82/// Unified trait for GPU backends (CUDA and WGPU).
83///
84/// Implementations must provide GPU-resident tile buffers with:
85/// - Double-buffered pressure (ping-pong)
86/// - Halo staging buffers for K2K exchange
87/// - Minimal transfer operations
88pub trait TileGpuBackend: Send + Sync {
89 /// GPU buffer type for this backend.
90 type Buffer: Send + Sync;
91
92 /// Create GPU buffers for a tile.
93 ///
94 /// Returns buffers for:
95 /// - pressure: 18x18 f32 buffer (current state)
96 /// - pressure_prev: 18x18 f32 buffer (previous state)
97 /// - halo_staging: 4 × 16 f32 buffers (one per edge)
98 fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>>;
99
100 /// Upload initial pressure data (one-time, at tile creation).
101 fn upload_initial_state(
102 &self,
103 buffers: &TileGpuBuffers<Self::Buffer>,
104 pressure: &[f32],
105 pressure_prev: &[f32],
106 ) -> Result<()>;
107
108 /// Execute FDTD step entirely on GPU (no host transfer).
109 ///
110 /// Reads from current buffer, writes to previous buffer (ping-pong).
111 fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()>;
112
113 /// Extract halo from GPU to host (small transfer for K2K).
114 ///
115 /// Returns 16 f32 values for the specified edge.
116 fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>>;
117
118 /// Inject neighbor halo from K2K message into GPU buffer.
119 ///
120 /// Takes 16 f32 values to write to the specified halo region.
121 fn inject_halo(
122 &self,
123 buffers: &TileGpuBuffers<Self::Buffer>,
124 edge: Edge,
125 data: &[f32],
126 ) -> Result<()>;
127
128 /// Swap ping-pong buffers (pointer swap, no data movement).
129 fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>);
130
131 /// Read full tile pressure for visualization.
132 ///
133 /// Only called when GUI needs to render (once per frame).
134 /// Returns 16x16 interior values (not the full 18x18 buffer).
135 fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>>;
136
137 /// Synchronize GPU operations.
138 fn synchronize(&self) -> Result<()>;
139}
140
141/// GPU buffers for a single tile.
142///
143/// Uses generic Buffer type to support both CUDA and WGPU backends.
144pub struct TileGpuBuffers<B> {
145 /// Current pressure buffer (18x18 with halos).
146 pub pressure_a: B,
147 /// Previous pressure buffer (18x18 with halos).
148 pub pressure_b: B,
149 /// Staging buffers for halo extraction (4 edges × 16 cells).
150 pub halo_north: B,
151 pub halo_south: B,
152 pub halo_west: B,
153 pub halo_east: B,
154 /// Which buffer is "current" (true = A, false = B).
155 pub current_is_a: bool,
156 /// Tile interior size.
157 pub tile_size: u32,
158 /// Buffer width including halos.
159 pub buffer_width: u32,
160}
161
162impl<B> TileGpuBuffers<B> {
163 /// Get buffer width (tile_size + 2 for halos).
164 pub fn buffer_width(&self) -> u32 {
165 self.buffer_width
166 }
167
168 /// Get total buffer size in f32 elements.
169 pub fn buffer_size(&self) -> usize {
170 (self.buffer_width * self.buffer_width) as usize
171 }
172
173 /// Get halo buffer for an edge.
174 pub fn halo_buffer(&self, edge: Edge) -> &B {
175 match edge {
176 Edge::North => &self.halo_north,
177 Edge::South => &self.halo_south,
178 Edge::West => &self.halo_west,
179 Edge::East => &self.halo_east,
180 }
181 }
182
183 /// Get mutable halo buffer for an edge.
184 pub fn halo_buffer_mut(&mut self, edge: Edge) -> &mut B {
185 match edge {
186 Edge::North => &mut self.halo_north,
187 Edge::South => &mut self.halo_south,
188 Edge::West => &mut self.halo_west,
189 Edge::East => &mut self.halo_east,
190 }
191 }
192}
193
194/// Calculate buffer index for interior coordinates (0-indexed).
195#[inline(always)]
196pub fn buffer_index(local_x: usize, local_y: usize, buffer_width: usize) -> usize {
197 // Add 1 to account for halo border
198 (local_y + 1) * buffer_width + (local_x + 1)
199}
200
201/// Calculate halo row index (row 0 or row buffer_width-1).
202#[inline(always)]
203pub fn halo_row_start(edge: Edge, buffer_width: usize) -> usize {
204 match edge {
205 Edge::North => 1, // Row 0, cols 1..tile_size+1
206 Edge::South => (buffer_width - 1) * buffer_width + 1, // Last row, cols 1..tile_size+1
207 Edge::West | Edge::East => unreachable!(),
208 }
209}
210
211/// Calculate halo column index.
212#[inline(always)]
213pub fn halo_col_index(edge: Edge, y: usize, buffer_width: usize) -> usize {
214 match edge {
215 Edge::West => (y + 1) * buffer_width, // Col 0
216 Edge::East => (y + 1) * buffer_width + buffer_width - 1, // Last col
217 Edge::North | Edge::South => unreachable!(),
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_edge_opposite() {
227 assert_eq!(Edge::North.opposite(), Edge::South);
228 assert_eq!(Edge::South.opposite(), Edge::North);
229 assert_eq!(Edge::East.opposite(), Edge::West);
230 assert_eq!(Edge::West.opposite(), Edge::East);
231 }
232
233 #[test]
234 fn test_buffer_index() {
235 // 18x18 buffer for 16x16 tile
236 let bw = 18;
237
238 // Interior (0,0) should be at buffer position (1,1) = row 1, col 1
239 assert_eq!(buffer_index(0, 0, bw), 18 + 1);
240 assert_eq!(buffer_index(0, 0, bw), 19);
241
242 // Interior (15,15) should be at buffer position (16,16)
243 assert_eq!(buffer_index(15, 15, bw), 16 * 18 + 16);
244 assert_eq!(buffer_index(15, 15, bw), 304);
245 }
246
247 #[test]
248 fn test_fdtd_params() {
249 let params = FdtdParams::new(16, 0.25, 0.99);
250 assert_eq!(params.tile_size, 16);
251 assert_eq!(params.c2, 0.25);
252 assert_eq!(params.damping, 0.99);
253 }
254}