use std::sync::{Arc, Weak};
use super::{PendingTransition, TrackerIndex};
use crate::{
resource::{Buffer, Trackable},
snatch::SnatchGuard,
track::{
invalid_resource_state, skip_barrier, ResourceMetadata, ResourceMetadataProvider,
ResourceUsageCompatibilityError, ResourceUses,
},
};
use hal::{BufferBarrier, BufferUses};
use wgt::{strict_assert, strict_assert_eq};
impl ResourceUses for BufferUses {
const EXCLUSIVE: Self = Self::EXCLUSIVE;
type Selector = ();
fn bits(self) -> u16 {
Self::bits(&self)
}
fn all_ordered(self) -> bool {
Self::ORDERED.contains(self)
}
fn any_exclusive(self) -> bool {
self.intersects(Self::EXCLUSIVE)
}
}
#[derive(Debug)]
pub(crate) struct BufferBindGroupState {
buffers: Vec<(Arc<Buffer>, BufferUses)>,
}
impl BufferBindGroupState {
pub fn new() -> Self {
Self {
buffers: Vec::new(),
}
}
pub(crate) fn optimize(&mut self) {
self.buffers
.sort_unstable_by_key(|(b, _)| b.tracker_index());
}
pub fn used_tracker_indices(&self) -> impl Iterator<Item = TrackerIndex> + '_ {
self.buffers
.iter()
.map(|(b, _)| b.tracker_index())
.collect::<Vec<_>>()
.into_iter()
}
pub fn insert_single(&mut self, buffer: Arc<Buffer>, state: BufferUses) {
self.buffers.push((buffer, state));
}
}
#[derive(Debug)]
pub(crate) struct BufferUsageScope {
state: Vec<BufferUses>,
metadata: ResourceMetadata<Arc<Buffer>>,
}
impl Default for BufferUsageScope {
fn default() -> Self {
Self {
state: Vec::new(),
metadata: ResourceMetadata::new(),
}
}
}
impl BufferUsageScope {
fn tracker_assert_in_bounds(&self, index: usize) {
strict_assert!(index < self.state.len());
self.metadata.tracker_assert_in_bounds(index);
}
pub fn clear(&mut self) {
self.state.clear();
self.metadata.clear();
}
pub fn set_size(&mut self, size: usize) {
self.state.resize(size, BufferUses::empty());
self.metadata.set_size(size);
}
fn allow_index(&mut self, index: usize) {
if index >= self.state.len() {
self.set_size(index + 1);
}
}
pub unsafe fn merge_bind_group(
&mut self,
bind_group: &BufferBindGroupState,
) -> Result<(), ResourceUsageCompatibilityError> {
for &(ref resource, state) in bind_group.buffers.iter() {
let index = resource.tracker_index().as_usize();
unsafe {
insert_or_merge(
None,
&mut self.state,
&mut self.metadata,
index as _,
index,
BufferStateProvider::Direct { state },
ResourceMetadataProvider::Direct { resource },
)?
};
}
Ok(())
}
pub fn merge_usage_scope(
&mut self,
scope: &Self,
) -> Result<(), ResourceUsageCompatibilityError> {
let incoming_size = scope.state.len();
if incoming_size > self.state.len() {
self.set_size(incoming_size);
}
for index in scope.metadata.owned_indices() {
self.tracker_assert_in_bounds(index);
scope.tracker_assert_in_bounds(index);
unsafe {
insert_or_merge(
None,
&mut self.state,
&mut self.metadata,
index as u32,
index,
BufferStateProvider::Indirect {
state: &scope.state,
},
ResourceMetadataProvider::Indirect {
metadata: &scope.metadata,
},
)?;
};
}
Ok(())
}
pub fn merge_single(
&mut self,
buffer: &Arc<Buffer>,
new_state: BufferUses,
) -> Result<(), ResourceUsageCompatibilityError> {
let index = buffer.tracker_index().as_usize();
self.allow_index(index);
self.tracker_assert_in_bounds(index);
unsafe {
insert_or_merge(
None,
&mut self.state,
&mut self.metadata,
index as _,
index,
BufferStateProvider::Direct { state: new_state },
ResourceMetadataProvider::Direct { resource: buffer },
)?;
}
Ok(())
}
}
pub(crate) struct BufferTracker {
start: Vec<BufferUses>,
end: Vec<BufferUses>,
metadata: ResourceMetadata<Arc<Buffer>>,
temp: Vec<PendingTransition<BufferUses>>,
}
impl BufferTracker {
pub fn new() -> Self {
Self {
start: Vec::new(),
end: Vec::new(),
metadata: ResourceMetadata::new(),
temp: Vec::new(),
}
}
fn tracker_assert_in_bounds(&self, index: usize) {
strict_assert!(index < self.start.len());
strict_assert!(index < self.end.len());
self.metadata.tracker_assert_in_bounds(index);
}
pub fn set_size(&mut self, size: usize) {
self.start.resize(size, BufferUses::empty());
self.end.resize(size, BufferUses::empty());
self.metadata.set_size(size);
}
fn allow_index(&mut self, index: usize) {
if index >= self.start.len() {
self.set_size(index + 1);
}
}
pub fn contains(&self, buffer: &Buffer) -> bool {
self.metadata.contains(buffer.tracker_index().as_usize())
}
pub fn used_resources(&self) -> impl Iterator<Item = Arc<Buffer>> + '_ {
self.metadata.owned_resources()
}
pub fn drain_transitions<'a, 'b: 'a>(
&'b mut self,
snatch_guard: &'a SnatchGuard<'a>,
) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
let buffer_barriers = self.temp.drain(..).map(|pending| {
let buf = unsafe { self.metadata.get_resource_unchecked(pending.id as _) };
pending.into_hal(buf, snatch_guard)
});
buffer_barriers
}
pub fn set_single(
&mut self,
buffer: &Arc<Buffer>,
state: BufferUses,
) -> Option<PendingTransition<BufferUses>> {
let index: usize = buffer.tracker_index().as_usize();
self.allow_index(index);
self.tracker_assert_in_bounds(index);
unsafe {
insert_or_barrier_update(
Some(&mut self.start),
&mut self.end,
&mut self.metadata,
index,
BufferStateProvider::Direct { state },
None,
ResourceMetadataProvider::Direct { resource: buffer },
&mut self.temp,
)
};
strict_assert!(self.temp.len() <= 1);
self.temp.pop()
}
pub fn set_from_tracker(&mut self, tracker: &Self) {
let incoming_size = tracker.start.len();
if incoming_size > self.start.len() {
self.set_size(incoming_size);
}
for index in tracker.metadata.owned_indices() {
self.tracker_assert_in_bounds(index);
tracker.tracker_assert_in_bounds(index);
unsafe {
insert_or_barrier_update(
Some(&mut self.start),
&mut self.end,
&mut self.metadata,
index,
BufferStateProvider::Indirect {
state: &tracker.start,
},
Some(BufferStateProvider::Indirect {
state: &tracker.end,
}),
ResourceMetadataProvider::Indirect {
metadata: &tracker.metadata,
},
&mut self.temp,
)
}
}
}
pub fn set_from_usage_scope(&mut self, scope: &BufferUsageScope) {
let incoming_size = scope.state.len();
if incoming_size > self.start.len() {
self.set_size(incoming_size);
}
for index in scope.metadata.owned_indices() {
self.tracker_assert_in_bounds(index);
scope.tracker_assert_in_bounds(index);
unsafe {
insert_or_barrier_update(
Some(&mut self.start),
&mut self.end,
&mut self.metadata,
index,
BufferStateProvider::Indirect {
state: &scope.state,
},
None,
ResourceMetadataProvider::Indirect {
metadata: &scope.metadata,
},
&mut self.temp,
)
}
}
}
pub unsafe fn set_and_remove_from_usage_scope_sparse(
&mut self,
scope: &mut BufferUsageScope,
index_source: impl IntoIterator<Item = TrackerIndex>,
) {
let incoming_size = scope.state.len();
if incoming_size > self.start.len() {
self.set_size(incoming_size);
}
for index in index_source {
let index = index.as_usize();
scope.tracker_assert_in_bounds(index);
if unsafe { !scope.metadata.contains_unchecked(index) } {
continue;
}
unsafe {
insert_or_barrier_update(
Some(&mut self.start),
&mut self.end,
&mut self.metadata,
index,
BufferStateProvider::Indirect {
state: &scope.state,
},
None,
ResourceMetadataProvider::Indirect {
metadata: &scope.metadata,
},
&mut self.temp,
)
};
unsafe { scope.metadata.remove(index) };
}
}
}
pub(crate) struct DeviceBufferTracker {
current_states: Vec<BufferUses>,
metadata: ResourceMetadata<Weak<Buffer>>,
temp: Vec<PendingTransition<BufferUses>>,
}
impl DeviceBufferTracker {
pub fn new() -> Self {
Self {
current_states: Vec::new(),
metadata: ResourceMetadata::new(),
temp: Vec::new(),
}
}
fn tracker_assert_in_bounds(&self, index: usize) {
strict_assert!(index < self.current_states.len());
self.metadata.tracker_assert_in_bounds(index);
}
fn allow_index(&mut self, index: usize) {
if index >= self.current_states.len() {
self.current_states.resize(index + 1, BufferUses::empty());
self.metadata.set_size(index + 1);
}
}
pub fn used_resources(&self) -> impl Iterator<Item = Weak<Buffer>> + '_ {
self.metadata.owned_resources()
}
pub fn insert_single(&mut self, buffer: &Arc<Buffer>, state: BufferUses) {
let index = buffer.tracker_index().as_usize();
self.allow_index(index);
self.tracker_assert_in_bounds(index);
unsafe {
insert(
None,
&mut self.current_states,
&mut self.metadata,
index,
BufferStateProvider::Direct { state },
None,
ResourceMetadataProvider::Direct {
resource: &Arc::downgrade(buffer),
},
)
}
}
pub fn set_single(
&mut self,
buffer: &Arc<Buffer>,
state: BufferUses,
) -> Option<PendingTransition<BufferUses>> {
let index: usize = buffer.tracker_index().as_usize();
self.tracker_assert_in_bounds(index);
let start_state_provider = BufferStateProvider::Direct { state };
unsafe {
barrier(
&mut self.current_states,
index,
start_state_provider.clone(),
&mut self.temp,
)
};
unsafe { update(&mut self.current_states, index, start_state_provider) };
strict_assert!(self.temp.len() <= 1);
self.temp.pop()
}
pub fn set_from_tracker_and_drain_transitions<'a, 'b: 'a>(
&'a mut self,
tracker: &'a BufferTracker,
snatch_guard: &'b SnatchGuard<'b>,
) -> impl Iterator<Item = BufferBarrier<'a, dyn hal::DynBuffer>> {
for index in tracker.metadata.owned_indices() {
self.tracker_assert_in_bounds(index);
let start_state_provider = BufferStateProvider::Indirect {
state: &tracker.start,
};
let end_state_provider = BufferStateProvider::Indirect {
state: &tracker.end,
};
unsafe {
barrier(
&mut self.current_states,
index,
start_state_provider,
&mut self.temp,
)
};
unsafe { update(&mut self.current_states, index, end_state_provider) };
}
self.temp.drain(..).map(|pending| {
let buf = unsafe { tracker.metadata.get_resource_unchecked(pending.id as _) };
pending.into_hal(buf, snatch_guard)
})
}
}
#[derive(Debug, Clone)]
enum BufferStateProvider<'a> {
Direct { state: BufferUses },
Indirect { state: &'a [BufferUses] },
}
impl BufferStateProvider<'_> {
#[inline(always)]
unsafe fn get_state(&self, index: usize) -> BufferUses {
match *self {
BufferStateProvider::Direct { state } => state,
BufferStateProvider::Indirect { state } => {
strict_assert!(index < state.len());
*unsafe { state.get_unchecked(index) }
}
}
}
}
#[inline(always)]
unsafe fn insert_or_merge(
start_states: Option<&mut [BufferUses]>,
current_states: &mut [BufferUses],
resource_metadata: &mut ResourceMetadata<Arc<Buffer>>,
index32: u32,
index: usize,
state_provider: BufferStateProvider<'_>,
metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
) -> Result<(), ResourceUsageCompatibilityError> {
let currently_owned = unsafe { resource_metadata.contains_unchecked(index) };
if !currently_owned {
unsafe {
insert(
start_states,
current_states,
resource_metadata,
index,
state_provider,
None,
metadata_provider,
)
};
return Ok(());
}
unsafe {
merge(
current_states,
index32,
index,
state_provider,
metadata_provider,
)
}
}
#[inline(always)]
unsafe fn insert_or_barrier_update(
start_states: Option<&mut [BufferUses]>,
current_states: &mut [BufferUses],
resource_metadata: &mut ResourceMetadata<Arc<Buffer>>,
index: usize,
start_state_provider: BufferStateProvider<'_>,
end_state_provider: Option<BufferStateProvider<'_>>,
metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
barriers: &mut Vec<PendingTransition<BufferUses>>,
) {
let currently_owned = unsafe { resource_metadata.contains_unchecked(index) };
if !currently_owned {
unsafe {
insert(
start_states,
current_states,
resource_metadata,
index,
start_state_provider,
end_state_provider,
metadata_provider,
)
};
return;
}
let update_state_provider = end_state_provider.unwrap_or_else(|| start_state_provider.clone());
unsafe { barrier(current_states, index, start_state_provider, barriers) };
unsafe { update(current_states, index, update_state_provider) };
}
#[inline(always)]
unsafe fn insert<T: Clone>(
start_states: Option<&mut [BufferUses]>,
current_states: &mut [BufferUses],
resource_metadata: &mut ResourceMetadata<T>,
index: usize,
start_state_provider: BufferStateProvider<'_>,
end_state_provider: Option<BufferStateProvider<'_>>,
metadata_provider: ResourceMetadataProvider<'_, T>,
) {
let new_start_state = unsafe { start_state_provider.get_state(index) };
let new_end_state =
end_state_provider.map_or(new_start_state, |p| unsafe { p.get_state(index) });
strict_assert_eq!(invalid_resource_state(new_start_state), false);
strict_assert_eq!(invalid_resource_state(new_end_state), false);
unsafe {
if let Some(&mut ref mut start_state) = start_states {
*start_state.get_unchecked_mut(index) = new_start_state;
}
*current_states.get_unchecked_mut(index) = new_end_state;
let resource = metadata_provider.get(index);
resource_metadata.insert(index, resource.clone());
}
}
#[inline(always)]
unsafe fn merge(
current_states: &mut [BufferUses],
_index32: u32,
index: usize,
state_provider: BufferStateProvider<'_>,
metadata_provider: ResourceMetadataProvider<'_, Arc<Buffer>>,
) -> Result<(), ResourceUsageCompatibilityError> {
let current_state = unsafe { current_states.get_unchecked_mut(index) };
let new_state = unsafe { state_provider.get_state(index) };
let merged_state = *current_state | new_state;
if invalid_resource_state(merged_state) {
return Err(ResourceUsageCompatibilityError::from_buffer(
unsafe { metadata_provider.get(index) },
*current_state,
new_state,
));
}
*current_state = merged_state;
Ok(())
}
#[inline(always)]
unsafe fn barrier(
current_states: &mut [BufferUses],
index: usize,
state_provider: BufferStateProvider<'_>,
barriers: &mut Vec<PendingTransition<BufferUses>>,
) {
let current_state = unsafe { *current_states.get_unchecked(index) };
let new_state = unsafe { state_provider.get_state(index) };
if skip_barrier(current_state, new_state) {
return;
}
barriers.push(PendingTransition {
id: index as _,
selector: (),
usage: current_state..new_state,
});
}
#[inline(always)]
unsafe fn update(
current_states: &mut [BufferUses],
index: usize,
state_provider: BufferStateProvider<'_>,
) {
let current_state = unsafe { current_states.get_unchecked_mut(index) };
let new_state = unsafe { state_provider.get_state(index) };
*current_state = new_state;
}