use crate::{
device::{
queue::{EncoderInFlight, SubmittedWorkDoneClosure, TempResource},
DeviceError, DeviceLostClosure,
},
hal_api::HalApi,
id,
resource::{self, Buffer, Labeled, Trackable},
snatch::SnatchGuard,
SubmissionIndex,
};
use smallvec::SmallVec;
use std::sync::Arc;
use thiserror::Error;
struct ActiveSubmission<A: HalApi> {
index: SubmissionIndex,
temp_resources: Vec<TempResource<A>>,
mapped: Vec<Arc<Buffer<A>>>,
encoders: Vec<EncoderInFlight<A>>,
work_done_closures: SmallVec<[SubmittedWorkDoneClosure; 1]>,
}
#[derive(Clone, Debug, Error)]
#[non_exhaustive]
pub enum WaitIdleError {
#[error(transparent)]
Device(#[from] DeviceError),
#[error("Tried to wait using a submission index from the wrong device. Submission index is from device {0:?}. Called poll on device {1:?}.")]
WrongSubmissionIndex(id::QueueId, id::DeviceId),
#[error("GPU got stuck :(")]
StuckGpu,
}
pub(crate) struct LifetimeTracker<A: HalApi> {
mapped: Vec<Arc<Buffer<A>>>,
active: Vec<ActiveSubmission<A>>,
ready_to_map: Vec<Arc<Buffer<A>>>,
work_done_closures: SmallVec<[SubmittedWorkDoneClosure; 1]>,
pub device_lost_closure: Option<DeviceLostClosure>,
}
impl<A: HalApi> LifetimeTracker<A> {
pub fn new() -> Self {
Self {
mapped: Vec::new(),
active: Vec::new(),
ready_to_map: Vec::new(),
work_done_closures: SmallVec::new(),
device_lost_closure: None,
}
}
pub fn queue_empty(&self) -> bool {
self.active.is_empty()
}
pub fn track_submission(
&mut self,
index: SubmissionIndex,
temp_resources: impl Iterator<Item = TempResource<A>>,
encoders: Vec<EncoderInFlight<A>>,
) {
self.active.push(ActiveSubmission {
index,
temp_resources: temp_resources.collect(),
mapped: Vec::new(),
encoders,
work_done_closures: SmallVec::new(),
});
}
pub(crate) fn map(&mut self, value: &Arc<Buffer<A>>) {
self.mapped.push(value.clone());
}
#[must_use]
pub fn triage_submissions(
&mut self,
last_done: SubmissionIndex,
command_allocator: &crate::command::CommandAllocator<A>,
) -> SmallVec<[SubmittedWorkDoneClosure; 1]> {
profiling::scope!("triage_submissions");
let done_count = self
.active
.iter()
.position(|a| a.index > last_done)
.unwrap_or(self.active.len());
let mut work_done_closures: SmallVec<_> = self.work_done_closures.drain(..).collect();
for a in self.active.drain(..done_count) {
log::debug!("Active submission {} is done", a.index);
self.ready_to_map.extend(a.mapped);
for encoder in a.encoders {
let raw = unsafe { encoder.land() };
command_allocator.release_encoder(raw);
}
drop(a.temp_resources);
work_done_closures.extend(a.work_done_closures);
}
work_done_closures
}
pub fn schedule_resource_destruction(
&mut self,
temp_resource: TempResource<A>,
last_submit_index: SubmissionIndex,
) {
let resources = self
.active
.iter_mut()
.find(|a| a.index == last_submit_index)
.map(|a| &mut a.temp_resources);
if let Some(resources) = resources {
resources.push(temp_resource);
}
}
pub fn add_work_done_closure(&mut self, closure: SubmittedWorkDoneClosure) {
match self.active.last_mut() {
Some(active) => {
active.work_done_closures.push(closure);
}
None => {
self.work_done_closures.push(closure);
}
}
}
}
impl<A: HalApi> LifetimeTracker<A> {
pub(crate) fn triage_mapped(&mut self) {
if self.mapped.is_empty() {
return;
}
for buffer in self.mapped.drain(..) {
let submit_index = buffer.submission_index();
log::trace!(
"Mapping of {} at submission {:?} gets assigned to active {:?}",
buffer.error_ident(),
submit_index,
self.active.iter().position(|a| a.index == submit_index)
);
self.active
.iter_mut()
.find(|a| a.index == submit_index)
.map_or(&mut self.ready_to_map, |a| &mut a.mapped)
.push(buffer);
}
}
#[must_use]
pub(crate) fn handle_mapping(
&mut self,
raw: &A::Device,
snatch_guard: &SnatchGuard,
) -> Vec<super::BufferMapPendingClosure> {
if self.ready_to_map.is_empty() {
return Vec::new();
}
let mut pending_callbacks: Vec<super::BufferMapPendingClosure> =
Vec::with_capacity(self.ready_to_map.len());
for buffer in self.ready_to_map.drain(..) {
let tracker_index = buffer.tracker_index();
let mapping = std::mem::replace(
&mut *buffer.map_state.lock(),
resource::BufferMapState::Idle,
);
let pending_mapping = match mapping {
resource::BufferMapState::Waiting(pending_mapping) => pending_mapping,
resource::BufferMapState::Idle => continue,
resource::BufferMapState::Active { .. } => {
*buffer.map_state.lock() = mapping;
continue;
}
_ => panic!("No pending mapping."),
};
let status = if pending_mapping.range.start != pending_mapping.range.end {
log::debug!("Buffer {tracker_index:?} map state -> Active");
let host = pending_mapping.op.host;
let size = pending_mapping.range.end - pending_mapping.range.start;
match super::map_buffer(
raw,
&buffer,
pending_mapping.range.start,
size,
host,
snatch_guard,
) {
Ok(ptr) => {
*buffer.map_state.lock() = resource::BufferMapState::Active {
ptr,
range: pending_mapping.range.start..pending_mapping.range.start + size,
host,
};
Ok(())
}
Err(e) => {
log::error!("Mapping failed: {e}");
Err(e)
}
}
} else {
*buffer.map_state.lock() = resource::BufferMapState::Active {
ptr: std::ptr::NonNull::dangling(),
range: pending_mapping.range,
host: pending_mapping.op.host,
};
Ok(())
};
pending_callbacks.push((pending_mapping.op, status));
}
pending_callbacks
}
}