#![feature(alloc_layout_extra)]
#![feature(coerce_unsized)]
#![feature(downcast_unchecked)]
#![feature(non_null_convenience)]
#![feature(ptr_metadata)]
#![feature(sync_unsafe_cell)]
#![feature(unsize)]
#![allow(private_interfaces)]
#![deny(missing_docs)]
#![deny(clippy::missing_docs_in_private_items)]
use crate::dyn_vec::*;
use crate::graph::*;
use crate::unique_id::*;
pub use mutability_marker::*;
use private::*;
use slab::*;
use std::any::*;
use std::cell::*;
use std::hint::*;
use std::marker::*;
use std::mem::*;
use std::ops::*;
use std::pin::*;
use std::sync::atomic::*;
use std::sync::mpsc::*;
use std::sync::*;
use sync_rw_cell::*;
use task_pool::*;
#[allow(unused_imports)]
use wasm_sync::{Condvar, Mutex};
mod dyn_vec;
mod graph;
mod unique_id;
pub trait Kind: 'static + Send + Sync {
type FormatDescriptor: Send + Sync;
type UsageDescriptor: Send + Sync + Sized;
}
pub trait Format: 'static + Send + Sync + Sized {
type Kind: Kind;
fn allocate(descriptor: &<Self::Kind as Kind>::FormatDescriptor) -> Self;
}
pub trait DerivedDescriptor<F: Format>: 'static + Send + Sync + Sized {
type Format: Format<Kind = F::Kind>;
fn update(
&self,
data: &mut Self::Format,
parent: &F,
usages: &[&<F::Kind as Kind>::UsageDescriptor],
);
}
pub trait ViewUsage: Send + Sync + ViewUsageInner {}
pub struct AllocationDescriptor<'a, F: Format> {
pub descriptor: <F::Kind as Kind>::FormatDescriptor,
pub label: Option<&'static str>,
pub derived_formats: &'a [Derived<F>],
}
pub struct CommandBuffer {
command_list: DynVec,
first_command_entry: Option<DynEntry<CommandEntry>>,
label: Option<&'static str>,
last_command_entry: Option<DynEntry<CommandEntry>>,
}
impl CommandBuffer {
pub fn new(descriptor: CommandBufferDescriptor) -> Self {
const DEFAULT_ALLOCATION_SIZE: usize = 2048;
Self {
command_list: DynVec::with_capacity(DEFAULT_ALLOCATION_SIZE),
label: descriptor.label,
first_command_entry: None,
last_command_entry: None,
}
}
pub fn fence(&mut self, views: &[&dyn ViewUsage]) {
unsafe {
let computation = SyncUnsafeCell::new(Some(Computation::Execute {
command: self.command_list.push(()),
}));
let first_view_entry = self.push_views(views);
let next_command = self.command_list.push(CommandEntry {
computation,
first_view_entry,
label: Some("Fence"),
next_instance: None,
});
self.update_first_last_command_entries(next_command);
}
}
pub fn map<M: UsageMutability, F: Format>(
&mut self,
view: &ViewDescriptor<M, F>,
) -> Mapped<M, F> {
unsafe {
assert!(
TypeId::of::<M>() == TypeId::of::<Const>() || !view.view.derived,
"Attempted to mutably map derived view of object{} in command buffer{}",
FormattedLabel(" ", view.view.inner.inner.label, ""),
FormattedLabel(" ", self.label, "")
);
let inner = Arc::new(MappedInner {
context_id: view.view.inner.inner.context_id,
command_context: UnsafeCell::new(MaybeUninit::uninit()),
label: view.view.inner.inner.label,
map_state: MapObjectState::default(),
});
let computation = SyncUnsafeCell::new(Some(Computation::Map {
inner: Some(inner.clone()),
}));
let first_view_entry = self.push_views(&[view]);
let next_command = self.command_list.push(CommandEntry {
computation,
first_view_entry,
label: Some("Map format"),
next_instance: None,
});
self.update_first_last_command_entries(next_command);
Mapped {
inner,
view: view.view.clone(),
marker: PhantomData,
}
}
}
pub fn schedule(
&mut self,
descriptor: CommandDescriptor<impl Send + Sync + FnOnce(CommandContext)>,
) {
unsafe {
let computation = SyncUnsafeCell::new(Some(Computation::Execute {
command: self
.command_list
.push(SyncUnsafeCell::new(Some(descriptor.command))),
}));
let first_view_entry = self.push_views(descriptor.views);
let next_command = self.command_list.push(CommandEntry {
computation,
first_view_entry,
label: descriptor.label,
next_instance: None,
});
self.update_first_last_command_entries(next_command);
}
}
fn push_views(&mut self, list: &[&dyn ViewUsage]) -> Option<DynEntry<ViewEntry>> {
unsafe {
let mut view_iter = list.iter();
if let Some(first) = view_iter.next() {
let view = first.add_to_list(&mut self.command_list);
let first_entry = self.command_list.push(ViewEntry {
next_instance: None,
view,
});
let mut previous_entry = first_entry;
for to_add in view_iter {
let view = to_add.add_to_list(&mut self.command_list);
let next_entry = self.command_list.push(ViewEntry {
next_instance: None,
view,
});
self.command_list
.get_unchecked_mut(previous_entry)
.next_instance = Some(next_entry);
previous_entry = next_entry;
}
Some(first_entry)
} else {
None
}
}
}
unsafe fn update_first_last_command_entries(&mut self, next_command: DynEntry<CommandEntry>) {
unsafe {
if self.first_command_entry.is_none() {
self.first_command_entry = Some(next_command);
} else if let Some(entry) = self.last_command_entry {
self.command_list.get_unchecked_mut(entry).next_instance = Some(next_command);
}
self.last_command_entry = Some(next_command);
}
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct CommandBufferDescriptor {
pub label: Option<&'static str>,
}
#[derive(Copy, Clone, Debug, Default)]
pub struct CommandBufferStatus {
pub incomplete_commands: u32,
}
impl CommandBufferStatus {
pub fn complete(&self) -> bool {
self.incomplete_commands == 0
}
}
#[derive(Copy, Clone, Debug, PartialEq, Eq, PartialOrd, Ord)]
pub struct CommandBufferSubmission {
unique_id: u64,
context_id: u64,
command_buffer_id: u16,
}
pub struct CommandContext {
inner: ManuallyDrop<CommandContextInner>,
}
impl CommandContext {
pub fn get<F: Format>(&self, view: &View<F>) -> ViewRef<Const, F> {
ViewRef {
reference: self.find_view::<Const, _>(view).borrow(),
marker: PhantomData,
}
}
pub fn get_mut<F: Format>(&self, view: &View<F>) -> ViewRef<Mut, F> {
ViewRef {
reference: self.find_view::<Mut, _>(view).borrow_mut(),
marker: PhantomData,
}
}
fn find_view<M: Mutability, F: Format>(&self, view: &View<F>) -> &RwCell<*mut ()> {
let mutable = TypeId::of::<Mut>() == TypeId::of::<M>();
if let Some(command_view) = self
.inner
.views
.iter()
.find(|x| x.id == view.id && x.mutable == mutable)
{
&command_view.value
} else {
panic!(
"View{} was not referenced by command{}{}",
FormattedLabel(" ", view.inner.inner.label, ""),
FormattedLabel(" ", self.inner.label, ""),
FormattedLabel(
" (from command buffer ",
self.inner.command_buffer_label,
")"
)
);
}
}
}
impl std::fmt::Debug for CommandContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CommandContext").finish()
}
}
impl Drop for CommandContext {
fn drop(&mut self) {
unsafe {
self.inner
.context
.inner
.lock()
.expect("Failed to lock context.")
.complete_command(self.inner.command_id, &self.inner.context);
ManuallyDrop::drop(&mut self.inner);
}
}
}
pub struct CommandDescriptor<'a, C: 'static + Send + Sync + FnOnce(CommandContext)> {
pub label: Option<&'static str>,
pub command: C,
pub views: &'a [&'a dyn ViewUsage],
}
impl<'a, C: 'static + Send + Sync + FnOnce(CommandContext)> std::fmt::Debug
for CommandDescriptor<'a, C>
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("CommandDescriptor")
.field(&self.label)
.finish()
}
}
#[derive(Copy, Clone, Debug, Default)]
pub struct ContextDescriptor {
pub label: Option<&'static str>,
}
pub struct Data<K: Kind> {
inner: Arc<DataInner<K>>,
}
impl<K: Kind> Data<K> {
pub fn view<F: Format<Kind = K>>(&self) -> View<F> {
let (id, derived) = if TypeId::of::<F>() == self.inner.format_id {
(self.inner.id, false)
} else if let Some((_, id)) = self
.inner
.derived_formats
.iter()
.copied()
.find(|&(id, _)| id == TypeId::of::<F>())
{
(id, true)
} else {
panic!(
"Derived format {} of object{} did not exist",
type_name::<F>(),
FormattedLabel(" ", self.inner.label, "")
)
};
View {
inner: self.clone(),
id,
derived,
marker: PhantomData,
}
}
}
impl<K: Kind> Clone for Data<K> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
}
}
}
impl<K: Kind> std::fmt::Debug for Data<K>
where
K::FormatDescriptor: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(label) = self.inner.label {
f.debug_tuple("Data")
.field(&label)
.field(&self.inner.descriptor)
.finish()
} else {
f.debug_tuple("Data").field(&self.inner.descriptor).finish()
}
}
}
impl<K: Kind> Deref for Data<K> {
type Target = K::FormatDescriptor;
fn deref(&self) -> &Self::Target {
&self.inner.descriptor
}
}
impl<K: Kind> PartialEq for Data<K> {
fn eq(&self, other: &Self) -> bool {
Arc::ptr_eq(&self.inner, &other.inner)
}
}
impl<K: Kind> Eq for Data<K> {}
#[derive(Clone)]
pub struct DataFrostContext {
holder: Arc<ContextHolder>,
}
impl DataFrostContext {
pub fn new(_: ContextDescriptor) -> Self {
let (object_update_sender, object_updates) = channel();
let change_notifier = ChangeNotifier::default();
let change_listener = Condvar::new();
let context_id = unique_id();
let inner = Mutex::new(ContextInner {
active_command_buffers: Slab::new(),
context_id,
compute_graph: DirectedAcyclicGraph::new(),
critical_nodes: DirectedAcyclicGraphFlags::new(),
critical_top_level_nodes: DirectedAcyclicGraphFlags::new(),
objects: Slab::new(),
object_update_sender,
object_updates,
stalled: true,
temporary_node_buffer: Vec::new(),
top_level_nodes: DirectedAcyclicGraphFlags::new(),
});
let holder = Arc::new(ContextHolder {
change_notifier,
change_listener,
context_id,
inner,
});
Self { holder }
}
pub fn allocate<F: Format>(&self, descriptor: AllocationDescriptor<F>) -> Data<F::Kind> {
self.inner().allocate(descriptor)
}
pub fn get<'a, M: Mutability, F: Format>(
&self,
mapping: &'a Mapped<M, F>,
) -> ViewRef<'a, Const, F> {
unsafe {
self.wait_for_mapping(mapping);
return (*mapping.inner.command_context.get())
.assume_init_ref()
.get(&mapping.view);
}
}
pub fn get_mut<'a, F: Format>(&self, mapping: &'a mut Mapped<Mut, F>) -> ViewRef<'a, Mut, F> {
unsafe {
self.wait_for_mapping(mapping);
return (*mapping.inner.command_context.get())
.assume_init_mut()
.get_mut(&mapping.view);
}
}
pub fn query(&self, submission: &CommandBufferSubmission) -> CommandBufferStatus {
assert!(
submission.context_id == self.holder.context_id,
"Submission was not owned by this context."
);
self.inner()
.active_command_buffers
.get(submission.command_buffer_id as usize)
.filter(|x| x.unique_id == submission.unique_id)
.map(|x| CommandBufferStatus {
incomplete_commands: x.remaining_commands,
})
.unwrap_or_default()
}
pub fn submit(&self, buffer: CommandBuffer) -> CommandBufferSubmission {
self.inner().submit(buffer, &self.holder)
}
fn inner(&self) -> MutexGuard<ContextInner> {
self.holder
.inner
.lock()
.expect("Failed to obtain inner context.")
}
fn wait_for_mapping<M: Mutability, F: Format>(&self, mapping: &Mapped<M, F>) {
unsafe {
assert!(
mapping.inner.context_id == self.holder.context_id,
"Mapping was not from this context."
);
let query = mapping.inner.map_state.get();
if query.queued {
if !query.complete {
let mut inner = self.inner();
while !mapping.inner.map_state.get().complete {
if inner.top_level_nodes.get_unchecked(query.node) {
inner
.critical_top_level_nodes
.set_unchecked(query.node, false);
inner.top_level_nodes.set_unchecked(query.node, false);
inner
.compute_graph
.get_unchecked_mut(query.node)
.computation = Computation::Map { inner: None };
*mapping.inner.command_context.get() = MaybeUninit::new(
inner.create_command_context(&self.holder, query.node),
);
mapping.inner.map_state.set_complete();
return;
}
match inner.prepare_next_command::<true>(&self.holder) {
Some(Some(command)) => {
drop(inner);
command.execute();
inner = self.inner();
}
Some(None) => continue,
None => {
inner = self
.holder
.change_listener
.wait(inner)
.expect("Failed to lock mutex.")
}
}
}
}
} else {
panic!(
"Attempted to map object{} before submitting the associated command buffer.",
FormattedLabel(" ", mapping.inner.label, "")
)
}
}
}
}
impl std::fmt::Debug for DataFrostContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("DataFrostContext").finish()
}
}
impl Default for DataFrostContext {
fn default() -> Self {
Self::new(ContextDescriptor::default())
}
}
impl WorkProvider for DataFrostContext {
fn change_notifier(&self) -> &ChangeNotifier {
&self.holder.change_notifier
}
fn next_task(&self) -> Option<Box<dyn '_ + WorkUnit>> {
let mut inner = self.inner();
loop {
match inner.prepare_next_command::<false>(&self.holder) {
Some(Some(command)) => return Some(command),
Some(None) => continue,
None => return None,
}
}
}
}
pub struct Derived<F: Format> {
inner: Arc<dyn DerivedFormatUpdater>,
marker: PhantomData<fn() -> F>,
}
impl<F: Format> Derived<F> {
pub fn new<D: DerivedDescriptor<F>>(descriptor: D) -> Self {
Self {
inner: Arc::new(TypedDerivedFormatUpdater {
descriptor,
marker: PhantomData,
}),
marker: PhantomData,
}
}
}
impl<F: Format> Clone for Derived<F> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
marker: PhantomData,
}
}
}
impl<F: Format> std::fmt::Debug for Derived<F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("Derived").field(&type_name::<F>()).finish()
}
}
pub struct Mapped<M: Mutability, F: Format> {
inner: Arc<MappedInner>,
view: View<F>,
marker: PhantomData<fn() -> M>,
}
pub struct View<F: Format> {
inner: Data<F::Kind>,
id: u32,
derived: bool,
marker: PhantomData<fn() -> F>,
}
impl<F: Format> View<F> {
pub fn as_const(&self) -> ViewDescriptor<Const, F> {
ViewDescriptor {
view: self,
descriptor: SyncUnsafeCell::new(Some(())),
taken: AtomicBool::new(false),
}
}
pub fn as_mut(&self, usage: <F::Kind as Kind>::UsageDescriptor) -> ViewDescriptor<Mut, F> {
ViewDescriptor {
view: self,
descriptor: SyncUnsafeCell::new(Some(usage)),
taken: AtomicBool::new(false),
}
}
pub fn data(&self) -> &Data<F::Kind> {
&self.inner
}
}
impl<F: Format> Clone for View<F> {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
id: self.id,
derived: self.derived,
marker: PhantomData,
}
}
}
impl<F: Format> std::fmt::Debug for View<F>
where
<F::Kind as Kind>::FormatDescriptor: std::fmt::Debug,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_tuple("View")
.field(&type_name::<F>())
.field(&self.inner)
.finish()
}
}
impl<F: Format> Deref for View<F> {
type Target = <F::Kind as Kind>::FormatDescriptor;
fn deref(&self) -> &Self::Target {
&self.inner.inner.descriptor
}
}
impl<F: Format> PartialEq for View<F> {
fn eq(&self, other: &Self) -> bool {
self.inner == other.inner
}
}
impl<F: Format> Eq for View<F> {}
pub struct ViewDescriptor<'a, M: UsageMutability, F: Format> {
view: &'a View<F>,
descriptor: SyncUnsafeCell<Option<M::Descriptor<F>>>,
taken: AtomicBool,
}
impl<'a, M: UsageMutability, F: Format> ViewUsage for ViewDescriptor<'a, M, F> {}
impl<'a, M: UsageMutability, F: Format> ViewUsageInner for ViewDescriptor<'a, M, F> {
fn add_to_list(&self, command_list: &mut DynVec) -> DynEntry<dyn ViewHolder> {
unsafe {
assert!(
!self.taken.swap(true, Ordering::Relaxed),
"Attempted to reuse view descriptor{}",
FormattedLabel(" ", self.view.inner.inner.label, "")
);
command_list.push(TypedViewHolder::<M, F> {
view: self.view.clone(),
descriptor: take(&mut *self.descriptor.get()).unwrap_unchecked(),
})
}
}
}
pub struct ViewRef<'a, M: Mutability, F: Format> {
reference: RwCellGuard<'a, M, *mut ()>,
marker: PhantomData<&'a F>,
}
impl<'a, M: Mutability, F: Format + std::fmt::Debug> std::fmt::Debug for ViewRef<'a, M, F> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
(**self).fmt(f)
}
}
impl<'a, M: Mutability, F: Format> Deref for ViewRef<'a, M, F> {
type Target = F;
fn deref(&self) -> &Self::Target {
unsafe { &*self.reference.cast_const().cast() }
}
}
impl<'a, F: Format> DerefMut for ViewRef<'a, Mut, F> {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { &mut *self.reference.cast() }
}
}
trait DerivedFormatUpdater: 'static + Send + Sync {
unsafe fn allocate(&self, descriptor: *const ()) -> Box<UnsafeCell<dyn Any + Send + Sync>>;
fn format_type_id(&self) -> TypeId;
unsafe fn update(&self, context: CommandContext, usages: *const [*const ()]);
}
trait ExecutableCommand: 'static + Send + Sync {
unsafe fn execute(&self, ctx: CommandContext);
}
trait ViewHolder: 'static + Send + Sync {
fn context_id(&self) -> u64;
fn id(&self) -> u32;
fn mutable(&self) -> bool;
fn usage(&self) -> *const ();
}
impl ExecutableCommand for () {
unsafe fn execute(&self, _: CommandContext) {}
}
impl<F: 'static + Send + Sync + FnOnce(CommandContext)> ExecutableCommand
for SyncUnsafeCell<Option<F>>
{
unsafe fn execute(&self, ctx: CommandContext) {
take(&mut *self.get()).unwrap_unchecked()(ctx);
}
}
struct ActiveCommandBuffer {
command_list: Arc<DynVec>,
label: Option<&'static str>,
remaining_commands: u32,
unique_id: u64,
}
struct CommandContextInner {
pub command_id: NodeId,
pub context: Arc<ContextHolder>,
pub label: Option<&'static str>,
pub command_buffer_label: Option<&'static str>,
pub views: Vec<CommandContextView>,
}
struct CommandContextView {
pub id: u32,
pub mutable: bool,
pub value: RwCell<*mut ()>,
}
struct CommandEntry {
pub computation: SyncUnsafeCell<Option<Computation>>,
pub first_view_entry: Option<DynEntry<ViewEntry>>,
pub label: Option<&'static str>,
pub next_instance: Option<DynEntry<CommandEntry>>,
}
#[derive(Clone)]
enum Computation {
Execute {
command: DynEntry<dyn ExecutableCommand>,
},
Map {
inner: Option<Arc<MappedInner>>,
},
Update {
object: u32,
updates: Vec<(Arc<DynVec>, DynEntry<dyn ViewHolder>)>,
},
}
struct ComputationNode {
pub computation: Computation,
pub command_buffer: u16,
pub derived_update: Option<DerivedFormatUpdate>,
pub label: Option<&'static str>,
pub views: Vec<ComputationViewReference>,
}
struct ComputationViewReference {
pub id: u32,
pub mutable: bool,
pub view_holder: DynEntry<dyn ViewHolder>,
}
struct ContextHolder {
change_notifier: ChangeNotifier,
change_listener: Condvar,
context_id: u64,
inner: Mutex<ContextInner>,
}
struct ContextInner {
active_command_buffers: Slab<ActiveCommandBuffer>,
compute_graph: DirectedAcyclicGraph<ComputationNode>,
context_id: u64,
critical_nodes: DirectedAcyclicGraphFlags,
critical_top_level_nodes: DirectedAcyclicGraphFlags,
objects: Slab<DataHolder>,
object_update_sender: Sender<ObjectUpdate>,
object_updates: std::sync::mpsc::Receiver<ObjectUpdate>,
stalled: bool,
temporary_node_buffer: Vec<NodeId>,
top_level_nodes: DirectedAcyclicGraphFlags,
}
impl ContextInner {
pub fn allocate<F: Format>(&mut self, descriptor: AllocationDescriptor<F>) -> Data<F::Kind> {
unsafe {
let mut derived_formats = Vec::with_capacity(descriptor.derived_formats.len());
let mut derived_states: Vec<DerivedFormatState> =
Vec::with_capacity(descriptor.derived_formats.len());
let object = F::allocate(&descriptor.descriptor);
let id = self.objects.insert(DataHolder {
immutable_references: Vec::new(),
mutable_references: Vec::new(),
label: descriptor.label,
derive_state: FormatDeriveState::Base {
derived_formats: Vec::with_capacity(descriptor.derived_formats.len()),
},
value: Box::pin(UnsafeCell::new(object)),
}) as u32;
for (index, derived) in descriptor.derived_formats.iter().enumerate() {
let derived_object = derived
.inner
.allocate(&descriptor.descriptor as *const _ as *const _);
let id = self.objects.insert(DataHolder {
immutable_references: Vec::new(),
mutable_references: Vec::new(),
label: descriptor.label,
derive_state: FormatDeriveState::Derived {
index: index as u8,
parent: id,
updater: derived.inner.clone(),
},
value: Box::into_pin(derived_object),
}) as u32;
assert!(
derived.inner.format_type_id() != TypeId::of::<F>(),
"Derived format cannot be the same type as parent."
);
assert!(
derived_states
.iter()
.all(|x| x.format_id != derived.inner.format_type_id()),
"Duplicate derived formats."
);
let format_id = derived.inner.format_type_id();
derived_formats.push((format_id, id));
derived_states.push(DerivedFormatState {
format_id,
id,
next_update: None,
});
}
if let FormatDeriveState::Base { derived_formats } =
&mut self.objects.get_unchecked_mut(id as usize).derive_state
{
*derived_formats = derived_states;
} else {
unreachable_unchecked()
}
let inner = Arc::new(DataInner {
context_id: self.context_id,
derived_formats,
descriptor: descriptor.descriptor,
format_id: TypeId::of::<F>(),
id,
label: descriptor.label,
object_updater: self.object_update_sender.clone(),
});
Data { inner }
}
}
pub fn submit(
&mut self,
buffer: CommandBuffer,
context: &ContextHolder,
) -> CommandBufferSubmission {
self.update_objects();
let (submission, added_top_level_node) = self.submit_buffer(buffer);
self.critical_top_level_nodes
.and(&self.critical_nodes, &self.top_level_nodes);
if added_top_level_node {
self.notify_new_top_level_commands(context);
}
submission
}
fn create_command_context(
&self,
context: &Arc<ContextHolder>,
command_id: NodeId,
) -> CommandContext {
unsafe {
let computation = self.compute_graph.get_unchecked(command_id);
CommandContext {
inner: ManuallyDrop::new(CommandContextInner {
command_id,
command_buffer_label: self
.active_command_buffers
.get_unchecked(computation.command_buffer as usize)
.label,
context: context.clone(),
label: computation.label,
views: computation
.views
.iter()
.map(|x| CommandContextView {
id: x.id,
mutable: x.mutable,
value: RwCell::new(
self.objects.get_unchecked(x.id as usize).value.get().cast(),
),
})
.collect(),
}),
}
}
}
fn prepare_next_command<const CRITICAL_ONLY: bool>(
&mut self,
context: &Arc<ContextHolder>,
) -> Option<Option<Box<dyn WorkUnit>>> {
unsafe {
if let Some(node) = self.pop_command::<CRITICAL_ONLY>() {
let ctx = self.create_command_context(context, node);
let computation = self.compute_graph.get_unchecked_mut(node);
match &mut computation.computation {
Computation::Execute { command } => {
let command = *command;
let command_buffer = self
.active_command_buffers
.get_unchecked(computation.command_buffer as usize)
.command_list
.clone();
Some(Some(Box::new(move || {
command_buffer.get_unchecked(command).execute(ctx)
})))
}
Computation::Map { inner } => {
let value = take(inner).unwrap_unchecked();
*value.command_context.get() = MaybeUninit::new(ctx);
value.map_state.set_complete();
if let Some(mut value) = Arc::into_inner(value) {
self.complete_command(node, context);
let command_context = value.command_context.get_mut().assume_init_mut();
ManuallyDrop::drop(&mut command_context.inner);
forget(value);
}
Some(None)
}
Computation::Update { object, updates } => {
let derived = *object;
let value = self.objects.get_unchecked(derived as usize);
let FormatDeriveState::Derived {
updater,
parent,
index,
..
} = &value.derive_state
else {
unreachable_unchecked()
};
let parent = *parent as usize;
let index = *index as usize;
let updater = updater.clone();
let format_state = if let FormatDeriveState::Base { derived_formats } =
&mut self.objects.get_unchecked_mut(parent).derive_state
{
derived_formats.get_unchecked_mut(index)
} else {
unreachable_unchecked()
};
if format_state.next_update == Some(node) {
format_state.next_update = None;
}
let updates = take(updates);
Some(Some(Box::new(move || {
let mut update_list = Vec::with_capacity(updates.len());
update_list.extend(
updates
.iter()
.map(|(buffer, entry)| buffer.get_unchecked(*entry).usage()),
);
updater.update(ctx, &update_list[..] as *const _);
})))
}
}
} else {
if !self.compute_graph.is_empty() {
self.stalled = true;
}
None
}
}
}
fn pop_command<const CRITICAL_ONLY: bool>(&mut self) -> Option<NodeId> {
unsafe {
if let Some(node) = self.critical_top_level_nodes.first_set_node() {
debug_assert!(
self.compute_graph.parents(node).next().is_none(),
"Attempted to pop non-parent node."
);
self.critical_top_level_nodes.set_unchecked(node, false);
self.top_level_nodes.set_unchecked(node, false);
Some(node)
} else if let Some(node) = (!CRITICAL_ONLY)
.then(|| self.top_level_nodes.first_set_node())
.flatten()
{
debug_assert!(
self.compute_graph.parents(node).next().is_none(),
"Attempted to pop non-parent node."
);
self.top_level_nodes.set_unchecked(node, false);
Some(node)
} else {
None
}
}
}
fn submit_buffer(&mut self, buffer: CommandBuffer) -> (CommandBufferSubmission, bool) {
unsafe {
let mut added_top_level_node = false;
let unique_id = unique_id();
let command_buffer_id = if let Some(first_entry) = buffer.first_command_entry {
let command_buffer_id = self.active_command_buffers.insert(ActiveCommandBuffer {
command_list: Arc::new(buffer.command_list),
label: buffer.label,
remaining_commands: 0,
unique_id,
}) as u16;
let mut command_entry = Some(first_entry);
while let Some(entry) = command_entry {
let next = self
.active_command_buffers
.get_unchecked(command_buffer_id as usize)
.command_list
.get_unchecked(entry);
command_entry = next.next_instance;
added_top_level_node |= self.schedule_command(
command_buffer_id,
buffer.label,
take(&mut *next.computation.get()).unwrap_unchecked(),
next.label,
next.first_view_entry,
);
}
command_buffer_id
} else {
0
};
let submission = CommandBufferSubmission {
command_buffer_id,
context_id: self.context_id,
unique_id,
};
(submission, added_top_level_node)
}
}
fn schedule_command(
&mut self,
command_buffer_id: u16,
command_buffer_label: Option<&'static str>,
computation: Computation,
label: Option<&'static str>,
first_view_entry: Option<DynEntry<ViewEntry>>,
) -> bool {
unsafe {
let command_buffer = self
.active_command_buffers
.get_unchecked(command_buffer_id as usize)
.command_list
.clone();
let node = self.compute_graph.vacant_node();
self.temporary_node_buffer.clear();
let mut views = Vec::new();
let mut view_entry = first_view_entry;
while let Some(entry) = view_entry {
let next = command_buffer.get_unchecked(entry);
let next_view = command_buffer.get_unchecked(next.view);
assert!(
next_view.context_id() == self.context_id,
"View did not belong to this context."
);
views.push(ComputationViewReference {
id: next_view.id(),
view_holder: next.view,
mutable: next_view.mutable(),
});
view_entry = next.next_instance;
let object = self.objects.get_unchecked_mut(next_view.id() as usize);
for computation in object.mutable_references.iter().copied() {
assert!(!next_view.mutable() || computation != node,
"Attempted to use two conflicting views of object{} with command{} in buffer{}",
FormattedLabel(" ", object.label, ""),
FormattedLabel(" ", label, ""),
FormattedLabel(" ", command_buffer_label, ""));
self.temporary_node_buffer.push(computation);
}
if next_view.mutable() {
for computation in object.immutable_references.iter().copied() {
assert!(computation != node,
"Attempted to use two conflicting views of object{} with command{} in buffer{}",
FormattedLabel(" ", object.label, ""),
FormattedLabel(" ", label, ""),
FormattedLabel(" ", command_buffer_label, ""));
if let Some(derived) =
&self.compute_graph.get_unchecked(computation).derived_update
{
if derived.parent == next_view.id() {
let FormatDeriveState::Base { derived_formats } =
&mut object.derive_state
else {
unreachable_unchecked()
};
if derived_formats
.get_unchecked(derived.index as usize)
.next_update
== Some(computation)
{
continue;
}
}
}
self.temporary_node_buffer.push(computation);
}
}
if next_view.mutable() {
&mut object.mutable_references
} else {
&mut object.immutable_references
}
.push(node);
}
self.compute_graph.insert_unchecked(
ComputationNode {
computation: computation.clone(),
command_buffer: command_buffer_id,
derived_update: None,
label,
views,
},
&self.temporary_node_buffer,
);
for i in 0..self.compute_graph.get_unchecked(node).views.len() {
let view = &self
.compute_graph
.get_unchecked(node)
.views
.get_unchecked(i);
let view_id = view.id;
let view_holder = view.view_holder;
let mutable = view.mutable;
let object = self.objects.get_unchecked_mut(view.id as usize);
let mut derived_nodes_to_add = Vec::new();
match &mut object.derive_state {
FormatDeriveState::Base { derived_formats } => {
if mutable {
for (index, format) in derived_formats.iter_mut().enumerate() {
let derived = if let Some(derived) = format.next_update {
self.compute_graph.add_parent_unchecked(node, derived);
self.top_level_nodes.set_unchecked(derived, false);
derived
} else {
let derived_computation = self.compute_graph.insert_unchecked(
ComputationNode {
computation: Computation::Update {
object: format.id,
updates: Vec::with_capacity(1),
},
command_buffer: command_buffer_id,
derived_update: Some(DerivedFormatUpdate {
parent: view_id,
index: index as u32,
}),
label: None,
views: vec![
ComputationViewReference {
id: view_id,
view_holder,
mutable: false,
},
ComputationViewReference {
id: format.id,
view_holder,
mutable: true,
},
],
},
&[node],
);
format.next_update = Some(derived_computation);
object.immutable_references.push(derived_computation);
derived_nodes_to_add.push((format.id, derived_computation));
self.active_command_buffers
.get_unchecked_mut(command_buffer_id as usize)
.remaining_commands += 1;
derived_computation
};
let Computation::Update { updates, .. } =
&mut self.compute_graph.get_unchecked_mut(derived).computation
else {
unreachable_unchecked()
};
updates.push((command_buffer.clone(), view_holder));
}
}
}
&mut FormatDeriveState::Derived { parent, index, .. } => {
if let FormatDeriveState::Base { derived_formats } =
&mut self.objects.get_unchecked_mut(parent as usize).derive_state
{
derived_formats
.get_unchecked_mut(index as usize)
.next_update = None;
} else {
unreachable_unchecked()
}
}
}
for (id, node) in derived_nodes_to_add {
self.objects
.get_unchecked_mut(id as usize)
.mutable_references
.push(node);
}
}
self.top_level_nodes.resize_for(&self.compute_graph);
self.critical_nodes.resize_for(&self.compute_graph);
let top_level = if self.temporary_node_buffer.is_empty() {
self.top_level_nodes.set_unchecked(node, true);
true
} else {
false
};
if let Computation::Map { inner } = computation {
let inner_ref = inner.as_ref().unwrap_unchecked();
assert!(
inner_ref.context_id == self.context_id,
"Attempted to map object in incorrect context."
);
inner_ref.map_state.set_queued(node);
self.mark_critical(node);
}
self.active_command_buffers
.get_unchecked_mut(command_buffer_id as usize)
.remaining_commands += 1;
top_level
}
}
unsafe fn mark_critical(&mut self, node: NodeId) {
self.critical_nodes.set_unchecked(node, true);
while let Some(parent) = self.temporary_node_buffer.pop() {
if !self.critical_nodes.get_unchecked(parent) {
self.temporary_node_buffer
.extend(self.compute_graph.parents(parent));
self.critical_nodes.set_unchecked(parent, true);
}
}
}
unsafe fn complete_command(&mut self, id: NodeId, context: &ContextHolder) {
self.temporary_node_buffer.clear();
self.temporary_node_buffer
.extend(self.compute_graph.children_unchecked(id));
self.critical_nodes.set_unchecked(id, false);
let computation = self.compute_graph.pop_unchecked(id);
for child in self.temporary_node_buffer.iter().copied() {
if self.compute_graph.parents_unchecked(child).next().is_none() {
self.top_level_nodes.set_unchecked(child, true);
}
}
for view in computation.views {
let object = self.objects.get_unchecked_mut(view.id as usize);
let view_vec = if view.mutable {
&mut object.mutable_references
} else {
&mut object.immutable_references
};
view_vec.swap_remove(view_vec.iter().position(|x| *x == id).unwrap_unchecked());
}
self.critical_top_level_nodes
.and(&self.top_level_nodes, &self.critical_nodes);
let command_list = self
.active_command_buffers
.get_unchecked_mut(computation.command_buffer as usize);
command_list.remaining_commands -= 1;
if command_list.remaining_commands == 0 {
self.active_command_buffers
.remove(computation.command_buffer as usize);
}
if !self.temporary_node_buffer.is_empty() {
self.notify_new_top_level_commands(context);
}
}
fn update_objects(&mut self) {
unsafe {
while let Ok(update) = self.object_updates.try_recv() {
match update {
ObjectUpdate::DropData { id } => {
let FormatDeriveState::Base { derived_formats } =
self.objects.remove(id as usize).derive_state
else {
unreachable_unchecked()
};
for format in derived_formats {
self.objects.remove(format.id as usize);
}
}
}
}
}
}
fn notify_new_top_level_commands(&mut self, context: &ContextHolder) {
if self.stalled {
self.stalled = false;
context.change_notifier.notify();
}
if self.critical_top_level_nodes.first_set_node().is_some() {
context.change_listener.notify_all();
}
}
}
struct DataHolder {
pub derive_state: FormatDeriveState,
pub label: Option<&'static str>,
pub immutable_references: Vec<NodeId>,
pub mutable_references: Vec<NodeId>,
pub value: Pin<Box<UnsafeCell<dyn Any + Send + Sync>>>,
}
struct DataInner<K: Kind> {
pub context_id: u64,
pub derived_formats: Vec<(TypeId, u32)>,
pub descriptor: K::FormatDescriptor,
pub format_id: TypeId,
pub id: u32,
pub label: Option<&'static str>,
pub object_updater: Sender<ObjectUpdate>,
}
impl<K: Kind> Drop for DataInner<K> {
fn drop(&mut self) {
let _ = self
.object_updater
.send(ObjectUpdate::DropData { id: self.id });
}
}
struct DerivedFormatUpdate {
pub parent: u32,
pub index: u32,
}
struct DerivedFormatState {
pub format_id: TypeId,
pub id: u32,
pub next_update: Option<NodeId>,
}
enum FormatDeriveState {
Base {
derived_formats: Vec<DerivedFormatState>,
},
Derived {
parent: u32,
index: u8,
updater: Arc<dyn DerivedFormatUpdater>,
},
}
struct FormattedLabel(pub &'static str, pub Option<&'static str>, pub &'static str);
impl std::fmt::Display for FormattedLabel {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(label) = &self.1 {
f.write_fmt(format_args!("{}'{}'{}", self.0, label, self.2))
} else {
Ok(())
}
}
}
struct MappedInner {
pub context_id: u64,
pub command_context: UnsafeCell<MaybeUninit<CommandContext>>,
pub label: Option<&'static str>,
pub map_state: MapObjectState,
}
impl Drop for MappedInner {
fn drop(&mut self) {
unsafe {
self.command_context.get_mut().assume_init_drop();
}
}
}
unsafe impl Send for MappedInner {}
unsafe impl Sync for MappedInner {}
struct MapObjectStateQuery {
pub complete: bool,
pub node: NodeId,
pub queued: bool,
}
#[derive(Default)]
struct MapObjectState(AtomicU32);
impl MapObjectState {
const MAPPING_QUEUED: u32 = 1 << 16;
const MAPPING_COMPLETE: u32 = 1 << 17;
pub fn set_queued(&self, node: NodeId) {
self.0.store(
Self::MAPPING_QUEUED | (u16::from(node) as u32),
Ordering::Release,
);
}
pub fn set_complete(&self) {
self.0.fetch_or(Self::MAPPING_COMPLETE, Ordering::Release);
}
pub fn get(&self) -> MapObjectStateQuery {
let value = self.0.load(Ordering::Acquire);
let complete = (value & Self::MAPPING_COMPLETE) != 0;
let queued = (value & Self::MAPPING_QUEUED) != 0;
let node = (value as u16).into();
MapObjectStateQuery {
complete,
node,
queued,
}
}
}
enum ObjectUpdate {
DropData {
id: u32,
},
}
struct ViewEntry {
pub next_instance: Option<DynEntry<ViewEntry>>,
pub view: DynEntry<dyn ViewHolder>,
}
struct TypedDerivedFormatUpdater<F: Format, D: DerivedDescriptor<F>> {
pub descriptor: D,
pub marker: PhantomData<fn() -> (F, D)>,
}
impl<F: Format, D: DerivedDescriptor<F>> DerivedFormatUpdater for TypedDerivedFormatUpdater<F, D> {
unsafe fn allocate(&self, descriptor: *const ()) -> Box<UnsafeCell<dyn Any + Send + Sync>> {
Box::new(UnsafeCell::new(<D::Format as Format>::allocate(
&*(descriptor as *const _),
)))
}
fn format_type_id(&self) -> TypeId {
TypeId::of::<D::Format>()
}
unsafe fn update(&self, context: CommandContext, usages: *const [*const ()]) {
self.descriptor.update(
&mut *context.inner.views.get_unchecked(1).value.borrow().cast(),
&*context
.inner
.views
.get_unchecked(0)
.value
.borrow()
.cast_const()
.cast(),
&*transmute::<_, *const [_]>(usages),
);
}
}
struct TypedViewHolder<M: UsageMutability, F: Format> {
view: View<F>,
descriptor: M::Descriptor<F>,
}
impl<M: UsageMutability, F: Format> ViewHolder for TypedViewHolder<M, F> {
fn context_id(&self) -> u64 {
self.view.inner.inner.context_id
}
fn id(&self) -> u32 {
self.view.id
}
fn mutable(&self) -> bool {
TypeId::of::<M>() == TypeId::of::<Mut>()
}
fn usage(&self) -> *const () {
&self.descriptor as *const _ as *const _
}
}
mod private {
use super::*;
pub trait ViewUsageInner {
fn add_to_list(&self, command_list: &mut DynVec) -> DynEntry<dyn ViewHolder>;
}
pub trait UsageMutability: Mutability {
type Descriptor<F: Format>: Send + Sync;
}
impl UsageMutability for Const {
type Descriptor<F: Format> = ();
}
impl UsageMutability for Mut {
type Descriptor<F: Format> = <F::Kind as Kind>::UsageDescriptor;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
pub struct DataDescriptor {
pub initial_value: u32,
}
pub struct MyData;
impl Kind for MyData {
type FormatDescriptor = DataDescriptor;
type UsageDescriptor = u32;
}
#[derive(Debug)]
pub struct Primary(pub i32);
impl Format for Primary {
type Kind = MyData;
fn allocate(descriptor: &<Self::Kind as Kind>::FormatDescriptor) -> Self {
Self(descriptor.initial_value as i32)
}
}
#[derive(Debug)]
pub struct DerivedAccelerationStructure(pub i32);
impl Format for DerivedAccelerationStructure {
type Kind = MyData;
fn allocate(descriptor: &<Self::Kind as Kind>::FormatDescriptor) -> Self {
Self(2 * descriptor.initial_value as i32)
}
}
pub struct UpdateAccelerationFromPrimary;
impl DerivedDescriptor<Primary> for UpdateAccelerationFromPrimary {
type Format = DerivedAccelerationStructure;
fn update(&self, data: &mut Self::Format, parent: &Primary, _usage: &[&u32]) {
data.0 = 2 * parent.0;
}
}
#[test]
#[should_panic]
fn test_panic_on_conflicting_usage() {
let ctx = DataFrostContext::new(ContextDescriptor {
label: Some("my context"),
});
let data = ctx.allocate::<Primary>(AllocationDescriptor {
descriptor: DataDescriptor { initial_value: 23 },
label: Some("my int"),
derived_formats: &[],
});
let mut command_buffer = CommandBuffer::new(CommandBufferDescriptor {
label: Some("my command buffer"),
});
let view = data.view::<Primary>();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
views: &[&view.as_const(), &view.as_mut(25)],
command: |_| {},
});
ctx.submit(command_buffer);
}
#[test]
fn test_allow_multiple_const_usage() {
let ctx = DataFrostContext::new(ContextDescriptor {
label: Some("my context"),
});
let data = ctx.allocate::<Primary>(AllocationDescriptor {
descriptor: DataDescriptor { initial_value: 23 },
label: Some("my int"),
derived_formats: &[],
});
let mut command_buffer = CommandBuffer::new(CommandBufferDescriptor {
label: Some("my command buffer"),
});
let view_a = data.view::<Primary>();
let view_b = data.view::<Primary>();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
command: |_| {},
views: &[&view_a.as_const(), &view_b.as_const()],
});
ctx.submit(command_buffer);
}
#[test]
fn test_single_mappings() {
let ctx = DataFrostContext::new(ContextDescriptor {
label: Some("my context"),
});
let data = ctx.allocate::<Primary>(AllocationDescriptor {
descriptor: DataDescriptor { initial_value: 23 },
label: Some("my int"),
derived_formats: &[Derived::new(UpdateAccelerationFromPrimary)],
});
let mut command_buffer = CommandBuffer::new(CommandBufferDescriptor {
label: Some("my command buffer"),
});
let view = data.view::<Primary>();
let view_clone = view.clone();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
command: move |ctx| {
let mut vc = ctx.get_mut(&view_clone);
vc.0 += 4;
},
views: &[&view.as_mut(4)],
});
let mapping1 = command_buffer.map(&data.view::<DerivedAccelerationStructure>().as_const());
let view_clone = view.clone();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
command: move |ctx| {
let mut vc = ctx.get_mut(&view_clone);
vc.0 += 2;
},
views: &[&view.as_mut(2)],
});
let mapping2 = command_buffer.map(&data.view::<DerivedAccelerationStructure>().as_const());
ctx.submit(command_buffer);
let value = ctx.get(&mapping1);
assert_eq!(value.0, 54);
drop(value);
drop(mapping1);
let value = ctx.get(&mapping2);
assert_eq!(value.0, 58);
drop(value);
drop(mapping2);
}
#[test]
fn test_skip_irrelevant_command() {
let execution_count = Arc::new(AtomicU32::new(0));
let ctx = DataFrostContext::new(ContextDescriptor {
label: Some("my context"),
});
let data = ctx.allocate::<Primary>(AllocationDescriptor {
descriptor: DataDescriptor { initial_value: 23 },
label: Some("my int"),
derived_formats: &[Derived::new(UpdateAccelerationFromPrimary)],
});
let mut command_buffer = CommandBuffer::new(CommandBufferDescriptor {
label: Some("my command buffer"),
});
let view = data.view::<Primary>();
let ex_clone = execution_count.clone();
let view_clone = view.clone();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
command: move |ctx| {
let mut vc = ctx.get_mut(&view_clone);
vc.0 += 4;
ex_clone.fetch_add(1, Ordering::Relaxed);
},
views: &[&view.as_mut(4)],
});
let ex_clone = execution_count.clone();
command_buffer.schedule(CommandDescriptor {
label: Some("Test command"),
command: move |_| {
ex_clone.fetch_add(1, Ordering::Relaxed);
},
views: &[&view.as_const()],
});
let mapping = command_buffer.map(&data.view::<DerivedAccelerationStructure>().as_const());
ctx.submit(command_buffer);
ctx.get(&mapping);
assert_eq!(execution_count.load(Ordering::Relaxed), 1);
}
}