Skip to main content

oximedia_gpu/
workgroup.rs

1#![allow(dead_code)]
2#![allow(clippy::cast_precision_loss)]
3//! GPU workgroup configuration and dispatch sizing.
4//!
5//! This module provides utilities for computing optimal workgroup sizes
6//! and dispatch dimensions for GPU compute shaders. Proper workgroup sizing
7//! is critical for achieving good GPU utilization.
8
9/// Maximum workgroup size per dimension on most GPUs.
10const MAX_WORKGROUP_DIM: u32 = 1024;
11
12/// Maximum total invocations per workgroup (typical limit).
13const MAX_WORKGROUP_TOTAL: u32 = 1024;
14
15/// Preferred warp/wavefront size for NVIDIA/AMD GPUs.
16const WARP_SIZE: u32 = 32;
17
18/// Workgroup size in 3 dimensions.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct WorkgroupSize {
21    /// Size in X dimension.
22    pub x: u32,
23    /// Size in Y dimension.
24    pub y: u32,
25    /// Size in Z dimension.
26    pub z: u32,
27}
28
29impl WorkgroupSize {
30    /// Create a new workgroup size.
31    #[must_use]
32    pub fn new(x: u32, y: u32, z: u32) -> Self {
33        Self { x, y, z }
34    }
35
36    /// Create a 1D workgroup size.
37    #[must_use]
38    pub fn linear(size: u32) -> Self {
39        Self {
40            x: size,
41            y: 1,
42            z: 1,
43        }
44    }
45
46    /// Create a 2D workgroup size.
47    #[must_use]
48    pub fn flat(x: u32, y: u32) -> Self {
49        Self { x, y, z: 1 }
50    }
51
52    /// Total number of invocations in this workgroup.
53    #[must_use]
54    pub fn total(&self) -> u32 {
55        self.x * self.y * self.z
56    }
57
58    /// Check if the workgroup size is valid (within typical limits).
59    #[must_use]
60    pub fn is_valid(&self) -> bool {
61        self.x > 0
62            && self.y > 0
63            && self.z > 0
64            && self.x <= MAX_WORKGROUP_DIM
65            && self.y <= MAX_WORKGROUP_DIM
66            && self.z <= MAX_WORKGROUP_DIM
67            && self.total() <= MAX_WORKGROUP_TOTAL
68    }
69
70    /// Check if the total size is a multiple of the warp size.
71    #[must_use]
72    pub fn is_warp_aligned(&self) -> bool {
73        self.total() % WARP_SIZE == 0
74    }
75}
76
77impl Default for WorkgroupSize {
78    fn default() -> Self {
79        Self { x: 8, y: 8, z: 1 }
80    }
81}
82
83/// Dispatch dimensions for launching a compute shader.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct DispatchDimensions {
86    /// Number of workgroups in X.
87    pub groups_x: u32,
88    /// Number of workgroups in Y.
89    pub groups_y: u32,
90    /// Number of workgroups in Z.
91    pub groups_z: u32,
92}
93
94impl DispatchDimensions {
95    /// Create new dispatch dimensions.
96    #[must_use]
97    pub fn new(groups_x: u32, groups_y: u32, groups_z: u32) -> Self {
98        Self {
99            groups_x,
100            groups_y,
101            groups_z,
102        }
103    }
104
105    /// Create 1D dispatch dimensions.
106    #[must_use]
107    pub fn linear(groups: u32) -> Self {
108        Self {
109            groups_x: groups,
110            groups_y: 1,
111            groups_z: 1,
112        }
113    }
114
115    /// Total number of workgroups.
116    #[must_use]
117    pub fn total_groups(&self) -> u64 {
118        u64::from(self.groups_x) * u64::from(self.groups_y) * u64::from(self.groups_z)
119    }
120
121    /// Total number of invocations given a workgroup size.
122    #[must_use]
123    pub fn total_invocations(&self, workgroup: &WorkgroupSize) -> u64 {
124        self.total_groups() * u64::from(workgroup.total())
125    }
126}
127
128/// Strategy for choosing workgroup sizes.
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum WorkgroupStrategy {
131    /// Use a square workgroup (e.g. 16x16 for 2D).
132    Square,
133    /// Prefer wide workgroups (e.g. 256x1).
134    Wide,
135    /// Prefer tall workgroups (e.g. 1x256).
136    Tall,
137    /// Optimize for warp/wavefront alignment.
138    WarpAligned,
139    /// Use the smallest valid workgroup size.
140    Minimal,
141}
142
143/// Compute optimal workgroup size and dispatch dimensions.
144pub struct WorkgroupPlanner;
145
146impl WorkgroupPlanner {
147    /// Compute 1D dispatch dimensions for a linear problem.
148    ///
149    /// Returns `(workgroup_size, dispatch_dims)`.
150    #[must_use]
151    pub fn plan_1d(
152        total_elements: u32,
153        strategy: WorkgroupStrategy,
154    ) -> (WorkgroupSize, DispatchDimensions) {
155        let wg_size = match strategy {
156            WorkgroupStrategy::WarpAligned => 256,
157            WorkgroupStrategy::Minimal => 64,
158            _ => 128,
159        };
160        let wg = WorkgroupSize::linear(wg_size);
161        let groups = div_ceil(total_elements, wg_size);
162        (wg, DispatchDimensions::linear(groups))
163    }
164
165    /// Compute 2D dispatch dimensions for an image-like problem.
166    ///
167    /// Returns `(workgroup_size, dispatch_dims)`.
168    #[must_use]
169    pub fn plan_2d(
170        width: u32,
171        height: u32,
172        strategy: WorkgroupStrategy,
173    ) -> (WorkgroupSize, DispatchDimensions) {
174        let (wg_x, wg_y) = match strategy {
175            WorkgroupStrategy::Square => (16, 16),
176            WorkgroupStrategy::Wide => (32, 8),
177            WorkgroupStrategy::Tall => (8, 32),
178            WorkgroupStrategy::WarpAligned => (16, 16),
179            WorkgroupStrategy::Minimal => (8, 8),
180        };
181        let wg = WorkgroupSize::flat(wg_x, wg_y);
182        let groups_x = div_ceil(width, wg_x);
183        let groups_y = div_ceil(height, wg_y);
184        (wg, DispatchDimensions::new(groups_x, groups_y, 1))
185    }
186
187    /// Compute 3D dispatch dimensions.
188    ///
189    /// Returns `(workgroup_size, dispatch_dims)`.
190    #[must_use]
191    pub fn plan_3d(width: u32, height: u32, depth: u32) -> (WorkgroupSize, DispatchDimensions) {
192        let wg = WorkgroupSize::new(8, 8, 4);
193        let groups_x = div_ceil(width, 8);
194        let groups_y = div_ceil(height, 8);
195        let groups_z = div_ceil(depth, 4);
196        (wg, DispatchDimensions::new(groups_x, groups_y, groups_z))
197    }
198
199    /// Estimate efficiency ratio of a dispatch (useful work / total work).
200    #[allow(clippy::cast_precision_loss)]
201    #[allow(clippy::manual_checked_ops)]
202    #[must_use]
203    pub fn efficiency(
204        problem_size: (u32, u32),
205        workgroup: &WorkgroupSize,
206        dispatch: &DispatchDimensions,
207    ) -> f64 {
208        let useful = u64::from(problem_size.0) * u64::from(problem_size.1);
209        let total = dispatch.total_invocations(workgroup);
210        if total == 0 {
211            return 0.0;
212        }
213        useful as f64 / total as f64
214    }
215}
216
217/// Integer ceiling division.
218fn div_ceil(a: u32, b: u32) -> u32 {
219    a.div_ceil(b)
220}
221
222// ============================================================================
223// Device-aware auto-tuning
224// ============================================================================
225
226/// GPU device limits relevant to workgroup sizing.
227#[derive(Debug, Clone, Copy)]
228pub struct DeviceLimits {
229    /// Maximum workgroup size per dimension.
230    pub max_workgroup_size_per_dim: u32,
231    /// Maximum total invocations per workgroup.
232    pub max_workgroup_total_invocations: u32,
233    /// Maximum shared memory per workgroup in bytes.
234    pub max_shared_memory_bytes: u32,
235    /// Preferred warp/wavefront size (0 = unknown).
236    pub subgroup_size: u32,
237    /// Maximum dispatch groups per dimension.
238    pub max_dispatch_per_dim: u32,
239}
240
241impl Default for DeviceLimits {
242    fn default() -> Self {
243        Self {
244            max_workgroup_size_per_dim: MAX_WORKGROUP_DIM,
245            max_workgroup_total_invocations: MAX_WORKGROUP_TOTAL,
246            max_shared_memory_bytes: 49152, // 48 KB
247            subgroup_size: WARP_SIZE,
248            max_dispatch_per_dim: 65535,
249        }
250    }
251}
252
253impl DeviceLimits {
254    /// Create `DeviceLimits` from `wgpu::Limits`.
255    #[must_use]
256    pub fn from_wgpu(limits: &wgpu::Limits) -> Self {
257        Self {
258            max_workgroup_size_per_dim: limits
259                .max_compute_workgroup_size_x
260                .min(limits.max_compute_workgroup_size_y)
261                .min(limits.max_compute_workgroup_size_z),
262            max_workgroup_total_invocations: limits.max_compute_invocations_per_workgroup,
263            max_shared_memory_bytes: limits.max_compute_workgroup_storage_size,
264            subgroup_size: WARP_SIZE, // wgpu doesn't expose this directly
265            max_dispatch_per_dim: limits.max_compute_workgroups_per_dimension,
266        }
267    }
268}
269
270/// Auto-tuner that selects optimal workgroup sizes based on device limits.
271pub struct WorkgroupAutoTuner {
272    limits: DeviceLimits,
273}
274
275impl WorkgroupAutoTuner {
276    /// Create a new auto-tuner with the given device limits.
277    #[must_use]
278    pub fn new(limits: DeviceLimits) -> Self {
279        Self { limits }
280    }
281
282    /// Create a new auto-tuner with default limits.
283    #[must_use]
284    pub fn with_defaults() -> Self {
285        Self::new(DeviceLimits::default())
286    }
287
288    /// Get the device limits.
289    #[must_use]
290    pub fn limits(&self) -> &DeviceLimits {
291        &self.limits
292    }
293
294    /// Auto-tune a 1D workgroup size for a linear problem.
295    ///
296    /// Picks the largest warp-aligned size that fits within device limits.
297    #[must_use]
298    pub fn tune_1d(&self, total_elements: u32) -> (WorkgroupSize, DispatchDimensions) {
299        let subgroup = self.limits.subgroup_size.max(1);
300        let max_total = self.limits.max_workgroup_total_invocations;
301        let max_dim = self.limits.max_workgroup_size_per_dim;
302
303        // Start from 256, clamp to device limits, align to subgroup size.
304        let mut size = 256u32.min(max_total).min(max_dim);
305        // Round down to subgroup alignment.
306        if let Some(aligned) = size.checked_div(subgroup) {
307            size = aligned * subgroup;
308        }
309        size = size.max(subgroup).max(1);
310
311        // If problem is small, use smaller workgroups.
312        if total_elements < size * 4 {
313            let smaller = (total_elements.div_ceil(subgroup.max(1))) * subgroup.max(1);
314            size = smaller.max(subgroup.max(1)).min(size);
315        }
316
317        let wg = WorkgroupSize::linear(size);
318        let groups = div_ceil(total_elements, size).min(self.limits.max_dispatch_per_dim);
319        (wg, DispatchDimensions::linear(groups))
320    }
321
322    /// Auto-tune a 2D workgroup size for an image-like problem.
323    ///
324    /// Balances squareness with warp alignment and device limits.
325    #[must_use]
326    #[allow(clippy::manual_checked_ops)]
327    pub fn tune_2d(&self, width: u32, height: u32) -> (WorkgroupSize, DispatchDimensions) {
328        let max_total = self.limits.max_workgroup_total_invocations;
329        let max_dim = self.limits.max_workgroup_size_per_dim;
330        let subgroup = self.limits.subgroup_size.max(1);
331
332        // Candidate workgroup sizes (prefer multiples of subgroup_size).
333        let candidates: [(u32, u32); 6] = [
334            (16, 16), // 256 threads — good default
335            (32, 8),  // 256 threads — wide, good for row-major access
336            (8, 32),  // 256 threads — tall
337            (16, 8),  // 128 threads — smaller
338            (8, 8),   // 64 threads — small
339            (32, 16), // 512 threads — large
340        ];
341
342        let mut best_wg = WorkgroupSize::flat(8, 8);
343        let mut best_efficiency = 0.0_f64;
344
345        for &(wx, wy) in &candidates {
346            if wx > max_dim || wy > max_dim || wx * wy > max_total {
347                continue;
348            }
349            // Prefer warp-aligned total.
350            let total = wx * wy;
351            if total % subgroup != 0 {
352                continue;
353            }
354
355            let gx = div_ceil(width, wx).min(self.limits.max_dispatch_per_dim);
356            let gy = div_ceil(height, wy).min(self.limits.max_dispatch_per_dim);
357            let total_invocations = (gx as u64) * (gy as u64) * (total as u64);
358            let useful = (width as u64) * (height as u64);
359            let eff = if total_invocations > 0 {
360                useful as f64 / total_invocations as f64
361            } else {
362                0.0
363            };
364
365            if eff > best_efficiency {
366                best_efficiency = eff;
367                best_wg = WorkgroupSize::flat(wx, wy);
368            }
369        }
370
371        let gx = div_ceil(width, best_wg.x).min(self.limits.max_dispatch_per_dim);
372        let gy = div_ceil(height, best_wg.y).min(self.limits.max_dispatch_per_dim);
373        (best_wg, DispatchDimensions::new(gx, gy, 1))
374    }
375
376    /// Auto-tune for a 2D problem with shared memory requirements.
377    ///
378    /// Takes into account the per-pixel shared memory usage and ensures
379    /// the workgroup's shared memory fits within device limits.
380    #[must_use]
381    pub fn tune_2d_with_shared_memory(
382        &self,
383        width: u32,
384        height: u32,
385        shared_bytes_per_pixel: u32,
386    ) -> (WorkgroupSize, DispatchDimensions) {
387        let max_shared = self.limits.max_shared_memory_bytes;
388        let max_total = self.limits.max_workgroup_total_invocations;
389        let max_dim = self.limits.max_workgroup_size_per_dim;
390        let subgroup = self.limits.subgroup_size.max(1);
391
392        // Find largest square-ish workgroup whose shared mem fits.
393        let mut best_side = 8u32;
394        for candidate_side in &[32u32, 24, 16, 12, 8] {
395            let side = *candidate_side;
396            let total = side * side;
397            if total > max_total || side > max_dim {
398                continue;
399            }
400            if total % subgroup != 0 {
401                continue;
402            }
403            let shared_needed = total * shared_bytes_per_pixel;
404            if shared_needed <= max_shared {
405                best_side = side;
406                break;
407            }
408        }
409
410        let wg = WorkgroupSize::flat(best_side, best_side);
411        let gx = div_ceil(width, best_side).min(self.limits.max_dispatch_per_dim);
412        let gy = div_ceil(height, best_side).min(self.limits.max_dispatch_per_dim);
413        (wg, DispatchDimensions::new(gx, gy, 1))
414    }
415
416    /// Estimate the efficiency of a given configuration.
417    #[must_use]
418    pub fn estimate_efficiency(
419        &self,
420        problem_width: u32,
421        problem_height: u32,
422        workgroup: &WorkgroupSize,
423    ) -> f64 {
424        let gx = div_ceil(problem_width, workgroup.x);
425        let gy = div_ceil(problem_height, workgroup.y);
426        let dispatch = DispatchDimensions::new(gx, gy, 1);
427        WorkgroupPlanner::efficiency((problem_width, problem_height), workgroup, &dispatch)
428    }
429}
430
431/// Shared memory layout descriptor for a workgroup.
432#[derive(Debug, Clone, PartialEq, Eq)]
433pub struct SharedMemoryLayout {
434    /// Size in bytes per workgroup.
435    pub size_bytes: u32,
436    /// Alignment requirement in bytes.
437    pub alignment: u32,
438    /// Number of elements (stride-based).
439    pub element_count: u32,
440    /// Size per element in bytes.
441    pub element_size: u32,
442}
443
444impl SharedMemoryLayout {
445    /// Create a new shared memory layout.
446    #[must_use]
447    pub fn new(element_count: u32, element_size: u32, alignment: u32) -> Self {
448        let aligned_element = round_up(element_size, alignment);
449        Self {
450            size_bytes: element_count * aligned_element,
451            alignment,
452            element_count,
453            element_size,
454        }
455    }
456
457    /// Create a layout for float data.
458    #[must_use]
459    pub fn floats(count: u32) -> Self {
460        Self::new(count, 4, 4)
461    }
462
463    /// Create a layout for vec4 data.
464    #[must_use]
465    pub fn vec4s(count: u32) -> Self {
466        Self::new(count, 16, 16)
467    }
468
469    /// Check if the layout fits within the typical shared memory limit (48 KB).
470    #[must_use]
471    pub fn fits_in_shared_memory(&self) -> bool {
472        self.size_bytes <= 49152 // 48 * 1024
473    }
474}
475
476/// Round a value up to a given alignment.
477fn round_up(value: u32, alignment: u32) -> u32 {
478    if alignment == 0 {
479        return value;
480    }
481    value.div_ceil(alignment) * alignment
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487
488    #[test]
489    fn test_workgroup_size_default() {
490        let wg = WorkgroupSize::default();
491        assert_eq!(wg.x, 8);
492        assert_eq!(wg.y, 8);
493        assert_eq!(wg.z, 1);
494        assert_eq!(wg.total(), 64);
495    }
496
497    #[test]
498    fn test_workgroup_size_linear() {
499        let wg = WorkgroupSize::linear(256);
500        assert_eq!(wg.total(), 256);
501        assert!(wg.is_valid());
502        assert!(wg.is_warp_aligned());
503    }
504
505    #[test]
506    fn test_workgroup_size_flat() {
507        let wg = WorkgroupSize::flat(16, 16);
508        assert_eq!(wg.total(), 256);
509        assert!(wg.is_valid());
510    }
511
512    #[test]
513    fn test_workgroup_size_3d() {
514        let wg = WorkgroupSize::new(8, 8, 4);
515        assert_eq!(wg.total(), 256);
516        assert!(wg.is_valid());
517    }
518
519    #[test]
520    fn test_workgroup_size_invalid_exceeds_max() {
521        let wg = WorkgroupSize::new(1025, 1, 1);
522        assert!(!wg.is_valid());
523    }
524
525    #[test]
526    fn test_workgroup_size_invalid_exceeds_total() {
527        let wg = WorkgroupSize::new(32, 64, 1);
528        assert_eq!(wg.total(), 2048);
529        assert!(!wg.is_valid());
530    }
531
532    #[test]
533    fn test_dispatch_dimensions_linear() {
534        let d = DispatchDimensions::linear(10);
535        assert_eq!(d.total_groups(), 10);
536    }
537
538    #[test]
539    fn test_dispatch_total_invocations() {
540        let wg = WorkgroupSize::flat(16, 16);
541        let d = DispatchDimensions::new(4, 4, 1);
542        assert_eq!(d.total_invocations(&wg), 4096);
543    }
544
545    #[test]
546    fn test_plan_1d() {
547        let (wg, d) = WorkgroupPlanner::plan_1d(1000, WorkgroupStrategy::WarpAligned);
548        assert_eq!(wg.x, 256);
549        assert!(d.groups_x * wg.x >= 1000);
550    }
551
552    #[test]
553    fn test_plan_2d_square() {
554        let (wg, d) = WorkgroupPlanner::plan_2d(1920, 1080, WorkgroupStrategy::Square);
555        assert_eq!(wg.x, 16);
556        assert_eq!(wg.y, 16);
557        assert!(d.groups_x * wg.x >= 1920);
558        assert!(d.groups_y * wg.y >= 1080);
559    }
560
561    #[test]
562    fn test_plan_2d_wide() {
563        let (wg, d) = WorkgroupPlanner::plan_2d(3840, 2160, WorkgroupStrategy::Wide);
564        assert_eq!(wg.x, 32);
565        assert_eq!(wg.y, 8);
566        assert!(d.groups_x * wg.x >= 3840);
567        assert!(d.groups_y * wg.y >= 2160);
568    }
569
570    #[test]
571    fn test_plan_3d() {
572        let (wg, d) = WorkgroupPlanner::plan_3d(64, 64, 16);
573        assert_eq!(wg.total(), 256);
574        assert_eq!(d.groups_x, 8);
575        assert_eq!(d.groups_y, 8);
576        assert_eq!(d.groups_z, 4);
577    }
578
579    #[test]
580    fn test_efficiency_perfect() {
581        let wg = WorkgroupSize::flat(16, 16);
582        let d = DispatchDimensions::new(2, 2, 1);
583        let eff = WorkgroupPlanner::efficiency((32, 32), &wg, &d);
584        assert!((eff - 1.0).abs() < 1e-9);
585    }
586
587    #[test]
588    fn test_efficiency_partial() {
589        let wg = WorkgroupSize::flat(16, 16);
590        let d = DispatchDimensions::new(1, 1, 1);
591        let eff = WorkgroupPlanner::efficiency((10, 10), &wg, &d);
592        assert!(eff < 1.0);
593        assert!(eff > 0.0);
594    }
595
596    #[test]
597    fn test_shared_memory_floats() {
598        let layout = SharedMemoryLayout::floats(256);
599        assert_eq!(layout.size_bytes, 1024);
600        assert!(layout.fits_in_shared_memory());
601    }
602
603    #[test]
604    fn test_shared_memory_vec4s() {
605        let layout = SharedMemoryLayout::vec4s(64);
606        assert_eq!(layout.size_bytes, 1024);
607        assert!(layout.fits_in_shared_memory());
608    }
609
610    #[test]
611    fn test_shared_memory_exceeds_limit() {
612        let layout = SharedMemoryLayout::new(50000, 4, 4);
613        assert!(!layout.fits_in_shared_memory());
614    }
615
616    #[test]
617    fn test_div_ceil() {
618        assert_eq!(div_ceil(10, 3), 4);
619        assert_eq!(div_ceil(9, 3), 3);
620        assert_eq!(div_ceil(1, 256), 1);
621    }
622
623    #[test]
624    fn test_round_up() {
625        assert_eq!(round_up(5, 4), 8);
626        assert_eq!(round_up(8, 4), 8);
627        assert_eq!(round_up(0, 4), 0);
628        assert_eq!(round_up(7, 0), 7);
629    }
630
631    #[test]
632    fn test_warp_alignment() {
633        let wg = WorkgroupSize::linear(64);
634        assert!(wg.is_warp_aligned());
635        let wg2 = WorkgroupSize::linear(33);
636        assert!(!wg2.is_warp_aligned());
637    }
638
639    // --- Auto-tuner tests ---
640
641    #[test]
642    fn test_auto_tuner_1d_default_limits() {
643        let tuner = WorkgroupAutoTuner::with_defaults();
644        let (wg, dispatch) = tuner.tune_1d(10000);
645        assert!(wg.is_valid(), "workgroup must be valid");
646        assert!(wg.is_warp_aligned(), "should be warp-aligned");
647        assert!(dispatch.groups_x * wg.x >= 10000, "must cover all elements");
648    }
649
650    #[test]
651    fn test_auto_tuner_1d_small_problem() {
652        let tuner = WorkgroupAutoTuner::with_defaults();
653        let (wg, dispatch) = tuner.tune_1d(64);
654        assert!(wg.is_valid());
655        assert!(dispatch.groups_x * wg.x >= 64);
656    }
657
658    #[test]
659    fn test_auto_tuner_2d_1080p() {
660        let tuner = WorkgroupAutoTuner::with_defaults();
661        let (wg, dispatch) = tuner.tune_2d(1920, 1080);
662        assert!(wg.is_valid());
663        assert!(wg.is_warp_aligned());
664        assert!(dispatch.groups_x * wg.x >= 1920);
665        assert!(dispatch.groups_y * wg.y >= 1080);
666    }
667
668    #[test]
669    fn test_auto_tuner_2d_4k() {
670        let tuner = WorkgroupAutoTuner::with_defaults();
671        let (wg, dispatch) = tuner.tune_2d(3840, 2160);
672        assert!(wg.is_valid());
673        assert!(dispatch.groups_x * wg.x >= 3840);
674        assert!(dispatch.groups_y * wg.y >= 2160);
675    }
676
677    #[test]
678    fn test_auto_tuner_2d_small_image() {
679        let tuner = WorkgroupAutoTuner::with_defaults();
680        let (wg, dispatch) = tuner.tune_2d(16, 16);
681        assert!(wg.is_valid());
682        assert!(dispatch.groups_x * wg.x >= 16);
683        assert!(dispatch.groups_y * wg.y >= 16);
684    }
685
686    #[test]
687    fn test_auto_tuner_2d_non_square() {
688        let tuner = WorkgroupAutoTuner::with_defaults();
689        let (wg, dispatch) = tuner.tune_2d(4096, 32);
690        assert!(wg.is_valid());
691        assert!(dispatch.groups_x * wg.x >= 4096);
692        assert!(dispatch.groups_y * wg.y >= 32);
693    }
694
695    #[test]
696    fn test_auto_tuner_with_shared_memory() {
697        let tuner = WorkgroupAutoTuner::with_defaults();
698        // 64 bytes per pixel shared memory — should pick smaller workgroup.
699        let (wg, dispatch) = tuner.tune_2d_with_shared_memory(1920, 1080, 64);
700        let shared_used = wg.total() * 64;
701        assert!(
702            shared_used <= tuner.limits().max_shared_memory_bytes,
703            "shared memory {} must fit in {} bytes",
704            shared_used,
705            tuner.limits().max_shared_memory_bytes
706        );
707        assert!(dispatch.groups_x * wg.x >= 1920);
708        assert!(dispatch.groups_y * wg.y >= 1080);
709    }
710
711    #[test]
712    fn test_auto_tuner_with_large_shared_memory() {
713        let tuner = WorkgroupAutoTuner::with_defaults();
714        // Very large per-pixel shared memory — should fall back to small workgroup.
715        let (wg, dispatch) = tuner.tune_2d_with_shared_memory(256, 256, 512);
716        let shared_used = wg.total() * 512;
717        assert!(shared_used <= tuner.limits().max_shared_memory_bytes);
718        assert!(dispatch.groups_x * wg.x >= 256);
719    }
720
721    #[test]
722    fn test_auto_tuner_respects_constrained_limits() {
723        let limits = DeviceLimits {
724            max_workgroup_size_per_dim: 128,
725            max_workgroup_total_invocations: 128,
726            max_shared_memory_bytes: 16384,
727            subgroup_size: 16,
728            max_dispatch_per_dim: 32768,
729        };
730        let tuner = WorkgroupAutoTuner::new(limits);
731        let (wg, _) = tuner.tune_2d(1920, 1080);
732        assert!(wg.x <= 128);
733        assert!(wg.y <= 128);
734        assert!(wg.total() <= 128);
735    }
736
737    #[test]
738    fn test_auto_tuner_efficiency_estimate() {
739        let tuner = WorkgroupAutoTuner::with_defaults();
740        let wg = WorkgroupSize::flat(16, 16);
741        let eff = tuner.estimate_efficiency(32, 32, &wg);
742        assert!(
743            (eff - 1.0).abs() < 1e-9,
744            "perfect fit should have efficiency 1.0"
745        );
746
747        let eff2 = tuner.estimate_efficiency(17, 17, &wg);
748        assert!(
749            eff2 < 1.0,
750            "non-aligned problem should have < 1.0 efficiency"
751        );
752        assert!(eff2 > 0.0);
753    }
754
755    #[test]
756    fn test_device_limits_default() {
757        let limits = DeviceLimits::default();
758        assert_eq!(limits.max_workgroup_size_per_dim, 1024);
759        assert_eq!(limits.max_workgroup_total_invocations, 1024);
760        assert_eq!(limits.max_shared_memory_bytes, 49152);
761        assert_eq!(limits.subgroup_size, 32);
762    }
763}