astrelis_render/compute.rs
1//! Compute pass management with ergonomic builder pattern.
2//!
3//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
4//! `RenderPassBuilder` for compute shader operations.
5
6use astrelis_core::profiling::profile_function;
7
8use crate::context::GraphicsContext;
9use crate::frame::Frame;
10
11/// Builder for creating compute passes.
12///
13/// # Example
14///
15/// ```ignore
16/// let mut compute_pass = frame.compute_pass()
17/// .label("My Compute Pass")
18/// .build();
19///
20/// compute_pass.set_pipeline(&pipeline);
21/// compute_pass.set_bind_group(0, &bind_group, &[]);
22/// compute_pass.dispatch_workgroups(64, 64, 1);
23/// ```
24pub struct ComputePassBuilder<'f, 'w> {
25 frame: &'f Frame<'w>,
26 label: Option<String>,
27}
28
29impl<'f, 'w> ComputePassBuilder<'f, 'w> {
30 /// Create a new compute pass builder.
31 pub fn new(frame: &'f Frame<'w>) -> Self {
32 Self { frame, label: None }
33 }
34
35 /// Set a debug label for the compute pass.
36 pub fn label(mut self, label: impl Into<String>) -> Self {
37 self.label = Some(label.into());
38 self
39 }
40
41 /// Build the compute pass.
42 ///
43 /// This creates a new command encoder for this compute pass and returns it
44 /// when the ComputePass is dropped.
45 pub fn build(self) -> ComputePass<'f> {
46 profile_function!();
47
48 let label = self.label.clone();
49 let label_str = label.as_deref();
50
51 // Create encoder for this pass
52 let encoder = self
53 .frame
54 .device()
55 .create_command_encoder(&wgpu::CommandEncoderDescriptor { label: label_str });
56
57 let mut encoder = encoder;
58 let compute_pass = encoder
59 .begin_compute_pass(&wgpu::ComputePassDescriptor {
60 label: label_str,
61 timestamp_writes: None,
62 })
63 .forget_lifetime();
64
65 ComputePass {
66 frame: self.frame,
67 encoder: Some(encoder),
68 pass: Some(compute_pass),
69 }
70 }
71}
72
73impl Default for ComputePassBuilder<'_, '_> {
74 fn default() -> Self {
75 unimplemented!("ComputePassBuilder requires a Frame reference")
76 }
77}
78
79/// A compute pass wrapper that automatically returns the command buffer to the frame.
80///
81/// This struct mirrors `RenderPass` in its lifecycle management - it owns its
82/// encoder and pushes the command buffer to the frame when dropped.
83pub struct ComputePass<'f> {
84 /// Reference to the frame (for pushing command buffer on drop).
85 frame: &'f Frame<'f>,
86 /// The command encoder (owned by this pass).
87 pub(crate) encoder: Option<wgpu::CommandEncoder>,
88 /// The active wgpu compute pass.
89 pub(crate) pass: Option<wgpu::ComputePass<'static>>,
90}
91
92impl<'f> ComputePass<'f> {
93 /// Get the underlying wgpu compute pass (mutable).
94 pub fn wgpu_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
95 self.pass.as_mut().expect("ComputePass already consumed")
96 }
97
98 /// Get the underlying wgpu compute pass (immutable).
99 pub fn wgpu_pass_ref(&self) -> &wgpu::ComputePass<'static> {
100 self.pass.as_ref().expect("ComputePass already consumed")
101 }
102
103 /// Get raw access to the underlying wgpu compute pass.
104 ///
105 /// This is an alias for [`wgpu_pass()`](Self::wgpu_pass) for consistency with `RenderPass::raw_pass()`.
106 pub fn raw_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
107 self.pass.as_mut().expect("ComputePass already consumed")
108 }
109
110 /// Get the graphics context.
111 pub fn graphics(&self) -> &GraphicsContext {
112 self.frame.graphics()
113 }
114
115 /// Set the compute pipeline to use.
116 pub fn set_pipeline(&mut self, pipeline: &wgpu::ComputePipeline) {
117 self.wgpu_pass().set_pipeline(pipeline);
118 }
119
120 /// Set a bind group.
121 pub fn set_bind_group(&mut self, index: u32, bind_group: &wgpu::BindGroup, offsets: &[u32]) {
122 self.wgpu_pass().set_bind_group(index, bind_group, offsets);
123 }
124
125 /// Dispatch workgroups.
126 ///
127 /// # Arguments
128 ///
129 /// * `x` - Number of workgroups in the X dimension
130 /// * `y` - Number of workgroups in the Y dimension
131 /// * `z` - Number of workgroups in the Z dimension
132 pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
133 self.wgpu_pass().dispatch_workgroups(x, y, z);
134 }
135
136 /// Dispatch workgroups with a 1D configuration.
137 ///
138 /// Equivalent to `dispatch_workgroups(x, 1, 1)`.
139 pub fn dispatch_workgroups_1d(&mut self, x: u32) {
140 self.dispatch_workgroups(x, 1, 1);
141 }
142
143 /// Dispatch workgroups with a 2D configuration.
144 ///
145 /// Equivalent to `dispatch_workgroups(x, y, 1)`.
146 pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
147 self.dispatch_workgroups(x, y, 1);
148 }
149
150 /// Dispatch workgroups indirectly from a buffer.
151 ///
152 /// The buffer should contain a `DispatchIndirect` struct:
153 /// ```ignore
154 /// #[repr(C)]
155 /// struct DispatchIndirect {
156 /// x: u32,
157 /// y: u32,
158 /// z: u32,
159 /// }
160 /// ```
161 pub fn dispatch_workgroups_indirect(&mut self, buffer: &wgpu::Buffer, offset: u64) {
162 self.wgpu_pass()
163 .dispatch_workgroups_indirect(buffer, offset);
164 }
165
166 /// Insert a debug marker.
167 pub fn insert_debug_marker(&mut self, label: &str) {
168 self.wgpu_pass().insert_debug_marker(label);
169 }
170
171 /// Push a debug group.
172 pub fn push_debug_group(&mut self, label: &str) {
173 self.wgpu_pass().push_debug_group(label);
174 }
175
176 /// Pop a debug group.
177 pub fn pop_debug_group(&mut self) {
178 self.wgpu_pass().pop_debug_group();
179 }
180
181 /// Set push constants for the compute shader.
182 ///
183 /// Push constants are a fast way to pass small amounts of data to shaders
184 /// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
185 /// to be enabled on the device.
186 ///
187 /// # Arguments
188 ///
189 /// * `offset` - Byte offset into the push constant range
190 /// * `data` - Data to set (must be `Pod` for safe byte casting)
191 ///
192 /// # Example
193 ///
194 /// ```ignore
195 /// #[repr(C)]
196 /// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
197 /// struct ComputeConstants {
198 /// workgroup_count: u32,
199 /// time: f32,
200 /// }
201 ///
202 /// let constants = ComputeConstants {
203 /// workgroup_count: 64,
204 /// time: 1.5,
205 /// };
206 ///
207 /// pass.set_push_constants(0, &constants);
208 /// ```
209 pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
210 self.wgpu_pass()
211 .set_push_constants(offset, bytemuck::bytes_of(data));
212 }
213
214 /// Set push constants from raw bytes.
215 ///
216 /// Use this when you need more control over the data layout.
217 pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
218 self.wgpu_pass().set_push_constants(offset, data);
219 }
220
221 /// Finish the compute pass, pushing the command buffer to the frame.
222 pub fn finish(self) {
223 drop(self);
224 }
225}
226
227impl Drop for ComputePass<'_> {
228 fn drop(&mut self) {
229 profile_function!();
230
231 // End the compute pass
232 drop(self.pass.take());
233
234 // Finish encoder and push command buffer to frame
235 if let Some(encoder) = self.encoder.take() {
236 let command_buffer = encoder.finish();
237 self.frame.command_buffers.borrow_mut().push(command_buffer);
238 }
239 }
240}
241
242/// Indirect dispatch command.
243///
244/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
245#[repr(C)]
246#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
247pub struct DispatchIndirect {
248 /// Number of workgroups in the X dimension.
249 pub x: u32,
250 /// Number of workgroups in the Y dimension.
251 pub y: u32,
252 /// Number of workgroups in the Z dimension.
253 pub z: u32,
254}
255
256// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
257unsafe impl bytemuck::Pod for DispatchIndirect {}
258unsafe impl bytemuck::Zeroable for DispatchIndirect {}
259
260impl DispatchIndirect {
261 /// Create a new dispatch command.
262 pub const fn new(x: u32, y: u32, z: u32) -> Self {
263 Self { x, y, z }
264 }
265
266 /// Create a 1D dispatch command.
267 pub const fn new_1d(x: u32) -> Self {
268 Self::new(x, 1, 1)
269 }
270
271 /// Create a 2D dispatch command.
272 pub const fn new_2d(x: u32, y: u32) -> Self {
273 Self::new(x, y, 1)
274 }
275
276 /// Size of the command in bytes.
277 pub const fn size() -> u64 {
278 std::mem::size_of::<Self>() as u64
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_dispatch_indirect_size() {
288 // Verify the struct matches wgpu's expected layout
289 assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
290 }
291
292 #[test]
293 fn test_dispatch_indirect_1d() {
294 let cmd = DispatchIndirect::new_1d(64);
295 assert_eq!(cmd.x, 64);
296 assert_eq!(cmd.y, 1);
297 assert_eq!(cmd.z, 1);
298 }
299
300 #[test]
301 fn test_dispatch_indirect_2d() {
302 let cmd = DispatchIndirect::new_2d(32, 32);
303 assert_eq!(cmd.x, 32);
304 assert_eq!(cmd.y, 32);
305 assert_eq!(cmd.z, 1);
306 }
307}