Skip to main content

bevy_malek_async/
lib.rs

1use crate::keyed_queues::KeyedQueues;
2use bevy_ecs::change_detection::Tick;
3use bevy_ecs::error::ErrorContext;
4use bevy_ecs::prelude::NonSend;
5use bevy_ecs::schedule::{InternedScheduleLabel, ScheduleLabel};
6use bevy_ecs::system::SystemParamValidationError;
7use bevy_ecs::world::FromWorld;
8use bevy_ecs::world::unsafe_world_cell::UnsafeWorldCell;
9use bevy_ecs::world::{Mut, WorldId};
10use bevy_ecs::{
11    system::{SystemParam, SystemState},
12    world::World,
13};
14use bevy_platform::collections::HashMap;
15use bevy_platform::sync::{Arc, Mutex, OnceLock, RwLock};
16use concurrent_queue::ConcurrentQueue;
17use core::any::TypeId;
18use core::marker::PhantomData;
19use core::pin::Pin;
20use core::sync::atomic::{AtomicU64, Ordering};
21use core::task::{Context, Poll, Waker};
22use crossbeam::sync::WaitGroup;
23
24pub struct AsyncEcsPlugin;
25
26impl bevy_app::Plugin for AsyncEcsPlugin {
27    fn build(&self, app: &mut bevy_app::App) {
28        use bevy_app::prelude::{
29            First, FixedFirst, FixedLast, FixedPostUpdate, FixedPreUpdate, FixedUpdate, Last,
30            PostStartup, PostUpdate, PreStartup, PreUpdate, Startup, Update,
31        };
32        for awa in vec![
33            PreStartup.intern(),
34            Startup.intern(),
35            PostStartup.intern(),
36            PreUpdate.intern(),
37            Update.intern(),
38            PostUpdate.intern(),
39            FixedPostUpdate.intern(),
40            FixedPreUpdate.intern(),
41            FixedUpdate.intern(),
42            First.intern(),
43            Last.intern(),
44            FixedFirst.intern(),
45            FixedLast.intern(),
46        ] {
47            app.add_systems(awa, move |world: &mut World| {
48                run_async_ecs_on_schedule(awa, world);
49            });
50        }
51    }
52}
53
54pub fn run_async_ecs_on_schedule(schedule: InternedScheduleLabel, world: &mut World) {
55    GLOBAL_WAKE_REGISTRY.wait(schedule, world);
56}
57
58/// Keyed queues is a combination of a hashmap and a concurrent queue which is useful because it
59/// allows for non-blocking keyed queues.
60/// We want every World's async machinery to be as independent as possible, and this allows us
61/// to key our Queues on `(WorldId, Schedule)` so that there is 0 contention on the fast path and
62/// arbitrary N number of worlds running in parallel on the same process do not interfere at all
63/// except the very first time a new world initializes it's key.
64mod keyed_queues {
65    use bevy_platform::collections::HashMap;
66    use bevy_platform::sync::{Arc, RwLock};
67    use concurrent_queue::ConcurrentQueue;
68    use core::hash::Hash;
69    /// `HashMap<K, Arc<ConcurrentQueue<V>>>` behind a single `RwLock`.
70    /// - Writers only contend when creating a new key.
71    /// - `push` is almost always non-blocking (unbounded queue).
72    pub struct KeyedQueues<K, V> {
73        inner: RwLock<HashMap<K, Arc<ConcurrentQueue<V>>>>,
74    }
75
76    impl<K, V> KeyedQueues<K, V>
77    where
78        K: Eq + Hash + Clone,
79        V: Send + 'static,
80    {
81        pub fn new() -> Self {
82            Self {
83                inner: RwLock::new(HashMap::new()),
84            }
85        }
86
87        #[inline]
88        pub fn get_or_create(&self, key: &K) -> Arc<ConcurrentQueue<V>> {
89            // Fast path: try read lock first
90            if let Some(q) = self.inner.read().unwrap().get(key).cloned() {
91                return q;
92            }
93            // Slow path: create under write lock if still absent
94            let mut write = self.inner.write().unwrap();
95            // We intentionally check a second time because of synchronization
96            if let Some(q) = write.get(key).cloned() {
97                return q;
98            }
99            let q = Arc::new(ConcurrentQueue::unbounded());
100            write.insert(key.clone(), q.clone());
101            q
102        }
103
104        /// Potentially-blocking send but almost never blocking (unbounded queue => `push` never fails).
105        /// ( Only blocks when the `(WorldId, Schedule)` has never been used before
106        #[inline]
107        pub fn try_send(&self, key: &K, val: V) -> Result<(), concurrent_queue::PushError<V>> {
108            let q = self.get_or_create(key);
109            q.push(val)
110        }
111    }
112}
113
114/// This is an abstraction that temporarily and soundly stores the `UnsafeWorldCell` in a static so we can access
115/// it from any async task, runtime, and thread.
116static GLOBAL_WORLD_ACCESS: WorldAccessRegistry = WorldAccessRegistry(OnceLock::new());
117
118/// The entrypoint, stores `Waker`s from `async_access`'s that wish to be polled with world access
119/// also stores the generic function pointer to the concrete function that initializes the
120/// system state for any set of `SystemParams`
121pub(crate) static GLOBAL_WAKE_REGISTRY: WakeRegistry = WakeRegistry(OnceLock::new());
122
123/// Acts as a barrier that is waited on in the `wait` call, and once the `AtomicI64` reaches 0 the
124/// thread that `wait` was called on gets woken up and resumes.
125#[derive(bevy_ecs::prelude::Resource, Default, Clone)]
126pub(crate) struct WakeParkBarrier(Arc<Mutex<HashMap<AsyncTaskId, WaitGroup>>>);
127
128/// Stores the previous system state per task id which allows `Local`, `Changed` and other filters
129/// that depend on persistent state to work.
130#[derive(bevy_ecs::prelude::Resource)]
131pub(crate) struct SystemStatePool<T: SystemParam + 'static>(
132    RwLock<HashMap<AsyncTaskId, ConcurrentQueue<SystemState<T>>>>,
133);
134
135/// Function pointer to a concrete version of a genericized system state being applied to the world.
136#[derive(bevy_ecs::prelude::Resource, Default)]
137pub(crate) struct SystemParamAppliers(HashMap<TypeId, fn(&mut World)>);
138impl SystemParamAppliers {
139    fn run(&mut self, world: &mut World) {
140        for closure in self.0.values_mut() {
141            closure(world);
142        }
143    }
144}
145impl<T: SystemParam + 'static> FromWorld for SystemStatePool<T> {
146    fn from_world(world: &mut World) -> Self {
147        let this = Self(RwLock::new(HashMap::default()));
148        world.init_resource::<SystemParamAppliers>();
149        let mut appliers = world.get_resource_mut::<SystemParamAppliers>().unwrap();
150        if !appliers.0.contains_key(&TypeId::of::<T>()) {
151            appliers.0.insert(TypeId::of::<T>(), |world: &mut World| {
152                world.try_resource_scope(|world, param_pool: Mut<SystemStatePool<T>>| {
153                    for concurrent_queue in param_pool.0.read().unwrap().values() {
154                        let Ok(mut system_state) = concurrent_queue.pop() else {
155                            unreachable!()
156                        };
157                        system_state.apply(world);
158                        match concurrent_queue.push(system_state) {
159                            Ok(_) => {}
160                            Err(_) => panic!(),
161                        }
162                    }
163                });
164            });
165        }
166        this
167    }
168}
169
170/// A monotonically increasing global identifier for any particular async task.
171/// Is an internal implementation detail and thus not generally accessible
172#[derive(Clone, Copy, Hash, PartialOrd, PartialEq, Eq, Debug)]
173struct AsyncTaskId(u64);
174
175/// The next [`AsyncTaskId`].
176static MAX_TASK_ID: AtomicU64 = AtomicU64::new(0);
177
178impl AsyncTaskId {
179    /// Create a new, unique [`AsyncTaskId`]. Returns [`None`] if the supply of unique
180    /// IDs has been exhausted.
181    ///
182    /// Please note that the IDs created from this method are unique across
183    /// time - if a given ID is [`Drop`]ped its value still cannot be reused
184    pub fn new() -> Option<Self> {
185        MAX_TASK_ID
186            // We use `Relaxed` here since this atomic only needs to be consistent with itself
187            .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |val| {
188                val.checked_add(1)
189            })
190            .map(AsyncTaskId)
191            .ok()
192    }
193}
194
195/// Is the `GLOBAL_WAKE_REGISTRY`
196pub(crate) struct WakeRegistry(
197    OnceLock<
198        KeyedQueues<
199            (WorldId, InternedScheduleLabel),
200            (Waker, fn(&mut World, AsyncTaskId), AsyncTaskId),
201        >,
202    >,
203);
204
205impl WakeRegistry {
206    /// This function finds all pending `async_access` calls for a particular `Schedule` and a particular
207    /// `WorldId`. It wakes all of them, temporarily and soundly stores a `UnsafeWorldCell` in the
208    /// `GLOBAL_WORLD_ACCESS` and parks until the tasks it has awoken either complete their `async_access`
209    /// or have returned `Poll::Pending` for a variety of reasons.
210    /// The performance implications of this call are entirely dependent on the async runtime
211    /// you are using it with, certain poor implementations *could* cause this to take longer
212    /// than expect to resolve.
213    /// Returns `Some` as long as the last call processed any number of waiting `async_access` calls.
214    pub fn wait(&self, schedule: InternedScheduleLabel, world: &mut World) -> Option<()> {
215        let world_id = world.id();
216        if GLOBAL_WAKE_REGISTRY
217            .0
218            .get_or_init(KeyedQueues::new)
219            .get_or_create(&(world_id, schedule))
220            .is_empty()
221        {
222            return None;
223        }
224        // Cleanups the garbage first.
225        for (cleanup_function, task_to_cleanup) in TASKS_TO_CLEANUP
226            .get_or_init(KeyedQueues::new)
227            .get_or_create(&world_id)
228            .try_iter()
229        {
230            cleanup_function(world, task_to_cleanup);
231        }
232        let mut waker_list = bevy_platform::prelude::vec![];
233        let mut task_id_list = bevy_platform::prelude::vec![];
234        while let Ok((waker, system_init, task_id)) = GLOBAL_WAKE_REGISTRY
235            .0
236            .get_or_init(KeyedQueues::new)
237            .get_or_create(&(world_id, schedule))
238            .pop()
239        {
240            // It's okay to call this every time, because it only *actually* inits the system if the task id is new
241            system_init(world, task_id);
242            waker_list.push(waker);
243            task_id_list.push(task_id);
244        }
245        let wait_group = WaitGroup::new();
246        world.init_resource::<WakeParkBarrier>();
247        for task in task_id_list {
248            world
249                .resource_mut::<WakeParkBarrier>()
250                .0
251                .lock()
252                .unwrap()
253                .insert(task, wait_group.clone());
254        }
255        GLOBAL_WORLD_ACCESS.set(world, || {
256            for waker in waker_list {
257                waker.wake();
258            }
259            // We do this because we can get spurious wakes, but we wanna ensure that
260            // we stay parked until we have at least given every poll a chance to happen.
261            wait_group.wait();
262        })?;
263        // Applies all the commands stored up to the world
264        world.try_resource_scope(|world, mut appliers: Mut<SystemParamAppliers>| {
265            appliers.run(world);
266        });
267        Some(())
268    }
269}
270
271/// This is a very low contention, no contention in the normal execution path, way of storing and
272/// using a `UnsafeWorldCell` from any thread/async task/async runtime.
273/// The `Mutex<PhantomData<>>` is used to return `Poll::Pending` early from an `async_access` if
274/// another `async_access` is currently using it.
275pub(crate) struct WorldAccessRegistry(
276    OnceLock<
277        RwLock<
278            HashMap<
279                WorldId,
280                RwLock<
281                    Option<(
282                        UnsafeWorldCell<'static>,
283                        Mutex<PhantomData<UnsafeWorldCell<'static>>>,
284                    )>,
285                >,
286            >,
287        >,
288    >,
289);
290
291impl WorldAccessRegistry {
292    /// During this `func: FnOnce()` call, calling `get` will access the stored `UnsafeWorldCell`
293    fn set(&self, world: &mut World, func: impl FnOnce()) -> Option<()> {
294        let this = self.0.get_or_init(|| RwLock::new(HashMap::new()));
295        let world_id = world.id();
296        if !this.read().unwrap().contains_key(&world_id) {
297            // VERY rare only happens the first time we try to do anything async in a new World
298            let _ = this.write().unwrap().insert(world_id, RwLock::new(None));
299        }
300
301        struct ClearOnDropGuard<'a> {
302            slot: &'a RwLock<
303                Option<(
304                    UnsafeWorldCell<'static>,
305                    Mutex<PhantomData<UnsafeWorldCell<'static>>>,
306                )>,
307            >,
308        }
309        impl<'a> Drop for ClearOnDropGuard<'a> {
310            fn drop(&mut self) {
311                // clear it on the way out
312                // we can't actually panic here because panicking in a drop is bad
313                match self.slot.write() {
314                    Ok(mut slot) => {
315                        let _ = slot.take();
316                    }
317                    Err(_) => {
318                        // This is okay because the mutex is poisoned so nothing can access the
319                        // UnsafeWorldCell now.
320                    }
321                }
322            }
323        }
324        // SAFETY: This mem transmute is safe only because we drop it after, and our GLOBAL_WORLD_ACCESS is private, and we don't clone it
325        // where we do use it, so the lifetime doesn't get propagated anywhere.
326        // Lifetimes are not used in any actual code optimization, so turning it into a static does not violate any of rust's rules
327        // As *LONG* as we keep it within it's lifetime, which we do here, manually, with our `ClearOnDrop` struct.
328        unsafe {
329            let binding = this.read().unwrap();
330            let world_container = binding.get(&world_id).unwrap();
331            // SAFETY this is required in order to make sure that even in the event of a panic, this can't get accessed
332            let _clear = ClearOnDropGuard {
333                slot: world_container,
334            };
335            // SAFETY: This mem transmute is safe only because we drop it after, and our GLOBAL_WORLD_ACCESS is private, and we don't clone it
336            // where we do use it, so the lifetime doesn't get propagated anywhere.
337            // Lifetimes are not used in any actual code optimization, so turning it into a static does not violate any of rust's rules
338            // As *LONG* as we keep it within it's lifetime, which we do here, manually, with our `ClearOnDrop` struct.
339            world_container.write().unwrap().replace((
340                core::mem::transmute::<UnsafeWorldCell, UnsafeWorldCell<'static>>(
341                    world.as_unsafe_world_cell(),
342                ),
343                Mutex::new(PhantomData),
344            ));
345            func();
346        }
347        Some(())
348    }
349    fn get<T>(
350        &self,
351        world_id: WorldId,
352        task_id: AsyncTaskId,
353        func: impl FnOnce(UnsafeWorldCell) -> Poll<T>,
354    ) -> Option<Poll<T>> {
355        // it's okay to *not* do the RaiiThing on these early returns, because that means we aren't in a state
356        // where a thread is parked because of our world.
357        let a = self.0.get()?.read().unwrap();
358        let b = a.get(&world_id)?.read().unwrap();
359        let our_thing = b.as_ref()?;
360        // SAFETY: WakeParkBarrier is only *read* during this section per world, so reading it
361        // without an associated mutex is okay.
362        // Furthermore the WakeParkBarrier cannot be queried by `async_access` because it's type
363        // is not public, `async_access` cannot access `&mut World` to do a dynamic resource
364        // modification.
365        let _async_barrier = unsafe {
366            our_thing
367                .0
368                .get_resource::<WakeParkBarrier>()
369                .unwrap()
370                .0
371                .lock()
372                .unwrap()
373                .remove(&task_id)?
374        };
375        // this allows us to effectively yield as if pending if the world doesn't exist rn.
376        let _world = our_thing.1.try_lock().ok()?;
377        // SAFETY: this is safe because we ensure no one else has access to the world.
378        Some(func(our_thing.0))
379    }
380}
381
382impl<P: bevy_ecs::system::SystemParam + 'static> EcsTask<P> {
383    pub async fn run_system<Func, Out>(self, schedule: impl ScheduleLabel, ecs_access: Func) -> Out
384    where
385        for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
386    {
387        async_access(self, schedule, ecs_access).await
388    }
389}
390
391pub trait CreateEcsTask {
392    fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P>;
393}
394
395impl CreateEcsTask for WorldId {
396    fn ecs_task<P: SystemParam + 'static>(self) -> EcsTask<P> {
397        EcsTask::new(self)
398    }
399}
400
401/// Allows you to access the ECS from any arbitrary async runtime.
402/// Calls will never return immediately and will always start Pending at least once.
403/// Call this with the same `EcsTask` to persist `SystemParams` like `Local` or `Changed`
404/// Just use `world_id` if you do not mind a new `SystemParam` being initialized every time.
405async fn async_access<P, Func, Out>(
406    task_identifier: impl Into<EcsTask<P>>,
407    schedule: impl ScheduleLabel,
408    ecs_access: Func,
409) -> Out
410where
411    P: SystemParam + 'static,
412    for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
413{
414    let task_identifier = task_identifier.into();
415    PendingEcsCall::<P, Func, Out>(
416        PhantomData::<P>,
417        PhantomData,
418        Some(ecs_access),
419        (task_identifier.0.1, schedule.intern()),
420        task_identifier.0.0,
421    )
422    .await
423}
424
425static TASKS_TO_CLEANUP: OnceLock<
426    KeyedQueues<WorldId, (fn(&mut World, AsyncTaskId), AsyncTaskId)>,
427> = OnceLock::new();
428
429/// Pass the `EcsTask` into here after you're done using it
430/// This function will mark the `SystemState` for that task for cleanup.
431fn cleanup_ecs_task<P: SystemParam + 'static>(task: &InternalEcsTask<P>) {
432    fn cleanup_task<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
433        world.try_resource_scope(|_world, param_pool: Mut<SystemStatePool<P>>| {
434            let mut pool = param_pool.0.write().unwrap();
435            pool.remove(&task_id);
436            if pool.len() * 2 < pool.capacity() {
437                pool.shrink_to_fit();
438            }
439        });
440    }
441    // Should never panic cause this is an unbounded queue
442    match TASKS_TO_CLEANUP
443        .get_or_init(KeyedQueues::new)
444        .try_send(&task.1, (cleanup_task::<P>, task.0))
445    {
446        Ok(_) => {}
447        Err(_) => unreachable!(),
448    }
449}
450
451impl<P: SystemParam + 'static> From<WorldId> for EcsTask<P> {
452    fn from(value: WorldId) -> Self {
453        EcsTask(Arc::new(InternalEcsTask(
454            AsyncTaskId::new().unwrap(),
455            value,
456            PhantomData,
457        )))
458    }
459}
460
461/// An `EcsTask` can be re-used in order to persist `SystemParams` like `Local`, `Changed`, or `Added`
462pub struct EcsTask<P: SystemParam + 'static>(Arc<InternalEcsTask<P>>);
463
464struct InternalEcsTask<P: SystemParam + 'static>(AsyncTaskId, WorldId, PhantomData<P>);
465
466impl<T: SystemParam + 'static> Drop for InternalEcsTask<T> {
467    fn drop(&mut self) {
468        cleanup_ecs_task(self);
469    }
470}
471
472impl<P: SystemParam + 'static> Clone for EcsTask<P> {
473    fn clone(&self) -> Self {
474        EcsTask(self.0.clone())
475    }
476}
477impl<P: SystemParam + 'static> EcsTask<P> {
478    /// Generates a new unique `EcsTask` that can be re-used in order to persist `SystemParams`
479    /// like `Local`, `Changed`, or `Added`
480    pub fn new(world_id: WorldId) -> Self {
481        Self(Arc::new(InternalEcsTask(
482            AsyncTaskId::new().unwrap(),
483            world_id,
484            PhantomData,
485        )))
486    }
487}
488
489struct PendingEcsCall<P: SystemParam + 'static, Func, Out>(
490    PhantomData<P>,
491    PhantomData<Out>,
492    Option<Func>,
493    (WorldId, InternedScheduleLabel),
494    AsyncTaskId,
495);
496
497impl<P: SystemParam + 'static, Func, Out> Unpin for PendingEcsCall<P, Func, Out> {}
498
499impl<P, Func, Out> Future for PendingEcsCall<P, Func, Out>
500where
501    P: SystemParam + 'static,
502    for<'w, 's> Func: FnOnce(P::Item<'w, 's>) -> Out,
503{
504    type Output = Out;
505
506    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
507        fn system_state_init<P: SystemParam + 'static>(world: &mut World, task_id: AsyncTaskId) {
508            world.init_resource::<SystemStatePool<P>>();
509            if !world
510                .get_resource::<SystemStatePool<P>>()
511                .unwrap()
512                .0
513                .read()
514                .unwrap()
515                .contains_key(&task_id)
516            {
517                let system_state = SystemState::<P>::new(world);
518                let cq = ConcurrentQueue::bounded(1);
519                match cq.push(system_state) {
520                    Ok(_) => {}
521                    Err(_) => {
522                        panic!()
523                    }
524                }
525                world
526                    .get_resource::<SystemStatePool<P>>()
527                    .unwrap()
528                    .0
529                    .write()
530                    .unwrap()
531                    .insert(task_id, cq);
532            }
533        }
534
535        let task_id = self.4;
536        let world_id = self.3.0;
537
538        match GLOBAL_WORLD_ACCESS.get(world_id, task_id,|world: UnsafeWorldCell| {
539            // SAFETY: We have a fake-mutex around our world, so no one else can do mutable access to it.
540            let Some(system_param_queue) = (unsafe { world.get_resource::<SystemStatePool<P>>() }) else { return Poll::Pending };
541            let mut system_state = match system_param_queue.0.read().unwrap().get(&task_id) {
542                None => return Poll::Pending,
543                Some(cq) => cq.pop().unwrap(),
544            };
545            let out;
546            // SAFETY: This is safe because we have a fake-mutex around our world cell, so only one thing can have access to it at a time.
547            unsafe {
548                let default_error_handler = world.default_error_handler();
549                // Obtain params and immediately consume them with the closure,
550                // ensuring the borrow ends before `apply`.
551                if let Err(err) = SystemState::validate_param(&mut system_state, world) {
552                    default_error_handler(err.into(), ErrorContext::System {
553                        name: system_state.meta().name().clone(),
554                        last_run: /*system_state.meta().last_run*/ Tick::new(0),
555                    });
556                }
557                if !system_state.meta().is_send() {
558                    default_error_handler(SystemParamValidationError::invalid::<NonSend<()>>(
559                        "Cannot have your system be non-send / exclusive",
560                    ).into(), ErrorContext::System {
561                        name: system_state.meta().name().clone(),
562                        last_run: /*system_state.meta.last_run */Tick::new(0),
563                    });
564                }
565                let state = system_state.get_unchecked(world);
566                out = self.as_mut().2.take().unwrap()(state);
567            }
568            // SAFETY: We have a fake-mutex around our world, so no one else can do mutable access to it.
569            unsafe {
570                match world
571                    .get_resource::<SystemStatePool<P>>()
572                    .unwrap()
573                    .0
574                    .read()
575                    .unwrap()
576                    .get(&task_id)
577                    .unwrap()
578                    .push(system_state)
579                {
580                    Ok(_) => {}
581                    Err(_) => unreachable!("SystemStatePool should not be able to be removed if it previously existed, otherwise an invariant was violated"),
582                }
583            }
584            Poll::Ready(out)
585        }) {
586            Some(awa) => awa,
587            _ => {
588                // This must be a static, sadly, because we must always make sure that we can store
589                // our pending wakers no matter what. Everything else that we care about can be
590                // stored on the world itself, but this must always be accessible, even if another
591                // `async_access` is currently running.
592                match GLOBAL_WAKE_REGISTRY
593                    .0
594                    .get_or_init(KeyedQueues::new)
595                    .try_send(
596                        &self.3,
597                        (cx.waker().clone(), system_state_init::<P>, task_id),
598                    ) {
599                    Ok(_) => {}
600                    // This should never panic because we never `close` our concurrent queues and
601                    // the concurrent queue here is unbounded.
602                    Err(_) => unreachable!(),
603                }
604                Poll::Pending
605            }
606        }
607    }
608}