use crate::keyed_queues::KeyedQueues;
use bevy_ecs::change_detection::Tick;
use bevy_ecs::error::ErrorContext;
use bevy_ecs::prelude::NonSend;
use bevy_ecs::schedule::{InternedScheduleLabel, ScheduleLabel};
use bevy_ecs::system::SystemParamValidationError;
use bevy_ecs::world::FromWorld;
use bevy_ecs::world::unsafe_world_cell::UnsafeWorldCell;
use bevy_ecs::world::{Mut, WorldId};
use bevy_ecs::{
system::{SystemParam, SystemState},
world::World,
};
use bevy_platform::collections::HashMap;
use bevy_platform::sync::{Arc, Mutex, OnceLock, RwLock};
use concurrent_queue::ConcurrentQueue;
use core::any::TypeId;
use core::marker::PhantomData;
use core::pin::Pin;
use core::sync::atomic::{AtomicU64, Ordering};
use core::task::{Context, Poll, Waker};
use crossbeam::sync::WaitGroup;
pub struct AsyncEcsPlugin;
impl bevy_app::Plugin for AsyncEcsPlugin {
fn build(&self, app: &mut bevy_app::App) {
use bevy_app::prelude::{
First, FixedFirst, FixedLast, FixedPostUpdate, FixedPreUpdate, FixedUpdate, Last,
PostStartup, PostUpdate, PreStartup, PreUpdate, Startup, Update,
};
for awa in vec![
PreStartup.intern(),
Startup.intern(),
PostStartup.intern(),
PreUpdate.intern(),
Update.intern(),
PostUpdate.intern(),
FixedPostUpdate.intern(),
FixedPreUpdate.intern(),
FixedUpdate.intern(),
First.intern(),
Last.intern(),
FixedFirst.intern(),
FixedLast.intern(),
] {
app.add_systems(awa, move |world: &mut World| {
run_async_ecs_on_schedule(awa, world);
});
}
}
}
pub fn run_async_ecs_on_schedule(schedule: InternedScheduleLabel, world: &mut World) {
GLOBAL_WAKE_REGISTRY.wait(schedule, world);
}
mod keyed_queues {
use bevy_platform::collections::HashMap;
use bevy_platform::sync::{Arc, RwLock};
use concurrent_queue::ConcurrentQueue;
use core::hash::Hash;
pub struct KeyedQueues<K, V> {
inner: RwLock<HashMap<K, Arc<ConcurrentQueue<V>>>>,
}
impl<K, V> KeyedQueues<K, V>
where
K: Eq + Hash + Clone,
V: Send + 'static,
{
pub fn new() -> Self {
Self {
inner: RwLock::new(HashMap::new()),
}
}
#[inline]
pub fn get_or_create(&self, key: &K) -> Arc<ConcurrentQueue<V>> {
if let Some(q) = self.inner.read().unwrap().get(key).cloned() {
return q;
}
let mut write = self.inner.write().unwrap();
if let Some(q) = write.get(key).cloned() {
return q;
}
let q = Arc::new(ConcurrentQueue::unbounded());
write.insert(key.clone(), q.clone());
q
}
#[inline]
pub fn try_send(&self, key: &K, val: V) -> Result<(), concurrent_queue::PushError<V>> {
let q = self.get_or_create(key);
q.push(val)
}
}
}
static GLOBAL_WORLD_ACCESS: WorldAccessRegistry = WorldAccessRegistry(OnceLock::new());
pub(crate) static GLOBAL_WAKE_REGISTRY: WakeRegistry = WakeRegistry(OnceLock::new());
#[derive(bevy_ecs::prelude::Resource, Default, Clone)]
pub(crate) struct WakeParkBarrier(Arc<Mutex<HashMap<AsyncTaskId, WaitGroup>>>);
#[derive(bevy_ecs::prelude::Resource)]
pub(crate) struct SystemStatePool<T: SystemParam + 'static>(
RwLock<HashMap<AsyncTaskId, ConcurrentQueue<SystemState<T>>>>,
);
#[derive(bevy_ecs::prelude::Resource, Default)]
pub(crate) struct SystemParamAppliers(HashMap<TypeId, fn(&mut World)>);
impl SystemParamAppliers {
fn run(&mut self, world: &mut World) {
for closure in self.0.values_mut() {
closure(world);
}
}
}
impl<T: SystemParam + 'static> FromWorld for SystemStatePool<T> {
fn from_world(world: &mut World) -> Self {
let this = Self(RwLock::new(HashMap::default()));
world.init_resource::<SystemParamAppliers>();
let mut appliers = world.get_resource_mut::<SystemParamAppliers>().unwrap();
if !appliers.0.contains_key(&TypeId::of::<T>()) {
appliers.0.insert(TypeId::of::<T>(), |world: &mut World| {
world.try_resource_scope(|world, param_pool: Mut<SystemStatePool<T>>| {
for concurrent_queue in param_pool.0.read().unwrap().values() {
let Ok(mut system_state) = concurrent_queue.pop() else {
unreachable!()
};
system_state.apply(world);
match concurrent_queue.push(system_state) {
Ok(_) => {}
Err(_) => panic!(),
}
}
});
});
}
this
}
}
#[derive(Clone, Copy, Hash, PartialOrd, PartialEq, Eq, Debug)]
struct AsyncTaskId(u64);
static MAX_TASK_ID: AtomicU64 = AtomicU64::new(0);
impl AsyncTaskId {
pub fn new() -> Option<Self> {
MAX_TASK_ID
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, |val| {
val.checked_add(1)
})
.map(AsyncTaskId)
.ok()
}
}
pub(crate) struct WakeRegistry(
OnceLock<
KeyedQueues<
(WorldId, InternedScheduleLabel),
(Waker, fn(&mut World, AsyncTaskId), AsyncTaskId),
>,
>,
);
impl WakeRegistry {
pub fn wait(&self, schedule: InternedScheduleLabel, world: &mut World) -> Option<()> {
let world_id = world.id();
if GLOBAL_WAKE_REGISTRY
.0
.get_or_init(KeyedQueues::new)
.get_or_create(&(world_id, schedule))
.is_empty()
{
return None;
}
for (cleanup_function, task_to_cleanup) in TASKS_TO_CLEANUP
.get_or_init(KeyedQueues::new)
.get_or_create(&world_id)
.try_iter()
{
cleanup_function(world, task_to_cleanup);
}
let mut waker_list = bevy_platform::prelude::vec![];
let mut task_id_list = bevy_platform::prelude::vec![];
while let Ok((waker, system_init, task_id)) = GLOBAL_WAKE_REGISTRY
.0
.get_or_init(KeyedQueues::new)
.get_or_create(&(world_id, schedule))
.pop()
{
system_init(world, task_id);
waker_list.push(waker);
task_id_list.push(task_id);
}
let wait_group = WaitGroup::new();
world.init_resource::<WakeParkBarrier>();
for task in task_id_list {
world
.resource_mut::<WakeParkBarrier>()
.0
.lock()
.unwrap()
.insert(task, wait_group.clone());
}
GLOBAL_WORLD_ACCESS.set(world, || {
for waker in waker_list {
waker.wake();
}
wait_group.wait();
})?;
world.try_resource_scope(|world, mut appliers: Mut<SystemParamAppliers>| {
appliers.run(world);
});
Some(())
}
}
pub(crate) struct WorldAccessRegistry(
OnceLock<
RwLock<
HashMap<
WorldId,
RwLock<
Option<(
UnsafeWorldCell<'static>,
Mutex<PhantomData<UnsafeWorldCell<'static>>>,
)>,
>,
>,
>,
>,
);
impl WorldAccessRegistry {
fn set(&self, world: &mut World, func: impl FnOnce()) -> Option<()> {
let this = self.0.get_or_init(|| RwLock::new(HashMap::new()));
let world_id = world.id();
if !this.read().unwrap().contains_key(&world_id) {
let _ = this.write().unwrap().insert(world_id, RwLock::new(None));
}
struct ClearOnDropGuard<'a> {
slot: &'a RwLock<
Option<(
UnsafeWorldCell<'static>,
Mutex<PhantomData<UnsafeWorldCell<'static>>>,
)>,
>,
}
impl<'a> Drop for ClearOnDropGuard<'a> {
fn drop(&mut self) {
match self.slot.write() {
Ok(mut slot) => {
let _ = slot.take();
}
Err(_) => {
}
}
}
}
unsafe {
let binding = this.read().unwrap();
let world_container = binding.get(&world_id).unwrap();
let _clear = ClearOnDropGuard {
slot: world_container,
};
world_container.write().unwrap().replace((
core::mem::transmute::<UnsafeWorldCell, UnsafeWorldCell<'static>>(
world.as_unsafe_world_cell(),
),
Mutex::new(PhantomData),
));
func();
}
Some(())
}
fn get<T>(
&self,
world_id: WorldId,
task_id: AsyncTaskId,
func: impl FnOnce(UnsafeWorldCell) -> Poll<T>,
) -> Option<Poll<T>> {
let a = self.0.get()?.read().unwrap();
let b = a.get(&world_id)?.read().unwrap();
let our_thing = b.as_ref()?;
let _async_barrier = unsafe {
our_thing
.0
.get_resource::<WakeParkBarrier>()
.unwrap()
.0
.lock()
.unwrap()
.remove(&task_id)?
};
let _world = our_thing.1.try_lock().ok()?;
Some(func(our_thing.0))
}
}
impl<P: bevy_ecs::system::SystemParam + 'static> EcsTask<P> {
pub async fn run_system<Func, Out>(self, schedule: impl ScheduleLabel, ecs_access: Func) -> Out
where
for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
{
async_access(self, schedule, ecs_access).await
}
}
pub trait CreateEcsTask {
fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P>;
}
impl CreateEcsTask for WorldId {
fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P> {
EcsTask::new(self)
}
}
async fn async_access<P, Func, Out>(
task_identifier: impl Into<EcsTask<P>>,
schedule: impl ScheduleLabel,
ecs_access: Func,
) -> Out
where
P: SystemParam + 'static,
for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
{
let task_identifier = task_identifier.into();
PendingEcsCall::<P, Func, Out>(
PhantomData::<P>,
PhantomData,
Some(ecs_access),
(task_identifier.0.1, schedule.intern()),
task_identifier.0.0,
)
.await
}
static TASKS_TO_CLEANUP: OnceLock<
KeyedQueues<WorldId, (fn(&mut World, AsyncTaskId), AsyncTaskId)>,
> = OnceLock::new();
fn cleanup_ecs_task<P: SystemParam + 'static>(task: &InternalEcsTask<P>) {
fn cleanup_task<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
world.try_resource_scope(|_world, param_pool: Mut<SystemStatePool<P>>| {
let mut pool = param_pool.0.write().unwrap();
pool.remove(&task_id);
if pool.len() * 2 < pool.capacity() {
pool.shrink_to_fit();
}
});
}
match TASKS_TO_CLEANUP
.get_or_init(KeyedQueues::new)
.try_send(&task.1, (cleanup_task::<P>, task.0))
{
Ok(_) => {}
Err(_) => unreachable!(),
}
}
impl<P: SystemParam + 'static> From<WorldId> for EcsTask<P> {
fn from(value: WorldId) -> Self {
EcsTask(Arc::new(InternalEcsTask(
AsyncTaskId::new().unwrap(),
value,
PhantomData,
)))
}
}
pub struct EcsTask<P: SystemParam + 'static>(Arc<InternalEcsTask<P>>);
struct InternalEcsTask<P: SystemParam + 'static>(AsyncTaskId, WorldId, PhantomData<P>);
impl<T: SystemParam + 'static> Drop for InternalEcsTask<T> {
fn drop(&mut self) {
cleanup_ecs_task(self);
}
}
impl<P: SystemParam + 'static> Clone for EcsTask<P> {
fn clone(&self) -> Self {
EcsTask(self.0.clone())
}
}
impl<P: SystemParam + 'static> EcsTask<P> {
pub fn new(world_id: WorldId) -> Self {
Self(Arc::new(InternalEcsTask(
AsyncTaskId::new().unwrap(),
world_id,
PhantomData,
)))
}
}
struct PendingEcsCall<P: SystemParam + 'static, Func, Out>(
PhantomData<P>,
PhantomData<Out>,
Option<Func>,
(WorldId, InternedScheduleLabel),
AsyncTaskId,
);
impl<P: SystemParam + 'static, Func, Out> Unpin for PendingEcsCall<P, Func, Out> {}
impl<P, Func, Out> Future for PendingEcsCall<P, Func, Out>
where
P: SystemParam + 'static,
for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
{
type Output = Out;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
fn system_state_init<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
world.init_resource::<SystemStatePool<P>>();
if !world
.get_resource::<SystemStatePool<P>>()
.unwrap()
.0
.read()
.unwrap()
.contains_key(&task_id)
{
let system_state = SystemState::<P>::new(world);
let cq = ConcurrentQueue::bounded(1);
match cq.push(system_state) {
Ok(_) => {}
Err(_) => {
panic!()
}
}
world
.get_resource::<SystemStatePool<P>>()
.unwrap()
.0
.write()
.unwrap()
.insert(task_id, cq);
}
}
let task_id = self.4;
let world_id = self.3.0;
match GLOBAL_WORLD_ACCESS.get(world_id, task_id,|world: UnsafeWorldCell| {
let Some(system_param_queue) = (unsafe { world.get_resource::<SystemStatePool<P>>() }) else { return Poll::Pending };
let mut system_state = match system_param_queue.0.read().unwrap().get(&task_id) {
None => return Poll::Pending,
Some(cq) => cq.pop().unwrap(),
};
let out;
unsafe {
let default_error_handler = world.default_error_handler();
if let Err(err) = SystemState::validate_param(&mut system_state, world) {
default_error_handler(err.into(), ErrorContext::System {
name: system_state.meta().name().clone(),
last_run: Tick::new(0),
});
}
if !system_state.meta().is_send() {
default_error_handler(SystemParamValidationError::invalid::<NonSend<()>>(
"Cannot have your system be non-send / exclusive",
).into(), ErrorContext::System {
name: system_state.meta().name().clone(),
last_run: Tick::new(0),
});
}
let state = system_state.get_unchecked(world);
out = self.as_mut().2.take().unwrap()(state);
}
unsafe {
match world
.get_resource::<SystemStatePool<P>>()
.unwrap()
.0
.read()
.unwrap()
.get(&task_id)
.unwrap()
.push(system_state)
{
Ok(_) => {}
Err(_) => unreachable!("SystemStatePool should not be able to be removed if it previously existed, otherwise an invariant was violated"),
}
}
Poll::Ready(out)
}) {
Some(awa) => awa,
_ => {
match GLOBAL_WAKE_REGISTRY
.0
.get_or_init(KeyedQueues::new)
.try_send(
&self.3,
(cx.waker().clone(), system_state_init::<P>, task_id),
) {
Ok(_) => {}
Err(_) => unreachable!(),
}
Poll::Pending
}
}
}
}