1#![allow(dead_code)]
2use std::fmt;
10
11#[derive(Debug, Clone, Copy, PartialEq, Eq)]
14#[repr(C)]
15pub struct IndirectDispatchArgs {
16 pub x: u32,
18 pub y: u32,
20 pub z: u32,
22}
23
24impl IndirectDispatchArgs {
25 pub fn new(x: u32, y: u32, z: u32) -> Self {
27 Self { x, y, z }
28 }
29
30 pub fn one_d(x: u32) -> Self {
32 Self { x, y: 1, z: 1 }
33 }
34
35 pub fn two_d(x: u32, y: u32) -> Self {
37 Self { x, y, z: 1 }
38 }
39
40 pub fn total_workgroups(&self) -> u64 {
42 u64::from(self.x) * u64::from(self.y) * u64::from(self.z)
43 }
44
45 pub fn to_bytes(&self) -> [u8; 12] {
47 let mut buf = [0u8; 12];
48 buf[0..4].copy_from_slice(&self.x.to_le_bytes());
49 buf[4..8].copy_from_slice(&self.y.to_le_bytes());
50 buf[8..12].copy_from_slice(&self.z.to_le_bytes());
51 buf
52 }
53
54 pub fn from_bytes(bytes: &[u8]) -> Option<Self> {
58 if bytes.len() < 12 {
59 return None;
60 }
61 let x = u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]);
62 let y = u32::from_le_bytes([bytes[4], bytes[5], bytes[6], bytes[7]]);
63 let z = u32::from_le_bytes([bytes[8], bytes[9], bytes[10], bytes[11]]);
64 Some(Self { x, y, z })
65 }
66
67 pub fn is_valid(&self) -> bool {
69 self.x > 0 && self.y > 0 && self.z > 0
70 }
71}
72
73impl Default for IndirectDispatchArgs {
74 fn default() -> Self {
75 Self { x: 1, y: 1, z: 1 }
76 }
77}
78
79impl fmt::Display for IndirectDispatchArgs {
80 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81 write!(f, "Dispatch({}x{}x{})", self.x, self.y, self.z)
82 }
83}
84
85#[derive(Debug, Clone, Copy, PartialEq, Eq)]
88pub enum DispatchStrategy {
89 Linear,
91 Tiled2D {
93 tile_w: u32,
95 tile_h: u32,
97 },
98 Volumetric {
100 vol_w: u32,
102 vol_h: u32,
104 vol_d: u32,
106 },
107}
108
109#[allow(clippy::cast_precision_loss)]
111pub fn compute_dispatch(
112 element_count: u32,
113 workgroup_size: u32,
114 strategy: DispatchStrategy,
115) -> IndirectDispatchArgs {
116 match strategy {
117 DispatchStrategy::Linear => {
118 let groups = (element_count + workgroup_size - 1) / workgroup_size;
119 IndirectDispatchArgs::one_d(groups)
120 }
121 DispatchStrategy::Tiled2D { tile_w, tile_h } => {
122 let gx = (tile_w + workgroup_size - 1) / workgroup_size;
123 let gy = (tile_h + workgroup_size - 1) / workgroup_size;
124 IndirectDispatchArgs::two_d(gx, gy)
125 }
126 DispatchStrategy::Volumetric {
127 vol_w,
128 vol_h,
129 vol_d,
130 } => {
131 let gx = (vol_w + workgroup_size - 1) / workgroup_size;
132 let gy = (vol_h + workgroup_size - 1) / workgroup_size;
133 let gz = (vol_d + workgroup_size - 1) / workgroup_size;
134 IndirectDispatchArgs::new(gx, gy, gz)
135 }
136 }
137}
138
139pub struct IndirectBuffer {
145 args: IndirectDispatchArgs,
147 label: String,
149 generation: u64,
151}
152
153impl IndirectBuffer {
154 pub fn new(label: &str) -> Self {
156 Self {
157 args: IndirectDispatchArgs::default(),
158 label: label.to_string(),
159 generation: 0,
160 }
161 }
162
163 pub fn with_args(label: &str, args: IndirectDispatchArgs) -> Self {
165 Self {
166 args,
167 label: label.to_string(),
168 generation: 0,
169 }
170 }
171
172 pub fn update(&mut self, args: IndirectDispatchArgs) {
174 self.args = args;
175 self.generation += 1;
176 }
177
178 pub fn args(&self) -> IndirectDispatchArgs {
180 self.args
181 }
182
183 pub fn label(&self) -> &str {
185 &self.label
186 }
187
188 pub fn generation(&self) -> u64 {
190 self.generation
191 }
192
193 pub fn size_bytes(&self) -> usize {
195 12
196 }
197
198 pub fn to_bytes(&self) -> [u8; 12] {
200 self.args.to_bytes()
201 }
202}
203
204pub fn validate_dispatch_limits(
206 args: &IndirectDispatchArgs,
207 max_per_dimension: u32,
208) -> Result<(), String> {
209 if args.x > max_per_dimension {
210 return Err(format!(
211 "X workgroup count {} exceeds limit {}",
212 args.x, max_per_dimension
213 ));
214 }
215 if args.y > max_per_dimension {
216 return Err(format!(
217 "Y workgroup count {} exceeds limit {}",
218 args.y, max_per_dimension
219 ));
220 }
221 if args.z > max_per_dimension {
222 return Err(format!(
223 "Z workgroup count {} exceeds limit {}",
224 args.z, max_per_dimension
225 ));
226 }
227 Ok(())
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_dispatch_args_new() {
236 let args = IndirectDispatchArgs::new(4, 8, 2);
237 assert_eq!(args.x, 4);
238 assert_eq!(args.y, 8);
239 assert_eq!(args.z, 2);
240 }
241
242 #[test]
243 fn test_dispatch_args_one_d() {
244 let args = IndirectDispatchArgs::one_d(16);
245 assert_eq!(args.x, 16);
246 assert_eq!(args.y, 1);
247 assert_eq!(args.z, 1);
248 }
249
250 #[test]
251 fn test_total_workgroups() {
252 let args = IndirectDispatchArgs::new(4, 8, 2);
253 assert_eq!(args.total_workgroups(), 64);
254 }
255
256 #[test]
257 fn test_to_from_bytes_roundtrip() {
258 let original = IndirectDispatchArgs::new(123, 456, 789);
259 let bytes = original.to_bytes();
260 let restored = IndirectDispatchArgs::from_bytes(&bytes)
261 .expect("deserialization from bytes should succeed");
262 assert_eq!(original, restored);
263 }
264
265 #[test]
266 fn test_from_bytes_too_short() {
267 assert!(IndirectDispatchArgs::from_bytes(&[0u8; 8]).is_none());
268 }
269
270 #[test]
271 fn test_is_valid() {
272 assert!(IndirectDispatchArgs::new(1, 1, 1).is_valid());
273 assert!(!IndirectDispatchArgs::new(0, 1, 1).is_valid());
274 assert!(!IndirectDispatchArgs::new(1, 0, 1).is_valid());
275 assert!(!IndirectDispatchArgs::new(1, 1, 0).is_valid());
276 }
277
278 #[test]
279 fn test_display() {
280 let args = IndirectDispatchArgs::new(4, 8, 2);
281 assert_eq!(format!("{args}"), "Dispatch(4x8x2)");
282 }
283
284 #[test]
285 fn test_compute_dispatch_linear() {
286 let args = compute_dispatch(1000, 64, DispatchStrategy::Linear);
287 assert_eq!(args.x, 16);
289 assert_eq!(args.y, 1);
290 assert_eq!(args.z, 1);
291 }
292
293 #[test]
294 fn test_compute_dispatch_tiled() {
295 let args = compute_dispatch(
296 0,
297 16,
298 DispatchStrategy::Tiled2D {
299 tile_w: 1920,
300 tile_h: 1080,
301 },
302 );
303 assert_eq!(args.x, 120); assert_eq!(args.y, 68); assert_eq!(args.z, 1);
306 }
307
308 #[test]
309 fn test_compute_dispatch_volumetric() {
310 let args = compute_dispatch(
311 0,
312 8,
313 DispatchStrategy::Volumetric {
314 vol_w: 64,
315 vol_h: 64,
316 vol_d: 32,
317 },
318 );
319 assert_eq!(args.x, 8);
320 assert_eq!(args.y, 8);
321 assert_eq!(args.z, 4);
322 }
323
324 #[test]
325 fn test_indirect_buffer_new() {
326 let buf = IndirectBuffer::new("test_buf");
327 assert_eq!(buf.label(), "test_buf");
328 assert_eq!(buf.args(), IndirectDispatchArgs::default());
329 assert_eq!(buf.generation(), 0);
330 assert_eq!(buf.size_bytes(), 12);
331 }
332
333 #[test]
334 fn test_indirect_buffer_update() {
335 let mut buf = IndirectBuffer::new("buf");
336 buf.update(IndirectDispatchArgs::new(10, 20, 30));
337 assert_eq!(buf.args().x, 10);
338 assert_eq!(buf.generation(), 1);
339 buf.update(IndirectDispatchArgs::one_d(5));
340 assert_eq!(buf.generation(), 2);
341 }
342
343 #[test]
344 fn test_validate_dispatch_limits_ok() {
345 let args = IndirectDispatchArgs::new(100, 100, 100);
346 assert!(validate_dispatch_limits(&args, 65535).is_ok());
347 }
348
349 #[test]
350 fn test_validate_dispatch_limits_exceeded() {
351 let args = IndirectDispatchArgs::new(70000, 1, 1);
352 assert!(validate_dispatch_limits(&args, 65535).is_err());
353 }
354
355 #[test]
356 fn test_default_dispatch_args() {
357 let args = IndirectDispatchArgs::default();
358 assert_eq!(args.x, 1);
359 assert_eq!(args.y, 1);
360 assert_eq!(args.z, 1);
361 assert!(args.is_valid());
362 }
363}