Skip to main content

ringkernel_wavesim/simulation/
gpu_backend.rs

1//! Unified GPU backend trait for tile-based FDTD simulation.
2//!
3//! This module provides an abstraction layer over CUDA and WGPU backends,
4//! enabling GPU-resident state with minimal host transfers.
5//!
6//! ## Design
7//!
8//! The key optimization is keeping simulation state on the GPU:
9//! - Pressure buffers persist on GPU across steps (no per-step upload/download)
10//! - Only halo data (64 bytes per tile edge) transfers for K2K messaging
11//! - Full grid readback only when GUI needs to render
12//!
13//! ## Buffer Layout
14//!
15//! Each tile uses an 18x18 buffer (324 floats = 1,296 bytes):
16//!
17//! ```text
18//! +---+----------------+---+
19//! | NW|   North Halo   |NE |  <- Row 0 (from neighbor)
20//! +---+----------------+---+
21//! |   |                |   |
22//! | W |   16x16 Tile   | E |  <- Rows 1-16 (owned)
23//! |   |    Interior    |   |
24//! +---+----------------+---+
25//! | SW|   South Halo   |SE |  <- Row 17 (from neighbor)
26//! +---+----------------+---+
27//! ```
28
29use ringkernel_core::error::Result;
30
31/// Edge direction for halo exchange.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33#[repr(u8)]
34pub enum Edge {
35    /// North edge (top row of interior)
36    North = 0,
37    /// South edge (bottom row of interior)
38    South = 1,
39    /// West edge (left column of interior)
40    West = 2,
41    /// East edge (right column of interior)
42    East = 3,
43}
44
45/// Boundary condition type for domain edges.
46#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
47#[repr(u8)]
48pub enum BoundaryCondition {
49    /// Absorbing boundary - pressure goes to zero (default).
50    /// Simulates infinite domain where waves exit and don't return.
51    #[default]
52    Absorbing = 0,
53    /// Reflecting boundary - mirrors interior values to halo.
54    /// Simulates a hard wall where waves bounce back.
55    Reflecting = 1,
56    /// Periodic boundary - wraps around to opposite edge.
57    /// Simulates an infinite repeating pattern.
58    Periodic = 2,
59}
60
61impl Edge {
62    /// Get the opposite edge (for K2K routing).
63    pub fn opposite(self) -> Self {
64        match self {
65            Edge::North => Edge::South,
66            Edge::South => Edge::North,
67            Edge::West => Edge::East,
68            Edge::East => Edge::West,
69        }
70    }
71
72    /// All edges in order.
73    pub const ALL: [Edge; 4] = [Edge::North, Edge::South, Edge::West, Edge::East];
74}
75
76/// Parameters for FDTD computation.
77#[derive(Debug, Clone, Copy)]
78pub struct FdtdParams {
79    /// Tile interior size (e.g., 16 for 16x16 tiles).
80    pub tile_size: u32,
81    /// Courant number squared (c² = (speed * dt / dx)²).
82    pub c2: f32,
83    /// Damping factor (1.0 - damping_coefficient).
84    pub damping: f32,
85}
86
87impl FdtdParams {
88    /// Create new FDTD parameters.
89    pub fn new(tile_size: u32, c2: f32, damping: f32) -> Self {
90        Self {
91            tile_size,
92            c2,
93            damping,
94        }
95    }
96}
97
98/// Unified trait for GPU backends (CUDA and WGPU).
99///
100/// Implementations must provide GPU-resident tile buffers with:
101/// - Double-buffered pressure (ping-pong)
102/// - Halo staging buffers for K2K exchange
103/// - Minimal transfer operations
104pub trait TileGpuBackend: Send + Sync {
105    /// GPU buffer type for this backend.
106    type Buffer: Send + Sync;
107
108    /// Create GPU buffers for a tile.
109    ///
110    /// Returns buffers for:
111    /// - pressure: 18x18 f32 buffer (current state)
112    /// - pressure_prev: 18x18 f32 buffer (previous state)
113    /// - halo_staging: 4 × 16 f32 buffers (one per edge)
114    fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>>;
115
116    /// Upload initial pressure data (one-time, at tile creation).
117    fn upload_initial_state(
118        &self,
119        buffers: &TileGpuBuffers<Self::Buffer>,
120        pressure: &[f32],
121        pressure_prev: &[f32],
122    ) -> Result<()>;
123
124    /// Execute FDTD step entirely on GPU (no host transfer).
125    ///
126    /// Reads from current buffer, writes to previous buffer (ping-pong).
127    fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()>;
128
129    /// Extract halo from GPU to host (small transfer for K2K).
130    ///
131    /// Returns 16 f32 values for the specified edge.
132    fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>>;
133
134    /// Inject neighbor halo from K2K message into GPU buffer.
135    ///
136    /// Takes 16 f32 values to write to the specified halo region.
137    fn inject_halo(
138        &self,
139        buffers: &TileGpuBuffers<Self::Buffer>,
140        edge: Edge,
141        data: &[f32],
142    ) -> Result<()>;
143
144    /// Swap ping-pong buffers (pointer swap, no data movement).
145    fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>);
146
147    /// Read full tile pressure for visualization.
148    ///
149    /// Only called when GUI needs to render (once per frame).
150    /// Returns 16x16 interior values (not the full 18x18 buffer).
151    fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>>;
152
153    /// Apply boundary condition for a domain edge tile.
154    ///
155    /// For tiles at the domain boundary (without a neighbor on one or more edges),
156    /// this method applies the specified boundary condition to the halo region:
157    /// - Absorbing: Sets halo to zero (waves exit and don't return)
158    /// - Reflecting: Mirrors interior values to halo (hard wall)
159    /// - Periodic: Would need opposite edge data (handled separately)
160    ///
161    /// This is a GPU operation - more efficient than CPU-side manipulation.
162    fn apply_boundary(
163        &self,
164        buffers: &TileGpuBuffers<Self::Buffer>,
165        edge: Edge,
166        condition: BoundaryCondition,
167    ) -> Result<()>;
168
169    /// Synchronize GPU operations.
170    fn synchronize(&self) -> Result<()>;
171}
172
173/// GPU buffers for a single tile.
174///
175/// Uses generic Buffer type to support both CUDA and WGPU backends.
176pub struct TileGpuBuffers<B> {
177    /// Current pressure buffer (18x18 with halos).
178    pub pressure_a: B,
179    /// Previous pressure buffer (18x18 with halos).
180    pub pressure_b: B,
181    /// Staging buffers for halo extraction (4 edges × 16 cells).
182    pub halo_north: B,
183    pub halo_south: B,
184    pub halo_west: B,
185    pub halo_east: B,
186    /// Which buffer is "current" (true = A, false = B).
187    pub current_is_a: bool,
188    /// Tile interior size.
189    pub tile_size: u32,
190    /// Buffer width including halos.
191    pub buffer_width: u32,
192}
193
194impl<B> TileGpuBuffers<B> {
195    /// Get buffer width (tile_size + 2 for halos).
196    pub fn buffer_width(&self) -> u32 {
197        self.buffer_width
198    }
199
200    /// Get total buffer size in f32 elements.
201    pub fn buffer_size(&self) -> usize {
202        (self.buffer_width * self.buffer_width) as usize
203    }
204
205    /// Get halo buffer for an edge.
206    pub fn halo_buffer(&self, edge: Edge) -> &B {
207        match edge {
208            Edge::North => &self.halo_north,
209            Edge::South => &self.halo_south,
210            Edge::West => &self.halo_west,
211            Edge::East => &self.halo_east,
212        }
213    }
214
215    /// Get mutable halo buffer for an edge.
216    pub fn halo_buffer_mut(&mut self, edge: Edge) -> &mut B {
217        match edge {
218            Edge::North => &mut self.halo_north,
219            Edge::South => &mut self.halo_south,
220            Edge::West => &mut self.halo_west,
221            Edge::East => &mut self.halo_east,
222        }
223    }
224}
225
226/// Calculate buffer index for interior coordinates (0-indexed).
227#[inline(always)]
228pub fn buffer_index(local_x: usize, local_y: usize, buffer_width: usize) -> usize {
229    // Add 1 to account for halo border
230    (local_y + 1) * buffer_width + (local_x + 1)
231}
232
233/// Calculate halo row index (row 0 or row buffer_width-1).
234#[inline(always)]
235pub fn halo_row_start(edge: Edge, buffer_width: usize) -> usize {
236    match edge {
237        Edge::North => 1,                                     // Row 0, cols 1..tile_size+1
238        Edge::South => (buffer_width - 1) * buffer_width + 1, // Last row, cols 1..tile_size+1
239        Edge::West | Edge::East => unreachable!(),
240    }
241}
242
243/// Calculate halo column index.
244#[inline(always)]
245pub fn halo_col_index(edge: Edge, y: usize, buffer_width: usize) -> usize {
246    match edge {
247        Edge::West => (y + 1) * buffer_width, // Col 0
248        Edge::East => (y + 1) * buffer_width + buffer_width - 1, // Last col
249        Edge::North | Edge::South => unreachable!(),
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use super::*;
256
257    #[test]
258    fn test_edge_opposite() {
259        assert_eq!(Edge::North.opposite(), Edge::South);
260        assert_eq!(Edge::South.opposite(), Edge::North);
261        assert_eq!(Edge::East.opposite(), Edge::West);
262        assert_eq!(Edge::West.opposite(), Edge::East);
263    }
264
265    #[test]
266    fn test_buffer_index() {
267        // 18x18 buffer for 16x16 tile
268        let bw = 18;
269
270        // Interior (0,0) should be at buffer position (1,1) = row 1, col 1
271        assert_eq!(buffer_index(0, 0, bw), 18 + 1);
272        assert_eq!(buffer_index(0, 0, bw), 19);
273
274        // Interior (15,15) should be at buffer position (16,16)
275        assert_eq!(buffer_index(15, 15, bw), 16 * 18 + 16);
276        assert_eq!(buffer_index(15, 15, bw), 304);
277    }
278
279    #[test]
280    fn test_fdtd_params() {
281        let params = FdtdParams::new(16, 0.25, 0.99);
282        assert_eq!(params.tile_size, 16);
283        assert_eq!(params.c2, 0.25);
284        assert_eq!(params.damping, 0.99);
285    }
286}