Skip to main content

oximedia_gpu/
compute_dispatch.rs

1//! Compute shader dispatch helpers.
2//!
3//! Provides workgroup sizing utilities, dispatch grid calculation, and
4//! basic barrier / dependency tracking for GPU compute passes.
5
6/// Maximum recommended workgroup size per dimension on most GPUs.
7pub const MAX_WORKGROUP_DIM: u32 = 256;
8
9/// A 3-D workgroup size.
10#[allow(dead_code)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct WorkgroupSize {
13    pub x: u32,
14    pub y: u32,
15    pub z: u32,
16}
17
18impl WorkgroupSize {
19    /// Create a 1-D workgroup (y=1, z=1).
20    #[allow(dead_code)]
21    #[must_use]
22    pub const fn linear(x: u32) -> Self {
23        Self { x, y: 1, z: 1 }
24    }
25
26    /// Create a 2-D workgroup (z=1).
27    #[allow(dead_code)]
28    #[must_use]
29    pub const fn planar(x: u32, y: u32) -> Self {
30        Self { x, y, z: 1 }
31    }
32
33    /// Create a full 3-D workgroup.
34    #[allow(dead_code)]
35    #[must_use]
36    pub const fn new(x: u32, y: u32, z: u32) -> Self {
37        Self { x, y, z }
38    }
39
40    /// Total number of threads per workgroup.
41    #[allow(dead_code)]
42    #[must_use]
43    pub const fn thread_count(self) -> u32 {
44        self.x * self.y * self.z
45    }
46
47    /// Returns `true` if the workgroup size is valid (all dims ≥ 1 and total
48    /// threads ≤ `max_threads`).
49    #[allow(dead_code)]
50    #[must_use]
51    pub fn is_valid(self, max_threads: u32) -> bool {
52        self.x >= 1 && self.y >= 1 && self.z >= 1 && self.thread_count() <= max_threads
53    }
54}
55
56impl Default for WorkgroupSize {
57    fn default() -> Self {
58        Self::linear(64)
59    }
60}
61
62/// A 3-D dispatch grid (number of workgroups in each dimension).
63#[allow(dead_code)]
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct DispatchGrid {
66    pub x: u32,
67    pub y: u32,
68    pub z: u32,
69}
70
71impl DispatchGrid {
72    /// Create a new dispatch grid.
73    #[allow(dead_code)]
74    #[must_use]
75    pub const fn new(x: u32, y: u32, z: u32) -> Self {
76        Self { x, y, z }
77    }
78
79    /// Total workgroups dispatched.
80    #[allow(dead_code)]
81    #[must_use]
82    pub const fn total_workgroups(self) -> u64 {
83        self.x as u64 * self.y as u64 * self.z as u64
84    }
85
86    /// Total threads dispatched (grid × workgroup size).
87    #[allow(dead_code)]
88    #[must_use]
89    pub const fn total_threads(self, wg: WorkgroupSize) -> u64 {
90        self.total_workgroups() * wg.thread_count() as u64
91    }
92}
93
94/// Calculate the dispatch grid needed to cover `count` elements with
95/// threads of size `wg_size` in the X dimension.
96#[allow(dead_code)]
97#[must_use]
98pub fn dispatch_1d(count: u32, wg_size: u32) -> DispatchGrid {
99    assert!(wg_size > 0, "wg_size must be > 0");
100    let x = count.div_ceil(wg_size);
101    DispatchGrid::new(x, 1, 1)
102}
103
104/// Calculate the dispatch grid needed to cover a `width × height` image with
105/// a planar workgroup of size `(wg_x, wg_y)`.
106#[allow(dead_code)]
107#[must_use]
108pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> DispatchGrid {
109    assert!(wg_x > 0 && wg_y > 0, "workgroup dims must be > 0");
110    let x = width.div_ceil(wg_x);
111    let y = height.div_ceil(wg_y);
112    DispatchGrid::new(x, y, 1)
113}
114
115/// Calculate the dispatch grid for a 3-D volume.
116#[allow(dead_code)]
117#[must_use]
118pub fn dispatch_3d(
119    width: u32,
120    height: u32,
121    depth: u32,
122    wg_x: u32,
123    wg_y: u32,
124    wg_z: u32,
125) -> DispatchGrid {
126    assert!(
127        wg_x > 0 && wg_y > 0 && wg_z > 0,
128        "workgroup dims must be > 0"
129    );
130    DispatchGrid::new(
131        width.div_ceil(wg_x),
132        height.div_ceil(wg_y),
133        depth.div_ceil(wg_z),
134    )
135}
136
137/// Recommend a square workgroup size that keeps total threads ≤ `max_threads`
138/// and is a power of two.
139#[allow(dead_code)]
140#[must_use]
141pub fn recommend_2d_workgroup(max_threads: u32) -> WorkgroupSize {
142    let mut side = 1u32;
143    while side * side * 4 <= max_threads {
144        side *= 2;
145    }
146    // side² ≤ max_threads
147    while side * side > max_threads {
148        side /= 2;
149    }
150    WorkgroupSize::planar(side.max(1), side.max(1))
151}
152
153// ---------------------------------------------------------------------------
154// Barrier tracking
155// ---------------------------------------------------------------------------
156
157/// Type of pipeline barrier.
158#[allow(dead_code)]
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum BarrierKind {
161    /// Ensures all memory writes are visible to subsequent reads.
162    MemoryReadAfterWrite,
163    /// Ensures all dispatches before the barrier complete before new ones begin.
164    ExecutionOnly,
165    /// Full pipeline barrier (most restrictive, highest cost).
166    Full,
167}
168
169/// A recorded pipeline barrier.
170#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub struct BarrierRecord {
173    /// Sequential index in the command stream.
174    pub index: u32,
175    /// Kind of barrier.
176    pub kind: BarrierKind,
177    /// Optional label for debugging.
178    pub label: Option<String>,
179}
180
181/// Tracks barriers inserted during a compute pass.
182#[allow(dead_code)]
183#[derive(Debug, Default)]
184pub struct BarrierTracker {
185    records: Vec<BarrierRecord>,
186    next_index: u32,
187}
188
189impl BarrierTracker {
190    /// Create a new tracker.
191    #[allow(dead_code)]
192    #[must_use]
193    pub fn new() -> Self {
194        Self::default()
195    }
196
197    /// Record a barrier with the given kind and optional label.
198    #[allow(dead_code)]
199    pub fn push(&mut self, kind: BarrierKind, label: Option<&str>) {
200        self.records.push(BarrierRecord {
201            index: self.next_index,
202            kind,
203            label: label.map(String::from),
204        });
205        self.next_index += 1;
206    }
207
208    /// Number of barriers recorded.
209    #[allow(dead_code)]
210    #[must_use]
211    pub fn len(&self) -> usize {
212        self.records.len()
213    }
214
215    /// Returns true if no barriers have been recorded.
216    #[allow(dead_code)]
217    #[must_use]
218    pub fn is_empty(&self) -> bool {
219        self.records.is_empty()
220    }
221
222    /// All recorded barriers.
223    #[allow(dead_code)]
224    #[must_use]
225    pub fn records(&self) -> &[BarrierRecord] {
226        &self.records
227    }
228
229    /// Count barriers of a specific kind.
230    #[allow(dead_code)]
231    #[must_use]
232    pub fn count_of_kind(&self, kind: BarrierKind) -> usize {
233        self.records.iter().filter(|r| r.kind == kind).count()
234    }
235
236    /// Reset the tracker.
237    #[allow(dead_code)]
238    pub fn reset(&mut self) {
239        self.records.clear();
240        self.next_index = 0;
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Dispatch record
246// ---------------------------------------------------------------------------
247
248/// A recorded compute dispatch (for replay / inspection).
249#[allow(dead_code)]
250#[derive(Debug, Clone)]
251pub struct DispatchRecord {
252    /// Sequential index.
253    pub index: u32,
254    /// The pipeline identifier (e.g. shader name).
255    pub pipeline_id: String,
256    /// The dispatch grid.
257    pub grid: DispatchGrid,
258    /// The workgroup size declared by the shader.
259    pub workgroup_size: WorkgroupSize,
260}
261
262/// Tracks dispatches in a compute pass.
263#[allow(dead_code)]
264#[derive(Debug, Default)]
265pub struct DispatchTracker {
266    records: Vec<DispatchRecord>,
267    next_index: u32,
268}
269
270impl DispatchTracker {
271    /// Create a new tracker.
272    #[allow(dead_code)]
273    #[must_use]
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    /// Record a dispatch.
279    #[allow(dead_code)]
280    pub fn push(
281        &mut self,
282        pipeline_id: impl Into<String>,
283        grid: DispatchGrid,
284        workgroup_size: WorkgroupSize,
285    ) {
286        self.records.push(DispatchRecord {
287            index: self.next_index,
288            pipeline_id: pipeline_id.into(),
289            grid,
290            workgroup_size,
291        });
292        self.next_index += 1;
293    }
294
295    /// Number of dispatches recorded.
296    #[allow(dead_code)]
297    #[must_use]
298    pub fn len(&self) -> usize {
299        self.records.len()
300    }
301
302    /// Returns true when no dispatches have been recorded.
303    #[allow(dead_code)]
304    #[must_use]
305    pub fn is_empty(&self) -> bool {
306        self.records.is_empty()
307    }
308
309    /// Total GPU threads dispatched.
310    #[allow(dead_code)]
311    #[must_use]
312    pub fn total_threads(&self) -> u64 {
313        self.records
314            .iter()
315            .map(|r| r.grid.total_threads(r.workgroup_size))
316            .sum()
317    }
318
319    /// All dispatch records.
320    #[allow(dead_code)]
321    #[must_use]
322    pub fn records(&self) -> &[DispatchRecord] {
323        &self.records
324    }
325
326    /// Reset the tracker.
327    #[allow(dead_code)]
328    pub fn reset(&mut self) {
329        self.records.clear();
330        self.next_index = 0;
331    }
332}
333
334// ---------------------------------------------------------------------------
335// Unit tests
336// ---------------------------------------------------------------------------
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_workgroup_thread_count() {
344        let wg = WorkgroupSize::new(8, 8, 1);
345        assert_eq!(wg.thread_count(), 64);
346    }
347
348    #[test]
349    fn test_workgroup_is_valid() {
350        assert!(WorkgroupSize::linear(64).is_valid(1024));
351        assert!(!WorkgroupSize::new(33, 33, 1).is_valid(1024));
352    }
353
354    #[test]
355    fn test_dispatch_1d_exact() {
356        let g = dispatch_1d(256, 64);
357        assert_eq!(g.x, 4);
358        assert_eq!(g.y, 1);
359        assert_eq!(g.z, 1);
360    }
361
362    #[test]
363    fn test_dispatch_1d_rounds_up() {
364        let g = dispatch_1d(257, 64);
365        assert_eq!(g.x, 5);
366    }
367
368    #[test]
369    fn test_dispatch_2d() {
370        let g = dispatch_2d(1920, 1080, 16, 16);
371        assert_eq!(g.x, 120); // 1920 / 16
372        assert_eq!(g.y, 68); // ceil(1080 / 16)
373    }
374
375    #[test]
376    fn test_dispatch_3d() {
377        let g = dispatch_3d(8, 8, 8, 4, 4, 4);
378        assert_eq!(g.x, 2);
379        assert_eq!(g.y, 2);
380        assert_eq!(g.z, 2);
381    }
382
383    #[test]
384    fn test_total_workgroups() {
385        let g = DispatchGrid::new(4, 4, 1);
386        assert_eq!(g.total_workgroups(), 16);
387    }
388
389    #[test]
390    fn test_total_threads() {
391        let g = DispatchGrid::new(2, 2, 1);
392        let wg = WorkgroupSize::planar(8, 8);
393        assert_eq!(g.total_threads(wg), 256);
394    }
395
396    #[test]
397    fn test_recommend_2d_workgroup_within_limit() {
398        let wg = recommend_2d_workgroup(256);
399        assert!(wg.thread_count() <= 256);
400    }
401
402    #[test]
403    fn test_recommend_2d_workgroup_square() {
404        let wg = recommend_2d_workgroup(1024);
405        assert_eq!(wg.x, wg.y);
406    }
407
408    #[test]
409    fn test_barrier_tracker_push_and_count() {
410        let mut bt = BarrierTracker::new();
411        bt.push(BarrierKind::MemoryReadAfterWrite, Some("pre-blur"));
412        bt.push(BarrierKind::Full, None);
413        assert_eq!(bt.len(), 2);
414        assert_eq!(bt.count_of_kind(BarrierKind::Full), 1);
415    }
416
417    #[test]
418    fn test_barrier_tracker_reset() {
419        let mut bt = BarrierTracker::new();
420        bt.push(BarrierKind::ExecutionOnly, None);
421        bt.reset();
422        assert!(bt.is_empty());
423    }
424
425    #[test]
426    fn test_dispatch_tracker_total_threads() {
427        let mut dt = DispatchTracker::new();
428        dt.push(
429            "blur",
430            DispatchGrid::new(10, 10, 1),
431            WorkgroupSize::planar(8, 8),
432        );
433        // 100 workgroups × 64 threads = 6400
434        assert_eq!(dt.total_threads(), 6400);
435    }
436
437    #[test]
438    fn test_dispatch_tracker_records_sequential_indices() {
439        let mut dt = DispatchTracker::new();
440        dt.push("a", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
441        dt.push("b", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
442        assert_eq!(dt.records()[0].index, 0);
443        assert_eq!(dt.records()[1].index, 1);
444    }
445
446    #[test]
447    fn test_dispatch_tracker_reset() {
448        let mut dt = DispatchTracker::new();
449        dt.push("x", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(32));
450        dt.reset();
451        assert!(dt.is_empty());
452        assert_eq!(dt.total_threads(), 0);
453    }
454}