Skip to main content

oximedia_gpu/
workgroup.rs

1#![allow(dead_code)]
2#![allow(clippy::cast_precision_loss)]
3//! GPU workgroup configuration and dispatch sizing.
4//!
5//! This module provides utilities for computing optimal workgroup sizes
6//! and dispatch dimensions for GPU compute shaders. Proper workgroup sizing
7//! is critical for achieving good GPU utilization.
8
9/// Maximum workgroup size per dimension on most GPUs.
10const MAX_WORKGROUP_DIM: u32 = 1024;
11
12/// Maximum total invocations per workgroup (typical limit).
13const MAX_WORKGROUP_TOTAL: u32 = 1024;
14
15/// Preferred warp/wavefront size for NVIDIA/AMD GPUs.
16const WARP_SIZE: u32 = 32;
17
18/// Workgroup size in 3 dimensions.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub struct WorkgroupSize {
21    /// Size in X dimension.
22    pub x: u32,
23    /// Size in Y dimension.
24    pub y: u32,
25    /// Size in Z dimension.
26    pub z: u32,
27}
28
29impl WorkgroupSize {
30    /// Create a new workgroup size.
31    #[must_use]
32    pub fn new(x: u32, y: u32, z: u32) -> Self {
33        Self { x, y, z }
34    }
35
36    /// Create a 1D workgroup size.
37    #[must_use]
38    pub fn linear(size: u32) -> Self {
39        Self {
40            x: size,
41            y: 1,
42            z: 1,
43        }
44    }
45
46    /// Create a 2D workgroup size.
47    #[must_use]
48    pub fn flat(x: u32, y: u32) -> Self {
49        Self { x, y, z: 1 }
50    }
51
52    /// Total number of invocations in this workgroup.
53    #[must_use]
54    pub fn total(&self) -> u32 {
55        self.x * self.y * self.z
56    }
57
58    /// Check if the workgroup size is valid (within typical limits).
59    #[must_use]
60    pub fn is_valid(&self) -> bool {
61        self.x > 0
62            && self.y > 0
63            && self.z > 0
64            && self.x <= MAX_WORKGROUP_DIM
65            && self.y <= MAX_WORKGROUP_DIM
66            && self.z <= MAX_WORKGROUP_DIM
67            && self.total() <= MAX_WORKGROUP_TOTAL
68    }
69
70    /// Check if the total size is a multiple of the warp size.
71    #[must_use]
72    pub fn is_warp_aligned(&self) -> bool {
73        self.total() % WARP_SIZE == 0
74    }
75}
76
77impl Default for WorkgroupSize {
78    fn default() -> Self {
79        Self { x: 8, y: 8, z: 1 }
80    }
81}
82
83/// Dispatch dimensions for launching a compute shader.
84#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
85pub struct DispatchDimensions {
86    /// Number of workgroups in X.
87    pub groups_x: u32,
88    /// Number of workgroups in Y.
89    pub groups_y: u32,
90    /// Number of workgroups in Z.
91    pub groups_z: u32,
92}
93
94impl DispatchDimensions {
95    /// Create new dispatch dimensions.
96    #[must_use]
97    pub fn new(groups_x: u32, groups_y: u32, groups_z: u32) -> Self {
98        Self {
99            groups_x,
100            groups_y,
101            groups_z,
102        }
103    }
104
105    /// Create 1D dispatch dimensions.
106    #[must_use]
107    pub fn linear(groups: u32) -> Self {
108        Self {
109            groups_x: groups,
110            groups_y: 1,
111            groups_z: 1,
112        }
113    }
114
115    /// Total number of workgroups.
116    #[must_use]
117    pub fn total_groups(&self) -> u64 {
118        u64::from(self.groups_x) * u64::from(self.groups_y) * u64::from(self.groups_z)
119    }
120
121    /// Total number of invocations given a workgroup size.
122    #[must_use]
123    pub fn total_invocations(&self, workgroup: &WorkgroupSize) -> u64 {
124        self.total_groups() * u64::from(workgroup.total())
125    }
126}
127
128/// Strategy for choosing workgroup sizes.
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
130pub enum WorkgroupStrategy {
131    /// Use a square workgroup (e.g. 16x16 for 2D).
132    Square,
133    /// Prefer wide workgroups (e.g. 256x1).
134    Wide,
135    /// Prefer tall workgroups (e.g. 1x256).
136    Tall,
137    /// Optimize for warp/wavefront alignment.
138    WarpAligned,
139    /// Use the smallest valid workgroup size.
140    Minimal,
141}
142
143/// Compute optimal workgroup size and dispatch dimensions.
144pub struct WorkgroupPlanner;
145
146impl WorkgroupPlanner {
147    /// Compute 1D dispatch dimensions for a linear problem.
148    ///
149    /// Returns `(workgroup_size, dispatch_dims)`.
150    #[must_use]
151    pub fn plan_1d(
152        total_elements: u32,
153        strategy: WorkgroupStrategy,
154    ) -> (WorkgroupSize, DispatchDimensions) {
155        let wg_size = match strategy {
156            WorkgroupStrategy::WarpAligned => 256,
157            WorkgroupStrategy::Minimal => 64,
158            _ => 128,
159        };
160        let wg = WorkgroupSize::linear(wg_size);
161        let groups = div_ceil(total_elements, wg_size);
162        (wg, DispatchDimensions::linear(groups))
163    }
164
165    /// Compute 2D dispatch dimensions for an image-like problem.
166    ///
167    /// Returns `(workgroup_size, dispatch_dims)`.
168    #[must_use]
169    pub fn plan_2d(
170        width: u32,
171        height: u32,
172        strategy: WorkgroupStrategy,
173    ) -> (WorkgroupSize, DispatchDimensions) {
174        let (wg_x, wg_y) = match strategy {
175            WorkgroupStrategy::Square => (16, 16),
176            WorkgroupStrategy::Wide => (32, 8),
177            WorkgroupStrategy::Tall => (8, 32),
178            WorkgroupStrategy::WarpAligned => (16, 16),
179            WorkgroupStrategy::Minimal => (8, 8),
180        };
181        let wg = WorkgroupSize::flat(wg_x, wg_y);
182        let groups_x = div_ceil(width, wg_x);
183        let groups_y = div_ceil(height, wg_y);
184        (wg, DispatchDimensions::new(groups_x, groups_y, 1))
185    }
186
187    /// Compute 3D dispatch dimensions.
188    ///
189    /// Returns `(workgroup_size, dispatch_dims)`.
190    #[must_use]
191    pub fn plan_3d(width: u32, height: u32, depth: u32) -> (WorkgroupSize, DispatchDimensions) {
192        let wg = WorkgroupSize::new(8, 8, 4);
193        let groups_x = div_ceil(width, 8);
194        let groups_y = div_ceil(height, 8);
195        let groups_z = div_ceil(depth, 4);
196        (wg, DispatchDimensions::new(groups_x, groups_y, groups_z))
197    }
198
199    /// Estimate efficiency ratio of a dispatch (useful work / total work).
200    #[allow(clippy::cast_precision_loss)]
201    #[must_use]
202    pub fn efficiency(
203        problem_size: (u32, u32),
204        workgroup: &WorkgroupSize,
205        dispatch: &DispatchDimensions,
206    ) -> f64 {
207        let useful = u64::from(problem_size.0) * u64::from(problem_size.1);
208        let total = dispatch.total_invocations(workgroup);
209        if total == 0 {
210            return 0.0;
211        }
212        useful as f64 / total as f64
213    }
214}
215
216/// Integer ceiling division.
217fn div_ceil(a: u32, b: u32) -> u32 {
218    a.div_ceil(b)
219}
220
221/// Shared memory layout descriptor for a workgroup.
222#[derive(Debug, Clone, PartialEq, Eq)]
223pub struct SharedMemoryLayout {
224    /// Size in bytes per workgroup.
225    pub size_bytes: u32,
226    /// Alignment requirement in bytes.
227    pub alignment: u32,
228    /// Number of elements (stride-based).
229    pub element_count: u32,
230    /// Size per element in bytes.
231    pub element_size: u32,
232}
233
234impl SharedMemoryLayout {
235    /// Create a new shared memory layout.
236    #[must_use]
237    pub fn new(element_count: u32, element_size: u32, alignment: u32) -> Self {
238        let aligned_element = round_up(element_size, alignment);
239        Self {
240            size_bytes: element_count * aligned_element,
241            alignment,
242            element_count,
243            element_size,
244        }
245    }
246
247    /// Create a layout for float data.
248    #[must_use]
249    pub fn floats(count: u32) -> Self {
250        Self::new(count, 4, 4)
251    }
252
253    /// Create a layout for vec4 data.
254    #[must_use]
255    pub fn vec4s(count: u32) -> Self {
256        Self::new(count, 16, 16)
257    }
258
259    /// Check if the layout fits within the typical shared memory limit (48 KB).
260    #[must_use]
261    pub fn fits_in_shared_memory(&self) -> bool {
262        self.size_bytes <= 49152 // 48 * 1024
263    }
264}
265
266/// Round a value up to a given alignment.
267fn round_up(value: u32, alignment: u32) -> u32 {
268    if alignment == 0 {
269        return value;
270    }
271    value.div_ceil(alignment) * alignment
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277
278    #[test]
279    fn test_workgroup_size_default() {
280        let wg = WorkgroupSize::default();
281        assert_eq!(wg.x, 8);
282        assert_eq!(wg.y, 8);
283        assert_eq!(wg.z, 1);
284        assert_eq!(wg.total(), 64);
285    }
286
287    #[test]
288    fn test_workgroup_size_linear() {
289        let wg = WorkgroupSize::linear(256);
290        assert_eq!(wg.total(), 256);
291        assert!(wg.is_valid());
292        assert!(wg.is_warp_aligned());
293    }
294
295    #[test]
296    fn test_workgroup_size_flat() {
297        let wg = WorkgroupSize::flat(16, 16);
298        assert_eq!(wg.total(), 256);
299        assert!(wg.is_valid());
300    }
301
302    #[test]
303    fn test_workgroup_size_3d() {
304        let wg = WorkgroupSize::new(8, 8, 4);
305        assert_eq!(wg.total(), 256);
306        assert!(wg.is_valid());
307    }
308
309    #[test]
310    fn test_workgroup_size_invalid_exceeds_max() {
311        let wg = WorkgroupSize::new(1025, 1, 1);
312        assert!(!wg.is_valid());
313    }
314
315    #[test]
316    fn test_workgroup_size_invalid_exceeds_total() {
317        let wg = WorkgroupSize::new(32, 64, 1);
318        assert_eq!(wg.total(), 2048);
319        assert!(!wg.is_valid());
320    }
321
322    #[test]
323    fn test_dispatch_dimensions_linear() {
324        let d = DispatchDimensions::linear(10);
325        assert_eq!(d.total_groups(), 10);
326    }
327
328    #[test]
329    fn test_dispatch_total_invocations() {
330        let wg = WorkgroupSize::flat(16, 16);
331        let d = DispatchDimensions::new(4, 4, 1);
332        assert_eq!(d.total_invocations(&wg), 4096);
333    }
334
335    #[test]
336    fn test_plan_1d() {
337        let (wg, d) = WorkgroupPlanner::plan_1d(1000, WorkgroupStrategy::WarpAligned);
338        assert_eq!(wg.x, 256);
339        assert!(d.groups_x * wg.x >= 1000);
340    }
341
342    #[test]
343    fn test_plan_2d_square() {
344        let (wg, d) = WorkgroupPlanner::plan_2d(1920, 1080, WorkgroupStrategy::Square);
345        assert_eq!(wg.x, 16);
346        assert_eq!(wg.y, 16);
347        assert!(d.groups_x * wg.x >= 1920);
348        assert!(d.groups_y * wg.y >= 1080);
349    }
350
351    #[test]
352    fn test_plan_2d_wide() {
353        let (wg, d) = WorkgroupPlanner::plan_2d(3840, 2160, WorkgroupStrategy::Wide);
354        assert_eq!(wg.x, 32);
355        assert_eq!(wg.y, 8);
356        assert!(d.groups_x * wg.x >= 3840);
357        assert!(d.groups_y * wg.y >= 2160);
358    }
359
360    #[test]
361    fn test_plan_3d() {
362        let (wg, d) = WorkgroupPlanner::plan_3d(64, 64, 16);
363        assert_eq!(wg.total(), 256);
364        assert_eq!(d.groups_x, 8);
365        assert_eq!(d.groups_y, 8);
366        assert_eq!(d.groups_z, 4);
367    }
368
369    #[test]
370    fn test_efficiency_perfect() {
371        let wg = WorkgroupSize::flat(16, 16);
372        let d = DispatchDimensions::new(2, 2, 1);
373        let eff = WorkgroupPlanner::efficiency((32, 32), &wg, &d);
374        assert!((eff - 1.0).abs() < 1e-9);
375    }
376
377    #[test]
378    fn test_efficiency_partial() {
379        let wg = WorkgroupSize::flat(16, 16);
380        let d = DispatchDimensions::new(1, 1, 1);
381        let eff = WorkgroupPlanner::efficiency((10, 10), &wg, &d);
382        assert!(eff < 1.0);
383        assert!(eff > 0.0);
384    }
385
386    #[test]
387    fn test_shared_memory_floats() {
388        let layout = SharedMemoryLayout::floats(256);
389        assert_eq!(layout.size_bytes, 1024);
390        assert!(layout.fits_in_shared_memory());
391    }
392
393    #[test]
394    fn test_shared_memory_vec4s() {
395        let layout = SharedMemoryLayout::vec4s(64);
396        assert_eq!(layout.size_bytes, 1024);
397        assert!(layout.fits_in_shared_memory());
398    }
399
400    #[test]
401    fn test_shared_memory_exceeds_limit() {
402        let layout = SharedMemoryLayout::new(50000, 4, 4);
403        assert!(!layout.fits_in_shared_memory());
404    }
405
406    #[test]
407    fn test_div_ceil() {
408        assert_eq!(div_ceil(10, 3), 4);
409        assert_eq!(div_ceil(9, 3), 3);
410        assert_eq!(div_ceil(1, 256), 1);
411    }
412
413    #[test]
414    fn test_round_up() {
415        assert_eq!(round_up(5, 4), 8);
416        assert_eq!(round_up(8, 4), 8);
417        assert_eq!(round_up(0, 4), 0);
418        assert_eq!(round_up(7, 0), 7);
419    }
420
421    #[test]
422    fn test_warp_alignment() {
423        let wg = WorkgroupSize::linear(64);
424        assert!(wg.is_warp_aligned());
425        let wg2 = WorkgroupSize::linear(33);
426        assert!(!wg2.is_warp_aligned());
427    }
428}