1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
//! Compute pass management with ergonomic builder pattern.
//!
//! This module provides a `ComputePassBuilder` that mirrors the ergonomics of
//! `RenderPassBuilder` for compute shader operations.
use astrelis_core::profiling::profile_function;
use crate::context::GraphicsContext;
use crate::frame::Frame;
/// Builder for creating compute passes.
///
/// # Example
///
/// ```ignore
/// let mut compute_pass = frame.compute_pass()
/// .label("My Compute Pass")
/// .build();
///
/// compute_pass.set_pipeline(&pipeline);
/// compute_pass.set_bind_group(0, &bind_group, &[]);
/// compute_pass.dispatch_workgroups(64, 64, 1);
/// ```
pub struct ComputePassBuilder<'f, 'w> {
frame: &'f Frame<'w>,
label: Option<String>,
}
impl<'f, 'w> ComputePassBuilder<'f, 'w> {
/// Create a new compute pass builder.
pub fn new(frame: &'f Frame<'w>) -> Self {
Self { frame, label: None }
}
/// Set a debug label for the compute pass.
pub fn label(mut self, label: impl Into<String>) -> Self {
self.label = Some(label.into());
self
}
/// Build the compute pass.
///
/// This creates a new command encoder for this compute pass and returns it
/// when the ComputePass is dropped.
pub fn build(self) -> ComputePass<'f> {
profile_function!();
let label = self.label.clone();
let label_str = label.as_deref();
// Create encoder for this pass
let encoder = self
.frame
.device()
.create_command_encoder(&wgpu::CommandEncoderDescriptor { label: label_str });
let mut encoder = encoder;
let compute_pass = encoder
.begin_compute_pass(&wgpu::ComputePassDescriptor {
label: label_str,
timestamp_writes: None,
})
.forget_lifetime();
ComputePass {
frame: self.frame,
encoder: Some(encoder),
pass: Some(compute_pass),
}
}
}
impl Default for ComputePassBuilder<'_, '_> {
fn default() -> Self {
unimplemented!("ComputePassBuilder requires a Frame reference")
}
}
/// A compute pass wrapper that automatically returns the command buffer to the frame.
///
/// This struct mirrors `RenderPass` in its lifecycle management - it owns its
/// encoder and pushes the command buffer to the frame when dropped.
pub struct ComputePass<'f> {
/// Reference to the frame (for pushing command buffer on drop).
frame: &'f Frame<'f>,
/// The command encoder (owned by this pass).
pub(crate) encoder: Option<wgpu::CommandEncoder>,
/// The active wgpu compute pass.
pub(crate) pass: Option<wgpu::ComputePass<'static>>,
}
impl<'f> ComputePass<'f> {
/// Get the underlying wgpu compute pass (mutable).
pub fn wgpu_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
self.pass.as_mut().expect("ComputePass already consumed")
}
/// Get the underlying wgpu compute pass (immutable).
pub fn wgpu_pass_ref(&self) -> &wgpu::ComputePass<'static> {
self.pass.as_ref().expect("ComputePass already consumed")
}
/// Get raw access to the underlying wgpu compute pass.
///
/// This is an alias for [`wgpu_pass()`](Self::wgpu_pass) for consistency with `RenderPass::raw_pass()`.
pub fn raw_pass(&mut self) -> &mut wgpu::ComputePass<'static> {
self.pass.as_mut().expect("ComputePass already consumed")
}
/// Get the graphics context.
pub fn graphics(&self) -> &GraphicsContext {
self.frame.graphics()
}
/// Set the compute pipeline to use.
pub fn set_pipeline(&mut self, pipeline: &wgpu::ComputePipeline) {
self.wgpu_pass().set_pipeline(pipeline);
}
/// Set a bind group.
pub fn set_bind_group(&mut self, index: u32, bind_group: &wgpu::BindGroup, offsets: &[u32]) {
self.wgpu_pass().set_bind_group(index, bind_group, offsets);
}
/// Dispatch workgroups.
///
/// # Arguments
///
/// * `x` - Number of workgroups in the X dimension
/// * `y` - Number of workgroups in the Y dimension
/// * `z` - Number of workgroups in the Z dimension
pub fn dispatch_workgroups(&mut self, x: u32, y: u32, z: u32) {
self.wgpu_pass().dispatch_workgroups(x, y, z);
}
/// Dispatch workgroups with a 1D configuration.
///
/// Equivalent to `dispatch_workgroups(x, 1, 1)`.
pub fn dispatch_workgroups_1d(&mut self, x: u32) {
self.dispatch_workgroups(x, 1, 1);
}
/// Dispatch workgroups with a 2D configuration.
///
/// Equivalent to `dispatch_workgroups(x, y, 1)`.
pub fn dispatch_workgroups_2d(&mut self, x: u32, y: u32) {
self.dispatch_workgroups(x, y, 1);
}
/// Dispatch workgroups indirectly from a buffer.
///
/// The buffer should contain a `DispatchIndirect` struct:
/// ```ignore
/// #[repr(C)]
/// struct DispatchIndirect {
/// x: u32,
/// y: u32,
/// z: u32,
/// }
/// ```
pub fn dispatch_workgroups_indirect(&mut self, buffer: &wgpu::Buffer, offset: u64) {
self.wgpu_pass()
.dispatch_workgroups_indirect(buffer, offset);
}
/// Insert a debug marker.
pub fn insert_debug_marker(&mut self, label: &str) {
self.wgpu_pass().insert_debug_marker(label);
}
/// Push a debug group.
pub fn push_debug_group(&mut self, label: &str) {
self.wgpu_pass().push_debug_group(label);
}
/// Pop a debug group.
pub fn pop_debug_group(&mut self) {
self.wgpu_pass().pop_debug_group();
}
/// Set push constants for the compute shader.
///
/// Push constants are a fast way to pass small amounts of data to shaders
/// without using uniform buffers. They require the `PUSH_CONSTANTS` feature
/// to be enabled on the device.
///
/// # Arguments
///
/// * `offset` - Byte offset into the push constant range
/// * `data` - Data to set (must be `Pod` for safe byte casting)
///
/// # Example
///
/// ```ignore
/// #[repr(C)]
/// #[derive(Clone, Copy, bytemuck::Pod, bytemuck::Zeroable)]
/// struct ComputeConstants {
/// workgroup_count: u32,
/// time: f32,
/// }
///
/// let constants = ComputeConstants {
/// workgroup_count: 64,
/// time: 1.5,
/// };
///
/// pass.set_push_constants(0, &constants);
/// ```
pub fn set_push_constants<T: bytemuck::Pod>(&mut self, offset: u32, data: &T) {
self.wgpu_pass()
.set_push_constants(offset, bytemuck::bytes_of(data));
}
/// Set push constants from raw bytes.
///
/// Use this when you need more control over the data layout.
pub fn set_push_constants_raw(&mut self, offset: u32, data: &[u8]) {
self.wgpu_pass().set_push_constants(offset, data);
}
/// Finish the compute pass, pushing the command buffer to the frame.
pub fn finish(self) {
drop(self);
}
}
impl Drop for ComputePass<'_> {
fn drop(&mut self) {
profile_function!();
// End the compute pass
drop(self.pass.take());
// Finish encoder and push command buffer to frame
if let Some(encoder) = self.encoder.take() {
let command_buffer = encoder.finish();
self.frame.command_buffers.borrow_mut().push(command_buffer);
}
}
}
/// Indirect dispatch command.
///
/// This matches the layout expected by `wgpu::ComputePass::dispatch_workgroups_indirect`.
#[repr(C)]
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Hash)]
pub struct DispatchIndirect {
/// Number of workgroups in the X dimension.
pub x: u32,
/// Number of workgroups in the Y dimension.
pub y: u32,
/// Number of workgroups in the Z dimension.
pub z: u32,
}
// SAFETY: DispatchIndirect is a repr(C) struct of u32s with no padding
unsafe impl bytemuck::Pod for DispatchIndirect {}
unsafe impl bytemuck::Zeroable for DispatchIndirect {}
impl DispatchIndirect {
/// Create a new dispatch command.
pub const fn new(x: u32, y: u32, z: u32) -> Self {
Self { x, y, z }
}
/// Create a 1D dispatch command.
pub const fn new_1d(x: u32) -> Self {
Self::new(x, 1, 1)
}
/// Create a 2D dispatch command.
pub const fn new_2d(x: u32, y: u32) -> Self {
Self::new(x, y, 1)
}
/// Size of the command in bytes.
pub const fn size() -> u64 {
std::mem::size_of::<Self>() as u64
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dispatch_indirect_size() {
// Verify the struct matches wgpu's expected layout
assert_eq!(DispatchIndirect::size(), 12); // 3 u32s = 12 bytes
}
#[test]
fn test_dispatch_indirect_1d() {
let cmd = DispatchIndirect::new_1d(64);
assert_eq!(cmd.x, 64);
assert_eq!(cmd.y, 1);
assert_eq!(cmd.z, 1);
}
#[test]
fn test_dispatch_indirect_2d() {
let cmd = DispatchIndirect::new_2d(32, 32);
assert_eq!(cmd.x, 32);
assert_eq!(cmd.y, 32);
assert_eq!(cmd.z, 1);
}
}