ringkernel_wavesim/simulation/
gpu_backend.rs

1//! Unified GPU backend trait for tile-based FDTD simulation.
2//!
3//! This module provides an abstraction layer over CUDA and WGPU backends,
4//! enabling GPU-resident state with minimal host transfers.
5//!
6//! ## Design
7//!
8//! The key optimization is keeping simulation state on the GPU:
9//! - Pressure buffers persist on GPU across steps (no per-step upload/download)
10//! - Only halo data (64 bytes per tile edge) transfers for K2K messaging
11//! - Full grid readback only when GUI needs to render
12//!
13//! ## Buffer Layout
14//!
15//! Each tile uses an 18x18 buffer (324 floats = 1,296 bytes):
16//!
17//! ```text
18//! +---+----------------+---+
19//! | NW|   North Halo   |NE |  <- Row 0 (from neighbor)
20//! +---+----------------+---+
21//! |   |                |   |
22//! | W |   16x16 Tile   | E |  <- Rows 1-16 (owned)
23//! |   |    Interior    |   |
24//! +---+----------------+---+
25//! | SW|   South Halo   |SE |  <- Row 17 (from neighbor)
26//! +---+----------------+---+
27//! ```
28
29use ringkernel_core::error::Result;
30
31/// Edge direction for halo exchange.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
33#[repr(u8)]
34pub enum Edge {
35    /// North edge (top row of interior)
36    North = 0,
37    /// South edge (bottom row of interior)
38    South = 1,
39    /// West edge (left column of interior)
40    West = 2,
41    /// East edge (right column of interior)
42    East = 3,
43}
44
45impl Edge {
46    /// Get the opposite edge (for K2K routing).
47    pub fn opposite(self) -> Self {
48        match self {
49            Edge::North => Edge::South,
50            Edge::South => Edge::North,
51            Edge::West => Edge::East,
52            Edge::East => Edge::West,
53        }
54    }
55
56    /// All edges in order.
57    pub const ALL: [Edge; 4] = [Edge::North, Edge::South, Edge::West, Edge::East];
58}
59
60/// Parameters for FDTD computation.
61#[derive(Debug, Clone, Copy)]
62pub struct FdtdParams {
63    /// Tile interior size (e.g., 16 for 16x16 tiles).
64    pub tile_size: u32,
65    /// Courant number squared (c² = (speed * dt / dx)²).
66    pub c2: f32,
67    /// Damping factor (1.0 - damping_coefficient).
68    pub damping: f32,
69}
70
71impl FdtdParams {
72    /// Create new FDTD parameters.
73    pub fn new(tile_size: u32, c2: f32, damping: f32) -> Self {
74        Self {
75            tile_size,
76            c2,
77            damping,
78        }
79    }
80}
81
82/// Unified trait for GPU backends (CUDA and WGPU).
83///
84/// Implementations must provide GPU-resident tile buffers with:
85/// - Double-buffered pressure (ping-pong)
86/// - Halo staging buffers for K2K exchange
87/// - Minimal transfer operations
88pub trait TileGpuBackend: Send + Sync {
89    /// GPU buffer type for this backend.
90    type Buffer: Send + Sync;
91
92    /// Create GPU buffers for a tile.
93    ///
94    /// Returns buffers for:
95    /// - pressure: 18x18 f32 buffer (current state)
96    /// - pressure_prev: 18x18 f32 buffer (previous state)
97    /// - halo_staging: 4 × 16 f32 buffers (one per edge)
98    fn create_tile_buffers(&self, tile_size: u32) -> Result<TileGpuBuffers<Self::Buffer>>;
99
100    /// Upload initial pressure data (one-time, at tile creation).
101    fn upload_initial_state(
102        &self,
103        buffers: &TileGpuBuffers<Self::Buffer>,
104        pressure: &[f32],
105        pressure_prev: &[f32],
106    ) -> Result<()>;
107
108    /// Execute FDTD step entirely on GPU (no host transfer).
109    ///
110    /// Reads from current buffer, writes to previous buffer (ping-pong).
111    fn fdtd_step(&self, buffers: &TileGpuBuffers<Self::Buffer>, params: &FdtdParams) -> Result<()>;
112
113    /// Extract halo from GPU to host (small transfer for K2K).
114    ///
115    /// Returns 16 f32 values for the specified edge.
116    fn extract_halo(&self, buffers: &TileGpuBuffers<Self::Buffer>, edge: Edge) -> Result<Vec<f32>>;
117
118    /// Inject neighbor halo from K2K message into GPU buffer.
119    ///
120    /// Takes 16 f32 values to write to the specified halo region.
121    fn inject_halo(
122        &self,
123        buffers: &TileGpuBuffers<Self::Buffer>,
124        edge: Edge,
125        data: &[f32],
126    ) -> Result<()>;
127
128    /// Swap ping-pong buffers (pointer swap, no data movement).
129    fn swap_buffers(&self, buffers: &mut TileGpuBuffers<Self::Buffer>);
130
131    /// Read full tile pressure for visualization.
132    ///
133    /// Only called when GUI needs to render (once per frame).
134    /// Returns 16x16 interior values (not the full 18x18 buffer).
135    fn read_interior_pressure(&self, buffers: &TileGpuBuffers<Self::Buffer>) -> Result<Vec<f32>>;
136
137    /// Synchronize GPU operations.
138    fn synchronize(&self) -> Result<()>;
139}
140
141/// GPU buffers for a single tile.
142///
143/// Uses generic Buffer type to support both CUDA and WGPU backends.
144pub struct TileGpuBuffers<B> {
145    /// Current pressure buffer (18x18 with halos).
146    pub pressure_a: B,
147    /// Previous pressure buffer (18x18 with halos).
148    pub pressure_b: B,
149    /// Staging buffers for halo extraction (4 edges × 16 cells).
150    pub halo_north: B,
151    pub halo_south: B,
152    pub halo_west: B,
153    pub halo_east: B,
154    /// Which buffer is "current" (true = A, false = B).
155    pub current_is_a: bool,
156    /// Tile interior size.
157    pub tile_size: u32,
158    /// Buffer width including halos.
159    pub buffer_width: u32,
160}
161
162impl<B> TileGpuBuffers<B> {
163    /// Get buffer width (tile_size + 2 for halos).
164    pub fn buffer_width(&self) -> u32 {
165        self.buffer_width
166    }
167
168    /// Get total buffer size in f32 elements.
169    pub fn buffer_size(&self) -> usize {
170        (self.buffer_width * self.buffer_width) as usize
171    }
172
173    /// Get halo buffer for an edge.
174    pub fn halo_buffer(&self, edge: Edge) -> &B {
175        match edge {
176            Edge::North => &self.halo_north,
177            Edge::South => &self.halo_south,
178            Edge::West => &self.halo_west,
179            Edge::East => &self.halo_east,
180        }
181    }
182
183    /// Get mutable halo buffer for an edge.
184    pub fn halo_buffer_mut(&mut self, edge: Edge) -> &mut B {
185        match edge {
186            Edge::North => &mut self.halo_north,
187            Edge::South => &mut self.halo_south,
188            Edge::West => &mut self.halo_west,
189            Edge::East => &mut self.halo_east,
190        }
191    }
192}
193
194/// Calculate buffer index for interior coordinates (0-indexed).
195#[inline(always)]
196pub fn buffer_index(local_x: usize, local_y: usize, buffer_width: usize) -> usize {
197    // Add 1 to account for halo border
198    (local_y + 1) * buffer_width + (local_x + 1)
199}
200
201/// Calculate halo row index (row 0 or row buffer_width-1).
202#[inline(always)]
203pub fn halo_row_start(edge: Edge, buffer_width: usize) -> usize {
204    match edge {
205        Edge::North => 1,                                     // Row 0, cols 1..tile_size+1
206        Edge::South => (buffer_width - 1) * buffer_width + 1, // Last row, cols 1..tile_size+1
207        Edge::West | Edge::East => unreachable!(),
208    }
209}
210
211/// Calculate halo column index.
212#[inline(always)]
213pub fn halo_col_index(edge: Edge, y: usize, buffer_width: usize) -> usize {
214    match edge {
215        Edge::West => (y + 1) * buffer_width, // Col 0
216        Edge::East => (y + 1) * buffer_width + buffer_width - 1, // Last col
217        Edge::North | Edge::South => unreachable!(),
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_edge_opposite() {
227        assert_eq!(Edge::North.opposite(), Edge::South);
228        assert_eq!(Edge::South.opposite(), Edge::North);
229        assert_eq!(Edge::East.opposite(), Edge::West);
230        assert_eq!(Edge::West.opposite(), Edge::East);
231    }
232
233    #[test]
234    fn test_buffer_index() {
235        // 18x18 buffer for 16x16 tile
236        let bw = 18;
237
238        // Interior (0,0) should be at buffer position (1,1) = row 1, col 1
239        assert_eq!(buffer_index(0, 0, bw), 18 + 1);
240        assert_eq!(buffer_index(0, 0, bw), 19);
241
242        // Interior (15,15) should be at buffer position (16,16)
243        assert_eq!(buffer_index(15, 15, bw), 16 * 18 + 16);
244        assert_eq!(buffer_index(15, 15, bw), 304);
245    }
246
247    #[test]
248    fn test_fdtd_params() {
249        let params = FdtdParams::new(16, 0.25, 0.99);
250        assert_eq!(params.tile_size, 16);
251        assert_eq!(params.c2, 0.25);
252        assert_eq!(params.damping, 0.99);
253    }
254}