1#![allow(dead_code)]
2#![allow(clippy::cast_precision_loss)]
3const MAX_WORKGROUP_DIM: u32 = 1024;
11
12const MAX_WORKGROUP_TOTAL: u32 = 1024;
14
15const WARP_SIZE: u32 = 32;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct WorkgroupSize {
21 pub x: u32,
23 pub y: u32,
25 pub z: u32,
27}
28
29impl WorkgroupSize {
30 #[must_use]
32 pub fn new(x: u32, y: u32, z: u32) -> Self {
33 Self { x, y, z }
34 }
35
36 #[must_use]
38 pub fn linear(size: u32) -> Self {
39 Self {
40 x: size,
41 y: 1,
42 z: 1,
43 }
44 }
45
46 #[must_use]
48 pub fn flat(x: u32, y: u32) -> Self {
49 Self { x, y, z: 1 }
50 }
51
52 #[must_use]
54 pub fn total(&self) -> u32 {
55 self.x * self.y * self.z
56 }
57
58 #[must_use]
60 pub fn is_valid(&self) -> bool {
61 self.x > 0
62 && self.y > 0
63 && self.z > 0
64 && self.x <= MAX_WORKGROUP_DIM
65 && self.y <= MAX_WORKGROUP_DIM
66 && self.z <= MAX_WORKGROUP_DIM
67 && self.total() <= MAX_WORKGROUP_TOTAL
68 }
69
70 #[must_use]
72 pub fn is_warp_aligned(&self) -> bool {
73 self.total() % WARP_SIZE == 0
74 }
75}
76
77impl Default for WorkgroupSize {
78 fn default() -> Self {
79 Self { x: 8, y: 8, z: 1 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct DispatchDimensions {
86 pub groups_x: u32,
88 pub groups_y: u32,
90 pub groups_z: u32,
92}
93
94impl DispatchDimensions {
95 #[must_use]
97 pub fn new(groups_x: u32, groups_y: u32, groups_z: u32) -> Self {
98 Self {
99 groups_x,
100 groups_y,
101 groups_z,
102 }
103 }
104
105 #[must_use]
107 pub fn linear(groups: u32) -> Self {
108 Self {
109 groups_x: groups,
110 groups_y: 1,
111 groups_z: 1,
112 }
113 }
114
115 #[must_use]
117 pub fn total_groups(&self) -> u64 {
118 u64::from(self.groups_x) * u64::from(self.groups_y) * u64::from(self.groups_z)
119 }
120
121 #[must_use]
123 pub fn total_invocations(&self, workgroup: &WorkgroupSize) -> u64 {
124 self.total_groups() * u64::from(workgroup.total())
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum WorkgroupStrategy {
131 Square,
133 Wide,
135 Tall,
137 WarpAligned,
139 Minimal,
141}
142
143pub struct WorkgroupPlanner;
145
146impl WorkgroupPlanner {
147 #[must_use]
151 pub fn plan_1d(
152 total_elements: u32,
153 strategy: WorkgroupStrategy,
154 ) -> (WorkgroupSize, DispatchDimensions) {
155 let wg_size = match strategy {
156 WorkgroupStrategy::WarpAligned => 256,
157 WorkgroupStrategy::Minimal => 64,
158 _ => 128,
159 };
160 let wg = WorkgroupSize::linear(wg_size);
161 let groups = div_ceil(total_elements, wg_size);
162 (wg, DispatchDimensions::linear(groups))
163 }
164
165 #[must_use]
169 pub fn plan_2d(
170 width: u32,
171 height: u32,
172 strategy: WorkgroupStrategy,
173 ) -> (WorkgroupSize, DispatchDimensions) {
174 let (wg_x, wg_y) = match strategy {
175 WorkgroupStrategy::Square => (16, 16),
176 WorkgroupStrategy::Wide => (32, 8),
177 WorkgroupStrategy::Tall => (8, 32),
178 WorkgroupStrategy::WarpAligned => (16, 16),
179 WorkgroupStrategy::Minimal => (8, 8),
180 };
181 let wg = WorkgroupSize::flat(wg_x, wg_y);
182 let groups_x = div_ceil(width, wg_x);
183 let groups_y = div_ceil(height, wg_y);
184 (wg, DispatchDimensions::new(groups_x, groups_y, 1))
185 }
186
187 #[must_use]
191 pub fn plan_3d(width: u32, height: u32, depth: u32) -> (WorkgroupSize, DispatchDimensions) {
192 let wg = WorkgroupSize::new(8, 8, 4);
193 let groups_x = div_ceil(width, 8);
194 let groups_y = div_ceil(height, 8);
195 let groups_z = div_ceil(depth, 4);
196 (wg, DispatchDimensions::new(groups_x, groups_y, groups_z))
197 }
198
199 #[allow(clippy::cast_precision_loss)]
201 #[must_use]
202 pub fn efficiency(
203 problem_size: (u32, u32),
204 workgroup: &WorkgroupSize,
205 dispatch: &DispatchDimensions,
206 ) -> f64 {
207 let useful = u64::from(problem_size.0) * u64::from(problem_size.1);
208 let total = dispatch.total_invocations(workgroup);
209 if total == 0 {
210 return 0.0;
211 }
212 useful as f64 / total as f64
213 }
214}
215
216fn div_ceil(a: u32, b: u32) -> u32 {
218 a.div_ceil(b)
219}
220
221#[derive(Debug, Clone, PartialEq, Eq)]
223pub struct SharedMemoryLayout {
224 pub size_bytes: u32,
226 pub alignment: u32,
228 pub element_count: u32,
230 pub element_size: u32,
232}
233
234impl SharedMemoryLayout {
235 #[must_use]
237 pub fn new(element_count: u32, element_size: u32, alignment: u32) -> Self {
238 let aligned_element = round_up(element_size, alignment);
239 Self {
240 size_bytes: element_count * aligned_element,
241 alignment,
242 element_count,
243 element_size,
244 }
245 }
246
247 #[must_use]
249 pub fn floats(count: u32) -> Self {
250 Self::new(count, 4, 4)
251 }
252
253 #[must_use]
255 pub fn vec4s(count: u32) -> Self {
256 Self::new(count, 16, 16)
257 }
258
259 #[must_use]
261 pub fn fits_in_shared_memory(&self) -> bool {
262 self.size_bytes <= 49152 }
264}
265
266fn round_up(value: u32, alignment: u32) -> u32 {
268 if alignment == 0 {
269 return value;
270 }
271 value.div_ceil(alignment) * alignment
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277
278 #[test]
279 fn test_workgroup_size_default() {
280 let wg = WorkgroupSize::default();
281 assert_eq!(wg.x, 8);
282 assert_eq!(wg.y, 8);
283 assert_eq!(wg.z, 1);
284 assert_eq!(wg.total(), 64);
285 }
286
287 #[test]
288 fn test_workgroup_size_linear() {
289 let wg = WorkgroupSize::linear(256);
290 assert_eq!(wg.total(), 256);
291 assert!(wg.is_valid());
292 assert!(wg.is_warp_aligned());
293 }
294
295 #[test]
296 fn test_workgroup_size_flat() {
297 let wg = WorkgroupSize::flat(16, 16);
298 assert_eq!(wg.total(), 256);
299 assert!(wg.is_valid());
300 }
301
302 #[test]
303 fn test_workgroup_size_3d() {
304 let wg = WorkgroupSize::new(8, 8, 4);
305 assert_eq!(wg.total(), 256);
306 assert!(wg.is_valid());
307 }
308
309 #[test]
310 fn test_workgroup_size_invalid_exceeds_max() {
311 let wg = WorkgroupSize::new(1025, 1, 1);
312 assert!(!wg.is_valid());
313 }
314
315 #[test]
316 fn test_workgroup_size_invalid_exceeds_total() {
317 let wg = WorkgroupSize::new(32, 64, 1);
318 assert_eq!(wg.total(), 2048);
319 assert!(!wg.is_valid());
320 }
321
322 #[test]
323 fn test_dispatch_dimensions_linear() {
324 let d = DispatchDimensions::linear(10);
325 assert_eq!(d.total_groups(), 10);
326 }
327
328 #[test]
329 fn test_dispatch_total_invocations() {
330 let wg = WorkgroupSize::flat(16, 16);
331 let d = DispatchDimensions::new(4, 4, 1);
332 assert_eq!(d.total_invocations(&wg), 4096);
333 }
334
335 #[test]
336 fn test_plan_1d() {
337 let (wg, d) = WorkgroupPlanner::plan_1d(1000, WorkgroupStrategy::WarpAligned);
338 assert_eq!(wg.x, 256);
339 assert!(d.groups_x * wg.x >= 1000);
340 }
341
342 #[test]
343 fn test_plan_2d_square() {
344 let (wg, d) = WorkgroupPlanner::plan_2d(1920, 1080, WorkgroupStrategy::Square);
345 assert_eq!(wg.x, 16);
346 assert_eq!(wg.y, 16);
347 assert!(d.groups_x * wg.x >= 1920);
348 assert!(d.groups_y * wg.y >= 1080);
349 }
350
351 #[test]
352 fn test_plan_2d_wide() {
353 let (wg, d) = WorkgroupPlanner::plan_2d(3840, 2160, WorkgroupStrategy::Wide);
354 assert_eq!(wg.x, 32);
355 assert_eq!(wg.y, 8);
356 assert!(d.groups_x * wg.x >= 3840);
357 assert!(d.groups_y * wg.y >= 2160);
358 }
359
360 #[test]
361 fn test_plan_3d() {
362 let (wg, d) = WorkgroupPlanner::plan_3d(64, 64, 16);
363 assert_eq!(wg.total(), 256);
364 assert_eq!(d.groups_x, 8);
365 assert_eq!(d.groups_y, 8);
366 assert_eq!(d.groups_z, 4);
367 }
368
369 #[test]
370 fn test_efficiency_perfect() {
371 let wg = WorkgroupSize::flat(16, 16);
372 let d = DispatchDimensions::new(2, 2, 1);
373 let eff = WorkgroupPlanner::efficiency((32, 32), &wg, &d);
374 assert!((eff - 1.0).abs() < 1e-9);
375 }
376
377 #[test]
378 fn test_efficiency_partial() {
379 let wg = WorkgroupSize::flat(16, 16);
380 let d = DispatchDimensions::new(1, 1, 1);
381 let eff = WorkgroupPlanner::efficiency((10, 10), &wg, &d);
382 assert!(eff < 1.0);
383 assert!(eff > 0.0);
384 }
385
386 #[test]
387 fn test_shared_memory_floats() {
388 let layout = SharedMemoryLayout::floats(256);
389 assert_eq!(layout.size_bytes, 1024);
390 assert!(layout.fits_in_shared_memory());
391 }
392
393 #[test]
394 fn test_shared_memory_vec4s() {
395 let layout = SharedMemoryLayout::vec4s(64);
396 assert_eq!(layout.size_bytes, 1024);
397 assert!(layout.fits_in_shared_memory());
398 }
399
400 #[test]
401 fn test_shared_memory_exceeds_limit() {
402 let layout = SharedMemoryLayout::new(50000, 4, 4);
403 assert!(!layout.fits_in_shared_memory());
404 }
405
406 #[test]
407 fn test_div_ceil() {
408 assert_eq!(div_ceil(10, 3), 4);
409 assert_eq!(div_ceil(9, 3), 3);
410 assert_eq!(div_ceil(1, 256), 1);
411 }
412
413 #[test]
414 fn test_round_up() {
415 assert_eq!(round_up(5, 4), 8);
416 assert_eq!(round_up(8, 4), 8);
417 assert_eq!(round_up(0, 4), 0);
418 assert_eq!(round_up(7, 0), 7);
419 }
420
421 #[test]
422 fn test_warp_alignment() {
423 let wg = WorkgroupSize::linear(64);
424 assert!(wg.is_warp_aligned());
425 let wg2 = WorkgroupSize::linear(33);
426 assert!(!wg2.is_warp_aligned());
427 }
428}