1#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub enum Grid {
9 Grid1D,
11 Grid2D,
13 Grid3D,
15}
16
17impl Grid {
18 pub fn parse(s: &str) -> Option<Self> {
20 match s.to_lowercase().as_str() {
21 "1d" | "grid1d" => Some(Grid::Grid1D),
22 "2d" | "grid2d" => Some(Grid::Grid2D),
23 "3d" | "grid3d" => Some(Grid::Grid3D),
24 _ => None,
25 }
26 }
27
28 pub fn dimensions(&self) -> usize {
30 match self {
31 Grid::Grid1D => 1,
32 Grid::Grid2D => 2,
33 Grid::Grid3D => 3,
34 }
35 }
36}
37
38#[derive(Debug, Clone)]
40pub struct StencilConfig {
41 pub id: String,
43 pub grid: Grid,
45 pub tile_size: (usize, usize),
47 pub halo: usize,
49}
50
51impl Default for StencilConfig {
52 fn default() -> Self {
53 Self {
54 id: "stencil_kernel".to_string(),
55 grid: Grid::Grid2D,
56 tile_size: (16, 16),
57 halo: 1,
58 }
59 }
60}
61
62impl StencilConfig {
63 pub fn new(id: impl Into<String>) -> Self {
65 Self {
66 id: id.into(),
67 ..Default::default()
68 }
69 }
70
71 pub fn with_grid(mut self, grid: Grid) -> Self {
73 self.grid = grid;
74 self
75 }
76
77 pub fn with_tile_size(mut self, width: usize, height: usize) -> Self {
79 self.tile_size = (width, height);
80 self
81 }
82
83 pub fn with_halo(mut self, halo: usize) -> Self {
85 self.halo = halo;
86 self
87 }
88
89 pub fn buffer_width(&self) -> usize {
91 self.tile_size.0 + 2 * self.halo
92 }
93
94 pub fn buffer_height(&self) -> usize {
96 self.tile_size.1 + 2 * self.halo
97 }
98
99 pub fn generate_preamble(&self) -> String {
101 match self.grid {
102 Grid::Grid1D => self.generate_1d_preamble(),
103 Grid::Grid2D => self.generate_2d_preamble(),
104 Grid::Grid3D => self.generate_3d_preamble(),
105 }
106 }
107
108 fn generate_1d_preamble(&self) -> String {
109 let tile_size = self.tile_size.0;
110 let buffer_width = self.buffer_width();
111
112 format!(
113 r#" int lx = threadIdx.x;
114 if (lx >= {tile_size}) return;
115
116 int buffer_width = {buffer_width};
117 int idx = lx + {halo};
118"#,
119 tile_size = tile_size,
120 buffer_width = buffer_width,
121 halo = self.halo,
122 )
123 }
124
125 fn generate_2d_preamble(&self) -> String {
126 let (tile_w, tile_h) = self.tile_size;
127 let buffer_width = self.buffer_width();
128
129 format!(
130 r#" int lx = threadIdx.x;
131 int ly = threadIdx.y;
132 if (lx >= {tile_w} || ly >= {tile_h}) return;
133
134 int buffer_width = {buffer_width};
135 int idx = (ly + {halo}) * buffer_width + (lx + {halo});
136"#,
137 tile_w = tile_w,
138 tile_h = tile_h,
139 buffer_width = buffer_width,
140 halo = self.halo,
141 )
142 }
143
144 fn generate_3d_preamble(&self) -> String {
145 let (tile_w, tile_h) = self.tile_size;
146 let buffer_width = self.buffer_width();
147 let buffer_height = self.buffer_height();
148
149 format!(
150 r#" int lx = threadIdx.x;
151 int ly = threadIdx.y;
152 int lz = threadIdx.z;
153 if (lx >= {tile_w} || ly >= {tile_h}) return;
154
155 int buffer_width = {buffer_width};
156 int buffer_height = {buffer_height};
157 int buffer_slice = buffer_width * buffer_height;
158 int idx = (lz + {halo}) * buffer_slice + (ly + {halo}) * buffer_width + (lx + {halo});
159"#,
160 tile_w = tile_w,
161 tile_h = tile_h,
162 buffer_width = buffer_width,
163 buffer_height = buffer_height,
164 halo = self.halo,
165 )
166 }
167
168 pub fn generate_launch_bounds(&self) -> String {
170 let threads = self.tile_size.0 * self.tile_size.1;
171 format!("__launch_bounds__({threads})")
172 }
173}
174
175#[derive(Debug, Clone, Copy)]
180pub struct GridPos {
181 _private: (),
183}
184
185impl GridPos {
186 #[inline]
190 pub fn idx(&self) -> usize {
191 0
193 }
194
195 #[inline]
199 pub fn north<T: Copy>(&self, _buf: &[T]) -> T {
200 unsafe { std::mem::zeroed() }
202 }
203
204 #[inline]
208 pub fn south<T: Copy>(&self, _buf: &[T]) -> T {
209 unsafe { std::mem::zeroed() }
210 }
211
212 #[inline]
216 pub fn east<T: Copy>(&self, _buf: &[T]) -> T {
217 unsafe { std::mem::zeroed() }
218 }
219
220 #[inline]
224 pub fn west<T: Copy>(&self, _buf: &[T]) -> T {
225 unsafe { std::mem::zeroed() }
226 }
227
228 #[inline]
232 pub fn at<T: Copy>(&self, _buf: &[T], _dx: i32, _dy: i32) -> T {
233 unsafe { std::mem::zeroed() }
234 }
235
236 #[inline]
240 pub fn up<T: Copy>(&self, _buf: &[T]) -> T {
241 unsafe { std::mem::zeroed() }
242 }
243
244 #[inline]
248 pub fn down<T: Copy>(&self, _buf: &[T]) -> T {
249 unsafe { std::mem::zeroed() }
250 }
251}
252
253#[derive(Debug, Clone)]
255pub struct StencilLaunchConfig {
256 pub block_dim: (u32, u32, u32),
258 pub grid_dim: (u32, u32, u32),
260 pub shared_mem: u32,
262}
263
264impl StencilLaunchConfig {
265 pub fn for_2d_grid(grid_width: usize, grid_height: usize, tile_size: (usize, usize)) -> Self {
267 let tiles_x = grid_width.div_ceil(tile_size.0);
268 let tiles_y = grid_height.div_ceil(tile_size.1);
269
270 Self {
271 block_dim: (tile_size.0 as u32, tile_size.1 as u32, 1),
272 grid_dim: (tiles_x as u32, tiles_y as u32, 1),
273 shared_mem: 0,
274 }
275 }
276
277 pub fn for_packed_tiles(num_tiles: usize, tile_size: (usize, usize)) -> Self {
279 Self {
280 block_dim: (tile_size.0 as u32, tile_size.1 as u32, 1),
281 grid_dim: (num_tiles as u32, 1, 1),
282 shared_mem: 0,
283 }
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 #[test]
292 fn test_grid_parsing() {
293 assert_eq!(Grid::parse("2d"), Some(Grid::Grid2D));
294 assert_eq!(Grid::parse("Grid3D"), Some(Grid::Grid3D));
295 assert_eq!(Grid::parse("1D"), Some(Grid::Grid1D));
296 assert_eq!(Grid::parse("invalid"), None);
297 }
298
299 #[test]
300 fn test_stencil_config_defaults() {
301 let config = StencilConfig::default();
302 assert_eq!(config.tile_size, (16, 16));
303 assert_eq!(config.halo, 1);
304 assert_eq!(config.grid, Grid::Grid2D);
305 }
306
307 #[test]
308 fn test_buffer_dimensions() {
309 let config = StencilConfig::new("test")
310 .with_tile_size(16, 16)
311 .with_halo(1);
312
313 assert_eq!(config.buffer_width(), 18);
314 assert_eq!(config.buffer_height(), 18);
315 }
316
317 #[test]
318 fn test_2d_preamble_generation() {
319 let config = StencilConfig::new("fdtd")
320 .with_grid(Grid::Grid2D)
321 .with_tile_size(16, 16)
322 .with_halo(1);
323
324 let preamble = config.generate_preamble();
325
326 assert!(preamble.contains("threadIdx.x"));
327 assert!(preamble.contains("threadIdx.y"));
328 assert!(preamble.contains("buffer_width = 18"));
329 assert!(preamble.contains("if (lx >= 16 || ly >= 16) return;"));
330 }
331
332 #[test]
333 fn test_1d_preamble_generation() {
334 let config = StencilConfig::new("blur")
335 .with_grid(Grid::Grid1D)
336 .with_tile_size(256, 1)
337 .with_halo(2);
338
339 let preamble = config.generate_preamble();
340
341 assert!(preamble.contains("threadIdx.x"));
342 assert!(!preamble.contains("threadIdx.y"));
343 assert!(preamble.contains("buffer_width = 260")); }
345
346 #[test]
347 fn test_launch_config_2d() {
348 let config = StencilLaunchConfig::for_2d_grid(256, 256, (16, 16));
349
350 assert_eq!(config.block_dim, (16, 16, 1));
351 assert_eq!(config.grid_dim, (16, 16, 1)); }
353
354 #[test]
355 fn test_launch_config_packed() {
356 let config = StencilLaunchConfig::for_packed_tiles(100, (16, 16));
357
358 assert_eq!(config.block_dim, (16, 16, 1));
359 assert_eq!(config.grid_dim, (100, 1, 1));
360 }
361}