1#![allow(dead_code)]
2#![allow(clippy::cast_precision_loss)]
3const MAX_WORKGROUP_DIM: u32 = 1024;
11
12const MAX_WORKGROUP_TOTAL: u32 = 1024;
14
15const WARP_SIZE: u32 = 32;
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct WorkgroupSize {
21 pub x: u32,
23 pub y: u32,
25 pub z: u32,
27}
28
29impl WorkgroupSize {
30 #[must_use]
32 pub fn new(x: u32, y: u32, z: u32) -> Self {
33 Self { x, y, z }
34 }
35
36 #[must_use]
38 pub fn linear(size: u32) -> Self {
39 Self {
40 x: size,
41 y: 1,
42 z: 1,
43 }
44 }
45
46 #[must_use]
48 pub fn flat(x: u32, y: u32) -> Self {
49 Self { x, y, z: 1 }
50 }
51
52 #[must_use]
54 pub fn total(&self) -> u32 {
55 self.x * self.y * self.z
56 }
57
58 #[must_use]
60 pub fn is_valid(&self) -> bool {
61 self.x > 0
62 && self.y > 0
63 && self.z > 0
64 && self.x <= MAX_WORKGROUP_DIM
65 && self.y <= MAX_WORKGROUP_DIM
66 && self.z <= MAX_WORKGROUP_DIM
67 && self.total() <= MAX_WORKGROUP_TOTAL
68 }
69
70 #[must_use]
72 pub fn is_warp_aligned(&self) -> bool {
73 self.total() % WARP_SIZE == 0
74 }
75}
76
77impl Default for WorkgroupSize {
78 fn default() -> Self {
79 Self { x: 8, y: 8, z: 1 }
80 }
81}
82
83#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct DispatchDimensions {
86 pub groups_x: u32,
88 pub groups_y: u32,
90 pub groups_z: u32,
92}
93
94impl DispatchDimensions {
95 #[must_use]
97 pub fn new(groups_x: u32, groups_y: u32, groups_z: u32) -> Self {
98 Self {
99 groups_x,
100 groups_y,
101 groups_z,
102 }
103 }
104
105 #[must_use]
107 pub fn linear(groups: u32) -> Self {
108 Self {
109 groups_x: groups,
110 groups_y: 1,
111 groups_z: 1,
112 }
113 }
114
115 #[must_use]
117 pub fn total_groups(&self) -> u64 {
118 u64::from(self.groups_x) * u64::from(self.groups_y) * u64::from(self.groups_z)
119 }
120
121 #[must_use]
123 pub fn total_invocations(&self, workgroup: &WorkgroupSize) -> u64 {
124 self.total_groups() * u64::from(workgroup.total())
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum WorkgroupStrategy {
131 Square,
133 Wide,
135 Tall,
137 WarpAligned,
139 Minimal,
141}
142
143pub struct WorkgroupPlanner;
145
146impl WorkgroupPlanner {
147 #[must_use]
151 pub fn plan_1d(
152 total_elements: u32,
153 strategy: WorkgroupStrategy,
154 ) -> (WorkgroupSize, DispatchDimensions) {
155 let wg_size = match strategy {
156 WorkgroupStrategy::WarpAligned => 256,
157 WorkgroupStrategy::Minimal => 64,
158 _ => 128,
159 };
160 let wg = WorkgroupSize::linear(wg_size);
161 let groups = div_ceil(total_elements, wg_size);
162 (wg, DispatchDimensions::linear(groups))
163 }
164
165 #[must_use]
169 pub fn plan_2d(
170 width: u32,
171 height: u32,
172 strategy: WorkgroupStrategy,
173 ) -> (WorkgroupSize, DispatchDimensions) {
174 let (wg_x, wg_y) = match strategy {
175 WorkgroupStrategy::Square => (16, 16),
176 WorkgroupStrategy::Wide => (32, 8),
177 WorkgroupStrategy::Tall => (8, 32),
178 WorkgroupStrategy::WarpAligned => (16, 16),
179 WorkgroupStrategy::Minimal => (8, 8),
180 };
181 let wg = WorkgroupSize::flat(wg_x, wg_y);
182 let groups_x = div_ceil(width, wg_x);
183 let groups_y = div_ceil(height, wg_y);
184 (wg, DispatchDimensions::new(groups_x, groups_y, 1))
185 }
186
187 #[must_use]
191 pub fn plan_3d(width: u32, height: u32, depth: u32) -> (WorkgroupSize, DispatchDimensions) {
192 let wg = WorkgroupSize::new(8, 8, 4);
193 let groups_x = div_ceil(width, 8);
194 let groups_y = div_ceil(height, 8);
195 let groups_z = div_ceil(depth, 4);
196 (wg, DispatchDimensions::new(groups_x, groups_y, groups_z))
197 }
198
199 #[allow(clippy::cast_precision_loss)]
201 #[allow(clippy::manual_checked_ops)]
202 #[must_use]
203 pub fn efficiency(
204 problem_size: (u32, u32),
205 workgroup: &WorkgroupSize,
206 dispatch: &DispatchDimensions,
207 ) -> f64 {
208 let useful = u64::from(problem_size.0) * u64::from(problem_size.1);
209 let total = dispatch.total_invocations(workgroup);
210 if total == 0 {
211 return 0.0;
212 }
213 useful as f64 / total as f64
214 }
215}
216
217fn div_ceil(a: u32, b: u32) -> u32 {
219 a.div_ceil(b)
220}
221
222#[derive(Debug, Clone, Copy)]
228pub struct DeviceLimits {
229 pub max_workgroup_size_per_dim: u32,
231 pub max_workgroup_total_invocations: u32,
233 pub max_shared_memory_bytes: u32,
235 pub subgroup_size: u32,
237 pub max_dispatch_per_dim: u32,
239}
240
241impl Default for DeviceLimits {
242 fn default() -> Self {
243 Self {
244 max_workgroup_size_per_dim: MAX_WORKGROUP_DIM,
245 max_workgroup_total_invocations: MAX_WORKGROUP_TOTAL,
246 max_shared_memory_bytes: 49152, subgroup_size: WARP_SIZE,
248 max_dispatch_per_dim: 65535,
249 }
250 }
251}
252
253impl DeviceLimits {
254 #[must_use]
256 pub fn from_wgpu(limits: &wgpu::Limits) -> Self {
257 Self {
258 max_workgroup_size_per_dim: limits
259 .max_compute_workgroup_size_x
260 .min(limits.max_compute_workgroup_size_y)
261 .min(limits.max_compute_workgroup_size_z),
262 max_workgroup_total_invocations: limits.max_compute_invocations_per_workgroup,
263 max_shared_memory_bytes: limits.max_compute_workgroup_storage_size,
264 subgroup_size: WARP_SIZE, max_dispatch_per_dim: limits.max_compute_workgroups_per_dimension,
266 }
267 }
268}
269
270pub struct WorkgroupAutoTuner {
272 limits: DeviceLimits,
273}
274
275impl WorkgroupAutoTuner {
276 #[must_use]
278 pub fn new(limits: DeviceLimits) -> Self {
279 Self { limits }
280 }
281
282 #[must_use]
284 pub fn with_defaults() -> Self {
285 Self::new(DeviceLimits::default())
286 }
287
288 #[must_use]
290 pub fn limits(&self) -> &DeviceLimits {
291 &self.limits
292 }
293
294 #[must_use]
298 pub fn tune_1d(&self, total_elements: u32) -> (WorkgroupSize, DispatchDimensions) {
299 let subgroup = self.limits.subgroup_size.max(1);
300 let max_total = self.limits.max_workgroup_total_invocations;
301 let max_dim = self.limits.max_workgroup_size_per_dim;
302
303 let mut size = 256u32.min(max_total).min(max_dim);
305 if let Some(aligned) = size.checked_div(subgroup) {
307 size = aligned * subgroup;
308 }
309 size = size.max(subgroup).max(1);
310
311 if total_elements < size * 4 {
313 let smaller = (total_elements.div_ceil(subgroup.max(1))) * subgroup.max(1);
314 size = smaller.max(subgroup.max(1)).min(size);
315 }
316
317 let wg = WorkgroupSize::linear(size);
318 let groups = div_ceil(total_elements, size).min(self.limits.max_dispatch_per_dim);
319 (wg, DispatchDimensions::linear(groups))
320 }
321
322 #[must_use]
326 #[allow(clippy::manual_checked_ops)]
327 pub fn tune_2d(&self, width: u32, height: u32) -> (WorkgroupSize, DispatchDimensions) {
328 let max_total = self.limits.max_workgroup_total_invocations;
329 let max_dim = self.limits.max_workgroup_size_per_dim;
330 let subgroup = self.limits.subgroup_size.max(1);
331
332 let candidates: [(u32, u32); 6] = [
334 (16, 16), (32, 8), (8, 32), (16, 8), (8, 8), (32, 16), ];
341
342 let mut best_wg = WorkgroupSize::flat(8, 8);
343 let mut best_efficiency = 0.0_f64;
344
345 for &(wx, wy) in &candidates {
346 if wx > max_dim || wy > max_dim || wx * wy > max_total {
347 continue;
348 }
349 let total = wx * wy;
351 if total % subgroup != 0 {
352 continue;
353 }
354
355 let gx = div_ceil(width, wx).min(self.limits.max_dispatch_per_dim);
356 let gy = div_ceil(height, wy).min(self.limits.max_dispatch_per_dim);
357 let total_invocations = (gx as u64) * (gy as u64) * (total as u64);
358 let useful = (width as u64) * (height as u64);
359 let eff = if total_invocations > 0 {
360 useful as f64 / total_invocations as f64
361 } else {
362 0.0
363 };
364
365 if eff > best_efficiency {
366 best_efficiency = eff;
367 best_wg = WorkgroupSize::flat(wx, wy);
368 }
369 }
370
371 let gx = div_ceil(width, best_wg.x).min(self.limits.max_dispatch_per_dim);
372 let gy = div_ceil(height, best_wg.y).min(self.limits.max_dispatch_per_dim);
373 (best_wg, DispatchDimensions::new(gx, gy, 1))
374 }
375
376 #[must_use]
381 pub fn tune_2d_with_shared_memory(
382 &self,
383 width: u32,
384 height: u32,
385 shared_bytes_per_pixel: u32,
386 ) -> (WorkgroupSize, DispatchDimensions) {
387 let max_shared = self.limits.max_shared_memory_bytes;
388 let max_total = self.limits.max_workgroup_total_invocations;
389 let max_dim = self.limits.max_workgroup_size_per_dim;
390 let subgroup = self.limits.subgroup_size.max(1);
391
392 let mut best_side = 8u32;
394 for candidate_side in &[32u32, 24, 16, 12, 8] {
395 let side = *candidate_side;
396 let total = side * side;
397 if total > max_total || side > max_dim {
398 continue;
399 }
400 if total % subgroup != 0 {
401 continue;
402 }
403 let shared_needed = total * shared_bytes_per_pixel;
404 if shared_needed <= max_shared {
405 best_side = side;
406 break;
407 }
408 }
409
410 let wg = WorkgroupSize::flat(best_side, best_side);
411 let gx = div_ceil(width, best_side).min(self.limits.max_dispatch_per_dim);
412 let gy = div_ceil(height, best_side).min(self.limits.max_dispatch_per_dim);
413 (wg, DispatchDimensions::new(gx, gy, 1))
414 }
415
416 #[must_use]
418 pub fn estimate_efficiency(
419 &self,
420 problem_width: u32,
421 problem_height: u32,
422 workgroup: &WorkgroupSize,
423 ) -> f64 {
424 let gx = div_ceil(problem_width, workgroup.x);
425 let gy = div_ceil(problem_height, workgroup.y);
426 let dispatch = DispatchDimensions::new(gx, gy, 1);
427 WorkgroupPlanner::efficiency((problem_width, problem_height), workgroup, &dispatch)
428 }
429}
430
431#[derive(Debug, Clone, PartialEq, Eq)]
433pub struct SharedMemoryLayout {
434 pub size_bytes: u32,
436 pub alignment: u32,
438 pub element_count: u32,
440 pub element_size: u32,
442}
443
444impl SharedMemoryLayout {
445 #[must_use]
447 pub fn new(element_count: u32, element_size: u32, alignment: u32) -> Self {
448 let aligned_element = round_up(element_size, alignment);
449 Self {
450 size_bytes: element_count * aligned_element,
451 alignment,
452 element_count,
453 element_size,
454 }
455 }
456
457 #[must_use]
459 pub fn floats(count: u32) -> Self {
460 Self::new(count, 4, 4)
461 }
462
463 #[must_use]
465 pub fn vec4s(count: u32) -> Self {
466 Self::new(count, 16, 16)
467 }
468
469 #[must_use]
471 pub fn fits_in_shared_memory(&self) -> bool {
472 self.size_bytes <= 49152 }
474}
475
476fn round_up(value: u32, alignment: u32) -> u32 {
478 if alignment == 0 {
479 return value;
480 }
481 value.div_ceil(alignment) * alignment
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn test_workgroup_size_default() {
490 let wg = WorkgroupSize::default();
491 assert_eq!(wg.x, 8);
492 assert_eq!(wg.y, 8);
493 assert_eq!(wg.z, 1);
494 assert_eq!(wg.total(), 64);
495 }
496
497 #[test]
498 fn test_workgroup_size_linear() {
499 let wg = WorkgroupSize::linear(256);
500 assert_eq!(wg.total(), 256);
501 assert!(wg.is_valid());
502 assert!(wg.is_warp_aligned());
503 }
504
505 #[test]
506 fn test_workgroup_size_flat() {
507 let wg = WorkgroupSize::flat(16, 16);
508 assert_eq!(wg.total(), 256);
509 assert!(wg.is_valid());
510 }
511
512 #[test]
513 fn test_workgroup_size_3d() {
514 let wg = WorkgroupSize::new(8, 8, 4);
515 assert_eq!(wg.total(), 256);
516 assert!(wg.is_valid());
517 }
518
519 #[test]
520 fn test_workgroup_size_invalid_exceeds_max() {
521 let wg = WorkgroupSize::new(1025, 1, 1);
522 assert!(!wg.is_valid());
523 }
524
525 #[test]
526 fn test_workgroup_size_invalid_exceeds_total() {
527 let wg = WorkgroupSize::new(32, 64, 1);
528 assert_eq!(wg.total(), 2048);
529 assert!(!wg.is_valid());
530 }
531
532 #[test]
533 fn test_dispatch_dimensions_linear() {
534 let d = DispatchDimensions::linear(10);
535 assert_eq!(d.total_groups(), 10);
536 }
537
538 #[test]
539 fn test_dispatch_total_invocations() {
540 let wg = WorkgroupSize::flat(16, 16);
541 let d = DispatchDimensions::new(4, 4, 1);
542 assert_eq!(d.total_invocations(&wg), 4096);
543 }
544
545 #[test]
546 fn test_plan_1d() {
547 let (wg, d) = WorkgroupPlanner::plan_1d(1000, WorkgroupStrategy::WarpAligned);
548 assert_eq!(wg.x, 256);
549 assert!(d.groups_x * wg.x >= 1000);
550 }
551
552 #[test]
553 fn test_plan_2d_square() {
554 let (wg, d) = WorkgroupPlanner::plan_2d(1920, 1080, WorkgroupStrategy::Square);
555 assert_eq!(wg.x, 16);
556 assert_eq!(wg.y, 16);
557 assert!(d.groups_x * wg.x >= 1920);
558 assert!(d.groups_y * wg.y >= 1080);
559 }
560
561 #[test]
562 fn test_plan_2d_wide() {
563 let (wg, d) = WorkgroupPlanner::plan_2d(3840, 2160, WorkgroupStrategy::Wide);
564 assert_eq!(wg.x, 32);
565 assert_eq!(wg.y, 8);
566 assert!(d.groups_x * wg.x >= 3840);
567 assert!(d.groups_y * wg.y >= 2160);
568 }
569
570 #[test]
571 fn test_plan_3d() {
572 let (wg, d) = WorkgroupPlanner::plan_3d(64, 64, 16);
573 assert_eq!(wg.total(), 256);
574 assert_eq!(d.groups_x, 8);
575 assert_eq!(d.groups_y, 8);
576 assert_eq!(d.groups_z, 4);
577 }
578
579 #[test]
580 fn test_efficiency_perfect() {
581 let wg = WorkgroupSize::flat(16, 16);
582 let d = DispatchDimensions::new(2, 2, 1);
583 let eff = WorkgroupPlanner::efficiency((32, 32), &wg, &d);
584 assert!((eff - 1.0).abs() < 1e-9);
585 }
586
587 #[test]
588 fn test_efficiency_partial() {
589 let wg = WorkgroupSize::flat(16, 16);
590 let d = DispatchDimensions::new(1, 1, 1);
591 let eff = WorkgroupPlanner::efficiency((10, 10), &wg, &d);
592 assert!(eff < 1.0);
593 assert!(eff > 0.0);
594 }
595
596 #[test]
597 fn test_shared_memory_floats() {
598 let layout = SharedMemoryLayout::floats(256);
599 assert_eq!(layout.size_bytes, 1024);
600 assert!(layout.fits_in_shared_memory());
601 }
602
603 #[test]
604 fn test_shared_memory_vec4s() {
605 let layout = SharedMemoryLayout::vec4s(64);
606 assert_eq!(layout.size_bytes, 1024);
607 assert!(layout.fits_in_shared_memory());
608 }
609
610 #[test]
611 fn test_shared_memory_exceeds_limit() {
612 let layout = SharedMemoryLayout::new(50000, 4, 4);
613 assert!(!layout.fits_in_shared_memory());
614 }
615
616 #[test]
617 fn test_div_ceil() {
618 assert_eq!(div_ceil(10, 3), 4);
619 assert_eq!(div_ceil(9, 3), 3);
620 assert_eq!(div_ceil(1, 256), 1);
621 }
622
623 #[test]
624 fn test_round_up() {
625 assert_eq!(round_up(5, 4), 8);
626 assert_eq!(round_up(8, 4), 8);
627 assert_eq!(round_up(0, 4), 0);
628 assert_eq!(round_up(7, 0), 7);
629 }
630
631 #[test]
632 fn test_warp_alignment() {
633 let wg = WorkgroupSize::linear(64);
634 assert!(wg.is_warp_aligned());
635 let wg2 = WorkgroupSize::linear(33);
636 assert!(!wg2.is_warp_aligned());
637 }
638
639 #[test]
642 fn test_auto_tuner_1d_default_limits() {
643 let tuner = WorkgroupAutoTuner::with_defaults();
644 let (wg, dispatch) = tuner.tune_1d(10000);
645 assert!(wg.is_valid(), "workgroup must be valid");
646 assert!(wg.is_warp_aligned(), "should be warp-aligned");
647 assert!(dispatch.groups_x * wg.x >= 10000, "must cover all elements");
648 }
649
650 #[test]
651 fn test_auto_tuner_1d_small_problem() {
652 let tuner = WorkgroupAutoTuner::with_defaults();
653 let (wg, dispatch) = tuner.tune_1d(64);
654 assert!(wg.is_valid());
655 assert!(dispatch.groups_x * wg.x >= 64);
656 }
657
658 #[test]
659 fn test_auto_tuner_2d_1080p() {
660 let tuner = WorkgroupAutoTuner::with_defaults();
661 let (wg, dispatch) = tuner.tune_2d(1920, 1080);
662 assert!(wg.is_valid());
663 assert!(wg.is_warp_aligned());
664 assert!(dispatch.groups_x * wg.x >= 1920);
665 assert!(dispatch.groups_y * wg.y >= 1080);
666 }
667
668 #[test]
669 fn test_auto_tuner_2d_4k() {
670 let tuner = WorkgroupAutoTuner::with_defaults();
671 let (wg, dispatch) = tuner.tune_2d(3840, 2160);
672 assert!(wg.is_valid());
673 assert!(dispatch.groups_x * wg.x >= 3840);
674 assert!(dispatch.groups_y * wg.y >= 2160);
675 }
676
677 #[test]
678 fn test_auto_tuner_2d_small_image() {
679 let tuner = WorkgroupAutoTuner::with_defaults();
680 let (wg, dispatch) = tuner.tune_2d(16, 16);
681 assert!(wg.is_valid());
682 assert!(dispatch.groups_x * wg.x >= 16);
683 assert!(dispatch.groups_y * wg.y >= 16);
684 }
685
686 #[test]
687 fn test_auto_tuner_2d_non_square() {
688 let tuner = WorkgroupAutoTuner::with_defaults();
689 let (wg, dispatch) = tuner.tune_2d(4096, 32);
690 assert!(wg.is_valid());
691 assert!(dispatch.groups_x * wg.x >= 4096);
692 assert!(dispatch.groups_y * wg.y >= 32);
693 }
694
695 #[test]
696 fn test_auto_tuner_with_shared_memory() {
697 let tuner = WorkgroupAutoTuner::with_defaults();
698 let (wg, dispatch) = tuner.tune_2d_with_shared_memory(1920, 1080, 64);
700 let shared_used = wg.total() * 64;
701 assert!(
702 shared_used <= tuner.limits().max_shared_memory_bytes,
703 "shared memory {} must fit in {} bytes",
704 shared_used,
705 tuner.limits().max_shared_memory_bytes
706 );
707 assert!(dispatch.groups_x * wg.x >= 1920);
708 assert!(dispatch.groups_y * wg.y >= 1080);
709 }
710
711 #[test]
712 fn test_auto_tuner_with_large_shared_memory() {
713 let tuner = WorkgroupAutoTuner::with_defaults();
714 let (wg, dispatch) = tuner.tune_2d_with_shared_memory(256, 256, 512);
716 let shared_used = wg.total() * 512;
717 assert!(shared_used <= tuner.limits().max_shared_memory_bytes);
718 assert!(dispatch.groups_x * wg.x >= 256);
719 }
720
721 #[test]
722 fn test_auto_tuner_respects_constrained_limits() {
723 let limits = DeviceLimits {
724 max_workgroup_size_per_dim: 128,
725 max_workgroup_total_invocations: 128,
726 max_shared_memory_bytes: 16384,
727 subgroup_size: 16,
728 max_dispatch_per_dim: 32768,
729 };
730 let tuner = WorkgroupAutoTuner::new(limits);
731 let (wg, _) = tuner.tune_2d(1920, 1080);
732 assert!(wg.x <= 128);
733 assert!(wg.y <= 128);
734 assert!(wg.total() <= 128);
735 }
736
737 #[test]
738 fn test_auto_tuner_efficiency_estimate() {
739 let tuner = WorkgroupAutoTuner::with_defaults();
740 let wg = WorkgroupSize::flat(16, 16);
741 let eff = tuner.estimate_efficiency(32, 32, &wg);
742 assert!(
743 (eff - 1.0).abs() < 1e-9,
744 "perfect fit should have efficiency 1.0"
745 );
746
747 let eff2 = tuner.estimate_efficiency(17, 17, &wg);
748 assert!(
749 eff2 < 1.0,
750 "non-aligned problem should have < 1.0 efficiency"
751 );
752 assert!(eff2 > 0.0);
753 }
754
755 #[test]
756 fn test_device_limits_default() {
757 let limits = DeviceLimits::default();
758 assert_eq!(limits.max_workgroup_size_per_dim, 1024);
759 assert_eq!(limits.max_workgroup_total_invocations, 1024);
760 assert_eq!(limits.max_shared_memory_bytes, 49152);
761 assert_eq!(limits.subgroup_size, 32);
762 }
763}