li_wgpu_core/command/
compute.rs

1use crate::{
2    binding_model::{
3        BindError, BindGroup, BindGroupLayouts, LateMinBufferBindingSizeMismatch,
4        PushConstantUploadError,
5    },
6    command::{
7        bind::Binder,
8        end_pipeline_statistics_query,
9        memory_init::{fixup_discarded_surfaces, SurfacesInDiscardState},
10        BasePass, BasePassRef, BindGroupStateChange, CommandBuffer, CommandEncoderError,
11        CommandEncoderStatus, MapPassErr, PassErrorScope, QueryUseError, StateChange,
12    },
13    device::{MissingDownlevelFlags, MissingFeatures},
14    error::{ErrorFormatter, PrettyError},
15    global::Global,
16    hal_api::HalApi,
17    hal_label,
18    hub::Token,
19    id,
20    id::DeviceId,
21    identity::GlobalIdentityHandlerFactory,
22    init_tracker::MemoryInitKind,
23    pipeline,
24    resource::{self, Buffer, Texture},
25    storage::Storage,
26    track::{Tracker, UsageConflict, UsageScope},
27    validation::{check_buffer_usage, MissingBufferUsageError},
28    Label,
29};
30
31use hal::CommandEncoder as _;
32#[cfg(any(feature = "serial-pass", feature = "replay"))]
33use serde::Deserialize;
34#[cfg(any(feature = "serial-pass", feature = "trace"))]
35use serde::Serialize;
36
37use thiserror::Error;
38
39use std::{fmt, mem, str};
40
41#[doc(hidden)]
42#[derive(Clone, Copy, Debug)]
43#[cfg_attr(
44    any(feature = "serial-pass", feature = "trace"),
45    derive(serde::Serialize)
46)]
47#[cfg_attr(
48    any(feature = "serial-pass", feature = "replay"),
49    derive(serde::Deserialize)
50)]
51pub enum ComputeCommand {
52    SetBindGroup {
53        index: u32,
54        num_dynamic_offsets: u8,
55        bind_group_id: id::BindGroupId,
56    },
57    SetPipeline(id::ComputePipelineId),
58
59    /// Set a range of push constants to values stored in [`BasePass::push_constant_data`].
60    SetPushConstant {
61        /// The byte offset within the push constant storage to write to. This
62        /// must be a multiple of four.
63        offset: u32,
64
65        /// The number of bytes to write. This must be a multiple of four.
66        size_bytes: u32,
67
68        /// Index in [`BasePass::push_constant_data`] of the start of the data
69        /// to be written.
70        ///
71        /// Note: this is not a byte offset like `offset`. Rather, it is the
72        /// index of the first `u32` element in `push_constant_data` to read.
73        values_offset: u32,
74    },
75
76    Dispatch([u32; 3]),
77    DispatchIndirect {
78        buffer_id: id::BufferId,
79        offset: wgt::BufferAddress,
80    },
81    PushDebugGroup {
82        color: u32,
83        len: usize,
84    },
85    PopDebugGroup,
86    InsertDebugMarker {
87        color: u32,
88        len: usize,
89    },
90    WriteTimestamp {
91        query_set_id: id::QuerySetId,
92        query_index: u32,
93    },
94    BeginPipelineStatisticsQuery {
95        query_set_id: id::QuerySetId,
96        query_index: u32,
97    },
98    EndPipelineStatisticsQuery,
99}
100
101#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
102pub struct ComputePass {
103    base: BasePass<ComputeCommand>,
104    parent_id: id::CommandEncoderId,
105    timestamp_writes: Option<ComputePassTimestampWrites>,
106
107    // Resource binding dedupe state.
108    #[cfg_attr(feature = "serial-pass", serde(skip))]
109    current_bind_groups: BindGroupStateChange,
110    #[cfg_attr(feature = "serial-pass", serde(skip))]
111    current_pipeline: StateChange<id::ComputePipelineId>,
112}
113
114impl ComputePass {
115    pub fn new(parent_id: id::CommandEncoderId, desc: &ComputePassDescriptor) -> Self {
116        Self {
117            base: BasePass::new(&desc.label),
118            parent_id,
119            timestamp_writes: desc.timestamp_writes.cloned(),
120
121            current_bind_groups: BindGroupStateChange::new(),
122            current_pipeline: StateChange::new(),
123        }
124    }
125
126    pub fn parent_id(&self) -> id::CommandEncoderId {
127        self.parent_id
128    }
129
130    #[cfg(feature = "trace")]
131    pub fn into_command(self) -> crate::device::trace::Command {
132        crate::device::trace::Command::RunComputePass {
133            base: self.base,
134            timestamp_writes: self.timestamp_writes,
135        }
136    }
137}
138
139impl fmt::Debug for ComputePass {
140    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
141        write!(
142            f,
143            "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
144            self.parent_id,
145            self.base.commands.len(),
146            self.base.dynamic_offsets.len()
147        )
148    }
149}
150
151/// Describes the writing of timestamp values in a compute pass.
152#[repr(C)]
153#[derive(Clone, Debug, PartialEq, Eq)]
154#[cfg_attr(any(feature = "serial-pass", feature = "trace"), derive(Serialize))]
155#[cfg_attr(any(feature = "serial-pass", feature = "replay"), derive(Deserialize))]
156pub struct ComputePassTimestampWrites {
157    /// The query set to write the timestamps to.
158    pub query_set: id::QuerySetId,
159    /// The index of the query set at which a start timestamp of this pass is written, if any.
160    pub beginning_of_pass_write_index: Option<u32>,
161    /// The index of the query set at which an end timestamp of this pass is written, if any.
162    pub end_of_pass_write_index: Option<u32>,
163}
164
165#[derive(Clone, Debug, Default)]
166pub struct ComputePassDescriptor<'a> {
167    pub label: Label<'a>,
168    /// Defines where and when timestamp values will be written for this pass.
169    pub timestamp_writes: Option<&'a ComputePassTimestampWrites>,
170}
171
172#[derive(Clone, Debug, Error, Eq, PartialEq)]
173#[non_exhaustive]
174pub enum DispatchError {
175    #[error("Compute pipeline must be set")]
176    MissingPipeline,
177    #[error("The pipeline layout, associated with the current compute pipeline, contains a bind group layout at index {index} which is incompatible with the bind group layout associated with the bind group at {index}")]
178    IncompatibleBindGroup {
179        index: u32,
180        //expected: BindGroupLayoutId,
181        //provided: Option<(BindGroupLayoutId, BindGroupId)>,
182    },
183    #[error(
184        "Each current dispatch group size dimension ({current:?}) must be less or equal to {limit}"
185    )]
186    InvalidGroupSize { current: [u32; 3], limit: u32 },
187    #[error(transparent)]
188    BindingSizeTooSmall(#[from] LateMinBufferBindingSizeMismatch),
189}
190
191/// Error encountered when performing a compute pass.
192#[derive(Clone, Debug, Error)]
193pub enum ComputePassErrorInner {
194    #[error(transparent)]
195    Encoder(#[from] CommandEncoderError),
196    #[error("Bind group {0:?} is invalid")]
197    InvalidBindGroup(id::BindGroupId),
198    #[error("Device {0:?} is invalid")]
199    InvalidDevice(DeviceId),
200    #[error("Bind group index {index} is greater than the device's requested `max_bind_group` limit {max}")]
201    BindGroupIndexOutOfRange { index: u32, max: u32 },
202    #[error("Compute pipeline {0:?} is invalid")]
203    InvalidPipeline(id::ComputePipelineId),
204    #[error("QuerySet {0:?} is invalid")]
205    InvalidQuerySet(id::QuerySetId),
206    #[error("Indirect buffer {0:?} is invalid or destroyed")]
207    InvalidIndirectBuffer(id::BufferId),
208    #[error("Indirect buffer uses bytes {offset}..{end_offset} which overruns indirect buffer of size {buffer_size}")]
209    IndirectBufferOverrun {
210        offset: u64,
211        end_offset: u64,
212        buffer_size: u64,
213    },
214    #[error("Buffer {0:?} is invalid or destroyed")]
215    InvalidBuffer(id::BufferId),
216    #[error(transparent)]
217    ResourceUsageConflict(#[from] UsageConflict),
218    #[error(transparent)]
219    MissingBufferUsage(#[from] MissingBufferUsageError),
220    #[error("Cannot pop debug group, because number of pushed debug groups is zero")]
221    InvalidPopDebugGroup,
222    #[error(transparent)]
223    Dispatch(#[from] DispatchError),
224    #[error(transparent)]
225    Bind(#[from] BindError),
226    #[error(transparent)]
227    PushConstants(#[from] PushConstantUploadError),
228    #[error(transparent)]
229    QueryUse(#[from] QueryUseError),
230    #[error(transparent)]
231    MissingFeatures(#[from] MissingFeatures),
232    #[error(transparent)]
233    MissingDownlevelFlags(#[from] MissingDownlevelFlags),
234}
235
236impl PrettyError for ComputePassErrorInner {
237    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
238        fmt.error(self);
239        match *self {
240            Self::InvalidBindGroup(id) => {
241                fmt.bind_group_label(&id);
242            }
243            Self::InvalidPipeline(id) => {
244                fmt.compute_pipeline_label(&id);
245            }
246            Self::InvalidIndirectBuffer(id) => {
247                fmt.buffer_label(&id);
248            }
249            _ => {}
250        };
251    }
252}
253
254/// Error encountered when performing a compute pass.
255#[derive(Clone, Debug, Error)]
256#[error("{scope}")]
257pub struct ComputePassError {
258    pub scope: PassErrorScope,
259    #[source]
260    inner: ComputePassErrorInner,
261}
262impl PrettyError for ComputePassError {
263    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
264        // This error is wrapper for the inner error,
265        // but the scope has useful labels
266        fmt.error(self);
267        self.scope.fmt_pretty(fmt);
268    }
269}
270
271impl<T, E> MapPassErr<T, ComputePassError> for Result<T, E>
272where
273    E: Into<ComputePassErrorInner>,
274{
275    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, ComputePassError> {
276        self.map_err(|inner| ComputePassError {
277            scope,
278            inner: inner.into(),
279        })
280    }
281}
282
283struct State<A: HalApi> {
284    binder: Binder,
285    pipeline: Option<id::ComputePipelineId>,
286    scope: UsageScope<A>,
287    debug_scope_depth: u32,
288}
289
290impl<A: HalApi> State<A> {
291    fn is_ready(&self, bind_group_layouts: &BindGroupLayouts<A>) -> Result<(), DispatchError> {
292        let bind_mask = self.binder.invalid_mask(bind_group_layouts);
293        if bind_mask != 0 {
294            //let (expected, provided) = self.binder.entries[index as usize].info();
295            return Err(DispatchError::IncompatibleBindGroup {
296                index: bind_mask.trailing_zeros(),
297            });
298        }
299        if self.pipeline.is_none() {
300            return Err(DispatchError::MissingPipeline);
301        }
302        self.binder.check_late_buffer_bindings()?;
303
304        Ok(())
305    }
306
307    // `extra_buffer` is there to represent the indirect buffer that is also
308    // part of the usage scope.
309    fn flush_states(
310        &mut self,
311        raw_encoder: &mut A::CommandEncoder,
312        base_trackers: &mut Tracker<A>,
313        bind_group_guard: &Storage<BindGroup<A>, id::BindGroupId>,
314        buffer_guard: &Storage<Buffer<A>, id::BufferId>,
315        texture_guard: &Storage<Texture<A>, id::TextureId>,
316        indirect_buffer: Option<id::Valid<id::BufferId>>,
317    ) -> Result<(), UsageConflict> {
318        for id in self.binder.list_active() {
319            unsafe {
320                self.scope
321                    .merge_bind_group(texture_guard, &bind_group_guard[id].used)?
322            };
323            // Note: stateless trackers are not merged: the lifetime reference
324            // is held to the bind group itself.
325        }
326
327        for id in self.binder.list_active() {
328            unsafe {
329                base_trackers.set_and_remove_from_usage_scope_sparse(
330                    texture_guard,
331                    &mut self.scope,
332                    &bind_group_guard[id].used,
333                )
334            }
335        }
336
337        // Add the state of the indirect buffer if it hasn't been hit before.
338        unsafe {
339            base_trackers
340                .buffers
341                .set_and_remove_from_usage_scope_sparse(&mut self.scope.buffers, indirect_buffer);
342        }
343
344        log::trace!("Encoding dispatch barriers");
345
346        CommandBuffer::drain_barriers(raw_encoder, base_trackers, buffer_guard, texture_guard);
347        Ok(())
348    }
349}
350
351// Common routines between render/compute
352
353impl<G: GlobalIdentityHandlerFactory> Global<G> {
354    pub fn command_encoder_run_compute_pass<A: HalApi>(
355        &self,
356        encoder_id: id::CommandEncoderId,
357        pass: &ComputePass,
358    ) -> Result<(), ComputePassError> {
359        self.command_encoder_run_compute_pass_impl::<A>(
360            encoder_id,
361            pass.base.as_ref(),
362            pass.timestamp_writes.as_ref(),
363        )
364    }
365
366    #[doc(hidden)]
367    pub fn command_encoder_run_compute_pass_impl<A: HalApi>(
368        &self,
369        encoder_id: id::CommandEncoderId,
370        base: BasePassRef<ComputeCommand>,
371        timestamp_writes: Option<&ComputePassTimestampWrites>,
372    ) -> Result<(), ComputePassError> {
373        profiling::scope!("CommandEncoder::run_compute_pass");
374        let init_scope = PassErrorScope::Pass(encoder_id);
375
376        let hub = A::hub(self);
377        let mut token = Token::root();
378
379        let (device_guard, mut token) = hub.devices.read(&mut token);
380
381        let (mut cmd_buf_guard, mut token) = hub.command_buffers.write(&mut token);
382        // Spell out the type, to placate rust-analyzer.
383        // https://github.com/rust-lang/rust-analyzer/issues/12247
384        let cmd_buf: &mut CommandBuffer<A> =
385            CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)
386                .map_pass_err(init_scope)?;
387
388        // We automatically keep extending command buffers over time, and because
389        // we want to insert a command buffer _before_ what we're about to record,
390        // we need to make sure to close the previous one.
391        cmd_buf.encoder.close();
392        // We will reset this to `Recording` if we succeed, acts as a fail-safe.
393        cmd_buf.status = CommandEncoderStatus::Error;
394        let raw = cmd_buf.encoder.open();
395
396        let device = &device_guard[cmd_buf.device_id.value];
397        if !device.is_valid() {
398            return Err(ComputePassErrorInner::InvalidDevice(
399                cmd_buf.device_id.value.0,
400            ))
401            .map_pass_err(init_scope);
402        }
403
404        #[cfg(feature = "trace")]
405        if let Some(ref mut list) = cmd_buf.commands {
406            list.push(crate::device::trace::Command::RunComputePass {
407                base: BasePass::from_ref(base),
408                timestamp_writes: timestamp_writes.cloned(),
409            });
410        }
411
412        let (_, mut token) = hub.render_bundles.read(&mut token);
413        let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
414        let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
415        let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
416        let (query_set_guard, mut token) = hub.query_sets.read(&mut token);
417        let (bind_group_layout_guard, mut token) = hub.bind_group_layouts.read(&mut token);
418        let (buffer_guard, mut token) = hub.buffers.read(&mut token);
419        let (texture_guard, _) = hub.textures.read(&mut token);
420
421        let mut state = State {
422            binder: Binder::new(),
423            pipeline: None,
424            scope: UsageScope::new(&*buffer_guard, &*texture_guard),
425            debug_scope_depth: 0,
426        };
427        let mut temp_offsets = Vec::new();
428        let mut dynamic_offset_count = 0;
429        let mut string_offset = 0;
430        let mut active_query = None;
431
432        let timestamp_writes = if let Some(tw) = timestamp_writes {
433            let query_set: &resource::QuerySet<A> = cmd_buf
434                .trackers
435                .query_sets
436                .add_single(&*query_set_guard, tw.query_set)
437                .ok_or(ComputePassErrorInner::InvalidQuerySet(tw.query_set))
438                .map_pass_err(init_scope)?;
439
440            // Unlike in render passes we can't delay resetting the query sets since
441            // there is no auxillary pass.
442            let range = if let (Some(index_a), Some(index_b)) =
443                (tw.beginning_of_pass_write_index, tw.end_of_pass_write_index)
444            {
445                Some(index_a.min(index_b)..index_a.max(index_b) + 1)
446            } else {
447                tw.beginning_of_pass_write_index
448                    .or(tw.end_of_pass_write_index)
449                    .map(|i| i..i + 1)
450            };
451            // Range should always be Some, both values being None should lead to a validation error.
452            // But no point in erroring over that nuance here!
453            if let Some(range) = range {
454                unsafe {
455                    raw.reset_queries(&query_set.raw, range);
456                }
457            }
458
459            Some(hal::ComputePassTimestampWrites {
460                query_set: &query_set.raw,
461                beginning_of_pass_write_index: tw.beginning_of_pass_write_index,
462                end_of_pass_write_index: tw.end_of_pass_write_index,
463            })
464        } else {
465            None
466        };
467
468        cmd_buf.trackers.set_size(
469            Some(&*buffer_guard),
470            Some(&*texture_guard),
471            None,
472            None,
473            Some(&*bind_group_guard),
474            Some(&*pipeline_guard),
475            None,
476            None,
477            Some(&*query_set_guard),
478        );
479
480        let discard_hal_labels = self
481            .instance
482            .flags
483            .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS);
484        let hal_desc = hal::ComputePassDescriptor {
485            label: hal_label(base.label, self.instance.flags),
486            timestamp_writes,
487        };
488
489        unsafe {
490            raw.begin_compute_pass(&hal_desc);
491        }
492
493        let mut intermediate_trackers = Tracker::<A>::new();
494
495        // Immediate texture inits required because of prior discards. Need to
496        // be inserted before texture reads.
497        let mut pending_discard_init_fixups = SurfacesInDiscardState::new();
498
499        for command in base.commands {
500            match *command {
501                ComputeCommand::SetBindGroup {
502                    index,
503                    num_dynamic_offsets,
504                    bind_group_id,
505                } => {
506                    let scope = PassErrorScope::SetBindGroup(bind_group_id);
507
508                    let max_bind_groups = cmd_buf.limits.max_bind_groups;
509                    if index >= max_bind_groups {
510                        return Err(ComputePassErrorInner::BindGroupIndexOutOfRange {
511                            index,
512                            max: max_bind_groups,
513                        })
514                        .map_pass_err(scope);
515                    }
516
517                    temp_offsets.clear();
518                    temp_offsets.extend_from_slice(
519                        &base.dynamic_offsets[dynamic_offset_count
520                            ..dynamic_offset_count + (num_dynamic_offsets as usize)],
521                    );
522                    dynamic_offset_count += num_dynamic_offsets as usize;
523
524                    let bind_group: &BindGroup<A> = cmd_buf
525                        .trackers
526                        .bind_groups
527                        .add_single(&*bind_group_guard, bind_group_id)
528                        .ok_or(ComputePassErrorInner::InvalidBindGroup(bind_group_id))
529                        .map_pass_err(scope)?;
530                    bind_group
531                        .validate_dynamic_bindings(index, &temp_offsets, &cmd_buf.limits)
532                        .map_pass_err(scope)?;
533
534                    cmd_buf.buffer_memory_init_actions.extend(
535                        bind_group.used_buffer_ranges.iter().filter_map(
536                            |action| match buffer_guard.get(action.id) {
537                                Ok(buffer) => buffer.initialization_status.check_action(action),
538                                Err(_) => None,
539                            },
540                        ),
541                    );
542
543                    for action in bind_group.used_texture_ranges.iter() {
544                        pending_discard_init_fixups.extend(
545                            cmd_buf
546                                .texture_memory_actions
547                                .register_init_action(action, &texture_guard),
548                        );
549                    }
550
551                    let pipeline_layout_id = state.binder.pipeline_layout_id;
552                    let entries = state.binder.assign_group(
553                        index as usize,
554                        id::Valid(bind_group_id),
555                        bind_group,
556                        &temp_offsets,
557                    );
558                    if !entries.is_empty() {
559                        let pipeline_layout =
560                            &pipeline_layout_guard[pipeline_layout_id.unwrap()].raw;
561                        for (i, e) in entries.iter().enumerate() {
562                            let raw_bg = &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
563                            unsafe {
564                                raw.set_bind_group(
565                                    pipeline_layout,
566                                    index + i as u32,
567                                    raw_bg,
568                                    &e.dynamic_offsets,
569                                );
570                            }
571                        }
572                    }
573                }
574                ComputeCommand::SetPipeline(pipeline_id) => {
575                    let scope = PassErrorScope::SetPipelineCompute(pipeline_id);
576
577                    state.pipeline = Some(pipeline_id);
578
579                    let pipeline: &pipeline::ComputePipeline<A> = cmd_buf
580                        .trackers
581                        .compute_pipelines
582                        .add_single(&*pipeline_guard, pipeline_id)
583                        .ok_or(ComputePassErrorInner::InvalidPipeline(pipeline_id))
584                        .map_pass_err(scope)?;
585
586                    unsafe {
587                        raw.set_compute_pipeline(&pipeline.raw);
588                    }
589
590                    // Rebind resources
591                    if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
592                        let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
593
594                        let (start_index, entries) = state.binder.change_pipeline_layout(
595                            &*pipeline_layout_guard,
596                            pipeline.layout_id.value,
597                            &pipeline.late_sized_buffer_groups,
598                        );
599                        if !entries.is_empty() {
600                            for (i, e) in entries.iter().enumerate() {
601                                let raw_bg =
602                                    &bind_group_guard[e.group_id.as_ref().unwrap().value].raw;
603                                unsafe {
604                                    raw.set_bind_group(
605                                        &pipeline_layout.raw,
606                                        start_index as u32 + i as u32,
607                                        raw_bg,
608                                        &e.dynamic_offsets,
609                                    );
610                                }
611                            }
612                        }
613
614                        // Clear push constant ranges
615                        let non_overlapping = super::bind::compute_nonoverlapping_ranges(
616                            &pipeline_layout.push_constant_ranges,
617                        );
618                        for range in non_overlapping {
619                            let offset = range.range.start;
620                            let size_bytes = range.range.end - offset;
621                            super::push_constant_clear(
622                                offset,
623                                size_bytes,
624                                |clear_offset, clear_data| unsafe {
625                                    raw.set_push_constants(
626                                        &pipeline_layout.raw,
627                                        wgt::ShaderStages::COMPUTE,
628                                        clear_offset,
629                                        clear_data,
630                                    );
631                                },
632                            );
633                        }
634                    }
635                }
636                ComputeCommand::SetPushConstant {
637                    offset,
638                    size_bytes,
639                    values_offset,
640                } => {
641                    let scope = PassErrorScope::SetPushConstant;
642
643                    let end_offset_bytes = offset + size_bytes;
644                    let values_end_offset =
645                        (values_offset + size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT) as usize;
646                    let data_slice =
647                        &base.push_constant_data[(values_offset as usize)..values_end_offset];
648
649                    let pipeline_layout_id = state
650                        .binder
651                        .pipeline_layout_id
652                        //TODO: don't error here, lazily update the push constants
653                        .ok_or(ComputePassErrorInner::Dispatch(
654                            DispatchError::MissingPipeline,
655                        ))
656                        .map_pass_err(scope)?;
657                    let pipeline_layout = &pipeline_layout_guard[pipeline_layout_id];
658
659                    pipeline_layout
660                        .validate_push_constant_ranges(
661                            wgt::ShaderStages::COMPUTE,
662                            offset,
663                            end_offset_bytes,
664                        )
665                        .map_pass_err(scope)?;
666
667                    unsafe {
668                        raw.set_push_constants(
669                            &pipeline_layout.raw,
670                            wgt::ShaderStages::COMPUTE,
671                            offset,
672                            data_slice,
673                        );
674                    }
675                }
676                ComputeCommand::Dispatch(groups) => {
677                    let scope = PassErrorScope::Dispatch {
678                        indirect: false,
679                        pipeline: state.pipeline,
680                    };
681
682                    state
683                        .is_ready(&*bind_group_layout_guard)
684                        .map_pass_err(scope)?;
685                    state
686                        .flush_states(
687                            raw,
688                            &mut intermediate_trackers,
689                            &*bind_group_guard,
690                            &*buffer_guard,
691                            &*texture_guard,
692                            None,
693                        )
694                        .map_pass_err(scope)?;
695
696                    let groups_size_limit = cmd_buf.limits.max_compute_workgroups_per_dimension;
697
698                    if groups[0] > groups_size_limit
699                        || groups[1] > groups_size_limit
700                        || groups[2] > groups_size_limit
701                    {
702                        return Err(ComputePassErrorInner::Dispatch(
703                            DispatchError::InvalidGroupSize {
704                                current: groups,
705                                limit: groups_size_limit,
706                            },
707                        ))
708                        .map_pass_err(scope);
709                    }
710
711                    unsafe {
712                        raw.dispatch(groups);
713                    }
714                }
715                ComputeCommand::DispatchIndirect { buffer_id, offset } => {
716                    let scope = PassErrorScope::Dispatch {
717                        indirect: true,
718                        pipeline: state.pipeline,
719                    };
720
721                    state
722                        .is_ready(&*bind_group_layout_guard)
723                        .map_pass_err(scope)?;
724
725                    device
726                        .require_downlevel_flags(wgt::DownlevelFlags::INDIRECT_EXECUTION)
727                        .map_pass_err(scope)?;
728
729                    let indirect_buffer: &Buffer<A> = state
730                        .scope
731                        .buffers
732                        .merge_single(&*buffer_guard, buffer_id, hal::BufferUses::INDIRECT)
733                        .map_pass_err(scope)?;
734                    check_buffer_usage(indirect_buffer.usage, wgt::BufferUsages::INDIRECT)
735                        .map_pass_err(scope)?;
736
737                    let end_offset = offset + mem::size_of::<wgt::DispatchIndirectArgs>() as u64;
738                    if end_offset > indirect_buffer.size {
739                        return Err(ComputePassErrorInner::IndirectBufferOverrun {
740                            offset,
741                            end_offset,
742                            buffer_size: indirect_buffer.size,
743                        })
744                        .map_pass_err(scope);
745                    }
746
747                    let buf_raw = indirect_buffer
748                        .raw
749                        .as_ref()
750                        .ok_or(ComputePassErrorInner::InvalidIndirectBuffer(buffer_id))
751                        .map_pass_err(scope)?;
752
753                    let stride = 3 * 4; // 3 integers, x/y/z group size
754
755                    cmd_buf.buffer_memory_init_actions.extend(
756                        indirect_buffer.initialization_status.create_action(
757                            buffer_id,
758                            offset..(offset + stride),
759                            MemoryInitKind::NeedsInitializedMemory,
760                        ),
761                    );
762
763                    state
764                        .flush_states(
765                            raw,
766                            &mut intermediate_trackers,
767                            &*bind_group_guard,
768                            &*buffer_guard,
769                            &*texture_guard,
770                            Some(id::Valid(buffer_id)),
771                        )
772                        .map_pass_err(scope)?;
773                    unsafe {
774                        raw.dispatch_indirect(buf_raw, offset);
775                    }
776                }
777                ComputeCommand::PushDebugGroup { color: _, len } => {
778                    state.debug_scope_depth += 1;
779                    if !discard_hal_labels {
780                        let label =
781                            str::from_utf8(&base.string_data[string_offset..string_offset + len])
782                                .unwrap();
783                        unsafe {
784                            raw.begin_debug_marker(label);
785                        }
786                    }
787                    string_offset += len;
788                }
789                ComputeCommand::PopDebugGroup => {
790                    let scope = PassErrorScope::PopDebugGroup;
791
792                    if state.debug_scope_depth == 0 {
793                        return Err(ComputePassErrorInner::InvalidPopDebugGroup)
794                            .map_pass_err(scope);
795                    }
796                    state.debug_scope_depth -= 1;
797                    if !discard_hal_labels {
798                        unsafe {
799                            raw.end_debug_marker();
800                        }
801                    }
802                }
803                ComputeCommand::InsertDebugMarker { color: _, len } => {
804                    if !discard_hal_labels {
805                        let label =
806                            str::from_utf8(&base.string_data[string_offset..string_offset + len])
807                                .unwrap();
808                        unsafe { raw.insert_debug_marker(label) }
809                    }
810                    string_offset += len;
811                }
812                ComputeCommand::WriteTimestamp {
813                    query_set_id,
814                    query_index,
815                } => {
816                    let scope = PassErrorScope::WriteTimestamp;
817
818                    device
819                        .require_features(wgt::Features::TIMESTAMP_QUERY_INSIDE_PASSES)
820                        .map_pass_err(scope)?;
821
822                    let query_set: &resource::QuerySet<A> = cmd_buf
823                        .trackers
824                        .query_sets
825                        .add_single(&*query_set_guard, query_set_id)
826                        .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
827                        .map_pass_err(scope)?;
828
829                    query_set
830                        .validate_and_write_timestamp(raw, query_set_id, query_index, None)
831                        .map_pass_err(scope)?;
832                }
833                ComputeCommand::BeginPipelineStatisticsQuery {
834                    query_set_id,
835                    query_index,
836                } => {
837                    let scope = PassErrorScope::BeginPipelineStatisticsQuery;
838
839                    let query_set: &resource::QuerySet<A> = cmd_buf
840                        .trackers
841                        .query_sets
842                        .add_single(&*query_set_guard, query_set_id)
843                        .ok_or(ComputePassErrorInner::InvalidQuerySet(query_set_id))
844                        .map_pass_err(scope)?;
845
846                    query_set
847                        .validate_and_begin_pipeline_statistics_query(
848                            raw,
849                            query_set_id,
850                            query_index,
851                            None,
852                            &mut active_query,
853                        )
854                        .map_pass_err(scope)?;
855                }
856                ComputeCommand::EndPipelineStatisticsQuery => {
857                    let scope = PassErrorScope::EndPipelineStatisticsQuery;
858
859                    end_pipeline_statistics_query(raw, &*query_set_guard, &mut active_query)
860                        .map_pass_err(scope)?;
861                }
862            }
863        }
864
865        unsafe {
866            raw.end_compute_pass();
867        }
868        // We've successfully recorded the compute pass, bring the
869        // command buffer out of the error state.
870        cmd_buf.status = CommandEncoderStatus::Recording;
871
872        // Stop the current command buffer.
873        cmd_buf.encoder.close();
874
875        // Create a new command buffer, which we will insert _before_ the body of the compute pass.
876        //
877        // Use that buffer to insert barriers and clear discarded images.
878        let transit = cmd_buf.encoder.open();
879        fixup_discarded_surfaces(
880            pending_discard_init_fixups.into_iter(),
881            transit,
882            &texture_guard,
883            &mut cmd_buf.trackers.textures,
884            device,
885        );
886        CommandBuffer::insert_barriers_from_tracker(
887            transit,
888            &mut cmd_buf.trackers,
889            &intermediate_trackers,
890            &*buffer_guard,
891            &*texture_guard,
892        );
893        // Close the command buffer, and swap it with the previous.
894        cmd_buf.encoder.close_and_swap();
895
896        Ok(())
897    }
898}
899
900pub mod compute_ffi {
901    use super::{ComputeCommand, ComputePass};
902    use crate::{id, RawString};
903    use std::{convert::TryInto, ffi, slice};
904    use wgt::{BufferAddress, DynamicOffset};
905
906    /// # Safety
907    ///
908    /// This function is unsafe as there is no guarantee that the given pointer is
909    /// valid for `offset_length` elements.
910    #[no_mangle]
911    pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
912        pass: &mut ComputePass,
913        index: u32,
914        bind_group_id: id::BindGroupId,
915        offsets: *const DynamicOffset,
916        offset_length: usize,
917    ) {
918        let redundant = unsafe {
919            pass.current_bind_groups.set_and_check_redundant(
920                bind_group_id,
921                index,
922                &mut pass.base.dynamic_offsets,
923                offsets,
924                offset_length,
925            )
926        };
927
928        if redundant {
929            return;
930        }
931
932        pass.base.commands.push(ComputeCommand::SetBindGroup {
933            index,
934            num_dynamic_offsets: offset_length.try_into().unwrap(),
935            bind_group_id,
936        });
937    }
938
939    #[no_mangle]
940    pub extern "C" fn wgpu_compute_pass_set_pipeline(
941        pass: &mut ComputePass,
942        pipeline_id: id::ComputePipelineId,
943    ) {
944        if pass.current_pipeline.set_and_check_redundant(pipeline_id) {
945            return;
946        }
947
948        pass.base
949            .commands
950            .push(ComputeCommand::SetPipeline(pipeline_id));
951    }
952
953    /// # Safety
954    ///
955    /// This function is unsafe as there is no guarantee that the given pointer is
956    /// valid for `size_bytes` bytes.
957    #[no_mangle]
958    pub unsafe extern "C" fn wgpu_compute_pass_set_push_constant(
959        pass: &mut ComputePass,
960        offset: u32,
961        size_bytes: u32,
962        data: *const u8,
963    ) {
964        assert_eq!(
965            offset & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
966            0,
967            "Push constant offset must be aligned to 4 bytes."
968        );
969        assert_eq!(
970            size_bytes & (wgt::PUSH_CONSTANT_ALIGNMENT - 1),
971            0,
972            "Push constant size must be aligned to 4 bytes."
973        );
974        let data_slice = unsafe { slice::from_raw_parts(data, size_bytes as usize) };
975        let value_offset = pass.base.push_constant_data.len().try_into().expect(
976            "Ran out of push constant space. Don't set 4gb of push constants per ComputePass.",
977        );
978
979        pass.base.push_constant_data.extend(
980            data_slice
981                .chunks_exact(wgt::PUSH_CONSTANT_ALIGNMENT as usize)
982                .map(|arr| u32::from_ne_bytes([arr[0], arr[1], arr[2], arr[3]])),
983        );
984
985        pass.base.commands.push(ComputeCommand::SetPushConstant {
986            offset,
987            size_bytes,
988            values_offset: value_offset,
989        });
990    }
991
992    #[no_mangle]
993    pub extern "C" fn wgpu_compute_pass_dispatch_workgroups(
994        pass: &mut ComputePass,
995        groups_x: u32,
996        groups_y: u32,
997        groups_z: u32,
998    ) {
999        pass.base
1000            .commands
1001            .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
1002    }
1003
1004    #[no_mangle]
1005    pub extern "C" fn wgpu_compute_pass_dispatch_workgroups_indirect(
1006        pass: &mut ComputePass,
1007        buffer_id: id::BufferId,
1008        offset: BufferAddress,
1009    ) {
1010        pass.base
1011            .commands
1012            .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
1013    }
1014
1015    /// # Safety
1016    ///
1017    /// This function is unsafe as there is no guarantee that the given `label`
1018    /// is a valid null-terminated string.
1019    #[no_mangle]
1020    pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
1021        pass: &mut ComputePass,
1022        label: RawString,
1023        color: u32,
1024    ) {
1025        let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
1026        pass.base.string_data.extend_from_slice(bytes);
1027
1028        pass.base.commands.push(ComputeCommand::PushDebugGroup {
1029            color,
1030            len: bytes.len(),
1031        });
1032    }
1033
1034    #[no_mangle]
1035    pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
1036        pass.base.commands.push(ComputeCommand::PopDebugGroup);
1037    }
1038
1039    /// # Safety
1040    ///
1041    /// This function is unsafe as there is no guarantee that the given `label`
1042    /// is a valid null-terminated string.
1043    #[no_mangle]
1044    pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
1045        pass: &mut ComputePass,
1046        label: RawString,
1047        color: u32,
1048    ) {
1049        let bytes = unsafe { ffi::CStr::from_ptr(label) }.to_bytes();
1050        pass.base.string_data.extend_from_slice(bytes);
1051
1052        pass.base.commands.push(ComputeCommand::InsertDebugMarker {
1053            color,
1054            len: bytes.len(),
1055        });
1056    }
1057
1058    #[no_mangle]
1059    pub extern "C" fn wgpu_compute_pass_write_timestamp(
1060        pass: &mut ComputePass,
1061        query_set_id: id::QuerySetId,
1062        query_index: u32,
1063    ) {
1064        pass.base.commands.push(ComputeCommand::WriteTimestamp {
1065            query_set_id,
1066            query_index,
1067        });
1068    }
1069
1070    #[no_mangle]
1071    pub extern "C" fn wgpu_compute_pass_begin_pipeline_statistics_query(
1072        pass: &mut ComputePass,
1073        query_set_id: id::QuerySetId,
1074        query_index: u32,
1075    ) {
1076        pass.base
1077            .commands
1078            .push(ComputeCommand::BeginPipelineStatisticsQuery {
1079                query_set_id,
1080                query_index,
1081            });
1082    }
1083
1084    #[no_mangle]
1085    pub extern "C" fn wgpu_compute_pass_end_pipeline_statistics_query(pass: &mut ComputePass) {
1086        pass.base
1087            .commands
1088            .push(ComputeCommand::EndPipelineStatisticsQuery);
1089    }
1090}