scry_gpu/dispatch.rs
1// SPDX-License-Identifier: MIT OR Apache-2.0
2//! Shader dispatch configuration and execution.
3
4/// Configuration for a compute dispatch.
5///
6/// The simple path ([`Device::dispatch`]) covers most cases.
7/// Use `DispatchConfig` when you need control over workgroup sizes
8/// or push constants.
9///
10/// [`Device::dispatch`]: crate::Device::dispatch
11pub struct DispatchConfig<'a> {
12 /// Shader source (WGSL).
13 pub shader: &'a str,
14
15 /// Entry point name. Defaults to `"main"` if `None`.
16 pub entry_point: Option<&'a str>,
17
18 /// Workgroup dimensions `[x, y, z]`.
19 ///
20 /// If `None`, the crate auto-calculates from `invocations` and the
21 /// shader's declared `@workgroup_size`.
22 pub workgroups: Option<[u32; 3]>,
23
24 /// Total invocations requested. Used to auto-calculate workgroup
25 /// dispatch count when `workgroups` is `None`.
26 pub invocations: u32,
27
28 /// Optional push constant data (raw bytes, must match shader layout).
29 pub push_constants: Option<&'a [u8]>,
30}
31
32impl<'a> DispatchConfig<'a> {
33 /// Create a minimal dispatch config.
34 pub const fn new(shader: &'a str, invocations: u32) -> Self {
35 Self {
36 shader,
37 entry_point: None,
38 workgroups: None,
39 invocations,
40 push_constants: None,
41 }
42 }
43
44 /// Override the entry point name (default: `"main"`).
45 pub const fn entry_point(mut self, name: &'a str) -> Self {
46 self.entry_point = Some(name);
47 self
48 }
49
50 /// Set explicit workgroup dispatch dimensions.
51 pub const fn workgroups(mut self, dims: [u32; 3]) -> Self {
52 self.workgroups = Some(dims);
53 self
54 }
55
56 /// Attach push constant data.
57 pub const fn push_constants(mut self, data: &'a [u8]) -> Self {
58 self.push_constants = Some(data);
59 self
60 }
61}
62
63/// Extract `@workgroup_size` from a parsed naga module's entry point.
64///
65/// Returns `[x, y, z]` or a default of `[64, 1, 1]` if the shader
66/// doesn't declare one.
67pub fn extract_workgroup_size(module: &naga::Module, entry: &str) -> [u32; 3] {
68 for ep in &module.entry_points {
69 if ep.name == entry {
70 let s = ep.workgroup_size;
71 return [s[0], s[1], s[2]];
72 }
73 }
74 [64, 1, 1]
75}
76
77/// Calculate dispatch dimensions given total invocations and per-workgroup size.
78///
79/// Applies `ceil(invocations / workgroup_size)` and clamps to the Vulkan
80/// `maxComputeWorkGroupCount` limit (65535 per axis).
81pub fn calc_dispatch(invocations: u32, workgroup_size: [u32; 3]) -> [u32; 3] {
82 let ceil_div = |a: u32, b: u32| a.div_ceil(b);
83
84 [ceil_div(invocations, workgroup_size[0]).min(65535), 1, 1]
85}