li_wgpu_core/command/
mod.rs

1mod bind;
2mod bundle;
3mod clear;
4mod compute;
5mod draw;
6mod memory_init;
7mod query;
8mod render;
9mod transfer;
10
11use std::slice;
12
13pub(crate) use self::clear::clear_texture;
14pub use self::{
15    bundle::*, clear::ClearError, compute::*, draw::*, query::*, render::*, transfer::*,
16};
17
18use self::memory_init::CommandBufferTextureMemoryActions;
19
20use crate::error::{ErrorFormatter, PrettyError};
21use crate::init_tracker::BufferInitTrackerAction;
22use crate::track::{Tracker, UsageScope};
23use crate::{
24    global::Global,
25    hal_api::HalApi,
26    hub::Token,
27    id,
28    identity::GlobalIdentityHandlerFactory,
29    resource::{Buffer, Texture},
30    storage::Storage,
31    Label, Stored,
32};
33
34use hal::CommandEncoder as _;
35use thiserror::Error;
36
37#[cfg(feature = "trace")]
38use crate::device::trace::Command as TraceCommand;
39
40const PUSH_CONSTANT_CLEAR_ARRAY: &[u32] = &[0_u32; 64];
41
42#[derive(Debug)]
43enum CommandEncoderStatus {
44    Recording,
45    Finished,
46    Error,
47}
48
49struct CommandEncoder<A: hal::Api> {
50    raw: A::CommandEncoder,
51    list: Vec<A::CommandBuffer>,
52    is_open: bool,
53    label: Option<String>,
54}
55
56//TODO: handle errors better
57impl<A: hal::Api> CommandEncoder<A> {
58    /// Closes the live encoder
59    fn close_and_swap(&mut self) {
60        if self.is_open {
61            self.is_open = false;
62            let new = unsafe { self.raw.end_encoding().unwrap() };
63            self.list.insert(self.list.len() - 1, new);
64        }
65    }
66
67    fn close(&mut self) {
68        if self.is_open {
69            self.is_open = false;
70            let cmd_buf = unsafe { self.raw.end_encoding().unwrap() };
71            self.list.push(cmd_buf);
72        }
73    }
74
75    fn discard(&mut self) {
76        if self.is_open {
77            self.is_open = false;
78            unsafe { self.raw.discard_encoding() };
79        }
80    }
81
82    fn open(&mut self) -> &mut A::CommandEncoder {
83        if !self.is_open {
84            self.is_open = true;
85            let label = self.label.as_deref();
86            unsafe { self.raw.begin_encoding(label).unwrap() };
87        }
88        &mut self.raw
89    }
90
91    fn open_pass(&mut self, label: Option<&str>) {
92        self.is_open = true;
93        unsafe { self.raw.begin_encoding(label).unwrap() };
94    }
95}
96
97pub struct BakedCommands<A: HalApi> {
98    pub(crate) encoder: A::CommandEncoder,
99    pub(crate) list: Vec<A::CommandBuffer>,
100    pub(crate) trackers: Tracker<A>,
101    buffer_memory_init_actions: Vec<BufferInitTrackerAction>,
102    texture_memory_actions: CommandBufferTextureMemoryActions,
103}
104
105pub(crate) struct DestroyedBufferError(pub id::BufferId);
106pub(crate) struct DestroyedTextureError(pub id::TextureId);
107
108pub struct CommandBuffer<A: HalApi> {
109    encoder: CommandEncoder<A>,
110    status: CommandEncoderStatus,
111    pub(crate) device_id: Stored<id::DeviceId>,
112    pub(crate) trackers: Tracker<A>,
113    buffer_memory_init_actions: Vec<BufferInitTrackerAction>,
114    texture_memory_actions: CommandBufferTextureMemoryActions,
115    pub(crate) pending_query_resets: QueryResetMap<A>,
116    limits: wgt::Limits,
117    support_clear_texture: bool,
118    #[cfg(feature = "trace")]
119    pub(crate) commands: Option<Vec<TraceCommand>>,
120}
121
122impl<A: HalApi> CommandBuffer<A> {
123    pub(crate) fn new(
124        encoder: A::CommandEncoder,
125        device_id: Stored<id::DeviceId>,
126        limits: wgt::Limits,
127        _downlevel: wgt::DownlevelCapabilities,
128        features: wgt::Features,
129        #[cfg(feature = "trace")] enable_tracing: bool,
130        label: Option<String>,
131    ) -> Self {
132        CommandBuffer {
133            encoder: CommandEncoder {
134                raw: encoder,
135                is_open: false,
136                list: Vec::new(),
137                label,
138            },
139            status: CommandEncoderStatus::Recording,
140            device_id,
141            trackers: Tracker::new(),
142            buffer_memory_init_actions: Default::default(),
143            texture_memory_actions: Default::default(),
144            pending_query_resets: QueryResetMap::new(),
145            limits,
146            support_clear_texture: features.contains(wgt::Features::CLEAR_TEXTURE),
147            #[cfg(feature = "trace")]
148            commands: if enable_tracing {
149                Some(Vec::new())
150            } else {
151                None
152            },
153        }
154    }
155
156    pub(crate) fn insert_barriers_from_tracker(
157        raw: &mut A::CommandEncoder,
158        base: &mut Tracker<A>,
159        head: &Tracker<A>,
160        buffer_guard: &Storage<Buffer<A>, id::BufferId>,
161        texture_guard: &Storage<Texture<A>, id::TextureId>,
162    ) {
163        profiling::scope!("insert_barriers");
164
165        base.buffers.set_from_tracker(&head.buffers);
166        base.textures
167            .set_from_tracker(texture_guard, &head.textures);
168
169        Self::drain_barriers(raw, base, buffer_guard, texture_guard);
170    }
171
172    pub(crate) fn insert_barriers_from_scope(
173        raw: &mut A::CommandEncoder,
174        base: &mut Tracker<A>,
175        head: &UsageScope<A>,
176        buffer_guard: &Storage<Buffer<A>, id::BufferId>,
177        texture_guard: &Storage<Texture<A>, id::TextureId>,
178    ) {
179        profiling::scope!("insert_barriers");
180
181        base.buffers.set_from_usage_scope(&head.buffers);
182        base.textures
183            .set_from_usage_scope(texture_guard, &head.textures);
184
185        Self::drain_barriers(raw, base, buffer_guard, texture_guard);
186    }
187
188    pub(crate) fn drain_barriers(
189        raw: &mut A::CommandEncoder,
190        base: &mut Tracker<A>,
191        buffer_guard: &Storage<Buffer<A>, id::BufferId>,
192        texture_guard: &Storage<Texture<A>, id::TextureId>,
193    ) {
194        profiling::scope!("drain_barriers");
195
196        let buffer_barriers = base.buffers.drain().map(|pending| {
197            let buf = unsafe { &buffer_guard.get_unchecked(pending.id) };
198            pending.into_hal(buf)
199        });
200        let texture_barriers = base.textures.drain().map(|pending| {
201            let tex = unsafe { texture_guard.get_unchecked(pending.id) };
202            pending.into_hal(tex)
203        });
204
205        unsafe {
206            raw.transition_buffers(buffer_barriers);
207            raw.transition_textures(texture_barriers);
208        }
209    }
210}
211
212impl<A: HalApi> CommandBuffer<A> {
213    fn get_encoder_mut(
214        storage: &mut Storage<Self, id::CommandEncoderId>,
215        id: id::CommandEncoderId,
216    ) -> Result<&mut Self, CommandEncoderError> {
217        match storage.get_mut(id) {
218            Ok(cmd_buf) => match cmd_buf.status {
219                CommandEncoderStatus::Recording => Ok(cmd_buf),
220                CommandEncoderStatus::Finished => Err(CommandEncoderError::NotRecording),
221                CommandEncoderStatus::Error => Err(CommandEncoderError::Invalid),
222            },
223            Err(_) => Err(CommandEncoderError::Invalid),
224        }
225    }
226
227    pub fn is_finished(&self) -> bool {
228        match self.status {
229            CommandEncoderStatus::Finished => true,
230            _ => false,
231        }
232    }
233
234    pub(crate) fn into_baked(self) -> BakedCommands<A> {
235        BakedCommands {
236            encoder: self.encoder.raw,
237            list: self.encoder.list,
238            trackers: self.trackers,
239            buffer_memory_init_actions: self.buffer_memory_init_actions,
240            texture_memory_actions: self.texture_memory_actions,
241        }
242    }
243}
244
245impl<A: HalApi> crate::resource::Resource for CommandBuffer<A> {
246    const TYPE: &'static str = "CommandBuffer";
247
248    fn life_guard(&self) -> &crate::LifeGuard {
249        unreachable!()
250    }
251
252    fn label(&self) -> &str {
253        self.encoder.label.as_ref().map_or("", |s| s.as_str())
254    }
255}
256
257#[derive(Copy, Clone, Debug)]
258pub struct BasePassRef<'a, C> {
259    pub label: Option<&'a str>,
260    pub commands: &'a [C],
261    pub dynamic_offsets: &'a [wgt::DynamicOffset],
262    pub string_data: &'a [u8],
263    pub push_constant_data: &'a [u32],
264}
265
266/// A stream of commands for a render pass or compute pass.
267///
268/// This also contains side tables referred to by certain commands,
269/// like dynamic offsets for [`SetBindGroup`] or string data for
270/// [`InsertDebugMarker`].
271///
272/// Render passes use `BasePass<RenderCommand>`, whereas compute
273/// passes use `BasePass<ComputeCommand>`.
274///
275/// [`SetBindGroup`]: RenderCommand::SetBindGroup
276/// [`InsertDebugMarker`]: RenderCommand::InsertDebugMarker
277#[doc(hidden)]
278#[derive(Debug)]
279#[cfg_attr(
280    any(feature = "serial-pass", feature = "trace"),
281    derive(serde::Serialize)
282)]
283#[cfg_attr(
284    any(feature = "serial-pass", feature = "replay"),
285    derive(serde::Deserialize)
286)]
287pub struct BasePass<C> {
288    pub label: Option<String>,
289
290    /// The stream of commands.
291    pub commands: Vec<C>,
292
293    /// Dynamic offsets consumed by [`SetBindGroup`] commands in `commands`.
294    ///
295    /// Each successive `SetBindGroup` consumes the next
296    /// [`num_dynamic_offsets`] values from this list.
297    pub dynamic_offsets: Vec<wgt::DynamicOffset>,
298
299    /// Strings used by debug instructions.
300    ///
301    /// Each successive [`PushDebugGroup`] or [`InsertDebugMarker`]
302    /// instruction consumes the next `len` bytes from this vector.
303    pub string_data: Vec<u8>,
304
305    /// Data used by `SetPushConstant` instructions.
306    ///
307    /// See the documentation for [`RenderCommand::SetPushConstant`]
308    /// and [`ComputeCommand::SetPushConstant`] for details.
309    pub push_constant_data: Vec<u32>,
310}
311
312impl<C: Clone> BasePass<C> {
313    fn new(label: &Label) -> Self {
314        Self {
315            label: label.as_ref().map(|cow| cow.to_string()),
316            commands: Vec::new(),
317            dynamic_offsets: Vec::new(),
318            string_data: Vec::new(),
319            push_constant_data: Vec::new(),
320        }
321    }
322
323    #[cfg(feature = "trace")]
324    fn from_ref(base: BasePassRef<C>) -> Self {
325        Self {
326            label: base.label.map(str::to_string),
327            commands: base.commands.to_vec(),
328            dynamic_offsets: base.dynamic_offsets.to_vec(),
329            string_data: base.string_data.to_vec(),
330            push_constant_data: base.push_constant_data.to_vec(),
331        }
332    }
333
334    pub fn as_ref(&self) -> BasePassRef<C> {
335        BasePassRef {
336            label: self.label.as_deref(),
337            commands: &self.commands,
338            dynamic_offsets: &self.dynamic_offsets,
339            string_data: &self.string_data,
340            push_constant_data: &self.push_constant_data,
341        }
342    }
343}
344
345#[derive(Clone, Debug, Error)]
346#[non_exhaustive]
347pub enum CommandEncoderError {
348    #[error("Command encoder is invalid")]
349    Invalid,
350    #[error("Command encoder must be active")]
351    NotRecording,
352}
353
354impl<G: GlobalIdentityHandlerFactory> Global<G> {
355    pub fn command_encoder_finish<A: HalApi>(
356        &self,
357        encoder_id: id::CommandEncoderId,
358        _desc: &wgt::CommandBufferDescriptor<Label>,
359    ) -> (id::CommandBufferId, Option<CommandEncoderError>) {
360        profiling::scope!("CommandEncoder::finish");
361
362        let hub = A::hub(self);
363        let mut token = Token::root();
364        let (mut cmd_buf_guard, _) = hub.command_buffers.write(&mut token);
365
366        let error = match cmd_buf_guard.get_mut(encoder_id) {
367            Ok(cmd_buf) => match cmd_buf.status {
368                CommandEncoderStatus::Recording => {
369                    cmd_buf.encoder.close();
370                    cmd_buf.status = CommandEncoderStatus::Finished;
371                    //Note: if we want to stop tracking the swapchain texture view,
372                    // this is the place to do it.
373                    log::trace!("Command buffer {:?}", encoder_id);
374                    None
375                }
376                CommandEncoderStatus::Finished => Some(CommandEncoderError::NotRecording),
377                CommandEncoderStatus::Error => {
378                    cmd_buf.encoder.discard();
379                    Some(CommandEncoderError::Invalid)
380                }
381            },
382            Err(_) => Some(CommandEncoderError::Invalid),
383        };
384
385        (encoder_id, error)
386    }
387
388    pub fn command_encoder_push_debug_group<A: HalApi>(
389        &self,
390        encoder_id: id::CommandEncoderId,
391        label: &str,
392    ) -> Result<(), CommandEncoderError> {
393        profiling::scope!("CommandEncoder::push_debug_group");
394        log::trace!("CommandEncoder::push_debug_group {label}");
395
396        let hub = A::hub(self);
397        let mut token = Token::root();
398
399        let (mut cmd_buf_guard, _) = hub.command_buffers.write(&mut token);
400        let cmd_buf = CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)?;
401
402        #[cfg(feature = "trace")]
403        if let Some(ref mut list) = cmd_buf.commands {
404            list.push(TraceCommand::PushDebugGroup(label.to_string()));
405        }
406
407        let cmd_buf_raw = cmd_buf.encoder.open();
408        if !self
409            .instance
410            .flags
411            .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
412        {
413            unsafe {
414                cmd_buf_raw.begin_debug_marker(label);
415            }
416        }
417        Ok(())
418    }
419
420    pub fn command_encoder_insert_debug_marker<A: HalApi>(
421        &self,
422        encoder_id: id::CommandEncoderId,
423        label: &str,
424    ) -> Result<(), CommandEncoderError> {
425        profiling::scope!("CommandEncoder::insert_debug_marker");
426        log::trace!("CommandEncoder::insert_debug_marker {label}");
427
428        let hub = A::hub(self);
429        let mut token = Token::root();
430
431        let (mut cmd_buf_guard, _) = hub.command_buffers.write(&mut token);
432        let cmd_buf = CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)?;
433
434        #[cfg(feature = "trace")]
435        if let Some(ref mut list) = cmd_buf.commands {
436            list.push(TraceCommand::InsertDebugMarker(label.to_string()));
437        }
438
439        if !self
440            .instance
441            .flags
442            .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
443        {
444            let cmd_buf_raw = cmd_buf.encoder.open();
445            unsafe {
446                cmd_buf_raw.insert_debug_marker(label);
447            }
448        }
449        Ok(())
450    }
451
452    pub fn command_encoder_pop_debug_group<A: HalApi>(
453        &self,
454        encoder_id: id::CommandEncoderId,
455    ) -> Result<(), CommandEncoderError> {
456        profiling::scope!("CommandEncoder::pop_debug_marker");
457        log::trace!("CommandEncoder::pop_debug_group");
458
459        let hub = A::hub(self);
460        let mut token = Token::root();
461
462        let (mut cmd_buf_guard, _) = hub.command_buffers.write(&mut token);
463        let cmd_buf = CommandBuffer::get_encoder_mut(&mut *cmd_buf_guard, encoder_id)?;
464
465        #[cfg(feature = "trace")]
466        if let Some(ref mut list) = cmd_buf.commands {
467            list.push(TraceCommand::PopDebugGroup);
468        }
469
470        let cmd_buf_raw = cmd_buf.encoder.open();
471        if !self
472            .instance
473            .flags
474            .contains(wgt::InstanceFlags::DISCARD_HAL_LABELS)
475        {
476            unsafe {
477                cmd_buf_raw.end_debug_marker();
478            }
479        }
480        Ok(())
481    }
482}
483
484fn push_constant_clear<PushFn>(offset: u32, size_bytes: u32, mut push_fn: PushFn)
485where
486    PushFn: FnMut(u32, &[u32]),
487{
488    let mut count_words = 0_u32;
489    let size_words = size_bytes / wgt::PUSH_CONSTANT_ALIGNMENT;
490    while count_words < size_words {
491        let count_bytes = count_words * wgt::PUSH_CONSTANT_ALIGNMENT;
492        let size_to_write_words =
493            (size_words - count_words).min(PUSH_CONSTANT_CLEAR_ARRAY.len() as u32);
494
495        push_fn(
496            offset + count_bytes,
497            &PUSH_CONSTANT_CLEAR_ARRAY[0..size_to_write_words as usize],
498        );
499
500        count_words += size_to_write_words;
501    }
502}
503
504#[derive(Debug, Copy, Clone)]
505struct StateChange<T> {
506    last_state: Option<T>,
507}
508
509impl<T: Copy + PartialEq> StateChange<T> {
510    fn new() -> Self {
511        Self { last_state: None }
512    }
513    fn set_and_check_redundant(&mut self, new_state: T) -> bool {
514        let already_set = self.last_state == Some(new_state);
515        self.last_state = Some(new_state);
516        already_set
517    }
518    fn reset(&mut self) {
519        self.last_state = None;
520    }
521}
522
523impl<T: Copy + PartialEq> Default for StateChange<T> {
524    fn default() -> Self {
525        Self::new()
526    }
527}
528
529#[derive(Debug)]
530struct BindGroupStateChange {
531    last_states: [StateChange<id::BindGroupId>; hal::MAX_BIND_GROUPS],
532}
533
534impl BindGroupStateChange {
535    fn new() -> Self {
536        Self {
537            last_states: [StateChange::new(); hal::MAX_BIND_GROUPS],
538        }
539    }
540
541    unsafe fn set_and_check_redundant(
542        &mut self,
543        bind_group_id: id::BindGroupId,
544        index: u32,
545        dynamic_offsets: &mut Vec<u32>,
546        offsets: *const wgt::DynamicOffset,
547        offset_length: usize,
548    ) -> bool {
549        // For now never deduplicate bind groups with dynamic offsets.
550        if offset_length == 0 {
551            // If this get returns None, that means we're well over the limit,
552            // so let the call through to get a proper error
553            if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
554                // Bail out if we're binding the same bind group.
555                if current_bind_group.set_and_check_redundant(bind_group_id) {
556                    return true;
557                }
558            }
559        } else {
560            // We intentionally remove the memory of this bind group if we have dynamic offsets,
561            // such that if you try to bind this bind group later with _no_ dynamic offsets it
562            // tries to bind it again and gives a proper validation error.
563            if let Some(current_bind_group) = self.last_states.get_mut(index as usize) {
564                current_bind_group.reset();
565            }
566            dynamic_offsets
567                .extend_from_slice(unsafe { slice::from_raw_parts(offsets, offset_length) });
568        }
569        false
570    }
571    fn reset(&mut self) {
572        self.last_states = [StateChange::new(); hal::MAX_BIND_GROUPS];
573    }
574}
575
576impl Default for BindGroupStateChange {
577    fn default() -> Self {
578        Self::new()
579    }
580}
581
582trait MapPassErr<T, O> {
583    fn map_pass_err(self, scope: PassErrorScope) -> Result<T, O>;
584}
585
586#[derive(Clone, Copy, Debug, Error)]
587pub enum PassErrorScope {
588    #[error("In a bundle parameter")]
589    Bundle,
590    #[error("In a pass parameter")]
591    Pass(id::CommandEncoderId),
592    #[error("In a set_bind_group command")]
593    SetBindGroup(id::BindGroupId),
594    #[error("In a set_pipeline command")]
595    SetPipelineRender(id::RenderPipelineId),
596    #[error("In a set_pipeline command")]
597    SetPipelineCompute(id::ComputePipelineId),
598    #[error("In a set_push_constant command")]
599    SetPushConstant,
600    #[error("In a set_vertex_buffer command")]
601    SetVertexBuffer(id::BufferId),
602    #[error("In a set_index_buffer command")]
603    SetIndexBuffer(id::BufferId),
604    #[error("In a set_viewport command")]
605    SetViewport,
606    #[error("In a set_scissor_rect command")]
607    SetScissorRect,
608    #[error("In a draw command, indexed:{indexed} indirect:{indirect}")]
609    Draw {
610        indexed: bool,
611        indirect: bool,
612        pipeline: Option<id::RenderPipelineId>,
613    },
614    #[error("While resetting queries after the renderpass was ran")]
615    QueryReset,
616    #[error("In a write_timestamp command")]
617    WriteTimestamp,
618    #[error("In a begin_occlusion_query command")]
619    BeginOcclusionQuery,
620    #[error("In a end_occlusion_query command")]
621    EndOcclusionQuery,
622    #[error("In a begin_pipeline_statistics_query command")]
623    BeginPipelineStatisticsQuery,
624    #[error("In a end_pipeline_statistics_query command")]
625    EndPipelineStatisticsQuery,
626    #[error("In a execute_bundle command")]
627    ExecuteBundle,
628    #[error("In a dispatch command, indirect:{indirect}")]
629    Dispatch {
630        indirect: bool,
631        pipeline: Option<id::ComputePipelineId>,
632    },
633    #[error("In a pop_debug_group command")]
634    PopDebugGroup,
635}
636
637impl PrettyError for PassErrorScope {
638    fn fmt_pretty(&self, fmt: &mut ErrorFormatter) {
639        // This error is not in the error chain, only notes are needed
640        match *self {
641            Self::Pass(id) => {
642                fmt.command_buffer_label(&id);
643            }
644            Self::SetBindGroup(id) => {
645                fmt.bind_group_label(&id);
646            }
647            Self::SetPipelineRender(id) => {
648                fmt.render_pipeline_label(&id);
649            }
650            Self::SetPipelineCompute(id) => {
651                fmt.compute_pipeline_label(&id);
652            }
653            Self::SetVertexBuffer(id) => {
654                fmt.buffer_label(&id);
655            }
656            Self::SetIndexBuffer(id) => {
657                fmt.buffer_label(&id);
658            }
659            Self::Draw {
660                pipeline: Some(id), ..
661            } => {
662                fmt.render_pipeline_label(&id);
663            }
664            Self::Dispatch {
665                pipeline: Some(id), ..
666            } => {
667                fmt.compute_pipeline_label(&id);
668            }
669            _ => {}
670        }
671    }
672}