1pub const MAX_WORKGROUP_DIM: u32 = 256;
8
9#[allow(dead_code)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct WorkgroupSize {
13 pub x: u32,
14 pub y: u32,
15 pub z: u32,
16}
17
18impl WorkgroupSize {
19 #[allow(dead_code)]
21 #[must_use]
22 pub const fn linear(x: u32) -> Self {
23 Self { x, y: 1, z: 1 }
24 }
25
26 #[allow(dead_code)]
28 #[must_use]
29 pub const fn planar(x: u32, y: u32) -> Self {
30 Self { x, y, z: 1 }
31 }
32
33 #[allow(dead_code)]
35 #[must_use]
36 pub const fn new(x: u32, y: u32, z: u32) -> Self {
37 Self { x, y, z }
38 }
39
40 #[allow(dead_code)]
42 #[must_use]
43 pub const fn thread_count(self) -> u32 {
44 self.x * self.y * self.z
45 }
46
47 #[allow(dead_code)]
50 #[must_use]
51 pub fn is_valid(self, max_threads: u32) -> bool {
52 self.x >= 1 && self.y >= 1 && self.z >= 1 && self.thread_count() <= max_threads
53 }
54}
55
56impl Default for WorkgroupSize {
57 fn default() -> Self {
58 Self::linear(64)
59 }
60}
61
62#[allow(dead_code)]
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct DispatchGrid {
66 pub x: u32,
67 pub y: u32,
68 pub z: u32,
69}
70
71impl DispatchGrid {
72 #[allow(dead_code)]
74 #[must_use]
75 pub const fn new(x: u32, y: u32, z: u32) -> Self {
76 Self { x, y, z }
77 }
78
79 #[allow(dead_code)]
81 #[must_use]
82 pub const fn total_workgroups(self) -> u64 {
83 self.x as u64 * self.y as u64 * self.z as u64
84 }
85
86 #[allow(dead_code)]
88 #[must_use]
89 pub const fn total_threads(self, wg: WorkgroupSize) -> u64 {
90 self.total_workgroups() * wg.thread_count() as u64
91 }
92}
93
94#[allow(dead_code)]
97#[must_use]
98pub fn dispatch_1d(count: u32, wg_size: u32) -> DispatchGrid {
99 assert!(wg_size > 0, "wg_size must be > 0");
100 let x = count.div_ceil(wg_size);
101 DispatchGrid::new(x, 1, 1)
102}
103
104#[allow(dead_code)]
107#[must_use]
108pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> DispatchGrid {
109 assert!(wg_x > 0 && wg_y > 0, "workgroup dims must be > 0");
110 let x = width.div_ceil(wg_x);
111 let y = height.div_ceil(wg_y);
112 DispatchGrid::new(x, y, 1)
113}
114
115#[allow(dead_code)]
117#[must_use]
118pub fn dispatch_3d(
119 width: u32,
120 height: u32,
121 depth: u32,
122 wg_x: u32,
123 wg_y: u32,
124 wg_z: u32,
125) -> DispatchGrid {
126 assert!(
127 wg_x > 0 && wg_y > 0 && wg_z > 0,
128 "workgroup dims must be > 0"
129 );
130 DispatchGrid::new(
131 width.div_ceil(wg_x),
132 height.div_ceil(wg_y),
133 depth.div_ceil(wg_z),
134 )
135}
136
137#[allow(dead_code)]
140#[must_use]
141pub fn recommend_2d_workgroup(max_threads: u32) -> WorkgroupSize {
142 let mut side = 1u32;
143 while side * side * 4 <= max_threads {
144 side *= 2;
145 }
146 while side * side > max_threads {
148 side /= 2;
149 }
150 WorkgroupSize::planar(side.max(1), side.max(1))
151}
152
153#[allow(dead_code)]
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum BarrierKind {
161 MemoryReadAfterWrite,
163 ExecutionOnly,
165 Full,
167}
168
169#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub struct BarrierRecord {
173 pub index: u32,
175 pub kind: BarrierKind,
177 pub label: Option<String>,
179}
180
181#[allow(dead_code)]
183#[derive(Debug, Default)]
184pub struct BarrierTracker {
185 records: Vec<BarrierRecord>,
186 next_index: u32,
187}
188
189impl BarrierTracker {
190 #[allow(dead_code)]
192 #[must_use]
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 #[allow(dead_code)]
199 pub fn push(&mut self, kind: BarrierKind, label: Option<&str>) {
200 self.records.push(BarrierRecord {
201 index: self.next_index,
202 kind,
203 label: label.map(String::from),
204 });
205 self.next_index += 1;
206 }
207
208 #[allow(dead_code)]
210 #[must_use]
211 pub fn len(&self) -> usize {
212 self.records.len()
213 }
214
215 #[allow(dead_code)]
217 #[must_use]
218 pub fn is_empty(&self) -> bool {
219 self.records.is_empty()
220 }
221
222 #[allow(dead_code)]
224 #[must_use]
225 pub fn records(&self) -> &[BarrierRecord] {
226 &self.records
227 }
228
229 #[allow(dead_code)]
231 #[must_use]
232 pub fn count_of_kind(&self, kind: BarrierKind) -> usize {
233 self.records.iter().filter(|r| r.kind == kind).count()
234 }
235
236 #[allow(dead_code)]
238 pub fn reset(&mut self) {
239 self.records.clear();
240 self.next_index = 0;
241 }
242}
243
244#[allow(dead_code)]
250#[derive(Debug, Clone)]
251pub struct DispatchRecord {
252 pub index: u32,
254 pub pipeline_id: String,
256 pub grid: DispatchGrid,
258 pub workgroup_size: WorkgroupSize,
260}
261
262#[allow(dead_code)]
264#[derive(Debug, Default)]
265pub struct DispatchTracker {
266 records: Vec<DispatchRecord>,
267 next_index: u32,
268}
269
270impl DispatchTracker {
271 #[allow(dead_code)]
273 #[must_use]
274 pub fn new() -> Self {
275 Self::default()
276 }
277
278 #[allow(dead_code)]
280 pub fn push(
281 &mut self,
282 pipeline_id: impl Into<String>,
283 grid: DispatchGrid,
284 workgroup_size: WorkgroupSize,
285 ) {
286 self.records.push(DispatchRecord {
287 index: self.next_index,
288 pipeline_id: pipeline_id.into(),
289 grid,
290 workgroup_size,
291 });
292 self.next_index += 1;
293 }
294
295 #[allow(dead_code)]
297 #[must_use]
298 pub fn len(&self) -> usize {
299 self.records.len()
300 }
301
302 #[allow(dead_code)]
304 #[must_use]
305 pub fn is_empty(&self) -> bool {
306 self.records.is_empty()
307 }
308
309 #[allow(dead_code)]
311 #[must_use]
312 pub fn total_threads(&self) -> u64 {
313 self.records
314 .iter()
315 .map(|r| r.grid.total_threads(r.workgroup_size))
316 .sum()
317 }
318
319 #[allow(dead_code)]
321 #[must_use]
322 pub fn records(&self) -> &[DispatchRecord] {
323 &self.records
324 }
325
326 #[allow(dead_code)]
328 pub fn reset(&mut self) {
329 self.records.clear();
330 self.next_index = 0;
331 }
332}
333
334#[derive(Debug, Clone, Copy, PartialEq, Eq)]
341pub enum DataDispatchStrategy {
342 Linear1D,
344 Square2D,
346 FixedRowCount {
348 rows: u32,
350 },
351}
352
353pub struct DataDrivenDispatch {
360 wg_x: u32,
362 wg_y: u32,
364 strategy: DataDispatchStrategy,
365 grid: Option<DispatchGrid>,
367 last_element_count: u64,
369}
370
371impl DataDrivenDispatch {
372 #[must_use]
377 pub fn new(wg_x: u32, wg_y: u32, strategy: DataDispatchStrategy) -> Self {
378 let wg_x = wg_x.max(1);
379 let wg_y = wg_y.max(1);
380 Self {
381 wg_x,
382 wg_y,
383 strategy,
384 grid: None,
385 last_element_count: 0,
386 }
387 }
388
389 #[must_use]
392 pub fn linear(wg_size: u32) -> Self {
393 Self::new(wg_size, 1, DataDispatchStrategy::Linear1D)
394 }
395
396 #[must_use]
399 pub fn square(wg_x: u32, wg_y: u32) -> Self {
400 Self::new(wg_x, wg_y, DataDispatchStrategy::Square2D)
401 }
402
403 pub fn prepare(&mut self, element_count: u64) -> DispatchGrid {
408 self.last_element_count = element_count;
409 let n = element_count as u32;
410 let grid = match self.strategy {
411 DataDispatchStrategy::Linear1D => {
412 let x = n.div_ceil(self.wg_x);
413 DispatchGrid::new(x.max(1), 1, 1)
414 }
415 DataDispatchStrategy::Square2D => {
416 let threads_per_wg = self.wg_x * self.wg_y;
417 let total_wgs = n.div_ceil(threads_per_wg).max(1);
418 let side = (total_wgs as f64).sqrt().ceil() as u32;
419 let side = side.max(1);
420 DispatchGrid::new(side, side, 1)
421 }
422 DataDispatchStrategy::FixedRowCount { rows } => {
423 let rows = rows.max(1);
424 let total_wgs = n.div_ceil(self.wg_x * self.wg_y).max(1);
427 let cols = total_wgs.div_ceil(rows);
428 DispatchGrid::new(cols, rows, 1)
429 }
430 };
431 self.grid = Some(grid);
432 grid
433 }
434
435 #[must_use]
438 pub fn grid(&self) -> Option<DispatchGrid> {
439 self.grid
440 }
441
442 #[must_use]
444 pub fn last_element_count(&self) -> u64 {
445 self.last_element_count
446 }
447
448 #[must_use]
452 pub fn covered_elements(&self) -> u64 {
453 match self.grid {
454 None => 0,
455 Some(g) => {
456 u64::from(g.total_workgroups()) * u64::from(self.wg_x) * u64::from(self.wg_y)
457 }
458 }
459 }
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469
470 #[test]
471 fn test_workgroup_thread_count() {
472 let wg = WorkgroupSize::new(8, 8, 1);
473 assert_eq!(wg.thread_count(), 64);
474 }
475
476 #[test]
477 fn test_workgroup_is_valid() {
478 assert!(WorkgroupSize::linear(64).is_valid(1024));
479 assert!(!WorkgroupSize::new(33, 33, 1).is_valid(1024));
480 }
481
482 #[test]
483 fn test_dispatch_1d_exact() {
484 let g = dispatch_1d(256, 64);
485 assert_eq!(g.x, 4);
486 assert_eq!(g.y, 1);
487 assert_eq!(g.z, 1);
488 }
489
490 #[test]
491 fn test_dispatch_1d_rounds_up() {
492 let g = dispatch_1d(257, 64);
493 assert_eq!(g.x, 5);
494 }
495
496 #[test]
497 fn test_dispatch_2d() {
498 let g = dispatch_2d(1920, 1080, 16, 16);
499 assert_eq!(g.x, 120); assert_eq!(g.y, 68); }
502
503 #[test]
504 fn test_dispatch_3d() {
505 let g = dispatch_3d(8, 8, 8, 4, 4, 4);
506 assert_eq!(g.x, 2);
507 assert_eq!(g.y, 2);
508 assert_eq!(g.z, 2);
509 }
510
511 #[test]
512 fn test_total_workgroups() {
513 let g = DispatchGrid::new(4, 4, 1);
514 assert_eq!(g.total_workgroups(), 16);
515 }
516
517 #[test]
518 fn test_total_threads() {
519 let g = DispatchGrid::new(2, 2, 1);
520 let wg = WorkgroupSize::planar(8, 8);
521 assert_eq!(g.total_threads(wg), 256);
522 }
523
524 #[test]
525 fn test_recommend_2d_workgroup_within_limit() {
526 let wg = recommend_2d_workgroup(256);
527 assert!(wg.thread_count() <= 256);
528 }
529
530 #[test]
531 fn test_recommend_2d_workgroup_square() {
532 let wg = recommend_2d_workgroup(1024);
533 assert_eq!(wg.x, wg.y);
534 }
535
536 #[test]
537 fn test_barrier_tracker_push_and_count() {
538 let mut bt = BarrierTracker::new();
539 bt.push(BarrierKind::MemoryReadAfterWrite, Some("pre-blur"));
540 bt.push(BarrierKind::Full, None);
541 assert_eq!(bt.len(), 2);
542 assert_eq!(bt.count_of_kind(BarrierKind::Full), 1);
543 }
544
545 #[test]
546 fn test_barrier_tracker_reset() {
547 let mut bt = BarrierTracker::new();
548 bt.push(BarrierKind::ExecutionOnly, None);
549 bt.reset();
550 assert!(bt.is_empty());
551 }
552
553 #[test]
554 fn test_dispatch_tracker_total_threads() {
555 let mut dt = DispatchTracker::new();
556 dt.push(
557 "blur",
558 DispatchGrid::new(10, 10, 1),
559 WorkgroupSize::planar(8, 8),
560 );
561 assert_eq!(dt.total_threads(), 6400);
563 }
564
565 #[test]
566 fn test_dispatch_tracker_records_sequential_indices() {
567 let mut dt = DispatchTracker::new();
568 dt.push("a", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
569 dt.push("b", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
570 assert_eq!(dt.records()[0].index, 0);
571 assert_eq!(dt.records()[1].index, 1);
572 }
573
574 #[test]
575 fn test_dispatch_tracker_reset() {
576 let mut dt = DispatchTracker::new();
577 dt.push("x", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(32));
578 dt.reset();
579 assert!(dt.is_empty());
580 assert_eq!(dt.total_threads(), 0);
581 }
582
583 #[test]
586 fn test_data_driven_linear_exact() {
587 let mut dd = DataDrivenDispatch::linear(64);
588 let g = dd.prepare(128);
589 assert_eq!(g.x, 2);
590 assert_eq!(g.y, 1);
591 assert_eq!(g.z, 1);
592 }
593
594 #[test]
595 fn test_data_driven_linear_rounds_up() {
596 let mut dd = DataDrivenDispatch::linear(64);
597 let g = dd.prepare(65);
598 assert_eq!(g.x, 2);
599 }
600
601 #[test]
602 fn test_data_driven_linear_zero_elements() {
603 let mut dd = DataDrivenDispatch::linear(64);
604 let g = dd.prepare(0);
605 assert_eq!(g.x, 1);
607 }
608
609 #[test]
610 fn test_data_driven_square_covers_all_elements() {
611 let mut dd = DataDrivenDispatch::square(8, 8);
612 dd.prepare(500);
613 assert!(dd.covered_elements() >= 500);
615 }
616
617 #[test]
618 fn test_data_driven_square_grid_is_square() {
619 let mut dd = DataDrivenDispatch::square(8, 8);
620 let g = dd.prepare(1024);
621 assert_eq!(g.x, g.y);
622 }
623
624 #[test]
625 fn test_data_driven_fixed_row_count() {
626 let mut dd = DataDrivenDispatch::new(8, 1, DataDispatchStrategy::FixedRowCount { rows: 4 });
627 let g = dd.prepare(256);
628 assert_eq!(g.y, 4);
630 assert_eq!(g.x, 8);
631 }
632
633 #[test]
634 fn test_data_driven_grid_none_before_prepare() {
635 let dd = DataDrivenDispatch::linear(32);
636 assert!(dd.grid().is_none());
637 assert_eq!(dd.covered_elements(), 0);
638 }
639
640 #[test]
641 fn test_data_driven_last_element_count_stored() {
642 let mut dd = DataDrivenDispatch::linear(16);
643 dd.prepare(999);
644 assert_eq!(dd.last_element_count(), 999);
645 }
646
647 #[test]
648 fn test_data_driven_covered_elements_gte_last_count() {
649 let mut dd = DataDrivenDispatch::square(4, 4);
650 let count = 137_u64;
651 dd.prepare(count);
652 assert!(dd.covered_elements() >= count);
653 }
654}