astrelis_render/compute.rs
1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::frame::FrameContext;
9
10/// Builder for creating compute passes.
11///
12/// # Example
13///
14/// ```ignore
15/// let mut compute_pass = ComputePassBuilder::new()
16/// .label("My Compute Pass")
17/// .build(&mut frame);
18///
19/// compute_pass.set_pipeline(&pipeline);
20/// compute_pass.set_bind_group(0, &bind_group, &[]);
21/// compute_pass.dispatch_workgroups(64, 64, 1);
22/// ```
23pub struct ComputePassBuilder<'a> {
24 label: Option<&'a str>,
25}
26
27impl<'a> ComputePassBuilder<'a> {
28 /// Create a new compute pass builder.
29 pub fn new() -> Self {
30 Self { label: None }
31 }
32
33 /// Set a debug label for the compute pass.
34 pub fn label(mut self, label: &'a str) -> Self {
35 self.label = Some(label);
36 self
37 }
38
39 /// Build the compute pass.
40 ///
41 /// This takes the command encoder from the FrameContext and returns it
42 /// when the ComputePass is dropped.
43 pub fn build(self, frame_context: &'a mut FrameContext) -> ComputePass<'a> {
44 let mut encoder = frame_context.encoder.take().unwrap();
45
46 let compute_pass = encoder
47 .begin_compute_pass(&wgpu::ComputePassDescriptor {
48 label: self.label,
49 timestamp_writes: None,
50 })
51 .forget_lifetime();
52
53 ComputePass {
54 context: frame_context,
55 encoder: Some(encoder),
56 pass: Some(compute_pass),
57 }
58 }
59}
60
61impl Default for ComputePassBuilder<'_> {
62 fn default() -> Self {
63 Self::new()
64 }
65}
66
67/// A compute pass wrapper that automatically returns the encoder to the frame context.
68///
69/// This struct mirrors `RenderPass` in its lifecycle management - it takes the
70/// encoder from `FrameContext` and returns it when dropped.
71pub struct ComputePass<'a> {
72 pub(crate) context: &'a mut FrameContext,
73 pub(crate) encoder: Option<wgpu::CommandEncoder>,
74 pub(crate) pass: Option<wgpu::ComputePass<'static>>,
75}
76
77impl<'a> ComputePass<'a> {
78 /// Get the underlying wgpu compute pass.
79 pub fn pass(&mut self) -> &mut wgpu::ComputePass<'static> {
80 self.pass.as_mut().unwrap()
81 }
82
83 /// Set the compute pipeline to use.
84 pub fn set_pipeline(&mut self, pipeline: &'a wgpu::ComputePipeline) {
85 self.pass().set_pipeline(pipeline);
86 }
87
88 /// Set a bind group.
89 pub fn set_bind_group(
90 &mut self,
91 index: u32,
92 bind_group: &'a wgpu::BindGroup,
93 offsets: &[u32],
94 ) {
95 self.pass().set_bind_group(index, bind_group, offsets);
96 }
97
98 /// Dispatch workgroups.
99 ///
100 /// # Arguments
101 ///
102 /// * `x` - Number of workgroups in the X dimension
103 /// * `y` - Number of workgroups in the Y dimension
104 /// * `z` - Number of workgroups in the Z dimension
105 pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
106 self.pass().dispatch_workgroups(x, y, z);
107 }
108
109 /// Dispatch workgroups with a 1D configuration.
110 ///
111 /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
112 pub fn dispatch_workgroups_1d(&mut self, x: u32) {
113 self.dispatch_workgroups(x, 1, 1);
114 }
115
116 /// Dispatch workgroups with a 2D configuration.
117 ///
118 /// Equivalent to `dispatch_workgroups(x, y, 1)`.
119 pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
120 self.dispatch_workgroups(x, y, 1);
121 }
122
123 /// Dispatch workgroups indirectly from a buffer.
124 ///
125 /// The buffer should contain a `DispatchIndirect` struct:
126 /// ```ignore
127 /// #[repr(C)]
128 /// struct DispatchIndirect {
129 /// x: u32,
130 /// y: u32,
131 /// z: u32,
132 /// }
133 /// ```
134 pub fn dispatch_workgroups_indirect(&mut self, buffer: &'a wgpu::Buffer, offset: u64) {
135 self.pass().dispatch_workgroups_indirect(buffer, offset);
136 }
137
138 /// Insert a debug marker.
139 pub fn insert_debug_marker(&mut self, label: &str) {
140 self.pass().insert_debug_marker(label);
141 }
142
143 /// Push a debug group.
144 pub fn push_debug_group(&mut self, label: &str) {
145 self.pass().push_debug_group(label);
146 }
147
148 /// Pop a debug group.
149 pub fn pop_debug_group(&mut self) {
150 self.pass().pop_debug_group();
151 }
152
153 /// Set push constants for the compute shader.
154 ///
155 /// Push constants are a fast way to pass small amounts of data to shaders
156 /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
157 /// to be enabled on the device.
158 ///
159 /// # Arguments
160 ///
161 /// * `offset` - Byte offset into the push constant range
162 /// * `data` - Data to set (must be `Pod` for safe byte casting)
163 ///
164 /// # Example
165 ///
166 /// ```ignore
167 /// #[repr(C)]
168 /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
169 /// struct ComputeConstants {
170 /// workgroup_count: u32,
171 /// time: f32,
172 /// }
173 ///
174 /// let constants = ComputeConstants {
175 /// workgroup_count: 64,
176 /// time: 1.5,
177 /// };
178 ///
179 /// pass.set_push_constants(0, &constants);
180 /// ```
181 pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
182 self.pass()
183 .set_push_constants(offset, bytemuck::bytes_of(data));
184 }
185
186 /// Set push constants from raw bytes.
187 ///
188 /// Use this when you need more control over the data layout.
189 pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
190 self.pass().set_push_constants(offset, data);
191 }
192
193 /// Finish the compute pass, returning the encoder to the frame context.
194 pub fn finish(self) {
195 drop(self);
196 }
197}
198
199impl Drop for ComputePass<'_> {
200 fn drop(&mut self) {
201 profile_function!();
202
203 // End the compute pass
204 drop(self.pass.take());
205
206 // Return the encoder to the frame context
207 self.context.encoder = self.encoder.take();
208 }
209}
210
211/// Indirect dispatch command.
212///
213/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
214#[repr(C)]
215#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
216pub struct DispatchIndirect {
217 /// Number of workgroups in the X dimension.
218 pub x: u32,
219 /// Number of workgroups in the Y dimension.
220 pub y: u32,
221 /// Number of workgroups in the Z dimension.
222 pub z: u32,
223}
224
225// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
226unsafe impl bytemuck::Pod for DispatchIndirect {}
227unsafe impl bytemuck::Zeroable for DispatchIndirect {}
228
229impl DispatchIndirect {
230 /// Create a new dispatch command.
231 pub const fn new(x: u32, y: u32, z: u32) -> Self {
232 Self { x, y, z }
233 }
234
235 /// Create a 1D dispatch command.
236 pub const fn new_1d(x: u32) -> Self {
237 Self::new(x, 1, 1)
238 }
239
240 /// Create a 2D dispatch command.
241 pub const fn new_2d(x: u32, y: u32) -> Self {
242 Self::new(x, y, 1)
243 }
244
245 /// Size of the command in bytes.
246 pub const fn size() -> u64 {
247 std::mem::size_of::<Self>() as u64
248 }
249}
250
251/// Helper trait for creating compute passes from FrameContext.
252pub trait ComputePassExt {
253 /// Create a compute pass with a label.
254 fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a>;
255
256 /// Create a compute pass without a label.
257 fn compute_pass_unlabeled(&mut self) -> ComputePass<'_>;
258}
259
260impl ComputePassExt for FrameContext {
261 fn compute_pass<'a>(&'a mut self, label: &'a str) -> ComputePass<'a> {
262 ComputePassBuilder::new().label(label).build(self)
263 }
264
265 fn compute_pass_unlabeled(&mut self) -> ComputePass<'_> {
266 ComputePassBuilder::new().build(self)
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_dispatch_indirect_size() {
276 // Verify the struct matches wgpu's expected layout
277 assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
278 }
279
280 #[test]
281 fn test_dispatch_indirect_1d() {
282 let cmd = DispatchIndirect::new_1d(64);
283 assert_eq!(cmd.x, 64);
284 assert_eq!(cmd.y, 1);
285 assert_eq!(cmd.z, 1);
286 }
287
288 #[test]
289 fn test_dispatch_indirect_2d() {
290 let cmd = DispatchIndirect::new_2d(32, 32);
291 assert_eq!(cmd.x, 32);
292 assert_eq!(cmd.y, 32);
293 assert_eq!(cmd.z, 1);
294 }
295}