1use crate::keyed_queues::KeyedQueues;
2use bevy_ecs::change_detection::Tick;
3use bevy_ecs::error::ErrorContext;
4use bevy_ecs::prelude::NonSend;
5use bevy_ecs::schedule::{InternedScheduleLabel, ScheduleLabel};
6use bevy_ecs::system::SystemParamValidationError;
7use bevy_ecs::world::FromWorld;
8use bevy_ecs::world::unsafe_world_cell::UnsafeWorldCell;
9use bevy_ecs::world::{Mut, WorldId};
10use bevy_ecs::{
11 system::{SystemParam, SystemState},
12 world::World,
13};
14use bevy_platform::collections::HashMap;
15use bevy_platform::sync::{Arc, Mutex, OnceLock, RwLock};
16use concurrent_queue::ConcurrentQueue;
17use core::any::TypeId;
18use core::marker::PhantomData;
19use core::pin::Pin;
20use core::sync::atomic::{AtomicU64, Ordering};
21use core::task::{Context, Poll, Waker};
22use crossbeam::sync::WaitGroup;
23
24pub struct AsyncEcsPlugin;
25
26impl bevy_app::Plugin for AsyncEcsPlugin {
27 fn build(&self, app: &mut bevy_app::App) {
28 use bevy_app::prelude::{
29 First, FixedFirst, FixedLast, FixedPostUpdate, FixedPreUpdate, FixedUpdate, Last,
30 PostStartup, PostUpdate, PreStartup, PreUpdate, Startup, Update,
31 };
32 for awa in vec![
33 PreStartup.intern(),
34 Startup.intern(),
35 PostStartup.intern(),
36 PreUpdate.intern(),
37 Update.intern(),
38 PostUpdate.intern(),
39 FixedPostUpdate.intern(),
40 FixedPreUpdate.intern(),
41 FixedUpdate.intern(),
42 First.intern(),
43 Last.intern(),
44 FixedFirst.intern(),
45 FixedLast.intern(),
46 ] {
47 app.add_systems(awa, move |world: &mut World| {
48 run_async_ecs_on_schedule(awa, world);
49 });
50 }
51 }
52}
53
54pub fn run_async_ecs_on_schedule(schedule: InternedScheduleLabel, world: &mut World) {
55 GLOBAL_WAKE_REGISTRY.wait(schedule, world);
56}
57
58mod keyed_queues {
65 use bevy_platform::collections::HashMap;
66 use bevy_platform::sync::{Arc, RwLock};
67 use concurrent_queue::ConcurrentQueue;
68 use core::hash::Hash;
69 pub struct KeyedQueues<K, V> {
73 inner: RwLock<HashMap<K, Arc<ConcurrentQueue<V>>>>,
74 }
75
76 impl<K, V> KeyedQueues<K, V>
77 where
78 K: Eq + Hash + Clone,
79 V: Send + 'static,
80 {
81 pub fn new() -> Self {
82 Self {
83 inner: RwLock::new(HashMap::new()),
84 }
85 }
86
87 #[inline]
88 pub fn get_or_create(&self, key: &K) -> Arc<ConcurrentQueue<V>> {
89 if let Some(q) = self.inner.read().unwrap().get(key).cloned() {
91 return q;
92 }
93 let mut write = self.inner.write().unwrap();
95 if let Some(q) = write.get(key).cloned() {
97 return q;
98 }
99 let q = Arc::new(ConcurrentQueue::unbounded());
100 write.insert(key.clone(), q.clone());
101 q
102 }
103
104 #[inline]
107 pub fn try_send(&self, key: &K, val: V) -> Result<(), concurrent_queue::PushError<V>> {
108 let q = self.get_or_create(key);
109 q.push(val)
110 }
111 }
112}
113
114static GLOBAL_WORLD_ACCESS: WorldAccessRegistry = WorldAccessRegistry(OnceLock::new());
117
118pub(crate) static GLOBAL_WAKE_REGISTRY: WakeRegistry = WakeRegistry(OnceLock::new());
122
123#[derive(bevy_ecs::prelude::Resource, Default, Clone)]
126pub(crate) struct WakeParkBarrier(Arc<Mutex<HashMap<AsyncTaskId, WaitGroup>>>);
127
128#[derive(bevy_ecs::prelude::Resource)]
131pub(crate) struct SystemStatePool<T: SystemParam + 'static>(
132 RwLock<HashMap<AsyncTaskId, ConcurrentQueue<SystemState<T>>>>,
133);
134
135#[derive(bevy_ecs::prelude::Resource, Default)]
137pub(crate) struct SystemParamAppliers(HashMap<TypeId, fn(&mut World)>);
138impl SystemParamAppliers {
139 fn run(&mut self, world: &mut World) {
140 for closure in self.0.values_mut() {
141 closure(world);
142 }
143 }
144}
145impl<T: SystemParam + 'static> FromWorld for SystemStatePool<T> {
146 fn from_world(world: &mut World) -> Self {
147 let this = Self(RwLock::new(HashMap::default()));
148 world.init_resource::<SystemParamAppliers>();
149 let mut appliers = world.get_resource_mut::<SystemParamAppliers>().unwrap();
150 if !appliers.0.contains_key(&TypeId::of::<T>()) {
151 appliers.0.insert(TypeId::of::<T>(), |world: &mut World| {
152 world.try_resource_scope(|world, param_pool: Mut<SystemStatePool<T>>| {
153 for concurrent_queue in param_pool.0.read().unwrap().values() {
154 let Ok(mut system_state) = concurrent_queue.pop() else {
155 unreachable!()
156 };
157 system_state.apply(world);
158 match concurrent_queue.push(system_state) {
159 Ok(_) => {}
160 Err(_) => panic!(),
161 }
162 }
163 });
164 });
165 }
166 this
167 }
168}
169
170#[derive(Clone, Copy, Hash, PartialOrd, PartialEq, Eq, Debug)]
173struct AsyncTaskId(u64);
174
175static MAX_TASK_ID: AtomicU64 = AtomicU64::new(0);
177
178impl AsyncTaskId {
179 pub fn new() -> Option<Self> {
185 MAX_TASK_ID
186 .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |val| {
188 val.checked_add(1)
189 })
190 .map(AsyncTaskId)
191 .ok()
192 }
193}
194
195pub(crate) struct WakeRegistry(
197 OnceLock<
198 KeyedQueues<
199 (WorldId, InternedScheduleLabel),
200 (Waker, fn(&mut World, AsyncTaskId), AsyncTaskId),
201 >,
202 >,
203);
204
205impl WakeRegistry {
206 pub fn wait(&self, schedule: InternedScheduleLabel, world: &mut World) -> Option<()> {
215 let world_id = world.id();
216 if GLOBAL_WAKE_REGISTRY
217 .0
218 .get_or_init(KeyedQueues::new)
219 .get_or_create(&(world_id, schedule))
220 .is_empty()
221 {
222 return None;
223 }
224 for (cleanup_function, task_to_cleanup) in TASKS_TO_CLEANUP
226 .get_or_init(KeyedQueues::new)
227 .get_or_create(&world_id)
228 .try_iter()
229 {
230 cleanup_function(world, task_to_cleanup);
231 }
232 let mut waker_list = bevy_platform::prelude::vec![];
233 let mut task_id_list = bevy_platform::prelude::vec![];
234 while let Ok((waker, system_init, task_id)) = GLOBAL_WAKE_REGISTRY
235 .0
236 .get_or_init(KeyedQueues::new)
237 .get_or_create(&(world_id, schedule))
238 .pop()
239 {
240 system_init(world, task_id);
242 waker_list.push(waker);
243 task_id_list.push(task_id);
244 }
245 let wait_group = WaitGroup::new();
246 world.init_resource::<WakeParkBarrier>();
247 for task in task_id_list {
248 world
249 .resource_mut::<WakeParkBarrier>()
250 .0
251 .lock()
252 .unwrap()
253 .insert(task, wait_group.clone());
254 }
255 GLOBAL_WORLD_ACCESS.set(world, || {
256 for waker in waker_list {
257 waker.wake();
258 }
259 wait_group.wait();
262 })?;
263 world.try_resource_scope(|world, mut appliers: Mut<SystemParamAppliers>| {
265 appliers.run(world);
266 });
267 Some(())
268 }
269}
270
271pub(crate) struct WorldAccessRegistry(
276 OnceLock<
277 RwLock<
278 HashMap<
279 WorldId,
280 RwLock<
281 Option<(
282 UnsafeWorldCell<'static>,
283 Mutex<PhantomData<UnsafeWorldCell<'static>>>,
284 )>,
285 >,
286 >,
287 >,
288 >,
289);
290
291impl WorldAccessRegistry {
292 fn set(&self, world: &mut World, func: impl FnOnce()) -> Option<()> {
294 let this = self.0.get_or_init(|| RwLock::new(HashMap::new()));
295 let world_id = world.id();
296 if !this.read().unwrap().contains_key(&world_id) {
297 let _ = this.write().unwrap().insert(world_id, RwLock::new(None));
299 }
300
301 struct ClearOnDropGuard<'a> {
302 slot: &'a RwLock<
303 Option<(
304 UnsafeWorldCell<'static>,
305 Mutex<PhantomData<UnsafeWorldCell<'static>>>,
306 )>,
307 >,
308 }
309 impl<'a> Drop for ClearOnDropGuard<'a> {
310 fn drop(&mut self) {
311 match self.slot.write() {
314 Ok(mut slot) => {
315 let _ = slot.take();
316 }
317 Err(_) => {
318 }
321 }
322 }
323 }
324 unsafe {
329 let binding = this.read().unwrap();
330 let world_container = binding.get(&world_id).unwrap();
331 let _clear = ClearOnDropGuard {
333 slot: world_container,
334 };
335 world_container.write().unwrap().replace((
340 core::mem::transmute::<UnsafeWorldCell, UnsafeWorldCell<'static>>(
341 world.as_unsafe_world_cell(),
342 ),
343 Mutex::new(PhantomData),
344 ));
345 func();
346 }
347 Some(())
348 }
349 fn get<T>(
350 &self,
351 world_id: WorldId,
352 task_id: AsyncTaskId,
353 func: impl FnOnce(UnsafeWorldCell) -> Poll<T>,
354 ) -> Option<Poll<T>> {
355 let a = self.0.get()?.read().unwrap();
358 let b = a.get(&world_id)?.read().unwrap();
359 let our_thing = b.as_ref()?;
360 let _async_barrier = unsafe {
366 our_thing
367 .0
368 .get_resource::<WakeParkBarrier>()
369 .unwrap()
370 .0
371 .lock()
372 .unwrap()
373 .remove(&task_id)?
374 };
375 let _world = our_thing.1.try_lock().ok()?;
377 Some(func(our_thing.0))
379 }
380}
381
382impl<P: bevy_ecs::system::SystemParam + 'static> EcsTask<P> {
383 pub async fn run_system<Func, Out>(self, schedule: impl ScheduleLabel, ecs_access: Func) -> Out
384 where
385 for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
386 {
387 async_access(self, schedule, ecs_access).await
388 }
389}
390
391pub trait CreateEcsTask {
392 fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P>;
393}
394
395impl CreateEcsTask for WorldId {
396 fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P> {
397 EcsTask::new(self)
398 }
399}
400
401async fn async_access<P, Func, Out>(
406 task_identifier: impl Into<EcsTask<P>>,
407 schedule: impl ScheduleLabel,
408 ecs_access: Func,
409) -> Out
410where
411 P: SystemParam + 'static,
412 for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
413{
414 let task_identifier = task_identifier.into();
415 PendingEcsCall::<P, Func, Out>(
416 PhantomData::<P>,
417 PhantomData,
418 Some(ecs_access),
419 (task_identifier.0.1, schedule.intern()),
420 task_identifier.0.0,
421 )
422 .await
423}
424
425static TASKS_TO_CLEANUP: OnceLock<
426 KeyedQueues<WorldId, (fn(&mut World, AsyncTaskId), AsyncTaskId)>,
427> = OnceLock::new();
428
429fn cleanup_ecs_task<P: SystemParam + 'static>(task: &InternalEcsTask<P>) {
432 fn cleanup_task<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
433 world.try_resource_scope(|_world, param_pool: Mut<SystemStatePool<P>>| {
434 let mut pool = param_pool.0.write().unwrap();
435 pool.remove(&task_id);
436 if pool.len() * 2 < pool.capacity() {
437 pool.shrink_to_fit();
438 }
439 });
440 }
441 match TASKS_TO_CLEANUP
443 .get_or_init(KeyedQueues::new)
444 .try_send(&task.1, (cleanup_task::<P>, task.0))
445 {
446 Ok(_) => {}
447 Err(_) => unreachable!(),
448 }
449}
450
451impl<P: SystemParam + 'static> From<WorldId> for EcsTask<P> {
452 fn from(value: WorldId) -> Self {
453 EcsTask(Arc::new(InternalEcsTask(
454 AsyncTaskId::new().unwrap(),
455 value,
456 PhantomData,
457 )))
458 }
459}
460
461pub struct EcsTask<P: SystemParam + 'static>(Arc<InternalEcsTask<P>>);
463
464struct InternalEcsTask<P: SystemParam + 'static>(AsyncTaskId, WorldId, PhantomData<P>);
465
466impl<T: SystemParam + 'static> Drop for InternalEcsTask<T> {
467 fn drop(&mut self) {
468 cleanup_ecs_task(self);
469 }
470}
471
472impl<P: SystemParam + 'static> Clone for EcsTask<P> {
473 fn clone(&self) -> Self {
474 EcsTask(self.0.clone())
475 }
476}
477impl<P: SystemParam + 'static> EcsTask<P> {
478 pub fn new(world_id: WorldId) -> Self {
481 Self(Arc::new(InternalEcsTask(
482 AsyncTaskId::new().unwrap(),
483 world_id,
484 PhantomData,
485 )))
486 }
487}
488
489struct PendingEcsCall<P: SystemParam + 'static, Func, Out>(
490 PhantomData<P>,
491 PhantomData<Out>,
492 Option<Func>,
493 (WorldId, InternedScheduleLabel),
494 AsyncTaskId,
495);
496
497impl<P: SystemParam + 'static, Func, Out> Unpin for PendingEcsCall<P, Func, Out> {}
498
499impl<P, Func, Out> Future for PendingEcsCall<P, Func, Out>
500where
501 P: SystemParam + 'static,
502 for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
503{
504 type Output = Out;
505
506 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
507 fn system_state_init<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
508 world.init_resource::<SystemStatePool<P>>();
509 if !world
510 .get_resource::<SystemStatePool<P>>()
511 .unwrap()
512 .0
513 .read()
514 .unwrap()
515 .contains_key(&task_id)
516 {
517 let system_state = SystemState::<P>::new(world);
518 let cq = ConcurrentQueue::bounded(1);
519 match cq.push(system_state) {
520 Ok(_) => {}
521 Err(_) => {
522 panic!()
523 }
524 }
525 world
526 .get_resource::<SystemStatePool<P>>()
527 .unwrap()
528 .0
529 .write()
530 .unwrap()
531 .insert(task_id, cq);
532 }
533 }
534
535 let task_id = self.4;
536 let world_id = self.3.0;
537
538 match GLOBAL_WORLD_ACCESS.get(world_id, task_id,|world: UnsafeWorldCell| {
539 let Some(system_param_queue) = (unsafe { world.get_resource::<SystemStatePool<P>>() }) else { return Poll::Pending };
541 let mut system_state = match system_param_queue.0.read().unwrap().get(&task_id) {
542 None => return Poll::Pending,
543 Some(cq) => cq.pop().unwrap(),
544 };
545 let out;
546 unsafe {
548 let default_error_handler = world.default_error_handler();
549 if let Err(err) = SystemState::validate_param(&mut system_state, world) {
552 default_error_handler(err.into(), ErrorContext::System {
553 name: system_state.meta().name().clone(),
554 last_run: Tick::new(0),
555 });
556 }
557 if !system_state.meta().is_send() {
558 default_error_handler(SystemParamValidationError::invalid::<NonSend<()>>(
559 "Cannot have your system be non-send / exclusive",
560 ).into(), ErrorContext::System {
561 name: system_state.meta().name().clone(),
562 last_run: Tick::new(0),
563 });
564 }
565 let state = system_state.get_unchecked(world);
566 out = self.as_mut().2.take().unwrap()(state);
567 }
568 unsafe {
570 match world
571 .get_resource::<SystemStatePool<P>>()
572 .unwrap()
573 .0
574 .read()
575 .unwrap()
576 .get(&task_id)
577 .unwrap()
578 .push(system_state)
579 {
580 Ok(_) => {}
581 Err(_) => unreachable!("SystemStatePool should not be able to be removed if it previously existed, otherwise an invariant was violated"),
582 }
583 }
584 Poll::Ready(out)
585 }) {
586 Some(awa) => awa,
587 _ => {
588 match GLOBAL_WAKE_REGISTRY
593 .0
594 .get_or_init(KeyedQueues::new)
595 .try_send(
596 &self.3,
597 (cx.waker().clone(), system_state_init::<P>, task_id),
598 ) {
599 Ok(_) => {}
600 Err(_) => unreachable!(),
603 }
604 Poll::Pending
605 }
606 }
607 }
608}