Skip to main content

ringkernel_wgpu_codegen/
stencil.rs

1//! Stencil kernel support for WGSL code generation.
2//!
3//! Provides GridPos abstraction and stencil kernel configuration.
4
5/// Configuration for stencil kernel generation.
6#[derive(Debug, Clone)]
7pub struct StencilConfig {
8    /// Kernel name.
9    pub name: String,
10    /// Tile width (workgroup size X).
11    pub tile_width: u32,
12    /// Tile height (workgroup size Y).
13    pub tile_height: u32,
14    /// Halo size for neighbor access.
15    pub halo: u32,
16    /// Whether to use shared memory for the tile.
17    pub use_shared_memory: bool,
18}
19
20impl StencilConfig {
21    /// Create a new stencil configuration.
22    pub fn new(name: &str) -> Self {
23        Self {
24            name: name.to_string(),
25            tile_width: 16,
26            tile_height: 16,
27            halo: 1,
28            use_shared_memory: true,
29        }
30    }
31
32    /// Set tile size.
33    pub fn with_tile_size(mut self, width: u32, height: u32) -> Self {
34        self.tile_width = width;
35        self.tile_height = height;
36        self
37    }
38
39    /// Set halo size.
40    pub fn with_halo(mut self, halo: u32) -> Self {
41        self.halo = halo;
42        self
43    }
44
45    /// Disable shared memory usage.
46    pub fn without_shared_memory(mut self) -> Self {
47        self.use_shared_memory = false;
48        self
49    }
50
51    /// Get the buffer width including halo.
52    pub fn buffer_width(&self) -> u32 {
53        self.tile_width + 2 * self.halo
54    }
55
56    /// Get the buffer height including halo.
57    pub fn buffer_height(&self) -> u32 {
58        self.tile_height + 2 * self.halo
59    }
60
61    /// Get the workgroup size annotation.
62    pub fn workgroup_size_annotation(&self) -> String {
63        format!(
64            "@workgroup_size({}, {}, 1)",
65            self.tile_width, self.tile_height
66        )
67    }
68}
69
70/// Launch configuration for stencil kernels.
71#[derive(Debug, Clone)]
72pub struct StencilLaunchConfig {
73    /// Grid dimensions.
74    pub grid: Grid,
75    /// Block dimensions (derived from StencilConfig).
76    pub block_width: u32,
77    pub block_height: u32,
78}
79
80impl StencilLaunchConfig {
81    /// Create a new launch configuration.
82    pub fn new(grid: Grid, config: &StencilConfig) -> Self {
83        Self {
84            grid,
85            block_width: config.tile_width,
86            block_height: config.tile_height,
87        }
88    }
89}
90
91/// Grid dimensions for kernel dispatch.
92#[derive(Debug, Clone, Copy)]
93pub struct Grid {
94    /// Width in cells.
95    pub width: u32,
96    /// Height in cells.
97    pub height: u32,
98    /// Depth (for 3D grids).
99    pub depth: u32,
100}
101
102impl Grid {
103    /// Create a 2D grid.
104    pub fn new_2d(width: u32, height: u32) -> Self {
105        Self {
106            width,
107            height,
108            depth: 1,
109        }
110    }
111
112    /// Create a 3D grid.
113    pub fn new_3d(width: u32, height: u32, depth: u32) -> Self {
114        Self {
115            width,
116            height,
117            depth,
118        }
119    }
120
121    /// Calculate number of workgroups needed for this grid.
122    pub fn workgroups(&self, tile_width: u32, tile_height: u32) -> (u32, u32, u32) {
123        let x = self.width.div_ceil(tile_width);
124        let y = self.height.div_ceil(tile_height);
125        (x, y, self.depth)
126    }
127}
128
129/// Grid position context passed to stencil kernels.
130///
131/// This is a compile-time marker type. The transpiler recognizes method calls
132/// on `GridPos` and converts them to WGSL index calculations.
133#[derive(Debug, Clone, Copy)]
134pub struct GridPos {
135    /// X coordinate in the grid.
136    pub x: i32,
137    /// Y coordinate in the grid.
138    pub y: i32,
139    /// Linear index in the buffer.
140    pub idx: usize,
141    /// Buffer stride (width).
142    pub stride: usize,
143}
144
145impl GridPos {
146    /// Get the linear index for the current position.
147    ///
148    /// Maps to: `idx` (local variable in generated WGSL)
149    #[inline(always)]
150    pub fn idx(&self) -> usize {
151        self.idx
152    }
153
154    /// Get the x coordinate.
155    #[inline(always)]
156    pub fn x(&self) -> i32 {
157        self.x
158    }
159
160    /// Get the y coordinate.
161    #[inline(always)]
162    pub fn y(&self) -> i32 {
163        self.y
164    }
165
166    /// Access the north neighbor (y - 1).
167    ///
168    /// Maps to: `buffer[idx - buffer_width]`
169    #[inline(always)]
170    pub fn north<T: Copy>(&self, buffer: &[T]) -> T {
171        buffer[self.idx - self.stride]
172    }
173
174    /// Access the south neighbor (y + 1).
175    ///
176    /// Maps to: `buffer[idx + buffer_width]`
177    #[inline(always)]
178    pub fn south<T: Copy>(&self, buffer: &[T]) -> T {
179        buffer[self.idx + self.stride]
180    }
181
182    /// Access the east neighbor (x + 1).
183    ///
184    /// Maps to: `buffer[idx + 1]`
185    #[inline(always)]
186    pub fn east<T: Copy>(&self, buffer: &[T]) -> T {
187        buffer[self.idx + 1]
188    }
189
190    /// Access the west neighbor (x - 1).
191    ///
192    /// Maps to: `buffer[idx - 1]`
193    #[inline(always)]
194    pub fn west<T: Copy>(&self, buffer: &[T]) -> T {
195        buffer[self.idx - 1]
196    }
197
198    /// Access a neighbor at relative offset (dx, dy).
199    ///
200    /// Maps to: `buffer[idx + dy * buffer_width + dx]`
201    #[inline(always)]
202    pub fn at<T: Copy>(&self, buffer: &[T], dx: i32, dy: i32) -> T {
203        let offset = (dy as isize * self.stride as isize + dx as isize) as usize;
204        buffer[self.idx.wrapping_add(offset)]
205    }
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    #[test]
213    fn test_stencil_config() {
214        let config = StencilConfig::new("heat")
215            .with_tile_size(16, 16)
216            .with_halo(1);
217
218        assert_eq!(config.name, "heat");
219        assert_eq!(config.tile_width, 16);
220        assert_eq!(config.tile_height, 16);
221        assert_eq!(config.halo, 1);
222        assert_eq!(config.buffer_width(), 18);
223        assert_eq!(config.buffer_height(), 18);
224    }
225
226    #[test]
227    fn test_grid_workgroups() {
228        let grid = Grid::new_2d(256, 256);
229        let (wx, wy, wz) = grid.workgroups(16, 16);
230        assert_eq!(wx, 16);
231        assert_eq!(wy, 16);
232        assert_eq!(wz, 1);
233    }
234
235    #[test]
236    fn test_workgroup_size_annotation() {
237        let config = StencilConfig::new("test").with_tile_size(8, 8);
238        assert_eq!(
239            config.workgroup_size_annotation(),
240            "@workgroup_size(8, 8, 1)"
241        );
242    }
243}