Skip to main content

oximedia_gpu/
compute_dispatch.rs

1//! Compute shader dispatch helpers.
2//!
3//! Provides workgroup sizing utilities, dispatch grid calculation, and
4//! basic barrier / dependency tracking for GPU compute passes.
5
6/// Maximum recommended workgroup size per dimension on most GPUs.
7pub const MAX_WORKGROUP_DIM: u32 = 256;
8
9/// A 3-D workgroup size.
10#[allow(dead_code)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct WorkgroupSize {
13    pub x: u32,
14    pub y: u32,
15    pub z: u32,
16}
17
18impl WorkgroupSize {
19    /// Create a 1-D workgroup (y=1, z=1).
20    #[allow(dead_code)]
21    #[must_use]
22    pub const fn linear(x: u32) -> Self {
23        Self { x, y: 1, z: 1 }
24    }
25
26    /// Create a 2-D workgroup (z=1).
27    #[allow(dead_code)]
28    #[must_use]
29    pub const fn planar(x: u32, y: u32) -> Self {
30        Self { x, y, z: 1 }
31    }
32
33    /// Create a full 3-D workgroup.
34    #[allow(dead_code)]
35    #[must_use]
36    pub const fn new(x: u32, y: u32, z: u32) -> Self {
37        Self { x, y, z }
38    }
39
40    /// Total number of threads per workgroup.
41    #[allow(dead_code)]
42    #[must_use]
43    pub const fn thread_count(self) -> u32 {
44        self.x * self.y * self.z
45    }
46
47    /// Returns `true` if the workgroup size is valid (all dims ≥ 1 and total
48    /// threads ≤ `max_threads`).
49    #[allow(dead_code)]
50    #[must_use]
51    pub fn is_valid(self, max_threads: u32) -> bool {
52        self.x >= 1 && self.y >= 1 && self.z >= 1 && self.thread_count() <= max_threads
53    }
54}
55
56impl Default for WorkgroupSize {
57    fn default() -> Self {
58        Self::linear(64)
59    }
60}
61
62/// A 3-D dispatch grid (number of workgroups in each dimension).
63#[allow(dead_code)]
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct DispatchGrid {
66    pub x: u32,
67    pub y: u32,
68    pub z: u32,
69}
70
71impl DispatchGrid {
72    /// Create a new dispatch grid.
73    #[allow(dead_code)]
74    #[must_use]
75    pub const fn new(x: u32, y: u32, z: u32) -> Self {
76        Self { x, y, z }
77    }
78
79    /// Total workgroups dispatched.
80    #[allow(dead_code)]
81    #[must_use]
82    pub const fn total_workgroups(self) -> u64 {
83        self.x as u64 * self.y as u64 * self.z as u64
84    }
85
86    /// Total threads dispatched (grid × workgroup size).
87    #[allow(dead_code)]
88    #[must_use]
89    pub const fn total_threads(self, wg: WorkgroupSize) -> u64 {
90        self.total_workgroups() * wg.thread_count() as u64
91    }
92}
93
94/// Calculate the dispatch grid needed to cover `count` elements with
95/// threads of size `wg_size` in the X dimension.
96#[allow(dead_code)]
97#[must_use]
98pub fn dispatch_1d(count: u32, wg_size: u32) -> DispatchGrid {
99    assert!(wg_size > 0, "wg_size must be > 0");
100    let x = count.div_ceil(wg_size);
101    DispatchGrid::new(x, 1, 1)
102}
103
104/// Calculate the dispatch grid needed to cover a `width × height` image with
105/// a planar workgroup of size `(wg_x, wg_y)`.
106#[allow(dead_code)]
107#[must_use]
108pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> DispatchGrid {
109    assert!(wg_x > 0 && wg_y > 0, "workgroup dims must be > 0");
110    let x = width.div_ceil(wg_x);
111    let y = height.div_ceil(wg_y);
112    DispatchGrid::new(x, y, 1)
113}
114
115/// Calculate the dispatch grid for a 3-D volume.
116#[allow(dead_code)]
117#[must_use]
118pub fn dispatch_3d(
119    width: u32,
120    height: u32,
121    depth: u32,
122    wg_x: u32,
123    wg_y: u32,
124    wg_z: u32,
125) -> DispatchGrid {
126    assert!(
127        wg_x > 0 && wg_y > 0 && wg_z > 0,
128        "workgroup dims must be > 0"
129    );
130    DispatchGrid::new(
131        width.div_ceil(wg_x),
132        height.div_ceil(wg_y),
133        depth.div_ceil(wg_z),
134    )
135}
136
137/// Recommend a square workgroup size that keeps total threads ≤ `max_threads`
138/// and is a power of two.
139#[allow(dead_code)]
140#[must_use]
141pub fn recommend_2d_workgroup(max_threads: u32) -> WorkgroupSize {
142    let mut side = 1u32;
143    while side * side * 4 <= max_threads {
144        side *= 2;
145    }
146    // side² ≤ max_threads
147    while side * side > max_threads {
148        side /= 2;
149    }
150    WorkgroupSize::planar(side.max(1), side.max(1))
151}
152
153// ---------------------------------------------------------------------------
154// Barrier tracking
155// ---------------------------------------------------------------------------
156
157/// Type of pipeline barrier.
158#[allow(dead_code)]
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum BarrierKind {
161    /// Ensures all memory writes are visible to subsequent reads.
162    MemoryReadAfterWrite,
163    /// Ensures all dispatches before the barrier complete before new ones begin.
164    ExecutionOnly,
165    /// Full pipeline barrier (most restrictive, highest cost).
166    Full,
167}
168
169/// A recorded pipeline barrier.
170#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub struct BarrierRecord {
173    /// Sequential index in the command stream.
174    pub index: u32,
175    /// Kind of barrier.
176    pub kind: BarrierKind,
177    /// Optional label for debugging.
178    pub label: Option<String>,
179}
180
181/// Tracks barriers inserted during a compute pass.
182#[allow(dead_code)]
183#[derive(Debug, Default)]
184pub struct BarrierTracker {
185    records: Vec<BarrierRecord>,
186    next_index: u32,
187}
188
189impl BarrierTracker {
190    /// Create a new tracker.
191    #[allow(dead_code)]
192    #[must_use]
193    pub fn new() -> Self {
194        Self::default()
195    }
196
197    /// Record a barrier with the given kind and optional label.
198    #[allow(dead_code)]
199    pub fn push(&mut self, kind: BarrierKind, label: Option<&str>) {
200        self.records.push(BarrierRecord {
201            index: self.next_index,
202            kind,
203            label: label.map(String::from),
204        });
205        self.next_index += 1;
206    }
207
208    /// Number of barriers recorded.
209    #[allow(dead_code)]
210    #[must_use]
211    pub fn len(&self) -> usize {
212        self.records.len()
213    }
214
215    /// Returns true if no barriers have been recorded.
216    #[allow(dead_code)]
217    #[must_use]
218    pub fn is_empty(&self) -> bool {
219        self.records.is_empty()
220    }
221
222    /// All recorded barriers.
223    #[allow(dead_code)]
224    #[must_use]
225    pub fn records(&self) -> &[BarrierRecord] {
226        &self.records
227    }
228
229    /// Count barriers of a specific kind.
230    #[allow(dead_code)]
231    #[must_use]
232    pub fn count_of_kind(&self, kind: BarrierKind) -> usize {
233        self.records.iter().filter(|r| r.kind == kind).count()
234    }
235
236    /// Reset the tracker.
237    #[allow(dead_code)]
238    pub fn reset(&mut self) {
239        self.records.clear();
240        self.next_index = 0;
241    }
242}
243
244// ---------------------------------------------------------------------------
245// Dispatch record
246// ---------------------------------------------------------------------------
247
248/// A recorded compute dispatch (for replay / inspection).
249#[allow(dead_code)]
250#[derive(Debug, Clone)]
251pub struct DispatchRecord {
252    /// Sequential index.
253    pub index: u32,
254    /// The pipeline identifier (e.g. shader name).
255    pub pipeline_id: String,
256    /// The dispatch grid.
257    pub grid: DispatchGrid,
258    /// The workgroup size declared by the shader.
259    pub workgroup_size: WorkgroupSize,
260}
261
262/// Tracks dispatches in a compute pass.
263#[allow(dead_code)]
264#[derive(Debug, Default)]
265pub struct DispatchTracker {
266    records: Vec<DispatchRecord>,
267    next_index: u32,
268}
269
270impl DispatchTracker {
271    /// Create a new tracker.
272    #[allow(dead_code)]
273    #[must_use]
274    pub fn new() -> Self {
275        Self::default()
276    }
277
278    /// Record a dispatch.
279    #[allow(dead_code)]
280    pub fn push(
281        &mut self,
282        pipeline_id: impl Into<String>,
283        grid: DispatchGrid,
284        workgroup_size: WorkgroupSize,
285    ) {
286        self.records.push(DispatchRecord {
287            index: self.next_index,
288            pipeline_id: pipeline_id.into(),
289            grid,
290            workgroup_size,
291        });
292        self.next_index += 1;
293    }
294
295    /// Number of dispatches recorded.
296    #[allow(dead_code)]
297    #[must_use]
298    pub fn len(&self) -> usize {
299        self.records.len()
300    }
301
302    /// Returns true when no dispatches have been recorded.
303    #[allow(dead_code)]
304    #[must_use]
305    pub fn is_empty(&self) -> bool {
306        self.records.is_empty()
307    }
308
309    /// Total GPU threads dispatched.
310    #[allow(dead_code)]
311    #[must_use]
312    pub fn total_threads(&self) -> u64 {
313        self.records
314            .iter()
315            .map(|r| r.grid.total_threads(r.workgroup_size))
316            .sum()
317    }
318
319    /// All dispatch records.
320    #[allow(dead_code)]
321    #[must_use]
322    pub fn records(&self) -> &[DispatchRecord] {
323        &self.records
324    }
325
326    /// Reset the tracker.
327    #[allow(dead_code)]
328    pub fn reset(&mut self) {
329        self.records.clear();
330        self.next_index = 0;
331    }
332}
333
334// ---------------------------------------------------------------------------
335// Data-driven (indirect) dispatch support
336// ---------------------------------------------------------------------------
337
338/// Strategy used to derive workgroup counts from a data-dependent element
339/// count at dispatch preparation time.
340#[derive(Debug, Clone, Copy, PartialEq, Eq)]
341pub enum DataDispatchStrategy {
342    /// All elements are processed in a single 1D strip: `(ceil(n/wg), 1, 1)`.
343    Linear1D,
344    /// Elements are spread over a near-square 2D grid.
345    Square2D,
346    /// Fixed number of rows; columns derived from `ceil(n / (rows * wg_x))`.
347    FixedRowCount {
348        /// Number of rows in the Y dimension.
349        rows: u32,
350    },
351}
352
353/// Computes and stores dispatch parameters that depend on the number of data
354/// elements only known at dispatch-preparation time (e.g., after a GPU
355/// readback or a CPU-side counter).
356///
357/// In a real GPU pipeline this feeds an *indirect dispatch buffer*; here we
358/// compute the [`DispatchGrid`] on the CPU side for portability and testing.
359pub struct DataDrivenDispatch {
360    /// Workgroup size in X.
361    wg_x: u32,
362    /// Workgroup size in Y.
363    wg_y: u32,
364    strategy: DataDispatchStrategy,
365    /// Grid computed from the last call to [`Self::prepare`].
366    grid: Option<DispatchGrid>,
367    /// Element count from the last call to [`Self::prepare`].
368    last_element_count: u64,
369}
370
371impl DataDrivenDispatch {
372    /// Create a new data-driven dispatch helper.
373    ///
374    /// * `wg_x` / `wg_y` — workgroup size dimensions (must be ≥ 1).
375    /// * `strategy` — how to map element counts to workgroup grids.
376    #[must_use]
377    pub fn new(wg_x: u32, wg_y: u32, strategy: DataDispatchStrategy) -> Self {
378        let wg_x = wg_x.max(1);
379        let wg_y = wg_y.max(1);
380        Self {
381            wg_x,
382            wg_y,
383            strategy,
384            grid: None,
385            last_element_count: 0,
386        }
387    }
388
389    /// Convenience constructor for a 1D strip with `wg_size` threads per
390    /// workgroup.
391    #[must_use]
392    pub fn linear(wg_size: u32) -> Self {
393        Self::new(wg_size, 1, DataDispatchStrategy::Linear1D)
394    }
395
396    /// Convenience constructor for a 2D square grid with `wg_x × wg_y`
397    /// threads per workgroup.
398    #[must_use]
399    pub fn square(wg_x: u32, wg_y: u32) -> Self {
400        Self::new(wg_x, wg_y, DataDispatchStrategy::Square2D)
401    }
402
403    /// Prepare the dispatch grid for `element_count` data elements.
404    ///
405    /// Returns the resulting [`DispatchGrid`]; the value is also stored
406    /// internally and accessible via [`Self::grid`].
407    pub fn prepare(&mut self, element_count: u64) -> DispatchGrid {
408        self.last_element_count = element_count;
409        let n = element_count as u32;
410        let grid = match self.strategy {
411            DataDispatchStrategy::Linear1D => {
412                let x = n.div_ceil(self.wg_x);
413                DispatchGrid::new(x.max(1), 1, 1)
414            }
415            DataDispatchStrategy::Square2D => {
416                let threads_per_wg = self.wg_x * self.wg_y;
417                let total_wgs = n.div_ceil(threads_per_wg).max(1);
418                let side = (total_wgs as f64).sqrt().ceil() as u32;
419                let side = side.max(1);
420                DispatchGrid::new(side, side, 1)
421            }
422            DataDispatchStrategy::FixedRowCount { rows } => {
423                let rows = rows.max(1);
424                // Each row handles `cols` workgroups; each workgroup covers
425                // `wg_x` elements in X and implicitly one row in Y.
426                let total_wgs = n.div_ceil(self.wg_x * self.wg_y).max(1);
427                let cols = total_wgs.div_ceil(rows);
428                DispatchGrid::new(cols, rows, 1)
429            }
430        };
431        self.grid = Some(grid);
432        grid
433    }
434
435    /// The grid computed by the last [`Self::prepare`] call, or `None` if
436    /// [`Self::prepare`] has not yet been called.
437    #[must_use]
438    pub fn grid(&self) -> Option<DispatchGrid> {
439        self.grid
440    }
441
442    /// The element count supplied to the last [`Self::prepare`] call.
443    #[must_use]
444    pub fn last_element_count(&self) -> u64 {
445        self.last_element_count
446    }
447
448    /// Minimum elements coverable by the last computed grid.
449    ///
450    /// Returns 0 if [`Self::prepare`] has not been called.
451    #[must_use]
452    pub fn covered_elements(&self) -> u64 {
453        match self.grid {
454            None => 0,
455            Some(g) => {
456                u64::from(g.total_workgroups()) * u64::from(self.wg_x) * u64::from(self.wg_y)
457            }
458        }
459    }
460}
461
462// ---------------------------------------------------------------------------
463// Unit tests
464// ---------------------------------------------------------------------------
465
466#[cfg(test)]
467mod tests {
468    use super::*;
469
470    #[test]
471    fn test_workgroup_thread_count() {
472        let wg = WorkgroupSize::new(8, 8, 1);
473        assert_eq!(wg.thread_count(), 64);
474    }
475
476    #[test]
477    fn test_workgroup_is_valid() {
478        assert!(WorkgroupSize::linear(64).is_valid(1024));
479        assert!(!WorkgroupSize::new(33, 33, 1).is_valid(1024));
480    }
481
482    #[test]
483    fn test_dispatch_1d_exact() {
484        let g = dispatch_1d(256, 64);
485        assert_eq!(g.x, 4);
486        assert_eq!(g.y, 1);
487        assert_eq!(g.z, 1);
488    }
489
490    #[test]
491    fn test_dispatch_1d_rounds_up() {
492        let g = dispatch_1d(257, 64);
493        assert_eq!(g.x, 5);
494    }
495
496    #[test]
497    fn test_dispatch_2d() {
498        let g = dispatch_2d(1920, 1080, 16, 16);
499        assert_eq!(g.x, 120); // 1920 / 16
500        assert_eq!(g.y, 68); // ceil(1080 / 16)
501    }
502
503    #[test]
504    fn test_dispatch_3d() {
505        let g = dispatch_3d(8, 8, 8, 4, 4, 4);
506        assert_eq!(g.x, 2);
507        assert_eq!(g.y, 2);
508        assert_eq!(g.z, 2);
509    }
510
511    #[test]
512    fn test_total_workgroups() {
513        let g = DispatchGrid::new(4, 4, 1);
514        assert_eq!(g.total_workgroups(), 16);
515    }
516
517    #[test]
518    fn test_total_threads() {
519        let g = DispatchGrid::new(2, 2, 1);
520        let wg = WorkgroupSize::planar(8, 8);
521        assert_eq!(g.total_threads(wg), 256);
522    }
523
524    #[test]
525    fn test_recommend_2d_workgroup_within_limit() {
526        let wg = recommend_2d_workgroup(256);
527        assert!(wg.thread_count() <= 256);
528    }
529
530    #[test]
531    fn test_recommend_2d_workgroup_square() {
532        let wg = recommend_2d_workgroup(1024);
533        assert_eq!(wg.x, wg.y);
534    }
535
536    #[test]
537    fn test_barrier_tracker_push_and_count() {
538        let mut bt = BarrierTracker::new();
539        bt.push(BarrierKind::MemoryReadAfterWrite, Some("pre-blur"));
540        bt.push(BarrierKind::Full, None);
541        assert_eq!(bt.len(), 2);
542        assert_eq!(bt.count_of_kind(BarrierKind::Full), 1);
543    }
544
545    #[test]
546    fn test_barrier_tracker_reset() {
547        let mut bt = BarrierTracker::new();
548        bt.push(BarrierKind::ExecutionOnly, None);
549        bt.reset();
550        assert!(bt.is_empty());
551    }
552
553    #[test]
554    fn test_dispatch_tracker_total_threads() {
555        let mut dt = DispatchTracker::new();
556        dt.push(
557            "blur",
558            DispatchGrid::new(10, 10, 1),
559            WorkgroupSize::planar(8, 8),
560        );
561        // 100 workgroups × 64 threads = 6400
562        assert_eq!(dt.total_threads(), 6400);
563    }
564
565    #[test]
566    fn test_dispatch_tracker_records_sequential_indices() {
567        let mut dt = DispatchTracker::new();
568        dt.push("a", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
569        dt.push("b", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
570        assert_eq!(dt.records()[0].index, 0);
571        assert_eq!(dt.records()[1].index, 1);
572    }
573
574    #[test]
575    fn test_dispatch_tracker_reset() {
576        let mut dt = DispatchTracker::new();
577        dt.push("x", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(32));
578        dt.reset();
579        assert!(dt.is_empty());
580        assert_eq!(dt.total_threads(), 0);
581    }
582
583    // --- DataDrivenDispatch tests ---
584
585    #[test]
586    fn test_data_driven_linear_exact() {
587        let mut dd = DataDrivenDispatch::linear(64);
588        let g = dd.prepare(128);
589        assert_eq!(g.x, 2);
590        assert_eq!(g.y, 1);
591        assert_eq!(g.z, 1);
592    }
593
594    #[test]
595    fn test_data_driven_linear_rounds_up() {
596        let mut dd = DataDrivenDispatch::linear(64);
597        let g = dd.prepare(65);
598        assert_eq!(g.x, 2);
599    }
600
601    #[test]
602    fn test_data_driven_linear_zero_elements() {
603        let mut dd = DataDrivenDispatch::linear(64);
604        let g = dd.prepare(0);
605        // Must produce at least 1 workgroup
606        assert_eq!(g.x, 1);
607    }
608
609    #[test]
610    fn test_data_driven_square_covers_all_elements() {
611        let mut dd = DataDrivenDispatch::square(8, 8);
612        dd.prepare(500);
613        // covered_elements ≥ 500
614        assert!(dd.covered_elements() >= 500);
615    }
616
617    #[test]
618    fn test_data_driven_square_grid_is_square() {
619        let mut dd = DataDrivenDispatch::square(8, 8);
620        let g = dd.prepare(1024);
621        assert_eq!(g.x, g.y);
622    }
623
624    #[test]
625    fn test_data_driven_fixed_row_count() {
626        let mut dd = DataDrivenDispatch::new(8, 1, DataDispatchStrategy::FixedRowCount { rows: 4 });
627        let g = dd.prepare(256);
628        // 256 elements / 8 per wg = 32 workgroups; 32 / 4 rows = 8 cols
629        assert_eq!(g.y, 4);
630        assert_eq!(g.x, 8);
631    }
632
633    #[test]
634    fn test_data_driven_grid_none_before_prepare() {
635        let dd = DataDrivenDispatch::linear(32);
636        assert!(dd.grid().is_none());
637        assert_eq!(dd.covered_elements(), 0);
638    }
639
640    #[test]
641    fn test_data_driven_last_element_count_stored() {
642        let mut dd = DataDrivenDispatch::linear(16);
643        dd.prepare(999);
644        assert_eq!(dd.last_element_count(), 999);
645    }
646
647    #[test]
648    fn test_data_driven_covered_elements_gte_last_count() {
649        let mut dd = DataDrivenDispatch::square(4, 4);
650        let count = 137_u64;
651        dd.prepare(count);
652        assert!(dd.covered_elements() >= count);
653    }
654}