1use crate::{
2 binding_model::{
3 BindError, BindGroup, BindGroupLayouts, LateMinBufferBindingSizeMismatch,
4 PushConstantUploadError,
5 },
6 command::{
7 bind::Binder,
8 end_pipeline_statistics_query,
9 memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10 BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
11 CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
12 },
13 device::{MissingDownlevelFlags, MissingFeatures},
14 error::{ErrorFormatter, PrettyError},
15 global::Global,
16 hal_api::HalApi,
17 hal_label,
18 hub::Token,
19 id,
20 id::DeviceId,
21 identity::GlobalIdentityHandlerFactory,
22 init_tracker::MemoryInitKind,
23 pipeline,
24 resource::{self, Buffer, Texture},
25 storage::Storage,
26 track::{Tracker, UsageConflict, UsageScope},
27 validation::{check_buffer_usage, MissingBufferUsageError},
28 Label,
29};
30
31use hal::CommandEncoder as _;
32#[cfg(any(feature = "serial-pass", feature = "replay"))]
33use serde::Deserialize;
34#[cfg(any(feature = "serial-pass", feature = "trace"))]
35use serde::Serialize;
36
37use thiserror::Error;
38
39use std::{fmt, mem, str};
40
41#[doc(hidden)]
42#[derive(Clone, Copy, Debug)]
43#[cfg_attr(
44 any(feature = "serial-pass", feature = "trace"),
45 derive(serde::Serialize)
46)]
47#[cfg_attr(
48 any(feature = "serial-pass", feature = "replay"),
49 derive(serde::Deserialize)
50)]
51pub enum ComputeCommand {
52 SetBindGroup {
53 index: u32,
54 num_dynamic_offsets: u8,
55 bind_group_id: id::BindGroupId,
56 },
57 SetPipeline(id::ComputePipelineId),
58
59 SetPushConstant {
61 offset: u32,
64
65 size_bytes: u32,
67
68 values_offset: u32,
74 },
75
76 Dispatch([u32; 3]),
77 DispatchIndirect {
78 buffer_id: id::BufferId,
79 offset: wgt::BufferAddress,
80 },
81 PushDebugGroup {
82 color: u32,
83 len: usize,
84 },
85 PopDebugGroup,
86 InsertDebugMarker {
87 color: u32,
88 len: usize,
89 },
90 WriteTimestamp {
91 query_set_id: id::QuerySetId,
92 query_index: u32,
93 },
94 BeginPipelineStatisticsQuery {
95 query_set_id: id::QuerySetId,
96 query_index: u32,
97 },
98 EndPipelineStatisticsQuery,
99}
100
101#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
102pub struct ComputePass {
103 base: BasePass<ComputeCommand>,
104 parent_id: id::CommandEncoderId,
105 timestamp_writes: Option<ComputePassTimestampWrites>,
106
107 #[cfg_attr(feature = "serial-pass", serde(skip))]
109 current_bind_groups: BindGroupStateChange,
110 #[cfg_attr(feature = "serial-pass", serde(skip))]
111 current_pipeline: StateChange<id::ComputePipelineId>,
112}
113
114impl ComputePass {
115 pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
116 Self {
117 base: BasePass::new(&desc.label),
118 parent_id,
119 timestamp_writes: desc.timestamp_writes.cloned(),
120
121 current_bind_groups: BindGroupStateChange::new(),
122 current_pipeline: StateChange::new(),
123 }
124 }
125
126 pub fn parent_id(&self) -> id::CommandEncoderId {
127 self.parent_id
128 }
129
130 #[cfg(feature = "trace")]
131 pub fn into_command(self) -> crate::device::trace::Command {
132 crate::device::trace::Command::RunComputePass {
133 base: self.base,
134 timestamp_writes: self.timestamp_writes,
135 }
136 }
137}
138
139impl fmt::Debug for ComputePass {
140 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141 write!(
142 f,
143 "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
144 self.parent_id,
145 self.base.commands.len(),
146 self.base.dynamic_offsets.len()
147 )
148 }
149}
150
151#[repr(C)]
153#[derive(Clone, Debug, PartialEq, Eq)]
154#[cfg_attr(any(feature = "serial-pass", feature = "trace"), derive(Serialize))]
155#[cfg_attr(any(feature = "serial-pass", feature = "replay"), derive(Deserialize))]
156pub struct ComputePassTimestampWrites {
157 pub query_set: id::QuerySetId,
159 pub beginning_of_pass_write_index: Option<u32>,
161 pub end_of_pass_write_index: Option<u32>,
163}
164
165#[derive(Clone, Debug, Default)]
166pub struct ComputePassDescriptor<'a> {
167 pub label: Label<'a>,
168 pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
170}
171
172#[derive(Clone, Debug, Error, Eq, PartialEq)]
173#[non_exhaustive]
174pub enum DispatchError {
175 #[error("Compute pipeline must be set")]
176 MissingPipeline,
177 #[error("The pipeline layout, associated with the current compute pipeline, contains a bind group layout at index {index} which is incompatible with the bind group layout associated with the bind group at {index}")]
178 IncompatibleBindGroup {
179 index: u32,
180 },
183 #[error(
184 "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
185 )]
186 InvalidGroupSize { current: [u32; 3], limit: u32 },
187 #[error(transparent)]
188 BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
189}
190
191#[derive(Clone, Debug, Error)]
193pub enum ComputePassErrorInner {
194 #[error(transparent)]
195 Encoder(#[from] CommandEncoderError),
196 #[error("Bind group {0:?} is invalid")]
197 InvalidBindGroup(id::BindGroupId),
198 #[error("Device {0:?} is invalid")]
199 InvalidDevice(DeviceId),
200 #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
201 BindGroupIndexOutOfRange { index: u32, max: u32 },
202 #[error("Compute pipeline {0:?} is invalid")]
203 InvalidPipeline(id::ComputePipelineId),
204 #[error("QuerySet {0:?} is invalid")]
205 InvalidQuerySet(id::QuerySetId),
206 #[error("Indirect buffer {0:?} is invalid or destroyed")]
207 InvalidIndirectBuffer(id::BufferId),
208 #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
209 IndirectBufferOverrun {
210 offset: u64,
211 end_offset: u64,
212 buffer_size: u64,
213 },
214 #[error("Buffer {0:?} is invalid or destroyed")]
215 InvalidBuffer(id::BufferId),
216 #[error(transparent)]
217 ResourceUsageConflict(#[from] UsageConflict),
218 #[error(transparent)]
219 MissingBufferUsage(#[from] MissingBufferUsageError),
220 #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
221 InvalidPopDebugGroup,
222 #[error(transparent)]
223 Dispatch(#[from] DispatchError),
224 #[error(transparent)]
225 Bind(#[from] BindError),
226 #[error(transparent)]
227 PushConstants(#[from] PushConstantUploadError),
228 #[error(transparent)]
229 QueryUse(#[from] QueryUseError),
230 #[error(transparent)]
231 MissingFeatures(#[from] MissingFeatures),
232 #[error(transparent)]
233 MissingDownlevelFlags(#[from] MissingDownlevelFlags),
234}
235
236impl PrettyError for ComputePassErrorInner {
237 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
238 fmt.error(self);
239 match *self {
240 Self::InvalidBindGroup(id) => {
241 fmt.bind_group_label(&id);
242 }
243 Self::InvalidPipeline(id) => {
244 fmt.compute_pipeline_label(&id);
245 }
246 Self::InvalidIndirectBuffer(id) => {
247 fmt.buffer_label(&id);
248 }
249 _ => {}
250 };
251 }
252}
253
254#[derive(Clone, Debug, Error)]
256#[error("{scope}")]
257pub struct ComputePassError {
258 pub scope: PassErrorScope,
259 #[source]
260 inner: ComputePassErrorInner,
261}
262impl PrettyError for ComputePassError {
263 fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
264 fmt.error(self);
267 self.scope.fmt_pretty(fmt);
268 }
269}
270
271impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
272where
273 E: Into<ComputePassErrorInner>,
274{
275 fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
276 self.map_err(|inner| ComputePassError {
277 scope,
278 inner: inner.into(),
279 })
280 }
281}
282
283struct State<A: HalApi> {
284 binder: Binder,
285 pipeline: Option<id::ComputePipelineId>,
286 scope: UsageScope<A>,
287 debug_scope_depth: u32,
288}
289
290impl<A: HalApi> State<A> {
291 fn is_ready(&self, bind_group_layouts: &BindGroupLayouts<A>) -> Result<(), DispatchError> {
292 let bind_mask = self.binder.invalid_mask(bind_group_layouts);
293 if bind_mask != 0 {
294 return Err(DispatchError::IncompatibleBindGroup {
296 index: bind_mask.trailing_zeros(),
297 });
298 }
299 if self.pipeline.is_none() {
300 return Err(DispatchError::MissingPipeline);
301 }
302 self.binder.check_late_buffer_bindings()?;
303
304 Ok(())
305 }
306
307 fn flush_states(
310 &mut self,
311 raw_encoder: &mut A::CommandEncoder,
312 base_trackers: &mut Tracker<A>,
313 bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>,
314 buffer_guard: &Storage<Buffer<A>, id::BufferId>,
315 texture_guard: &Storage<Texture<A>, id::TextureId>,
316 indirect_buffer: Option<id::Valid<id::BufferId>>,
317 ) -> Result<(), UsageConflict> {
318 for id in self.binder.list_active() {
319 unsafe {
320 self.scope
321 .merge_bind_group(texture_guard, &bind_group_guard[id].used)?
322 };
323 }
326
327 for id in self.binder.list_active() {
328 unsafe {
329 base_trackers.set_and_remove_from_usage_scope_sparse(
330 texture_guard,
331 &mut self.scope,
332 &bind_group_guard[id].used,
333 )
334 }
335 }
336
337 unsafe {
339 base_trackers
340 .buffers
341 .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
342 }
343
344 log::trace!("Encoding dispatch barriers");
345
346 CommandBuffer::drain_barriers(raw_encoder, base_trackers, buffer_guard, texture_guard);
347 Ok(())
348 }
349}
350
351impl<G: GlobalIdentityHandlerFactory> Global<G> {
354 pub fn command_encoder_run_compute_pass<A: HalApi>(
355 &self,
356 encoder_id: id::CommandEncoderId,
357 pass: &ComputePass,
358 ) -> Result<(), ComputePassError> {
359 self.command_encoder_run_compute_pass_impl::<A>(
360 encoder_id,
361 pass.base.as_ref(),
362 pass.timestamp_writes.as_ref(),
363 )
364 }
365
366 #[doc(hidden)]
367 pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
368 &self,
369 encoder_id: id::CommandEncoderId,
370 base: BasePassRef<ComputeCommand>,
371 timestamp_writes: Option<&ComputePassTimestampWrites>,
372 ) -> Result<(), ComputePassError> {
373 profiling::scope!("CommandEncoder::run_compute_pass");
374 let init_scope = PassErrorScope::Pass(encoder_id);
375
376 let hub = A::hub(self);
377 let mut token = Token::root();
378
379 let (device_guard, mut token) = hub.devices.read(&mut token);
380
381 let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
382 let cmd_buf: &mut CommandBuffer<A> =
385 CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)
386 .map_pass_err(init_scope)?;
387
388 cmd_buf.encoder.close();
392 cmd_buf.status = CommandEncoderStatus::Error;
394 let raw = cmd_buf.encoder.open();
395
396 let device = &device_guard[cmd_buf.device_id.value];
397 if !device.is_valid() {
398 return Err(ComputePassErrorInner::InvalidDevice(
399 cmd_buf.device_id.value.0,
400 ))
401 .map_pass_err(init_scope);
402 }
403
404 #[cfg(feature = "trace")]
405 if let Some(ref mut list) = cmd_buf.commands {
406 list.push(crate::device::trace::Command::RunComputePass {
407 base: BasePass::from_ref(base),
408 timestamp_writes: timestamp_writes.cloned(),
409 });
410 }
411
412 let (_, mut token) = hub.render_bundles.read(&mut token);
413 let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
414 let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
415 let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
416 let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
417 let (bind_group_layout_guard, mut token) = hub.bind_group_layouts.read(&mut token);
418 let (buffer_guard, mut token) = hub.buffers.read(&mut token);
419 let (texture_guard, _) = hub.textures.read(&mut token);
420
421 let mut state = State {
422 binder: Binder::new(),
423 pipeline: None,
424 scope: UsageScope::new(&*buffer_guard, &*texture_guard),
425 debug_scope_depth: 0,
426 };
427 let mut temp_offsets = Vec::new();
428 let mut dynamic_offset_count = 0;
429 let mut string_offset = 0;
430 let mut active_query = None;
431
432 let timestamp_writes = if let Some(tw) = timestamp_writes {
433 let query_set: &resource::QuerySet<A> = cmd_buf
434 .trackers
435 .query_sets
436 .add_single(&*query_set_guard, tw.query_set)
437 .ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
438 .map_pass_err(init_scope)?;
439
440 let range = if let (Some(index_a), Some(index_b)) =
443 (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
444 {
445 Some(index_a.min(index_b)..index_a.max(index_b) + 1)
446 } else {
447 tw.beginning_of_pass_write_index
448 .or(tw.end_of_pass_write_index)
449 .map(|i| i..i + 1)
450 };
451 if let Some(range) = range {
454 unsafe {
455 raw.reset_queries(&query_set.raw, range);
456 }
457 }
458
459 Some(hal::ComputePassTimestampWrites {
460 query_set: &query_set.raw,
461 beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
462 end_of_pass_write_index: tw.end_of_pass_write_index,
463 })
464 } else {
465 None
466 };
467
468 cmd_buf.trackers.set_size(
469 Some(&*buffer_guard),
470 Some(&*texture_guard),
471 None,
472 None,
473 Some(&*bind_group_guard),
474 Some(&*pipeline_guard),
475 None,
476 None,
477 Some(&*query_set_guard),
478 );
479
480 let discard_hal_labels = self
481 .instance
482 .flags
483 .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
484 let hal_desc = hal::ComputePassDescriptor {
485 label: hal_label(base.label, self.instance.flags),
486 timestamp_writes,
487 };
488
489 unsafe {
490 raw.begin_compute_pass(&hal_desc);
491 }
492
493 let mut intermediate_trackers = Tracker::<A>::new();
494
495 let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
498
499 for command in base.commands {
500 match *command {
501 ComputeCommand::SetBindGroup {
502 index,
503 num_dynamic_offsets,
504 bind_group_id,
505 } => {
506 let scope = PassErrorScope::SetBindGroup(bind_group_id);
507
508 let max_bind_groups = cmd_buf.limits.max_bind_groups;
509 if index >= max_bind_groups {
510 return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
511 index,
512 max: max_bind_groups,
513 })
514 .map_pass_err(scope);
515 }
516
517 temp_offsets.clear();
518 temp_offsets.extend_from_slice(
519 &base.dynamic_offsets[dynamic_offset_count
520 ..dynamic_offset_count + (num_dynamic_offsets as usize)],
521 );
522 dynamic_offset_count += num_dynamic_offsets as usize;
523
524 let bind_group: &BindGroup<A> = cmd_buf
525 .trackers
526 .bind_groups
527 .add_single(&*bind_group_guard, bind_group_id)
528 .ok_or(ComputePassErrorInner::InvalidBindGroup(bind_group_id))
529 .map_pass_err(scope)?;
530 bind_group
531 .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
532 .map_pass_err(scope)?;
533
534 cmd_buf.buffer_memory_init_actions.extend(
535 bind_group.used_buffer_ranges.iter().filter_map(
536 |action| match buffer_guard.get(action.id) {
537 Ok(buffer) => buffer.initialization_status.check_action(action),
538 Err(_) => None,
539 },
540 ),
541 );
542
543 for action in bind_group.used_texture_ranges.iter() {
544 pending_discard_init_fixups.extend(
545 cmd_buf
546 .texture_memory_actions
547 .register_init_action(action, &texture_guard),
548 );
549 }
550
551 let pipeline_layout_id = state.binder.pipeline_layout_id;
552 let entries = state.binder.assign_group(
553 index as usize,
554 id::Valid(bind_group_id),
555 bind_group,
556 &temp_offsets,
557 );
558 if !entries.is_empty() {
559 let pipeline_layout =
560 &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw;
561 for (i, e) in entries.iter().enumerate() {
562 let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
563 unsafe {
564 raw.set_bind_group(
565 pipeline_layout,
566 index + i as u32,
567 raw_bg,
568 &e.dynamic_offsets,
569 );
570 }
571 }
572 }
573 }
574 ComputeCommand::SetPipeline(pipeline_id) => {
575 let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
576
577 state.pipeline = Some(pipeline_id);
578
579 let pipeline: &pipeline::ComputePipeline<A> = cmd_buf
580 .trackers
581 .compute_pipelines
582 .add_single(&*pipeline_guard, pipeline_id)
583 .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
584 .map_pass_err(scope)?;
585
586 unsafe {
587 raw.set_compute_pipeline(&pipeline.raw);
588 }
589
590 if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
592 let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
593
594 let (start_index, entries) = state.binder.change_pipeline_layout(
595 &*pipeline_layout_guard,
596 pipeline.layout_id.value,
597 &pipeline.late_sized_buffer_groups,
598 );
599 if !entries.is_empty() {
600 for (i, e) in entries.iter().enumerate() {
601 let raw_bg =
602 &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
603 unsafe {
604 raw.set_bind_group(
605 &pipeline_layout.raw,
606 start_index as u32 + i as u32,
607 raw_bg,
608 &e.dynamic_offsets,
609 );
610 }
611 }
612 }
613
614 let non_overlapping = super::bind::compute_nonoverlapping_ranges(
616 &pipeline_layout.push_constant_ranges,
617 );
618 for range in non_overlapping {
619 let offset = range.range.start;
620 let size_bytes = range.range.end - offset;
621 super::push_constant_clear(
622 offset,
623 size_bytes,
624 |clear_offset, clear_data| unsafe {
625 raw.set_push_constants(
626 &pipeline_layout.raw,
627 wgt::ShaderStages::COMPUTE,
628 clear_offset,
629 clear_data,
630 );
631 },
632 );
633 }
634 }
635 }
636 ComputeCommand::SetPushConstant {
637 offset,
638 size_bytes,
639 values_offset,
640 } => {
641 let scope = PassErrorScope::SetPushConstant;
642
643 let end_offset_bytes = offset + size_bytes;
644 let values_end_offset =
645 (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
646 let data_slice =
647 &base.push_constant_data[(values_offset as usize)..values_end_offset];
648
649 let pipeline_layout_id = state
650 .binder
651 .pipeline_layout_id
652 .ok_or(ComputePassErrorInner::Dispatch(
654 DispatchError::MissingPipeline,
655 ))
656 .map_pass_err(scope)?;
657 let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id];
658
659 pipeline_layout
660 .validate_push_constant_ranges(
661 wgt::ShaderStages::COMPUTE,
662 offset,
663 end_offset_bytes,
664 )
665 .map_pass_err(scope)?;
666
667 unsafe {
668 raw.set_push_constants(
669 &pipeline_layout.raw,
670 wgt::ShaderStages::COMPUTE,
671 offset,
672 data_slice,
673 );
674 }
675 }
676 ComputeCommand::Dispatch(groups) => {
677 let scope = PassErrorScope::Dispatch {
678 indirect: false,
679 pipeline: state.pipeline,
680 };
681
682 state
683 .is_ready(&*bind_group_layout_guard)
684 .map_pass_err(scope)?;
685 state
686 .flush_states(
687 raw,
688 &mut intermediate_trackers,
689 &*bind_group_guard,
690 &*buffer_guard,
691 &*texture_guard,
692 None,
693 )
694 .map_pass_err(scope)?;
695
696 let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
697
698 if groups[0] > groups_size_limit
699 || groups[1] > groups_size_limit
700 || groups[2] > groups_size_limit
701 {
702 return Err(ComputePassErrorInner::Dispatch(
703 DispatchError::InvalidGroupSize {
704 current: groups,
705 limit: groups_size_limit,
706 },
707 ))
708 .map_pass_err(scope);
709 }
710
711 unsafe {
712 raw.dispatch(groups);
713 }
714 }
715 ComputeCommand::DispatchIndirect { buffer_id, offset } => {
716 let scope = PassErrorScope::Dispatch {
717 indirect: true,
718 pipeline: state.pipeline,
719 };
720
721 state
722 .is_ready(&*bind_group_layout_guard)
723 .map_pass_err(scope)?;
724
725 device
726 .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
727 .map_pass_err(scope)?;
728
729 let indirect_buffer: &Buffer<A> = state
730 .scope
731 .buffers
732 .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
733 .map_pass_err(scope)?;
734 check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT)
735 .map_pass_err(scope)?;
736
737 let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
738 if end_offset > indirect_buffer.size {
739 return Err(ComputePassErrorInner::IndirectBufferOverrun {
740 offset,
741 end_offset,
742 buffer_size: indirect_buffer.size,
743 })
744 .map_pass_err(scope);
745 }
746
747 let buf_raw = indirect_buffer
748 .raw
749 .as_ref()
750 .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
751 .map_pass_err(scope)?;
752
753 let stride = 3 * 4; cmd_buf.buffer_memory_init_actions.extend(
756 indirect_buffer.initialization_status.create_action(
757 buffer_id,
758 offset..(offset + stride),
759 MemoryInitKind::NeedsInitializedMemory,
760 ),
761 );
762
763 state
764 .flush_states(
765 raw,
766 &mut intermediate_trackers,
767 &*bind_group_guard,
768 &*buffer_guard,
769 &*texture_guard,
770 Some(id::Valid(buffer_id)),
771 )
772 .map_pass_err(scope)?;
773 unsafe {
774 raw.dispatch_indirect(buf_raw, offset);
775 }
776 }
777 ComputeCommand::PushDebugGroup { color: _, len } => {
778 state.debug_scope_depth += 1;
779 if !discard_hal_labels {
780 let label =
781 str::from_utf8(&base.string_data[string_offset..string_offset + len])
782 .unwrap();
783 unsafe {
784 raw.begin_debug_marker(label);
785 }
786 }
787 string_offset += len;
788 }
789 ComputeCommand::PopDebugGroup => {
790 let scope = PassErrorScope::PopDebugGroup;
791
792 if state.debug_scope_depth == 0 {
793 return Err(ComputePassErrorInner::InvalidPopDebugGroup)
794 .map_pass_err(scope);
795 }
796 state.debug_scope_depth -= 1;
797 if !discard_hal_labels {
798 unsafe {
799 raw.end_debug_marker();
800 }
801 }
802 }
803 ComputeCommand::InsertDebugMarker { color: _, len } => {
804 if !discard_hal_labels {
805 let label =
806 str::from_utf8(&base.string_data[string_offset..string_offset + len])
807 .unwrap();
808 unsafe { raw.insert_debug_marker(label) }
809 }
810 string_offset += len;
811 }
812 ComputeCommand::WriteTimestamp {
813 query_set_id,
814 query_index,
815 } => {
816 let scope = PassErrorScope::WriteTimestamp;
817
818 device
819 .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
820 .map_pass_err(scope)?;
821
822 let query_set: &resource::QuerySet<A> = cmd_buf
823 .trackers
824 .query_sets
825 .add_single(&*query_set_guard, query_set_id)
826 .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
827 .map_pass_err(scope)?;
828
829 query_set
830 .validate_and_write_timestamp(raw, query_set_id, query_index, None)
831 .map_pass_err(scope)?;
832 }
833 ComputeCommand::BeginPipelineStatisticsQuery {
834 query_set_id,
835 query_index,
836 } => {
837 let scope = PassErrorScope::BeginPipelineStatisticsQuery;
838
839 let query_set: &resource::QuerySet<A> = cmd_buf
840 .trackers
841 .query_sets
842 .add_single(&*query_set_guard, query_set_id)
843 .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
844 .map_pass_err(scope)?;
845
846 query_set
847 .validate_and_begin_pipeline_statistics_query(
848 raw,
849 query_set_id,
850 query_index,
851 None,
852 &mut active_query,
853 )
854 .map_pass_err(scope)?;
855 }
856 ComputeCommand::EndPipelineStatisticsQuery => {
857 let scope = PassErrorScope::EndPipelineStatisticsQuery;
858
859 end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
860 .map_pass_err(scope)?;
861 }
862 }
863 }
864
865 unsafe {
866 raw.end_compute_pass();
867 }
868 cmd_buf.status = CommandEncoderStatus::Recording;
871
872 cmd_buf.encoder.close();
874
875 let transit = cmd_buf.encoder.open();
879 fixup_discarded_surfaces(
880 pending_discard_init_fixups.into_iter(),
881 transit,
882 &texture_guard,
883 &mut cmd_buf.trackers.textures,
884 device,
885 );
886 CommandBuffer::insert_barriers_from_tracker(
887 transit,
888 &mut cmd_buf.trackers,
889 &intermediate_trackers,
890 &*buffer_guard,
891 &*texture_guard,
892 );
893 cmd_buf.encoder.close_and_swap();
895
896 Ok(())
897 }
898}
899
900pub mod compute_ffi {
901 use super::{ComputeCommand, ComputePass};
902 use crate::{id, RawString};
903 use std::{convert::TryInto, ffi, slice};
904 use wgt::{BufferAddress, DynamicOffset};
905
906 #[no_mangle]
911 pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
912 pass: &mut ComputePass,
913 index: u32,
914 bind_group_id: id::BindGroupId,
915 offsets: *const DynamicOffset,
916 offset_length: usize,
917 ) {
918 let redundant = unsafe {
919 pass.current_bind_groups.set_and_check_redundant(
920 bind_group_id,
921 index,
922 &mut pass.base.dynamic_offsets,
923 offsets,
924 offset_length,
925 )
926 };
927
928 if redundant {
929 return;
930 }
931
932 pass.base.commands.push(ComputeCommand::SetBindGroup {
933 index,
934 num_dynamic_offsets: offset_length.try_into().unwrap(),
935 bind_group_id,
936 });
937 }
938
939 #[no_mangle]
940 pub extern "C" fn wgpu_compute_pass_set_pipeline(
941 pass: &mut ComputePass,
942 pipeline_id: id::ComputePipelineId,
943 ) {
944 if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
945 return;
946 }
947
948 pass.base
949 .commands
950 .push(ComputeCommand::SetPipeline(pipeline_id));
951 }
952
953 #[no_mangle]
958 pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant(
959 pass: &mut ComputePass,
960 offset: u32,
961 size_bytes: u32,
962 data: *const u8,
963 ) {
964 assert_eq!(
965 offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
966 0,
967 "Push constant offset must be aligned to 4 bytes."
968 );
969 assert_eq!(
970 size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
971 0,
972 "Push constant size must be aligned to 4 bytes."
973 );
974 let data_slice = unsafe { slice::from_raw_parts(data, size_bytes as usize) };
975 let value_offset = pass.base.push_constant_data.len().try_into().expect(
976 "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
977 );
978
979 pass.base.push_constant_data.extend(
980 data_slice
981 .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
982 .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
983 );
984
985 pass.base.commands.push(ComputeCommand::SetPushConstant {
986 offset,
987 size_bytes,
988 values_offset: value_offset,
989 });
990 }
991
992 #[no_mangle]
993 pub extern "C" fn wgpu_compute_pass_dispatch_workgroups(
994 pass: &mut ComputePass,
995 groups_x: u32,
996 groups_y: u32,
997 groups_z: u32,
998 ) {
999 pass.base
1000 .commands
1001 .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1002 }
1003
1004 #[no_mangle]
1005 pub extern "C" fn wgpu_compute_pass_dispatch_workgroups_indirect(
1006 pass: &mut ComputePass,
1007 buffer_id: id::BufferId,
1008 offset: BufferAddress,
1009 ) {
1010 pass.base
1011 .commands
1012 .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
1013 }
1014
1015 #[no_mangle]
1020 pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
1021 pass: &mut ComputePass,
1022 label: RawString,
1023 color: u32,
1024 ) {
1025 let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
1026 pass.base.string_data.extend_from_slice(bytes);
1027
1028 pass.base.commands.push(ComputeCommand::PushDebugGroup {
1029 color,
1030 len: bytes.len(),
1031 });
1032 }
1033
1034 #[no_mangle]
1035 pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
1036 pass.base.commands.push(ComputeCommand::PopDebugGroup);
1037 }
1038
1039 #[no_mangle]
1044 pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
1045 pass: &mut ComputePass,
1046 label: RawString,
1047 color: u32,
1048 ) {
1049 let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
1050 pass.base.string_data.extend_from_slice(bytes);
1051
1052 pass.base.commands.push(ComputeCommand::InsertDebugMarker {
1053 color,
1054 len: bytes.len(),
1055 });
1056 }
1057
1058 #[no_mangle]
1059 pub extern "C" fn wgpu_compute_pass_write_timestamp(
1060 pass: &mut ComputePass,
1061 query_set_id: id::QuerySetId,
1062 query_index: u32,
1063 ) {
1064 pass.base.commands.push(ComputeCommand::WriteTimestamp {
1065 query_set_id,
1066 query_index,
1067 });
1068 }
1069
1070 #[no_mangle]
1071 pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query(
1072 pass: &mut ComputePass,
1073 query_set_id: id::QuerySetId,
1074 query_index: u32,
1075 ) {
1076 pass.base
1077 .commands
1078 .push(ComputeCommand::BeginPipelineStatisticsQuery {
1079 query_set_id,
1080 query_index,
1081 });
1082 }
1083
1084 #[no_mangle]
1085 pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
1086 pass.base
1087 .commands
1088 .push(ComputeCommand::EndPipelineStatisticsQuery);
1089 }
1090}