1pub const MAX_WORKGROUP_DIM: u32 = 256;
8
9#[allow(dead_code)]
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
12pub struct WorkgroupSize {
13 pub x: u32,
14 pub y: u32,
15 pub z: u32,
16}
17
18impl WorkgroupSize {
19 #[allow(dead_code)]
21 #[must_use]
22 pub const fn linear(x: u32) -> Self {
23 Self { x, y: 1, z: 1 }
24 }
25
26 #[allow(dead_code)]
28 #[must_use]
29 pub const fn planar(x: u32, y: u32) -> Self {
30 Self { x, y, z: 1 }
31 }
32
33 #[allow(dead_code)]
35 #[must_use]
36 pub const fn new(x: u32, y: u32, z: u32) -> Self {
37 Self { x, y, z }
38 }
39
40 #[allow(dead_code)]
42 #[must_use]
43 pub const fn thread_count(self) -> u32 {
44 self.x * self.y * self.z
45 }
46
47 #[allow(dead_code)]
50 #[must_use]
51 pub fn is_valid(self, max_threads: u32) -> bool {
52 self.x >= 1 && self.y >= 1 && self.z >= 1 && self.thread_count() <= max_threads
53 }
54}
55
56impl Default for WorkgroupSize {
57 fn default() -> Self {
58 Self::linear(64)
59 }
60}
61
62#[allow(dead_code)]
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct DispatchGrid {
66 pub x: u32,
67 pub y: u32,
68 pub z: u32,
69}
70
71impl DispatchGrid {
72 #[allow(dead_code)]
74 #[must_use]
75 pub const fn new(x: u32, y: u32, z: u32) -> Self {
76 Self { x, y, z }
77 }
78
79 #[allow(dead_code)]
81 #[must_use]
82 pub const fn total_workgroups(self) -> u64 {
83 self.x as u64 * self.y as u64 * self.z as u64
84 }
85
86 #[allow(dead_code)]
88 #[must_use]
89 pub const fn total_threads(self, wg: WorkgroupSize) -> u64 {
90 self.total_workgroups() * wg.thread_count() as u64
91 }
92}
93
94#[allow(dead_code)]
97#[must_use]
98pub fn dispatch_1d(count: u32, wg_size: u32) -> DispatchGrid {
99 assert!(wg_size > 0, "wg_size must be > 0");
100 let x = count.div_ceil(wg_size);
101 DispatchGrid::new(x, 1, 1)
102}
103
104#[allow(dead_code)]
107#[must_use]
108pub fn dispatch_2d(width: u32, height: u32, wg_x: u32, wg_y: u32) -> DispatchGrid {
109 assert!(wg_x > 0 && wg_y > 0, "workgroup dims must be > 0");
110 let x = width.div_ceil(wg_x);
111 let y = height.div_ceil(wg_y);
112 DispatchGrid::new(x, y, 1)
113}
114
115#[allow(dead_code)]
117#[must_use]
118pub fn dispatch_3d(
119 width: u32,
120 height: u32,
121 depth: u32,
122 wg_x: u32,
123 wg_y: u32,
124 wg_z: u32,
125) -> DispatchGrid {
126 assert!(
127 wg_x > 0 && wg_y > 0 && wg_z > 0,
128 "workgroup dims must be > 0"
129 );
130 DispatchGrid::new(
131 width.div_ceil(wg_x),
132 height.div_ceil(wg_y),
133 depth.div_ceil(wg_z),
134 )
135}
136
137#[allow(dead_code)]
140#[must_use]
141pub fn recommend_2d_workgroup(max_threads: u32) -> WorkgroupSize {
142 let mut side = 1u32;
143 while side * side * 4 <= max_threads {
144 side *= 2;
145 }
146 while side * side > max_threads {
148 side /= 2;
149 }
150 WorkgroupSize::planar(side.max(1), side.max(1))
151}
152
153#[allow(dead_code)]
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum BarrierKind {
161 MemoryReadAfterWrite,
163 ExecutionOnly,
165 Full,
167}
168
169#[allow(dead_code)]
171#[derive(Debug, Clone)]
172pub struct BarrierRecord {
173 pub index: u32,
175 pub kind: BarrierKind,
177 pub label: Option<String>,
179}
180
181#[allow(dead_code)]
183#[derive(Debug, Default)]
184pub struct BarrierTracker {
185 records: Vec<BarrierRecord>,
186 next_index: u32,
187}
188
189impl BarrierTracker {
190 #[allow(dead_code)]
192 #[must_use]
193 pub fn new() -> Self {
194 Self::default()
195 }
196
197 #[allow(dead_code)]
199 pub fn push(&mut self, kind: BarrierKind, label: Option<&str>) {
200 self.records.push(BarrierRecord {
201 index: self.next_index,
202 kind,
203 label: label.map(String::from),
204 });
205 self.next_index += 1;
206 }
207
208 #[allow(dead_code)]
210 #[must_use]
211 pub fn len(&self) -> usize {
212 self.records.len()
213 }
214
215 #[allow(dead_code)]
217 #[must_use]
218 pub fn is_empty(&self) -> bool {
219 self.records.is_empty()
220 }
221
222 #[allow(dead_code)]
224 #[must_use]
225 pub fn records(&self) -> &[BarrierRecord] {
226 &self.records
227 }
228
229 #[allow(dead_code)]
231 #[must_use]
232 pub fn count_of_kind(&self, kind: BarrierKind) -> usize {
233 self.records.iter().filter(|r| r.kind == kind).count()
234 }
235
236 #[allow(dead_code)]
238 pub fn reset(&mut self) {
239 self.records.clear();
240 self.next_index = 0;
241 }
242}
243
244#[allow(dead_code)]
250#[derive(Debug, Clone)]
251pub struct DispatchRecord {
252 pub index: u32,
254 pub pipeline_id: String,
256 pub grid: DispatchGrid,
258 pub workgroup_size: WorkgroupSize,
260}
261
262#[allow(dead_code)]
264#[derive(Debug, Default)]
265pub struct DispatchTracker {
266 records: Vec<DispatchRecord>,
267 next_index: u32,
268}
269
270impl DispatchTracker {
271 #[allow(dead_code)]
273 #[must_use]
274 pub fn new() -> Self {
275 Self::default()
276 }
277
278 #[allow(dead_code)]
280 pub fn push(
281 &mut self,
282 pipeline_id: impl Into<String>,
283 grid: DispatchGrid,
284 workgroup_size: WorkgroupSize,
285 ) {
286 self.records.push(DispatchRecord {
287 index: self.next_index,
288 pipeline_id: pipeline_id.into(),
289 grid,
290 workgroup_size,
291 });
292 self.next_index += 1;
293 }
294
295 #[allow(dead_code)]
297 #[must_use]
298 pub fn len(&self) -> usize {
299 self.records.len()
300 }
301
302 #[allow(dead_code)]
304 #[must_use]
305 pub fn is_empty(&self) -> bool {
306 self.records.is_empty()
307 }
308
309 #[allow(dead_code)]
311 #[must_use]
312 pub fn total_threads(&self) -> u64 {
313 self.records
314 .iter()
315 .map(|r| r.grid.total_threads(r.workgroup_size))
316 .sum()
317 }
318
319 #[allow(dead_code)]
321 #[must_use]
322 pub fn records(&self) -> &[DispatchRecord] {
323 &self.records
324 }
325
326 #[allow(dead_code)]
328 pub fn reset(&mut self) {
329 self.records.clear();
330 self.next_index = 0;
331 }
332}
333
334#[cfg(test)]
339mod tests {
340 use super::*;
341
342 #[test]
343 fn test_workgroup_thread_count() {
344 let wg = WorkgroupSize::new(8, 8, 1);
345 assert_eq!(wg.thread_count(), 64);
346 }
347
348 #[test]
349 fn test_workgroup_is_valid() {
350 assert!(WorkgroupSize::linear(64).is_valid(1024));
351 assert!(!WorkgroupSize::new(33, 33, 1).is_valid(1024));
352 }
353
354 #[test]
355 fn test_dispatch_1d_exact() {
356 let g = dispatch_1d(256, 64);
357 assert_eq!(g.x, 4);
358 assert_eq!(g.y, 1);
359 assert_eq!(g.z, 1);
360 }
361
362 #[test]
363 fn test_dispatch_1d_rounds_up() {
364 let g = dispatch_1d(257, 64);
365 assert_eq!(g.x, 5);
366 }
367
368 #[test]
369 fn test_dispatch_2d() {
370 let g = dispatch_2d(1920, 1080, 16, 16);
371 assert_eq!(g.x, 120); assert_eq!(g.y, 68); }
374
375 #[test]
376 fn test_dispatch_3d() {
377 let g = dispatch_3d(8, 8, 8, 4, 4, 4);
378 assert_eq!(g.x, 2);
379 assert_eq!(g.y, 2);
380 assert_eq!(g.z, 2);
381 }
382
383 #[test]
384 fn test_total_workgroups() {
385 let g = DispatchGrid::new(4, 4, 1);
386 assert_eq!(g.total_workgroups(), 16);
387 }
388
389 #[test]
390 fn test_total_threads() {
391 let g = DispatchGrid::new(2, 2, 1);
392 let wg = WorkgroupSize::planar(8, 8);
393 assert_eq!(g.total_threads(wg), 256);
394 }
395
396 #[test]
397 fn test_recommend_2d_workgroup_within_limit() {
398 let wg = recommend_2d_workgroup(256);
399 assert!(wg.thread_count() <= 256);
400 }
401
402 #[test]
403 fn test_recommend_2d_workgroup_square() {
404 let wg = recommend_2d_workgroup(1024);
405 assert_eq!(wg.x, wg.y);
406 }
407
408 #[test]
409 fn test_barrier_tracker_push_and_count() {
410 let mut bt = BarrierTracker::new();
411 bt.push(BarrierKind::MemoryReadAfterWrite, Some("pre-blur"));
412 bt.push(BarrierKind::Full, None);
413 assert_eq!(bt.len(), 2);
414 assert_eq!(bt.count_of_kind(BarrierKind::Full), 1);
415 }
416
417 #[test]
418 fn test_barrier_tracker_reset() {
419 let mut bt = BarrierTracker::new();
420 bt.push(BarrierKind::ExecutionOnly, None);
421 bt.reset();
422 assert!(bt.is_empty());
423 }
424
425 #[test]
426 fn test_dispatch_tracker_total_threads() {
427 let mut dt = DispatchTracker::new();
428 dt.push(
429 "blur",
430 DispatchGrid::new(10, 10, 1),
431 WorkgroupSize::planar(8, 8),
432 );
433 assert_eq!(dt.total_threads(), 6400);
435 }
436
437 #[test]
438 fn test_dispatch_tracker_records_sequential_indices() {
439 let mut dt = DispatchTracker::new();
440 dt.push("a", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
441 dt.push("b", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(64));
442 assert_eq!(dt.records()[0].index, 0);
443 assert_eq!(dt.records()[1].index, 1);
444 }
445
446 #[test]
447 fn test_dispatch_tracker_reset() {
448 let mut dt = DispatchTracker::new();
449 dt.push("x", DispatchGrid::new(1, 1, 1), WorkgroupSize::linear(32));
450 dt.reset();
451 assert!(dt.is_empty());
452 assert_eq!(dt.total_threads(), 0);
453 }
454}