Skip to main content

oximedia_gpu/
indirect_dispatch.rs

1#![allow(dead_code)]
2//! Indirect dispatch support for GPU compute kernels.
3//!
4//! Indirect dispatch allows the GPU to determine workgroup counts from
5//! a buffer rather than from CPU-side values. This is essential for
6//! data-dependent workloads where the number of items is computed by
7//! a previous GPU pass (e.g., compaction, prefix-sum driven dispatch).
8
9use std::fmt;
10
11/// Arguments for an indirect compute dispatch, matching the layout
12/// expected by `wgpu::RenderPass::dispatch_workgroups_indirect`.
13#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[repr(C)]
15pub struct IndirectDispatchArgs {
16    /// Number of workgroups in the X dimension.
17    pub x: u32,
18    /// Number of workgroups in the Y dimension.
19    pub y: u32,
20    /// Number of workgroups in the Z dimension.
21    pub z: u32,
22}
23
24impl IndirectDispatchArgs {
25    /// Create new dispatch arguments.
26    pub fn new(x: u32, y: u32, z: u32) -> Self {
27        Self { x, y, z }
28    }
29
30    /// Create a 1D dispatch with the given workgroup count.
31    pub fn one_d(x: u32) -> Self {
32        Self { x, y: 1, z: 1 }
33    }
34
35    /// Create a 2D dispatch with the given workgroup counts.
36    pub fn two_d(x: u32, y: u32) -> Self {
37        Self { x, y, z: 1 }
38    }
39
40    /// Total number of workgroups dispatched.
41    pub fn total_workgroups(&self) -> u64 {
42        u64::from(self.x) * u64::from(self.y) * u64::from(self.z)
43    }
44
45    /// Serialize to a byte array (12 bytes, little-endian).
46    pub fn to_bytes(&self) -> [u8; 12] {
47        let mut buf = [0u8; 12];
48        buf[0..4].copy_from_slice(&self.x.to_le_bytes());
49        buf[4..8].copy_from_slice(&self.y.to_le_bytes());
50        buf[8..12].copy_from_slice(&self.z.to_le_bytes());
51        buf
52    }
53
54    /// Deserialize from a byte slice (must be at least 12 bytes).
55    ///
56    /// Returns `None` if the slice is too short.
57    pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
58        if bytes.len() < 12 {
59            return None;
60        }
61        let x = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
62        let y = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
63        let z = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
64        Some(Self { x, y, z })
65    }
66
67    /// Check whether all dimensions are non-zero.
68    pub fn is_valid(&self) -> bool {
69        self.x > 0 && self.y > 0 && self.z > 0
70    }
71}
72
73impl Default for IndirectDispatchArgs {
74    fn default() -> Self {
75        Self { x: 1, y: 1, z: 1 }
76    }
77}
78
79impl fmt::Display for IndirectDispatchArgs {
80    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81        write!(f, "Dispatch({}x{}x{})", self.x, self.y, self.z)
82    }
83}
84
85/// Strategy for computing indirect dispatch arguments from an element count
86/// and a workgroup size.
87#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum DispatchStrategy {
89    /// Simple 1D linear dispatch: ceil(elements / workgroup_size).
90    Linear,
91    /// 2D tiled dispatch for image-like workloads.
92    Tiled2D {
93        /// Tile width in workgroup units.
94        tile_w: u32,
95        /// Tile height in workgroup units.
96        tile_h: u32,
97    },
98    /// 3D volumetric dispatch.
99    Volumetric {
100        /// Volume width in workgroup units.
101        vol_w: u32,
102        /// Volume height in workgroup units.
103        vol_h: u32,
104        /// Volume depth in workgroup units.
105        vol_d: u32,
106    },
107}
108
109/// Compute dispatch arguments from an element count and strategy.
110#[allow(clippy::cast_precision_loss)]
111pub fn compute_dispatch(
112    element_count: u32,
113    workgroup_size: u32,
114    strategy: DispatchStrategy,
115) -> IndirectDispatchArgs {
116    match strategy {
117        DispatchStrategy::Linear => {
118            let groups = (element_count + workgroup_size - 1) / workgroup_size;
119            IndirectDispatchArgs::one_d(groups)
120        }
121        DispatchStrategy::Tiled2D { tile_w, tile_h } => {
122            let gx = (tile_w + workgroup_size - 1) / workgroup_size;
123            let gy = (tile_h + workgroup_size - 1) / workgroup_size;
124            IndirectDispatchArgs::two_d(gx, gy)
125        }
126        DispatchStrategy::Volumetric {
127            vol_w,
128            vol_h,
129            vol_d,
130        } => {
131            let gx = (vol_w + workgroup_size - 1) / workgroup_size;
132            let gy = (vol_h + workgroup_size - 1) / workgroup_size;
133            let gz = (vol_d + workgroup_size - 1) / workgroup_size;
134            IndirectDispatchArgs::new(gx, gy, gz)
135        }
136    }
137}
138
139/// A buffer that holds indirect dispatch arguments.
140///
141/// This represents the GPU-side buffer that would be used for
142/// `dispatch_workgroups_indirect`. In CPU simulation mode it
143/// stores the arguments in-memory for testing.
144pub struct IndirectBuffer {
145    /// The current dispatch arguments.
146    args: IndirectDispatchArgs,
147    /// Label for debugging.
148    label: String,
149    /// Generation counter for change tracking.
150    generation: u64,
151}
152
153impl IndirectBuffer {
154    /// Create a new indirect buffer with default (1,1,1) dispatch.
155    pub fn new(label: &str) -> Self {
156        Self {
157            args: IndirectDispatchArgs::default(),
158            label: label.to_string(),
159            generation: 0,
160        }
161    }
162
163    /// Create an indirect buffer with specific initial arguments.
164    pub fn with_args(label: &str, args: IndirectDispatchArgs) -> Self {
165        Self {
166            args,
167            label: label.to_string(),
168            generation: 0,
169        }
170    }
171
172    /// Update the dispatch arguments.
173    pub fn update(&mut self, args: IndirectDispatchArgs) {
174        self.args = args;
175        self.generation += 1;
176    }
177
178    /// Get the current dispatch arguments.
179    pub fn args(&self) -> IndirectDispatchArgs {
180        self.args
181    }
182
183    /// Get the buffer label.
184    pub fn label(&self) -> &str {
185        &self.label
186    }
187
188    /// Get the generation counter.
189    pub fn generation(&self) -> u64 {
190        self.generation
191    }
192
193    /// Get the buffer size in bytes (always 12 bytes for dispatch args).
194    pub fn size_bytes(&self) -> usize {
195        12
196    }
197
198    /// Serialize the current arguments to bytes.
199    pub fn to_bytes(&self) -> [u8; 12] {
200        self.args.to_bytes()
201    }
202}
203
204/// Validates that dispatch arguments do not exceed device limits.
205pub fn validate_dispatch_limits(
206    args: &IndirectDispatchArgs,
207    max_per_dimension: u32,
208) -> Result<(), String> {
209    if args.x > max_per_dimension {
210        return Err(format!(
211            "X workgroup count {} exceeds limit {}",
212            args.x, max_per_dimension
213        ));
214    }
215    if args.y > max_per_dimension {
216        return Err(format!(
217            "Y workgroup count {} exceeds limit {}",
218            args.y, max_per_dimension
219        ));
220    }
221    if args.z > max_per_dimension {
222        return Err(format!(
223            "Z workgroup count {} exceeds limit {}",
224            args.z, max_per_dimension
225        ));
226    }
227    Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_dispatch_args_new() {
236        let args = IndirectDispatchArgs::new(4, 8, 2);
237        assert_eq!(args.x, 4);
238        assert_eq!(args.y, 8);
239        assert_eq!(args.z, 2);
240    }
241
242    #[test]
243    fn test_dispatch_args_one_d() {
244        let args = IndirectDispatchArgs::one_d(16);
245        assert_eq!(args.x, 16);
246        assert_eq!(args.y, 1);
247        assert_eq!(args.z, 1);
248    }
249
250    #[test]
251    fn test_total_workgroups() {
252        let args = IndirectDispatchArgs::new(4, 8, 2);
253        assert_eq!(args.total_workgroups(), 64);
254    }
255
256    #[test]
257    fn test_to_from_bytes_roundtrip() {
258        let original = IndirectDispatchArgs::new(123, 456, 789);
259        let bytes = original.to_bytes();
260        let restored = IndirectDispatchArgs::from_bytes(&bytes)
261            .expect("deserialization from bytes should succeed");
262        assert_eq!(original, restored);
263    }
264
265    #[test]
266    fn test_from_bytes_too_short() {
267        assert!(IndirectDispatchArgs::from_bytes(&[0u8; 8]).is_none());
268    }
269
270    #[test]
271    fn test_is_valid() {
272        assert!(IndirectDispatchArgs::new(1, 1, 1).is_valid());
273        assert!(!IndirectDispatchArgs::new(0, 1, 1).is_valid());
274        assert!(!IndirectDispatchArgs::new(1, 0, 1).is_valid());
275        assert!(!IndirectDispatchArgs::new(1, 1, 0).is_valid());
276    }
277
278    #[test]
279    fn test_display() {
280        let args = IndirectDispatchArgs::new(4, 8, 2);
281        assert_eq!(format!("{args}"), "Dispatch(4x8x2)");
282    }
283
284    #[test]
285    fn test_compute_dispatch_linear() {
286        let args = compute_dispatch(1000, 64, DispatchStrategy::Linear);
287        // ceil(1000/64) = 16
288        assert_eq!(args.x, 16);
289        assert_eq!(args.y, 1);
290        assert_eq!(args.z, 1);
291    }
292
293    #[test]
294    fn test_compute_dispatch_tiled() {
295        let args = compute_dispatch(
296            0,
297            16,
298            DispatchStrategy::Tiled2D {
299                tile_w: 1920,
300                tile_h: 1080,
301            },
302        );
303        assert_eq!(args.x, 120); // ceil(1920/16)
304        assert_eq!(args.y, 68); // ceil(1080/16)
305        assert_eq!(args.z, 1);
306    }
307
308    #[test]
309    fn test_compute_dispatch_volumetric() {
310        let args = compute_dispatch(
311            0,
312            8,
313            DispatchStrategy::Volumetric {
314                vol_w: 64,
315                vol_h: 64,
316                vol_d: 32,
317            },
318        );
319        assert_eq!(args.x, 8);
320        assert_eq!(args.y, 8);
321        assert_eq!(args.z, 4);
322    }
323
324    #[test]
325    fn test_indirect_buffer_new() {
326        let buf = IndirectBuffer::new("test_buf");
327        assert_eq!(buf.label(), "test_buf");
328        assert_eq!(buf.args(), IndirectDispatchArgs::default());
329        assert_eq!(buf.generation(), 0);
330        assert_eq!(buf.size_bytes(), 12);
331    }
332
333    #[test]
334    fn test_indirect_buffer_update() {
335        let mut buf = IndirectBuffer::new("buf");
336        buf.update(IndirectDispatchArgs::new(10, 20, 30));
337        assert_eq!(buf.args().x, 10);
338        assert_eq!(buf.generation(), 1);
339        buf.update(IndirectDispatchArgs::one_d(5));
340        assert_eq!(buf.generation(), 2);
341    }
342
343    #[test]
344    fn test_validate_dispatch_limits_ok() {
345        let args = IndirectDispatchArgs::new(100, 100, 100);
346        assert!(validate_dispatch_limits(&args, 65535).is_ok());
347    }
348
349    #[test]
350    fn test_validate_dispatch_limits_exceeded() {
351        let args = IndirectDispatchArgs::new(70000, 1, 1);
352        assert!(validate_dispatch_limits(&args, 65535).is_err());
353    }
354
355    #[test]
356    fn test_default_dispatch_args() {
357        let args = IndirectDispatchArgs::default();
358        assert_eq!(args.x, 1);
359        assert_eq!(args.y, 1);
360        assert_eq!(args.z, 1);
361        assert!(args.is_valid());
362    }
363}