cart_tmp_wgc/command/
compute.rs

1/* This Source Code Form is subject to the terms of the Mozilla Public
2 * License, v. 2.0. If a copy of the MPL was not distributed with this
3 * file, You can obtain one at http://mozilla.org/MPL/2.0/. */
4
5use crate::{
6    command::{
7        bind::{Binder, LayoutChange},
8        BasePass, BasePassRef, CommandBuffer,
9    },
10    device::all_buffer_stages,
11    hub::{GfxBackend, Global, GlobalIdentityHandlerFactory, Token},
12    id,
13    resource::BufferUse,
14    span,
15};
16
17use hal::command::CommandBuffer as _;
18use wgt::{BufferAddress, BufferUsage};
19
20use std::{fmt, iter, str};
21
22#[doc(hidden)]
23#[derive(Clone, Copy, Debug)]
24#[cfg_attr(
25    any(feature = "serial-pass", feature = "trace"),
26    derive(serde::Serialize)
27)]
28#[cfg_attr(
29    any(feature = "serial-pass", feature = "replay"),
30    derive(serde::Deserialize)
31)]
32pub enum ComputeCommand {
33    SetBindGroup {
34        index: u8,
35        num_dynamic_offsets: u8,
36        bind_group_id: id::BindGroupId,
37    },
38    SetPipeline(id::ComputePipelineId),
39    Dispatch([u32; 3]),
40    DispatchIndirect {
41        buffer_id: id::BufferId,
42        offset: BufferAddress,
43    },
44    PushDebugGroup {
45        color: u32,
46        len: usize,
47    },
48    PopDebugGroup,
49    InsertDebugMarker {
50        color: u32,
51        len: usize,
52    },
53}
54
55#[cfg_attr(feature = "serial-pass", derive(serde::Deserialize, serde::Serialize))]
56pub struct ComputePass {
57    base: BasePass<ComputeCommand>,
58    parent_id: id::CommandEncoderId,
59}
60
61impl ComputePass {
62    pub fn new(parent_id: id::CommandEncoderId) -> Self {
63        ComputePass {
64            base: BasePass::new(),
65            parent_id,
66        }
67    }
68
69    pub fn parent_id(&self) -> id::CommandEncoderId {
70        self.parent_id
71    }
72}
73
74impl fmt::Debug for ComputePass {
75    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
76        write!(
77            f,
78            "ComputePass {{ encoder_id: {:?}, data: {:?} commands and {:?} dynamic offsets }}",
79            self.parent_id,
80            self.base.commands.len(),
81            self.base.dynamic_offsets.len()
82        )
83    }
84}
85
86#[repr(C)]
87#[derive(Clone, Debug, Default)]
88pub struct ComputePassDescriptor {
89    pub todo: u32,
90}
91
92#[derive(Debug, PartialEq)]
93enum PipelineState {
94    Required,
95    Set,
96}
97
98#[derive(Debug)]
99struct State {
100    binder: Binder,
101    pipeline: PipelineState,
102    debug_scope_depth: u32,
103}
104
105// Common routines between render/compute
106
107impl<G: GlobalIdentityHandlerFactory> Global<G> {
108    pub fn command_encoder_run_compute_pass<B: GfxBackend>(
109        &self,
110        encoder_id: id::CommandEncoderId,
111        pass: &ComputePass,
112    ) {
113        self.command_encoder_run_compute_pass_impl::<B>(encoder_id, pass.base.as_ref())
114    }
115
116    #[doc(hidden)]
117    pub fn command_encoder_run_compute_pass_impl<B: GfxBackend>(
118        &self,
119        encoder_id: id::CommandEncoderId,
120        mut base: BasePassRef<ComputeCommand>,
121    ) {
122        span!(_guard, INFO, "CommandEncoder::run_compute_pass");
123        let hub = B::hub(self);
124        let mut token = Token::root();
125
126        let (mut cmb_guard, mut token) = hub.command_buffers.write(&mut token);
127        let cmb = &mut cmb_guard[encoder_id];
128        let raw = cmb.raw.last_mut().unwrap();
129
130        #[cfg(feature = "trace")]
131        match cmb.commands {
132            Some(ref mut list) => {
133                list.push(crate::device::trace::Command::RunComputePass {
134                    base: BasePass::from_ref(base),
135                });
136            }
137            None => {}
138        }
139
140        let (_, mut token) = hub.render_bundles.read(&mut token);
141        let (pipeline_layout_guard, mut token) = hub.pipeline_layouts.read(&mut token);
142        let (bind_group_guard, mut token) = hub.bind_groups.read(&mut token);
143        let (pipeline_guard, mut token) = hub.compute_pipelines.read(&mut token);
144        let (buffer_guard, mut token) = hub.buffers.read(&mut token);
145        let (texture_guard, _) = hub.textures.read(&mut token);
146
147        let mut state = State {
148            binder: Binder::new(cmb.limits.max_bind_groups),
149            pipeline: PipelineState::Required,
150            debug_scope_depth: 0,
151        };
152
153        for command in base.commands {
154            match *command {
155                ComputeCommand::SetBindGroup {
156                    index,
157                    num_dynamic_offsets,
158                    bind_group_id,
159                } => {
160                    assert!(
161                        (index as u32) < cmb.limits.max_bind_groups,
162                        "Bind group index {0} is out of range 0..{1} provided by requested max_bind_group limit {1}",
163                        index,
164                        cmb.limits.max_bind_groups
165                    );
166
167                    let offsets = &base.dynamic_offsets[..num_dynamic_offsets as usize];
168                    base.dynamic_offsets = &base.dynamic_offsets[num_dynamic_offsets as usize..];
169
170                    let bind_group = cmb
171                        .trackers
172                        .bind_groups
173                        .use_extend(&*bind_group_guard, bind_group_id, (), ())
174                        .unwrap();
175                    bind_group.validate_dynamic_bindings(offsets).unwrap();
176
177                    log::trace!(
178                        "Encoding barriers on binding of {:?} to {:?}",
179                        bind_group_id,
180                        encoder_id
181                    );
182                    CommandBuffer::insert_barriers(
183                        raw,
184                        &mut cmb.trackers,
185                        &bind_group.used,
186                        &*buffer_guard,
187                        &*texture_guard,
188                    );
189
190                    if let Some((pipeline_layout_id, follow_ups)) = state.binder.provide_entry(
191                        index as usize,
192                        bind_group_id,
193                        bind_group,
194                        offsets,
195                    ) {
196                        let bind_groups = iter::once(bind_group.raw.raw()).chain(
197                            follow_ups
198                                .clone()
199                                .map(|(bg_id, _)| bind_group_guard[bg_id].raw.raw()),
200                        );
201                        unsafe {
202                            raw.bind_compute_descriptor_sets(
203                                &pipeline_layout_guard[pipeline_layout_id].raw,
204                                index as usize,
205                                bind_groups,
206                                offsets
207                                    .iter()
208                                    .chain(follow_ups.flat_map(|(_, offsets)| offsets))
209                                    .cloned(),
210                            );
211                        }
212                    }
213                }
214                ComputeCommand::SetPipeline(pipeline_id) => {
215                    state.pipeline = PipelineState::Set;
216                    let pipeline = cmb
217                        .trackers
218                        .compute_pipes
219                        .use_extend(&*pipeline_guard, pipeline_id, (), ())
220                        .unwrap();
221
222                    unsafe {
223                        raw.bind_compute_pipeline(&pipeline.raw);
224                    }
225
226                    // Rebind resources
227                    if state.binder.pipeline_layout_id != Some(pipeline.layout_id.value) {
228                        let pipeline_layout = &pipeline_layout_guard[pipeline.layout_id.value];
229                        state.binder.pipeline_layout_id = Some(pipeline.layout_id.value);
230                        state
231                            .binder
232                            .reset_expectations(pipeline_layout.bind_group_layout_ids.len());
233                        let mut is_compatible = true;
234
235                        for (index, (entry, bgl_id)) in state
236                            .binder
237                            .entries
238                            .iter_mut()
239                            .zip(&pipeline_layout.bind_group_layout_ids)
240                            .enumerate()
241                        {
242                            match entry.expect_layout(bgl_id.value) {
243                                LayoutChange::Match(bg_id, offsets) if is_compatible => {
244                                    let desc_set = bind_group_guard[bg_id].raw.raw();
245                                    unsafe {
246                                        raw.bind_compute_descriptor_sets(
247                                            &pipeline_layout.raw,
248                                            index,
249                                            iter::once(desc_set),
250                                            offsets.iter().cloned(),
251                                        );
252                                    }
253                                }
254                                LayoutChange::Match(..) | LayoutChange::Unchanged => {}
255                                LayoutChange::Mismatch => {
256                                    is_compatible = false;
257                                }
258                            }
259                        }
260                    }
261                }
262                ComputeCommand::Dispatch(groups) => {
263                    assert_eq!(
264                        state.pipeline,
265                        PipelineState::Set,
266                        "Dispatch DEBUG: Pipeline is missing"
267                    );
268                    unsafe {
269                        raw.dispatch(groups);
270                    }
271                }
272                ComputeCommand::DispatchIndirect { buffer_id, offset } => {
273                    assert_eq!(
274                        state.pipeline,
275                        PipelineState::Set,
276                        "Dispatch DEBUG: Pipeline is missing"
277                    );
278                    let (src_buffer, src_pending) = cmb.trackers.buffers.use_replace(
279                        &*buffer_guard,
280                        buffer_id,
281                        (),
282                        BufferUse::INDIRECT,
283                    );
284                    assert!(src_buffer.usage.contains(BufferUsage::INDIRECT));
285
286                    let barriers = src_pending.map(|pending| pending.into_hal(src_buffer));
287
288                    unsafe {
289                        raw.pipeline_barrier(
290                            all_buffer_stages()..all_buffer_stages(),
291                            hal::memory::Dependencies::empty(),
292                            barriers,
293                        );
294                        raw.dispatch_indirect(&src_buffer.raw, offset);
295                    }
296                }
297                ComputeCommand::PushDebugGroup { color, len } => {
298                    state.debug_scope_depth += 1;
299
300                    let label = str::from_utf8(&base.string_data[..len]).unwrap();
301                    unsafe {
302                        raw.begin_debug_marker(label, color);
303                    }
304                    base.string_data = &base.string_data[len..];
305                }
306                ComputeCommand::PopDebugGroup => {
307                    assert_ne!(
308                        state.debug_scope_depth, 0,
309                        "Can't pop debug group, because number of pushed debug groups is zero!"
310                    );
311                    state.debug_scope_depth -= 1;
312                    unsafe {
313                        raw.end_debug_marker();
314                    }
315                }
316                ComputeCommand::InsertDebugMarker { color, len } => {
317                    let label = str::from_utf8(&base.string_data[..len]).unwrap();
318                    unsafe { raw.insert_debug_marker(label, color) }
319                    base.string_data = &base.string_data[len..];
320                }
321            }
322        }
323    }
324}
325
326pub mod compute_ffi {
327    use super::{ComputeCommand, ComputePass};
328    use crate::{id, span, RawString};
329    use std::{convert::TryInto, ffi, slice};
330    use wgt::{BufferAddress, DynamicOffset};
331
332    /// # Safety
333    ///
334    /// This function is unsafe as there is no guarantee that the given pointer is
335    /// valid for `offset_length` elements.
336    // TODO: There might be other safety issues, such as using the unsafe
337    // `RawPass::encode` and `RawPass::encode_slice`.
338    #[no_mangle]
339    pub unsafe extern "C" fn wgpu_compute_pass_set_bind_group(
340        pass: &mut ComputePass,
341        index: u32,
342        bind_group_id: id::BindGroupId,
343        offsets: *const DynamicOffset,
344        offset_length: usize,
345    ) {
346        span!(_guard, DEBUG, "ComputePass::set_bind_group");
347        pass.base.commands.push(ComputeCommand::SetBindGroup {
348            index: index.try_into().unwrap(),
349            num_dynamic_offsets: offset_length.try_into().unwrap(),
350            bind_group_id,
351        });
352        pass.base
353            .dynamic_offsets
354            .extend_from_slice(slice::from_raw_parts(offsets, offset_length));
355    }
356
357    #[no_mangle]
358    pub extern "C" fn wgpu_compute_pass_set_pipeline(
359        pass: &mut ComputePass,
360        pipeline_id: id::ComputePipelineId,
361    ) {
362        span!(_guard, DEBUG, "ComputePass::set_pipeline");
363        pass.base
364            .commands
365            .push(ComputeCommand::SetPipeline(pipeline_id));
366    }
367
368    #[no_mangle]
369    pub extern "C" fn wgpu_compute_pass_dispatch(
370        pass: &mut ComputePass,
371        groups_x: u32,
372        groups_y: u32,
373        groups_z: u32,
374    ) {
375        span!(_guard, DEBUG, "ComputePass::dispatch");
376        pass.base
377            .commands
378            .push(ComputeCommand::Dispatch([groups_x, groups_y, groups_z]));
379    }
380
381    #[no_mangle]
382    pub extern "C" fn wgpu_compute_pass_dispatch_indirect(
383        pass: &mut ComputePass,
384        buffer_id: id::BufferId,
385        offset: BufferAddress,
386    ) {
387        span!(_guard, DEBUG, "ComputePass::dispatch_indirect");
388        pass.base
389            .commands
390            .push(ComputeCommand::DispatchIndirect { buffer_id, offset });
391    }
392
393    #[no_mangle]
394    pub unsafe extern "C" fn wgpu_compute_pass_push_debug_group(
395        pass: &mut ComputePass,
396        label: RawString,
397        color: u32,
398    ) {
399        span!(_guard, DEBUG, "ComputePass::push_debug_group");
400        let bytes = ffi::CStr::from_ptr(label).to_bytes();
401        pass.base.string_data.extend_from_slice(bytes);
402
403        pass.base.commands.push(ComputeCommand::PushDebugGroup {
404            color,
405            len: bytes.len(),
406        });
407    }
408
409    #[no_mangle]
410    pub extern "C" fn wgpu_compute_pass_pop_debug_group(pass: &mut ComputePass) {
411        span!(_guard, DEBUG, "ComputePass::pop_debug_group");
412        pass.base.commands.push(ComputeCommand::PopDebugGroup);
413    }
414
415    #[no_mangle]
416    pub unsafe extern "C" fn wgpu_compute_pass_insert_debug_marker(
417        pass: &mut ComputePass,
418        label: RawString,
419        color: u32,
420    ) {
421        span!(_guard, DEBUG, "ComputePass::insert_debug_marker");
422        let bytes = ffi::CStr::from_ptr(label).to_bytes();
423        pass.base.string_data.extend_from_slice(bytes);
424
425        pass.base.commands.push(ComputeCommand::InsertDebugMarker {
426            color,
427            len: bytes.len(),
428        });
429    }
430}