use std::{
any::Any,
sync::{Arc, Mutex},
};
use bevy_tasks::{ComputeTaskPool, Scope, TaskPool, ThreadExecutor};
use bevy_utils::default;
use bevy_utils::syncunsafecell::SyncUnsafeCell;
#[cfg(feature = "trace")]
use bevy_utils::tracing::{info_span, Instrument};
use std::panic::AssertUnwindSafe;
use async_channel::{Receiver, Sender};
use fixedbitset::FixedBitSet;
use crate::{
archetype::ArchetypeComponentId,
prelude::Resource,
query::Access,
schedule::{is_apply_deferred, BoxedCondition, ExecutorKind, SystemExecutor, SystemSchedule},
system::BoxedSystem,
world::{unsafe_world_cell::UnsafeWorldCell, World},
};
use crate as bevy_ecs;
struct SyncUnsafeSchedule<'a> {
systems: &'a [SyncUnsafeCell<BoxedSystem>],
conditions: Conditions<'a>,
}
struct Conditions<'a> {
system_conditions: &'a mut [Vec<BoxedCondition>],
set_conditions: &'a mut [Vec<BoxedCondition>],
sets_with_conditions_of_systems: &'a [FixedBitSet],
systems_in_sets_with_conditions: &'a [FixedBitSet],
}
impl SyncUnsafeSchedule<'_> {
fn new(schedule: &mut SystemSchedule) -> SyncUnsafeSchedule<'_> {
SyncUnsafeSchedule {
systems: SyncUnsafeCell::from_mut(schedule.systems.as_mut_slice()).as_slice_of_cells(),
conditions: Conditions {
system_conditions: &mut schedule.system_conditions,
set_conditions: &mut schedule.set_conditions,
sets_with_conditions_of_systems: &schedule.sets_with_conditions_of_systems,
systems_in_sets_with_conditions: &schedule.systems_in_sets_with_conditions,
},
}
}
}
struct SystemTaskMetadata {
archetype_component_access: Access<ArchetypeComponentId>,
dependents: Vec<usize>,
is_send: bool,
is_exclusive: bool,
}
struct SystemResult {
system_index: usize,
success: bool,
}
pub struct MultiThreadedExecutor {
sender: Sender<SystemResult>,
receiver: Receiver<SystemResult>,
system_task_metadata: Vec<SystemTaskMetadata>,
active_access: Access<ArchetypeComponentId>,
local_thread_running: bool,
exclusive_running: bool,
num_systems: usize,
num_running_systems: usize,
num_completed_systems: usize,
num_dependencies_remaining: Vec<usize>,
evaluated_sets: FixedBitSet,
ready_systems: FixedBitSet,
ready_systems_copy: FixedBitSet,
running_systems: FixedBitSet,
skipped_systems: FixedBitSet,
completed_systems: FixedBitSet,
unapplied_systems: FixedBitSet,
apply_final_deferred: bool,
panic_payload: Arc<Mutex<Option<Box<dyn Any + Send>>>>,
stop_spawning: bool,
}
impl Default for MultiThreadedExecutor {
fn default() -> Self {
Self::new()
}
}
impl SystemExecutor for MultiThreadedExecutor {
fn kind(&self) -> ExecutorKind {
ExecutorKind::MultiThreaded
}
fn set_apply_final_deferred(&mut self, value: bool) {
self.apply_final_deferred = value;
}
fn init(&mut self, schedule: &SystemSchedule) {
let sys_count = schedule.system_ids.len();
let set_count = schedule.set_ids.len();
let (tx, rx) = async_channel::bounded(sys_count.max(1));
self.sender = tx;
self.receiver = rx;
self.evaluated_sets = FixedBitSet::with_capacity(set_count);
self.ready_systems = FixedBitSet::with_capacity(sys_count);
self.ready_systems_copy = FixedBitSet::with_capacity(sys_count);
self.running_systems = FixedBitSet::with_capacity(sys_count);
self.completed_systems = FixedBitSet::with_capacity(sys_count);
self.skipped_systems = FixedBitSet::with_capacity(sys_count);
self.unapplied_systems = FixedBitSet::with_capacity(sys_count);
self.system_task_metadata = Vec::with_capacity(sys_count);
for index in 0..sys_count {
self.system_task_metadata.push(SystemTaskMetadata {
archetype_component_access: default(),
dependents: schedule.system_dependents[index].clone(),
is_send: schedule.systems[index].is_send(),
is_exclusive: schedule.systems[index].is_exclusive(),
});
}
self.num_dependencies_remaining = Vec::with_capacity(sys_count);
}
fn run(&mut self, schedule: &mut SystemSchedule, world: &mut World) {
self.num_systems = schedule.systems.len();
if self.num_systems == 0 {
return;
}
self.num_running_systems = 0;
self.num_completed_systems = 0;
self.num_dependencies_remaining.clear();
self.num_dependencies_remaining
.extend_from_slice(&schedule.system_dependencies);
for (system_index, dependencies) in self.num_dependencies_remaining.iter_mut().enumerate() {
if *dependencies == 0 {
self.ready_systems.insert(system_index);
}
}
let thread_executor = world
.get_resource::<MainThreadExecutor>()
.map(|e| e.0.clone());
let thread_executor = thread_executor.as_deref();
let SyncUnsafeSchedule {
systems,
mut conditions,
} = SyncUnsafeSchedule::new(schedule);
ComputeTaskPool::init(TaskPool::default).scope_with_executor(
false,
thread_executor,
|scope| {
let executor = async {
let world_cell = world.as_unsafe_world_cell();
while self.num_completed_systems < self.num_systems {
unsafe {
self.spawn_system_tasks(scope, systems, &mut conditions, world_cell);
}
if self.num_running_systems > 0 {
if let Ok(result) = self.receiver.recv().await {
self.finish_system_and_handle_dependents(result);
} else {
panic!("Channel closed unexpectedly!");
}
while let Ok(result) = self.receiver.try_recv() {
self.finish_system_and_handle_dependents(result);
}
self.rebuild_active_access();
}
}
};
#[cfg(feature = "trace")]
let executor_span = info_span!("multithreaded executor");
#[cfg(feature = "trace")]
let executor = executor.instrument(executor_span);
scope.spawn(executor);
},
);
if self.apply_final_deferred {
let res = apply_deferred(&self.unapplied_systems, systems, world);
if let Err(payload) = res {
let mut panic_payload = self.panic_payload.lock().unwrap();
*panic_payload = Some(payload);
}
self.unapplied_systems.clear();
debug_assert!(self.unapplied_systems.is_clear());
}
let mut payload = self.panic_payload.lock().unwrap();
if let Some(payload) = payload.take() {
std::panic::resume_unwind(payload);
}
debug_assert!(self.ready_systems.is_clear());
debug_assert!(self.running_systems.is_clear());
self.active_access.clear();
self.evaluated_sets.clear();
self.skipped_systems.clear();
self.completed_systems.clear();
}
}
impl MultiThreadedExecutor {
pub fn new() -> Self {
let (sender, receiver) = async_channel::unbounded();
Self {
sender,
receiver,
system_task_metadata: Vec::new(),
num_systems: 0,
num_running_systems: 0,
num_completed_systems: 0,
num_dependencies_remaining: Vec::new(),
active_access: default(),
local_thread_running: false,
exclusive_running: false,
evaluated_sets: FixedBitSet::new(),
ready_systems: FixedBitSet::new(),
ready_systems_copy: FixedBitSet::new(),
running_systems: FixedBitSet::new(),
skipped_systems: FixedBitSet::new(),
completed_systems: FixedBitSet::new(),
unapplied_systems: FixedBitSet::new(),
apply_final_deferred: true,
panic_payload: Arc::new(Mutex::new(None)),
stop_spawning: false,
}
}
unsafe fn spawn_system_tasks<'scope>(
&mut self,
scope: &Scope<'_, 'scope, ()>,
systems: &'scope [SyncUnsafeCell<BoxedSystem>],
conditions: &mut Conditions,
world_cell: UnsafeWorldCell<'scope>,
) {
if self.exclusive_running {
return;
}
let mut ready_systems = std::mem::take(&mut self.ready_systems_copy);
ready_systems.clear();
ready_systems.union_with(&self.ready_systems);
for system_index in ready_systems.ones() {
assert!(!self.running_systems.contains(system_index));
let system = unsafe { &mut *systems[system_index].get() };
if !self.can_run(system_index, system, conditions, world_cell) {
continue;
}
self.ready_systems.set(system_index, false);
if !self.should_run(system_index, system, conditions, world_cell) {
self.skip_system_and_signal_dependents(system_index);
continue;
}
self.running_systems.insert(system_index);
self.num_running_systems += 1;
if self.system_task_metadata[system_index].is_exclusive {
let world = unsafe { world_cell.world_mut() };
unsafe {
self.spawn_exclusive_system_task(scope, system_index, systems, world);
}
break;
}
unsafe {
self.spawn_system_task(scope, system_index, systems, world_cell);
}
}
self.ready_systems_copy = ready_systems;
}
fn can_run(
&mut self,
system_index: usize,
system: &mut BoxedSystem,
conditions: &mut Conditions,
world: UnsafeWorldCell,
) -> bool {
let system_meta = &self.system_task_metadata[system_index];
if system_meta.is_exclusive && self.num_running_systems > 0 {
return false;
}
if !system_meta.is_send && self.local_thread_running {
return false;
}
for set_idx in conditions.sets_with_conditions_of_systems[system_index]
.difference(&self.evaluated_sets)
{
for condition in &mut conditions.set_conditions[set_idx] {
condition.update_archetype_component_access(world);
if !condition
.archetype_component_access()
.is_compatible(&self.active_access)
{
return false;
}
}
}
for condition in &mut conditions.system_conditions[system_index] {
condition.update_archetype_component_access(world);
if !condition
.archetype_component_access()
.is_compatible(&self.active_access)
{
return false;
}
}
if !self.skipped_systems.contains(system_index) {
system.update_archetype_component_access(world);
if !system
.archetype_component_access()
.is_compatible(&self.active_access)
{
return false;
}
let meta_access =
&mut self.system_task_metadata[system_index].archetype_component_access;
meta_access.clear();
meta_access.extend(system.archetype_component_access());
}
true
}
unsafe fn should_run(
&mut self,
system_index: usize,
_system: &BoxedSystem,
conditions: &mut Conditions,
world: UnsafeWorldCell,
) -> bool {
let mut should_run = !self.skipped_systems.contains(system_index);
for set_idx in conditions.sets_with_conditions_of_systems[system_index].ones() {
if self.evaluated_sets.contains(set_idx) {
continue;
}
let set_conditions_met =
evaluate_and_fold_conditions(&mut conditions.set_conditions[set_idx], world);
if !set_conditions_met {
self.skipped_systems
.union_with(&conditions.systems_in_sets_with_conditions[set_idx]);
}
should_run &= set_conditions_met;
self.evaluated_sets.insert(set_idx);
}
let system_conditions_met =
evaluate_and_fold_conditions(&mut conditions.system_conditions[system_index], world);
if !system_conditions_met {
self.skipped_systems.insert(system_index);
}
should_run &= system_conditions_met;
should_run
}
unsafe fn spawn_system_task<'scope>(
&mut self,
scope: &Scope<'_, 'scope, ()>,
system_index: usize,
systems: &'scope [SyncUnsafeCell<BoxedSystem>],
world: UnsafeWorldCell<'scope>,
) {
let system = unsafe { &mut *systems[system_index].get() };
#[cfg(feature = "trace")]
let task_span = info_span!("system_task", name = &*system.name());
#[cfg(feature = "trace")]
let system_span = info_span!("system", name = &*system.name());
let sender = self.sender.clone();
let panic_payload = self.panic_payload.clone();
let task = async move {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
unsafe { system.run_unsafe((), world) };
}));
#[cfg(feature = "trace")]
drop(system_guard);
sender
.try_send(SystemResult {
system_index,
success: res.is_ok(),
})
.unwrap_or_else(|error| unreachable!("{}", error));
if let Err(payload) = res {
eprintln!("Encountered a panic in system `{}`!", &*system.name());
{
let mut panic_payload = panic_payload.lock().unwrap();
*panic_payload = Some(payload);
}
}
};
#[cfg(feature = "trace")]
let task = task.instrument(task_span);
let system_meta = &self.system_task_metadata[system_index];
self.active_access
.extend(&system_meta.archetype_component_access);
if system_meta.is_send {
scope.spawn(task);
} else {
self.local_thread_running = true;
scope.spawn_on_external(task);
}
}
unsafe fn spawn_exclusive_system_task<'scope>(
&mut self,
scope: &Scope<'_, 'scope, ()>,
system_index: usize,
systems: &'scope [SyncUnsafeCell<BoxedSystem>],
world: &'scope mut World,
) {
let system = unsafe { &mut *systems[system_index].get() };
#[cfg(feature = "trace")]
let task_span = info_span!("system_task", name = &*system.name());
#[cfg(feature = "trace")]
let system_span = info_span!("system", name = &*system.name());
let sender = self.sender.clone();
let panic_payload = self.panic_payload.clone();
if is_apply_deferred(system) {
let unapplied_systems = self.unapplied_systems.clone();
self.unapplied_systems.clear();
let task = async move {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = apply_deferred(&unapplied_systems, systems, world);
#[cfg(feature = "trace")]
drop(system_guard);
sender
.try_send(SystemResult {
system_index,
success: res.is_ok(),
})
.unwrap_or_else(|error| unreachable!("{}", error));
if let Err(payload) = res {
let mut panic_payload = panic_payload.lock().unwrap();
*panic_payload = Some(payload);
}
};
#[cfg(feature = "trace")]
let task = task.instrument(task_span);
scope.spawn_on_scope(task);
} else {
let task = async move {
#[cfg(feature = "trace")]
let system_guard = system_span.enter();
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
system.run((), world);
}));
#[cfg(feature = "trace")]
drop(system_guard);
sender
.try_send(SystemResult {
system_index,
success: res.is_ok(),
})
.unwrap_or_else(|error| unreachable!("{}", error));
if let Err(payload) = res {
eprintln!(
"Encountered a panic in exclusive system `{}`!",
&*system.name()
);
let mut panic_payload = panic_payload.lock().unwrap();
*panic_payload = Some(payload);
}
};
#[cfg(feature = "trace")]
let task = task.instrument(task_span);
scope.spawn_on_scope(task);
}
self.exclusive_running = true;
self.local_thread_running = true;
}
fn finish_system_and_handle_dependents(&mut self, result: SystemResult) {
let SystemResult {
system_index,
success,
} = result;
if self.system_task_metadata[system_index].is_exclusive {
self.exclusive_running = false;
}
if !self.system_task_metadata[system_index].is_send {
self.local_thread_running = false;
}
debug_assert!(self.num_running_systems >= 1);
self.num_running_systems -= 1;
self.num_completed_systems += 1;
self.running_systems.set(system_index, false);
self.completed_systems.insert(system_index);
self.unapplied_systems.insert(system_index);
self.signal_dependents(system_index);
if !success {
self.stop_spawning_systems();
}
}
fn skip_system_and_signal_dependents(&mut self, system_index: usize) {
self.num_completed_systems += 1;
self.completed_systems.insert(system_index);
self.signal_dependents(system_index);
}
fn signal_dependents(&mut self, system_index: usize) {
for &dep_idx in &self.system_task_metadata[system_index].dependents {
let remaining = &mut self.num_dependencies_remaining[dep_idx];
debug_assert!(*remaining >= 1);
*remaining -= 1;
if *remaining == 0 && !self.completed_systems.contains(dep_idx) {
self.ready_systems.insert(dep_idx);
}
}
}
fn stop_spawning_systems(&mut self) {
if !self.stop_spawning {
self.num_systems = self.num_completed_systems + self.num_running_systems;
self.stop_spawning = true;
}
}
fn rebuild_active_access(&mut self) {
self.active_access.clear();
for index in self.running_systems.ones() {
let system_meta = &self.system_task_metadata[index];
self.active_access
.extend(&system_meta.archetype_component_access);
}
}
}
fn apply_deferred(
unapplied_systems: &FixedBitSet,
systems: &[SyncUnsafeCell<BoxedSystem>],
world: &mut World,
) -> Result<(), Box<dyn std::any::Any + Send>> {
for system_index in unapplied_systems.ones() {
let system = unsafe { &mut *systems[system_index].get() };
let res = std::panic::catch_unwind(AssertUnwindSafe(|| {
system.apply_deferred(world);
}));
if let Err(payload) = res {
eprintln!(
"Encountered a panic when applying buffers for system `{}`!",
&*system.name()
);
return Err(payload);
}
}
Ok(())
}
unsafe fn evaluate_and_fold_conditions(
conditions: &mut [BoxedCondition],
world: UnsafeWorldCell,
) -> bool {
#[allow(clippy::unnecessary_fold)]
conditions
.iter_mut()
.map(|condition| {
#[cfg(feature = "trace")]
let _condition_span = info_span!("condition", name = &*condition.name()).entered();
unsafe { condition.run_unsafe((), world) }
})
.fold(true, |acc, res| acc && res)
}
#[derive(Resource, Clone)]
pub struct MainThreadExecutor(pub Arc<ThreadExecutor<'static>>);
impl Default for MainThreadExecutor {
fn default() -> Self {
Self::new()
}
}
impl MainThreadExecutor {
pub fn new() -> Self {
MainThreadExecutor(TaskPool::get_thread_executor())
}
}