ringkernel_wavesim/simulation/gpu_backend.rs
1//! Unified GPU backend trait for tile-based FDTD simulation.
2//!
3//! This module provides an abstraction layer over CUDA and WGPU backends,
4//! enabling GPU-resident state with minimal host transfers.
5//!
6//! ## Design
7//!
8//! The key optimization is keeping simulation state on the GPU:
9//! - Pressure buffers persist on GPU across steps (no per-step upload/download)
10//! - Only halo data (64 bytes per tile edge) transfers for K2K messaging
11//! - Full grid readback only when GUI needs to render
12//!
13//! ## Buffer Layout
14//!
15//! Each tile uses an 18x18 buffer (324 floats = 1,296 bytes):
16//!
17//! ```text
18//! +---+----------------+---+
19//! | NW| North Halo |NE | <- Row 0 (from neighbor)
20//! +---+----------------+---+
21//! | | | |
22//! | W | 16x16 Tile | E | <- Rows 1-16 (owned)
23//! | | Interior | |
24//! +---+----------------+---+
25//! | SW| South Halo |SE | <- Row 17 (from neighbor)
26//! +---+----------------+---+
27//! ```
28
29use ringkernel_core::error::Result;
30
31/// Edge direction for halo exchange.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33#[repr(u8)]
34pub enum Edge {
35 /// North edge (top row of interior)
36 North = 0,
37 /// South edge (bottom row of interior)
38 South = 1,
39 /// West edge (left column of interior)
40 West = 2,
41 /// East edge (right column of interior)
42 East = 3,
43}
44
45/// Boundary condition type for domain edges.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum BoundaryCondition {
49 /// Absorbing boundary - pressure goes to zero (default).
50 /// Simulates infinite domain where waves exit and don't return.
51 #[default]
52 Absorbing = 0,
53 /// Reflecting boundary - mirrors interior values to halo.
54 /// Simulates a hard wall where waves bounce back.
55 Reflecting = 1,
56 /// Periodic boundary - wraps around to opposite edge.
57 /// Simulates an infinite repeating pattern.
58 Periodic = 2,
59}
60
61impl Edge {
62 /// Get the opposite edge (for K2K routing).
63 pub fn opposite(self) -> Self {
64 match self {
65 Edge::North => Edge::South,
66 Edge::South => Edge::North,
67 Edge::West => Edge::East,
68 Edge::East => Edge::West,
69 }
70 }
71
72 /// All edges in order.
73 pub const ALL: [Edge; 4] = [Edge::North, Edge::South, Edge::West, Edge::East];
74}
75
76/// Parameters for FDTD computation.
77#[derive(Debug, Clone, Copy)]
78pub struct FdtdParams {
79 /// Tile interior size (e.g., 16 for 16x16 tiles).
80 pub tile_size: u32,
81 /// Courant number squared (c² = (speed * dt / dx)²).
82 pub c2: f32,
83 /// Damping factor (1.0 - damping_coefficient).
84 pub damping: f32,
85}
86
87impl FdtdParams {
88 /// Create new FDTD parameters.
89 pub fn new(tile_size: u32, c2: f32, damping: f32) -> Self {
90 Self {
91 tile_size,
92 c2,
93 damping,
94 }
95 }
96}
97
98/// Unified trait for GPU backends (CUDA and WGPU).
99///
100/// Implementations must provide GPU-resident tile buffers with:
101/// - Double-buffered pressure (ping-pong)
102/// - Halo staging buffers for K2K exchange
103/// - Minimal transfer operations
104pub trait TileGpuBackend: Send + Sync {
105 /// GPU buffer type for this backend.
106 type Buffer: Send + Sync;
107
108 /// Create GPU buffers for a tile.
109 ///
110 /// Returns buffers for:
111 /// - pressure: 18x18 f32 buffer (current state)
112 /// - pressure_prev: 18x18 f32 buffer (previous state)
113 /// - halo_staging: 4 × 16 f32 buffers (one per edge)
114 fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>>;
115
116 /// Upload initial pressure data (one-time, at tile creation).
117 fn upload_initial_state(
118 &self,
119 buffers: &TileGpuBuffers<Self::Buffer>,
120 pressure: &[f32],
121 pressure_prev: &[f32],
122 ) -> Result<()>;
123
124 /// Execute FDTD step entirely on GPU (no host transfer).
125 ///
126 /// Reads from current buffer, writes to previous buffer (ping-pong).
127 fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()>;
128
129 /// Extract halo from GPU to host (small transfer for K2K).
130 ///
131 /// Returns 16 f32 values for the specified edge.
132 fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>>;
133
134 /// Inject neighbor halo from K2K message into GPU buffer.
135 ///
136 /// Takes 16 f32 values to write to the specified halo region.
137 fn inject_halo(
138 &self,
139 buffers: &TileGpuBuffers<Self::Buffer>,
140 edge: Edge,
141 data: &[f32],
142 ) -> Result<()>;
143
144 /// Swap ping-pong buffers (pointer swap, no data movement).
145 fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>);
146
147 /// Read full tile pressure for visualization.
148 ///
149 /// Only called when GUI needs to render (once per frame).
150 /// Returns 16x16 interior values (not the full 18x18 buffer).
151 fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>>;
152
153 /// Apply boundary condition for a domain edge tile.
154 ///
155 /// For tiles at the domain boundary (without a neighbor on one or more edges),
156 /// this method applies the specified boundary condition to the halo region:
157 /// - Absorbing: Sets halo to zero (waves exit and don't return)
158 /// - Reflecting: Mirrors interior values to halo (hard wall)
159 /// - Periodic: Would need opposite edge data (handled separately)
160 ///
161 /// This is a GPU operation - more efficient than CPU-side manipulation.
162 fn apply_boundary(
163 &self,
164 buffers: &TileGpuBuffers<Self::Buffer>,
165 edge: Edge,
166 condition: BoundaryCondition,
167 ) -> Result<()>;
168
169 /// Synchronize GPU operations.
170 fn synchronize(&self) -> Result<()>;
171}
172
173/// GPU buffers for a single tile.
174///
175/// Uses generic Buffer type to support both CUDA and WGPU backends.
176pub struct TileGpuBuffers<B> {
177 /// Current pressure buffer (18x18 with halos).
178 pub pressure_a: B,
179 /// Previous pressure buffer (18x18 with halos).
180 pub pressure_b: B,
181 /// Staging buffers for halo extraction (4 edges × 16 cells).
182 pub halo_north: B,
183 pub halo_south: B,
184 pub halo_west: B,
185 pub halo_east: B,
186 /// Which buffer is "current" (true = A, false = B).
187 pub current_is_a: bool,
188 /// Tile interior size.
189 pub tile_size: u32,
190 /// Buffer width including halos.
191 pub buffer_width: u32,
192}
193
194impl<B> TileGpuBuffers<B> {
195 /// Get buffer width (tile_size + 2 for halos).
196 pub fn buffer_width(&self) -> u32 {
197 self.buffer_width
198 }
199
200 /// Get total buffer size in f32 elements.
201 pub fn buffer_size(&self) -> usize {
202 (self.buffer_width * self.buffer_width) as usize
203 }
204
205 /// Get halo buffer for an edge.
206 pub fn halo_buffer(&self, edge: Edge) -> &B {
207 match edge {
208 Edge::North => &self.halo_north,
209 Edge::South => &self.halo_south,
210 Edge::West => &self.halo_west,
211 Edge::East => &self.halo_east,
212 }
213 }
214
215 /// Get mutable halo buffer for an edge.
216 pub fn halo_buffer_mut(&mut self, edge: Edge) -> &mut B {
217 match edge {
218 Edge::North => &mut self.halo_north,
219 Edge::South => &mut self.halo_south,
220 Edge::West => &mut self.halo_west,
221 Edge::East => &mut self.halo_east,
222 }
223 }
224}
225
226/// Calculate buffer index for interior coordinates (0-indexed).
227#[inline(always)]
228pub fn buffer_index(local_x: usize, local_y: usize, buffer_width: usize) -> usize {
229 // Add 1 to account for halo border
230 (local_y + 1) * buffer_width + (local_x + 1)
231}
232
233/// Calculate halo row index (row 0 or row buffer_width-1).
234#[inline(always)]
235pub fn halo_row_start(edge: Edge, buffer_width: usize) -> usize {
236 match edge {
237 Edge::North => 1, // Row 0, cols 1..tile_size+1
238 Edge::South => (buffer_width - 1) * buffer_width + 1, // Last row, cols 1..tile_size+1
239 Edge::West | Edge::East => unreachable!(),
240 }
241}
242
243/// Calculate halo column index.
244#[inline(always)]
245pub fn halo_col_index(edge: Edge, y: usize, buffer_width: usize) -> usize {
246 match edge {
247 Edge::West => (y + 1) * buffer_width, // Col 0
248 Edge::East => (y + 1) * buffer_width + buffer_width - 1, // Last col
249 Edge::North | Edge::South => unreachable!(),
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use super::*;
256
257 #[test]
258 fn test_edge_opposite() {
259 assert_eq!(Edge::North.opposite(), Edge::South);
260 assert_eq!(Edge::South.opposite(), Edge::North);
261 assert_eq!(Edge::East.opposite(), Edge::West);
262 assert_eq!(Edge::West.opposite(), Edge::East);
263 }
264
265 #[test]
266 fn test_buffer_index() {
267 // 18x18 buffer for 16x16 tile
268 let bw = 18;
269
270 // Interior (0,0) should be at buffer position (1,1) = row 1, col 1
271 assert_eq!(buffer_index(0, 0, bw), 18 + 1);
272 assert_eq!(buffer_index(0, 0, bw), 19);
273
274 // Interior (15,15) should be at buffer position (16,16)
275 assert_eq!(buffer_index(15, 15, bw), 16 * 18 + 16);
276 assert_eq!(buffer_index(15, 15, bw), 304);
277 }
278
279 #[test]
280 fn test_fdtd_params() {
281 let params = FdtdParams::new(16, 0.25, 0.99);
282 assert_eq!(params.tile_size, 16);
283 assert_eq!(params.c2, 0.25);
284 assert_eq!(params.damping, 0.99);
285 }
286}