Skip to main content

ringkernel_cuda_codegen/
stencil.rs

1//! Stencil-specific patterns and configuration for CUDA code generation.
2//!
3//! This module provides the configuration types and patterns for generating
4//! stencil/grid kernels that follow the WaveSim FDTD style.
5
6/// Grid dimensionality for stencil kernels.
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Grid {
9    /// 1D grid (line).
10    Grid1D,
11    /// 2D grid (plane).
12    Grid2D,
13    /// 3D grid (volume).
14    Grid3D,
15}
16
17impl Grid {
18    /// Parse from string representation.
19    pub fn parse(s: &str) -> Option<Self> {
20        match s.to_lowercase().as_str() {
21            "1d" | "grid1d" => Some(Grid::Grid1D),
22            "2d" | "grid2d" => Some(Grid::Grid2D),
23            "3d" | "grid3d" => Some(Grid::Grid3D),
24            _ => None,
25        }
26    }
27
28    /// Get the number of dimensions.
29    pub fn dimensions(&self) -> usize {
30        match self {
31            Grid::Grid1D => 1,
32            Grid::Grid2D => 2,
33            Grid::Grid3D => 3,
34        }
35    }
36}
37
38/// Configuration for a stencil kernel.
39#[derive(Debug, Clone)]
40pub struct StencilConfig {
41    /// Kernel identifier.
42    pub id: String,
43    /// Grid dimensionality.
44    pub grid: Grid,
45    /// Tile/block size (width, height) for 2D, or (x, y, z) for 3D.
46    pub tile_size: (usize, usize),
47    /// Halo/ghost cell width (stencil radius).
48    pub halo: usize,
49}
50
51impl Default for StencilConfig {
52    fn default() -> Self {
53        Self {
54            id: "stencil_kernel".to_string(),
55            grid: Grid::Grid2D,
56            tile_size: (16, 16),
57            halo: 1,
58        }
59    }
60}
61
62impl StencilConfig {
63    /// Create a new stencil configuration.
64    pub fn new(id: impl Into<String>) -> Self {
65        Self {
66            id: id.into(),
67            ..Default::default()
68        }
69    }
70
71    /// Set the grid dimensionality.
72    pub fn with_grid(mut self, grid: Grid) -> Self {
73        self.grid = grid;
74        self
75    }
76
77    /// Set the tile size.
78    pub fn with_tile_size(mut self, width: usize, height: usize) -> Self {
79        self.tile_size = (width, height);
80        self
81    }
82
83    /// Set the halo width.
84    pub fn with_halo(mut self, halo: usize) -> Self {
85        self.halo = halo;
86        self
87    }
88
89    /// Get the buffer width including halos.
90    pub fn buffer_width(&self) -> usize {
91        self.tile_size.0 + 2 * self.halo
92    }
93
94    /// Get the buffer height including halos (for 2D+).
95    pub fn buffer_height(&self) -> usize {
96        self.tile_size.1 + 2 * self.halo
97    }
98
99    /// Generate the CUDA kernel preamble with thread index calculations.
100    pub fn generate_preamble(&self) -> String {
101        match self.grid {
102            Grid::Grid1D => self.generate_1d_preamble(),
103            Grid::Grid2D => self.generate_2d_preamble(),
104            Grid::Grid3D => self.generate_3d_preamble(),
105        }
106    }
107
108    fn generate_1d_preamble(&self) -> String {
109        let tile_size = self.tile_size.0;
110        let buffer_width = self.buffer_width();
111
112        format!(
113            r#"    int lx = threadIdx.x;
114    if (lx >= {tile_size}) return;
115
116    int buffer_width = {buffer_width};
117    int idx = lx + {halo};
118"#,
119            tile_size = tile_size,
120            buffer_width = buffer_width,
121            halo = self.halo,
122        )
123    }
124
125    fn generate_2d_preamble(&self) -> String {
126        let (tile_w, tile_h) = self.tile_size;
127        let buffer_width = self.buffer_width();
128
129        format!(
130            r#"    int lx = threadIdx.x;
131    int ly = threadIdx.y;
132    if (lx >= {tile_w} || ly >= {tile_h}) return;
133
134    int buffer_width = {buffer_width};
135    int idx = (ly + {halo}) * buffer_width + (lx + {halo});
136"#,
137            tile_w = tile_w,
138            tile_h = tile_h,
139            buffer_width = buffer_width,
140            halo = self.halo,
141        )
142    }
143
144    fn generate_3d_preamble(&self) -> String {
145        let (tile_w, tile_h) = self.tile_size;
146        let buffer_width = self.buffer_width();
147        let buffer_height = self.buffer_height();
148
149        format!(
150            r#"    int lx = threadIdx.x;
151    int ly = threadIdx.y;
152    int lz = threadIdx.z;
153    if (lx >= {tile_w} || ly >= {tile_h}) return;
154
155    int buffer_width = {buffer_width};
156    int buffer_height = {buffer_height};
157    int buffer_slice = buffer_width * buffer_height;
158    int idx = (lz + {halo}) * buffer_slice + (ly + {halo}) * buffer_width + (lx + {halo});
159"#,
160            tile_w = tile_w,
161            tile_h = tile_h,
162            buffer_width = buffer_width,
163            buffer_height = buffer_height,
164            halo = self.halo,
165        )
166    }
167
168    /// Generate launch bounds for the kernel.
169    pub fn generate_launch_bounds(&self) -> String {
170        let threads = self.tile_size.0 * self.tile_size.1;
171        format!("__launch_bounds__({threads})")
172    }
173}
174
175/// Context for grid position within a stencil kernel.
176///
177/// This is a marker type that gets translated to CUDA index calculations.
178/// In Rust, it provides the API; in generated CUDA, it becomes inline code.
179#[derive(Debug, Clone, Copy)]
180pub struct GridPos {
181    // Phantom - this struct is never actually instantiated in GPU code
182    _private: (),
183}
184
185impl GridPos {
186    /// Get the current cell's linear index.
187    ///
188    /// In CUDA: `idx`
189    #[inline]
190    pub fn idx(&self) -> usize {
191        // Placeholder - actual implementation is generated
192        0
193    }
194
195    /// Access the north neighbor (y - 1).
196    ///
197    /// In CUDA: `buf[idx - buffer_width]`
198    #[inline]
199    pub fn north<T: Copy>(&self, _buf: &[T]) -> T {
200        // Placeholder - actual implementation is generated
201        unsafe { std::mem::zeroed() }
202    }
203
204    /// Access the south neighbor (y + 1).
205    ///
206    /// In CUDA: `buf[idx + buffer_width]`
207    #[inline]
208    pub fn south<T: Copy>(&self, _buf: &[T]) -> T {
209        unsafe { std::mem::zeroed() }
210    }
211
212    /// Access the east neighbor (x + 1).
213    ///
214    /// In CUDA: `buf[idx + 1]`
215    #[inline]
216    pub fn east<T: Copy>(&self, _buf: &[T]) -> T {
217        unsafe { std::mem::zeroed() }
218    }
219
220    /// Access the west neighbor (x - 1).
221    ///
222    /// In CUDA: `buf[idx - 1]`
223    #[inline]
224    pub fn west<T: Copy>(&self, _buf: &[T]) -> T {
225        unsafe { std::mem::zeroed() }
226    }
227
228    /// Access a neighbor at arbitrary offset.
229    ///
230    /// In CUDA: `buf[idx + dy * buffer_width + dx]`
231    #[inline]
232    pub fn at<T: Copy>(&self, _buf: &[T], _dx: i32, _dy: i32) -> T {
233        unsafe { std::mem::zeroed() }
234    }
235
236    /// Access the neighbor above (z - 1, 3D only).
237    ///
238    /// In CUDA: `buf[idx - buffer_slice]`
239    #[inline]
240    pub fn up<T: Copy>(&self, _buf: &[T]) -> T {
241        unsafe { std::mem::zeroed() }
242    }
243
244    /// Access the neighbor below (z + 1, 3D only).
245    ///
246    /// In CUDA: `buf[idx + buffer_slice]`
247    #[inline]
248    pub fn down<T: Copy>(&self, _buf: &[T]) -> T {
249        unsafe { std::mem::zeroed() }
250    }
251}
252
253/// Launch configuration for stencil kernels.
254#[derive(Debug, Clone)]
255pub struct StencilLaunchConfig {
256    /// Block dimensions.
257    pub block_dim: (u32, u32, u32),
258    /// Grid dimensions.
259    pub grid_dim: (u32, u32, u32),
260    /// Shared memory size in bytes.
261    pub shared_mem: u32,
262}
263
264impl StencilLaunchConfig {
265    /// Create launch config for a 2D stencil.
266    pub fn for_2d_grid(grid_width: usize, grid_height: usize, tile_size: (usize, usize)) -> Self {
267        let tiles_x = grid_width.div_ceil(tile_size.0);
268        let tiles_y = grid_height.div_ceil(tile_size.1);
269
270        Self {
271            block_dim: (tile_size.0 as u32, tile_size.1 as u32, 1),
272            grid_dim: (tiles_x as u32, tiles_y as u32, 1),
273            shared_mem: 0,
274        }
275    }
276
277    /// Create launch config for packed tile execution (one block per tile).
278    pub fn for_packed_tiles(num_tiles: usize, tile_size: (usize, usize)) -> Self {
279        Self {
280            block_dim: (tile_size.0 as u32, tile_size.1 as u32, 1),
281            grid_dim: (num_tiles as u32, 1, 1),
282            shared_mem: 0,
283        }
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_grid_parsing() {
293        assert_eq!(Grid::parse("2d"), Some(Grid::Grid2D));
294        assert_eq!(Grid::parse("Grid3D"), Some(Grid::Grid3D));
295        assert_eq!(Grid::parse("1D"), Some(Grid::Grid1D));
296        assert_eq!(Grid::parse("invalid"), None);
297    }
298
299    #[test]
300    fn test_stencil_config_defaults() {
301        let config = StencilConfig::default();
302        assert_eq!(config.tile_size, (16, 16));
303        assert_eq!(config.halo, 1);
304        assert_eq!(config.grid, Grid::Grid2D);
305    }
306
307    #[test]
308    fn test_buffer_dimensions() {
309        let config = StencilConfig::new("test")
310            .with_tile_size(16, 16)
311            .with_halo(1);
312
313        assert_eq!(config.buffer_width(), 18);
314        assert_eq!(config.buffer_height(), 18);
315    }
316
317    #[test]
318    fn test_2d_preamble_generation() {
319        let config = StencilConfig::new("fdtd")
320            .with_grid(Grid::Grid2D)
321            .with_tile_size(16, 16)
322            .with_halo(1);
323
324        let preamble = config.generate_preamble();
325
326        assert!(preamble.contains("threadIdx.x"));
327        assert!(preamble.contains("threadIdx.y"));
328        assert!(preamble.contains("buffer_width = 18"));
329        assert!(preamble.contains("if (lx >= 16 || ly >= 16) return;"));
330    }
331
332    #[test]
333    fn test_1d_preamble_generation() {
334        let config = StencilConfig::new("blur")
335            .with_grid(Grid::Grid1D)
336            .with_tile_size(256, 1)
337            .with_halo(2);
338
339        let preamble = config.generate_preamble();
340
341        assert!(preamble.contains("threadIdx.x"));
342        assert!(!preamble.contains("threadIdx.y"));
343        assert!(preamble.contains("buffer_width = 260")); // 256 + 2*2
344    }
345
346    #[test]
347    fn test_launch_config_2d() {
348        let config = StencilLaunchConfig::for_2d_grid(256, 256, (16, 16));
349
350        assert_eq!(config.block_dim, (16, 16, 1));
351        assert_eq!(config.grid_dim, (16, 16, 1)); // 256/16 = 16 tiles each dim
352    }
353
354    #[test]
355    fn test_launch_config_packed() {
356        let config = StencilLaunchConfig::for_packed_tiles(100, (16, 16));
357
358        assert_eq!(config.block_dim, (16, 16, 1));
359        assert_eq!(config.grid_dim, (100, 1, 1));
360    }
361}