1use crate::{
6 command::{
7 bind::{Binder, LayoutChange},
8 BasePass, BasePassRef, CommandBuffer,
9 },
10 device::all_buffer_stages,
11 hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Token},
12 id,
13 resource::BufferUse,
14 span,
15};
16
17use hal::command::CommandBuffer as _;
18use wgt::{BufferAddress, BufferUsage};
19
20use std::{fmt, iter, str};
21
22#[doc(hidden)]
23#[derive(Clone, Copy, Debug)]
24#[cfg_attr(
25 any(feature = "serial-pass", feature = "trace"),
26 derive(serde::Serialize)
27)]
28#[cfg_attr(
29 any(feature = "serial-pass", feature = "replay"),
30 derive(serde::Deserialize)
31)]
32pub enum ComputeCommand {
33 SetBindGroup {
34 index: u8,
35 num_dynamic_offsets: u8,
36 bind_group_id: id::BindGroupId,
37 },
38 SetPipeline(id::ComputePipelineId),
39 Dispatch([u32; 3]),
40 DispatchIndirect {
41 buffer_id: id::BufferId,
42 offset: BufferAddress,
43 },
44 PushDebugGroup {
45 color: u32,
46 len: usize,
47 },
48 PopDebugGroup,
49 InsertDebugMarker {
50 color: u32,
51 len: usize,
52 },
53}
54
55#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
56pub struct ComputePass {
57 base: BasePass<ComputeCommand>,
58 parent_id: id::CommandEncoderId,
59}
60
61impl ComputePass {
62 pub fn new(parent_id: id::CommandEncoderId) -> Self {
63 ComputePass {
64 base: BasePass::new(),
65 parent_id,
66 }
67 }
68
69 pub fn parent_id(&self) -> id::CommandEncoderId {
70 self.parent_id
71 }
72}
73
74impl fmt::Debug for ComputePass {
75 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76 write!(
77 f,
78 "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
79 self.parent_id,
80 self.base.commands.len(),
81 self.base.dynamic_offsets.len()
82 )
83 }
84}
85
86#[repr(C)]
87#[derive(Clone, Debug, Default)]
88pub struct ComputePassDescriptor {
89 pub todo: u32,
90}
91
92#[derive(Debug, PartialEq)]
93enum PipelineState {
94 Required,
95 Set,
96}
97
98#[derive(Debug)]
99struct State {
100 binder: Binder,
101 pipeline: PipelineState,
102 debug_scope_depth: u32,
103}
104
105impl<G: GlobalIdentityHandlerFactory> Global<G> {
108 pub fn command_encoder_run_compute_pass<B: GfxBackend>(
109 &self,
110 encoder_id: id::CommandEncoderId,
111 pass: &ComputePass,
112 ) {
113 self.command_encoder_run_compute_pass_impl::<B>(encoder_id, pass.base.as_ref())
114 }
115
116 #[doc(hidden)]
117 pub fn command_encoder_run_compute_pass_impl<B: GfxBackend>(
118 &self,
119 encoder_id: id::CommandEncoderId,
120 mut base: BasePassRef<ComputeCommand>,
121 ) {
122 span!(_guard, INFO, "CommandEncoder::run_compute_pass");
123 let hub = B::hub(self);
124 let mut token = Token::root();
125
126 let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
127 let cmb = &mut cmb_guard[encoder_id];
128 let raw = cmb.raw.last_mut().unwrap();
129
130 #[cfg(feature = "trace")]
131 match cmb.commands {
132 Some(ref mut list) => {
133 list.push(crate::device::trace::Command::RunComputePass {
134 base: BasePass::from_ref(base),
135 });
136 }
137 None => {}
138 }
139
140 let (_, mut token) = hub.render_bundles.read(&mut token);
141 let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
142 let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
143 let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
144 let (buffer_guard, mut token) = hub.buffers.read(&mut token);
145 let (texture_guard, _) = hub.textures.read(&mut token);
146
147 let mut state = State {
148 binder: Binder::new(cmb.limits.max_bind_groups),
149 pipeline: PipelineState::Required,
150 debug_scope_depth: 0,
151 };
152
153 for command in base.commands {
154 match *command {
155 ComputeCommand::SetBindGroup {
156 index,
157 num_dynamic_offsets,
158 bind_group_id,
159 } => {
160 assert!(
161 (index as u32) < cmb.limits.max_bind_groups,
162 "Bind group index {0} is out of range 0..{1} provided by requested max_bind_group limit {1}",
163 index,
164 cmb.limits.max_bind_groups
165 );
166
167 let offsets = &base.dynamic_offsets[..num_dynamic_offsets as usize];
168 base.dynamic_offsets = &base.dynamic_offsets[num_dynamic_offsets as usize..];
169
170 let bind_group = cmb
171 .trackers
172 .bind_groups
173 .use_extend(&*bind_group_guard, bind_group_id, (), ())
174 .unwrap();
175 bind_group.validate_dynamic_bindings(offsets).unwrap();
176
177 log::trace!(
178 "Encoding barriers on binding of {:?} to {:?}",
179 bind_group_id,
180 encoder_id
181 );
182 CommandBuffer::insert_barriers(
183 raw,
184 &mut cmb.trackers,
185 &bind_group.used,
186 &*buffer_guard,
187 &*texture_guard,
188 );
189
190 if let Some((pipeline_layout_id, follow_ups)) = state.binder.provide_entry(
191 index as usize,
192 bind_group_id,
193 bind_group,
194 offsets,
195 ) {
196 let bind_groups = iter::once(bind_group.raw.raw()).chain(
197 follow_ups
198 .clone()
199 .map(|(bg_id, _)| bind_group_guard[bg_id].raw.raw()),
200 );
201 unsafe {
202 raw.bind_compute_descriptor_sets(
203 &pipeline_layout_guard[pipeline_layout_id].raw,
204 index as usize,
205 bind_groups,
206 offsets
207 .iter()
208 .chain(follow_ups.flat_map(|(_, offsets)| offsets))
209 .cloned(),
210 );
211 }
212 }
213 }
214 ComputeCommand::SetPipeline(pipeline_id) => {
215 state.pipeline = PipelineState::Set;
216 let pipeline = cmb
217 .trackers
218 .compute_pipes
219 .use_extend(&*pipeline_guard, pipeline_id, (), ())
220 .unwrap();
221
222 unsafe {
223 raw.bind_compute_pipeline(&pipeline.raw);
224 }
225
226 if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
228 let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
229 state.binder.pipeline_layout_id = Some(pipeline.layout_id.value);
230 state
231 .binder
232 .reset_expectations(pipeline_layout.bind_group_layout_ids.len());
233 let mut is_compatible = true;
234
235 for (index, (entry, bgl_id)) in state
236 .binder
237 .entries
238 .iter_mut()
239 .zip(&pipeline_layout.bind_group_layout_ids)
240 .enumerate()
241 {
242 match entry.expect_layout(bgl_id.value) {
243 LayoutChange::Match(bg_id, offsets) if is_compatible => {
244 let desc_set = bind_group_guard[bg_id].raw.raw();
245 unsafe {
246 raw.bind_compute_descriptor_sets(
247 &pipeline_layout.raw,
248 index,
249 iter::once(desc_set),
250 offsets.iter().cloned(),
251 );
252 }
253 }
254 LayoutChange::Match(..) | LayoutChange::Unchanged => {}
255 LayoutChange::Mismatch => {
256 is_compatible = false;
257 }
258 }
259 }
260 }
261 }
262 ComputeCommand::Dispatch(groups) => {
263 assert_eq!(
264 state.pipeline,
265 PipelineState::Set,
266 "Dispatch DEBUG: Pipeline is missing"
267 );
268 unsafe {
269 raw.dispatch(groups);
270 }
271 }
272 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
273 assert_eq!(
274 state.pipeline,
275 PipelineState::Set,
276 "Dispatch DEBUG: Pipeline is missing"
277 );
278 let (src_buffer, src_pending) = cmb.trackers.buffers.use_replace(
279 &*buffer_guard,
280 buffer_id,
281 (),
282 BufferUse::INDIRECT,
283 );
284 assert!(src_buffer.usage.contains(BufferUsage::INDIRECT));
285
286 let barriers = src_pending.map(|pending| pending.into_hal(src_buffer));
287
288 unsafe {
289 raw.pipeline_barrier(
290 all_buffer_stages()..all_buffer_stages(),
291 hal::memory::Dependencies::empty(),
292 barriers,
293 );
294 raw.dispatch_indirect(&src_buffer.raw, offset);
295 }
296 }
297 ComputeCommand::PushDebugGroup { color, len } => {
298 state.debug_scope_depth += 1;
299
300 let label = str::from_utf8(&base.string_data[..len]).unwrap();
301 unsafe {
302 raw.begin_debug_marker(label, color);
303 }
304 base.string_data = &base.string_data[len..];
305 }
306 ComputeCommand::PopDebugGroup => {
307 assert_ne!(
308 state.debug_scope_depth, 0,
309 "Can't pop debug group, because number of pushed debug groups is zero!"
310 );
311 state.debug_scope_depth -= 1;
312 unsafe {
313 raw.end_debug_marker();
314 }
315 }
316 ComputeCommand::InsertDebugMarker { color, len } => {
317 let label = str::from_utf8(&base.string_data[..len]).unwrap();
318 unsafe { raw.insert_debug_marker(label, color) }
319 base.string_data = &base.string_data[len..];
320 }
321 }
322 }
323 }
324}
325
326pub mod compute_ffi {
327 use super::{ComputeCommand, ComputePass};
328 use crate::{id, span, RawString};
329 use std::{convert::TryInto, ffi, slice};
330 use wgt::{BufferAddress, DynamicOffset};
331
332 #[no_mangle]
339 pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
340 pass: &mut ComputePass,
341 index: u32,
342 bind_group_id: id::BindGroupId,
343 offsets: *const DynamicOffset,
344 offset_length: usize,
345 ) {
346 span!(_guard, DEBUG, "ComputePass::set_bind_group");
347 pass.base.commands.push(ComputeCommand::SetBindGroup {
348 index: index.try_into().unwrap(),
349 num_dynamic_offsets: offset_length.try_into().unwrap(),
350 bind_group_id,
351 });
352 pass.base
353 .dynamic_offsets
354 .extend_from_slice(slice::from_raw_parts(offsets, offset_length));
355 }
356
357 #[no_mangle]
358 pub extern "C" fn wgpu_compute_pass_set_pipeline(
359 pass: &mut ComputePass,
360 pipeline_id: id::ComputePipelineId,
361 ) {
362 span!(_guard, DEBUG, "ComputePass::set_pipeline");
363 pass.base
364 .commands
365 .push(ComputeCommand::SetPipeline(pipeline_id));
366 }
367
368 #[no_mangle]
369 pub extern "C" fn wgpu_compute_pass_dispatch(
370 pass: &mut ComputePass,
371 groups_x: u32,
372 groups_y: u32,
373 groups_z: u32,
374 ) {
375 span!(_guard, DEBUG, "ComputePass::dispatch");
376 pass.base
377 .commands
378 .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
379 }
380
381 #[no_mangle]
382 pub extern "C" fn wgpu_compute_pass_dispatch_indirect(
383 pass: &mut ComputePass,
384 buffer_id: id::BufferId,
385 offset: BufferAddress,
386 ) {
387 span!(_guard, DEBUG, "ComputePass::dispatch_indirect");
388 pass.base
389 .commands
390 .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
391 }
392
393 #[no_mangle]
394 pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
395 pass: &mut ComputePass,
396 label: RawString,
397 color: u32,
398 ) {
399 span!(_guard, DEBUG, "ComputePass::push_debug_group");
400 let bytes = ffi::CStr::from_ptr(label).to_bytes();
401 pass.base.string_data.extend_from_slice(bytes);
402
403 pass.base.commands.push(ComputeCommand::PushDebugGroup {
404 color,
405 len: bytes.len(),
406 });
407 }
408
409 #[no_mangle]
410 pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
411 span!(_guard, DEBUG, "ComputePass::pop_debug_group");
412 pass.base.commands.push(ComputeCommand::PopDebugGroup);
413 }
414
415 #[no_mangle]
416 pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
417 pass: &mut ComputePass,
418 label: RawString,
419 color: u32,
420 ) {
421 span!(_guard, DEBUG, "ComputePass::insert_debug_marker");
422 let bytes = ffi::CStr::from_ptr(label).to_bytes();
423 pass.base.string_data.extend_from_slice(bytes);
424
425 pass.base.commands.push(ComputeCommand::InsertDebugMarker {
426 color,
427 len: bytes.len(),
428 });
429 }
430}