ringkernel_wgpu_codegen/
stencil.rs1#[derive(Debug, Clone)]
7pub struct StencilConfig {
8 pub name: String,
10 pub tile_width: u32,
12 pub tile_height: u32,
14 pub halo: u32,
16 pub use_shared_memory: bool,
18}
19
20impl StencilConfig {
21 pub fn new(name: &str) -> Self {
23 Self {
24 name: name.to_string(),
25 tile_width: 16,
26 tile_height: 16,
27 halo: 1,
28 use_shared_memory: true,
29 }
30 }
31
32 pub fn with_tile_size(mut self, width: u32, height: u32) -> Self {
34 self.tile_width = width;
35 self.tile_height = height;
36 self
37 }
38
39 pub fn with_halo(mut self, halo: u32) -> Self {
41 self.halo = halo;
42 self
43 }
44
45 pub fn without_shared_memory(mut self) -> Self {
47 self.use_shared_memory = false;
48 self
49 }
50
51 pub fn buffer_width(&self) -> u32 {
53 self.tile_width + 2 * self.halo
54 }
55
56 pub fn buffer_height(&self) -> u32 {
58 self.tile_height + 2 * self.halo
59 }
60
61 pub fn workgroup_size_annotation(&self) -> String {
63 format!(
64 "@workgroup_size({}, {}, 1)",
65 self.tile_width, self.tile_height
66 )
67 }
68}
69
70#[derive(Debug, Clone)]
72pub struct StencilLaunchConfig {
73 pub grid: Grid,
75 pub block_width: u32,
77 pub block_height: u32,
78}
79
80impl StencilLaunchConfig {
81 pub fn new(grid: Grid, config: &StencilConfig) -> Self {
83 Self {
84 grid,
85 block_width: config.tile_width,
86 block_height: config.tile_height,
87 }
88 }
89}
90
91#[derive(Debug, Clone, Copy)]
93pub struct Grid {
94 pub width: u32,
96 pub height: u32,
98 pub depth: u32,
100}
101
102impl Grid {
103 pub fn new_2d(width: u32, height: u32) -> Self {
105 Self {
106 width,
107 height,
108 depth: 1,
109 }
110 }
111
112 pub fn new_3d(width: u32, height: u32, depth: u32) -> Self {
114 Self {
115 width,
116 height,
117 depth,
118 }
119 }
120
121 pub fn workgroups(&self, tile_width: u32, tile_height: u32) -> (u32, u32, u32) {
123 let x = self.width.div_ceil(tile_width);
124 let y = self.height.div_ceil(tile_height);
125 (x, y, self.depth)
126 }
127}
128
129#[derive(Debug, Clone, Copy)]
134pub struct GridPos {
135 pub x: i32,
137 pub y: i32,
139 pub idx: usize,
141 pub stride: usize,
143}
144
145impl GridPos {
146 #[inline(always)]
150 pub fn idx(&self) -> usize {
151 self.idx
152 }
153
154 #[inline(always)]
156 pub fn x(&self) -> i32 {
157 self.x
158 }
159
160 #[inline(always)]
162 pub fn y(&self) -> i32 {
163 self.y
164 }
165
166 #[inline(always)]
170 pub fn north<T: Copy>(&self, buffer: &[T]) -> T {
171 buffer[self.idx - self.stride]
172 }
173
174 #[inline(always)]
178 pub fn south<T: Copy>(&self, buffer: &[T]) -> T {
179 buffer[self.idx + self.stride]
180 }
181
182 #[inline(always)]
186 pub fn east<T: Copy>(&self, buffer: &[T]) -> T {
187 buffer[self.idx + 1]
188 }
189
190 #[inline(always)]
194 pub fn west<T: Copy>(&self, buffer: &[T]) -> T {
195 buffer[self.idx - 1]
196 }
197
198 #[inline(always)]
202 pub fn at<T: Copy>(&self, buffer: &[T], dx: i32, dy: i32) -> T {
203 let offset = (dy as isize * self.stride as isize + dx as isize) as usize;
204 buffer[self.idx.wrapping_add(offset)]
205 }
206}
207
208#[cfg(test)]
209mod tests {
210 use super::*;
211
212 #[test]
213 fn test_stencil_config() {
214 let config = StencilConfig::new("heat")
215 .with_tile_size(16, 16)
216 .with_halo(1);
217
218 assert_eq!(config.name, "heat");
219 assert_eq!(config.tile_width, 16);
220 assert_eq!(config.tile_height, 16);
221 assert_eq!(config.halo, 1);
222 assert_eq!(config.buffer_width(), 18);
223 assert_eq!(config.buffer_height(), 18);
224 }
225
226 #[test]
227 fn test_grid_workgroups() {
228 let grid = Grid::new_2d(256, 256);
229 let (wx, wy, wz) = grid.workgroups(16, 16);
230 assert_eq!(wx, 16);
231 assert_eq!(wy, 16);
232 assert_eq!(wz, 1);
233 }
234
235 #[test]
236 fn test_workgroup_size_annotation() {
237 let config = StencilConfig::new("test").with_tile_size(8, 8);
238 assert_eq!(
239 config.workgroup_size_annotation(),
240 "@workgroup_size(8, 8, 1)"
241 );
242 }
243}