bevy_wasm_tasks/
lib.rs

1use bevy_app::{App, Plugin, Update};
2use bevy_ecs::{prelude::World, system::Resource};
3use futures_util::FutureExt;
4use std::future::Future;
5use std::sync::atomic::{AtomicUsize, Ordering};
6use std::sync::Arc;
7use futures_util::future::RemoteHandle;
8
9/// An internal struct keeping track of how many ticks have elapsed since the start of the program.
10#[derive(Resource)]
11struct UpdateTicks {
12    ticks: Arc<AtomicUsize>,
13    update_watch_tx: tokio::sync::watch::Sender<()>,
14}
15
16impl UpdateTicks {
17    fn increment_ticks(&self) -> usize {
18        let new_ticks = self.ticks.fetch_add(1, Ordering::SeqCst).wrapping_add(1);
19        self.update_watch_tx
20            .send(())
21            .expect("Failed to send update_watch channel message");
22        new_ticks
23    }
24}
25
26/// The Bevy [`Plugin`] which sets up the [`WASMTasksRuntime`] Bevy resource and registers
27/// the [`tick_runtime_update`] exclusive system.
28pub struct WASMTasksPlugin;
29
30impl Plugin for WASMTasksPlugin {
31    fn build(&self, app: &mut App) {
32        let ticks = Arc::new(AtomicUsize::new(0));
33        let (update_watch_tx, update_watch_rx) = tokio::sync::watch::channel(());
34        app.insert_resource(UpdateTicks {
35            ticks: ticks.clone(),
36            update_watch_tx,
37        });
38        app.insert_resource(WASMTasksRuntime::new(ticks, update_watch_rx));
39        app.add_systems(Update, tick_runtime_update);
40    }
41}
42
43/// The Bevy exclusive system which executes the main thread callbacks that background
44/// tasks requested using [`run_on_main_thread`](TaskContext::run_on_main_thread). You
45/// can control which [`CoreStage`] this system executes in by specifying a custom
46/// [`tick_stage`](WASMTasksPlugin::tick_stage) value.
47pub fn tick_runtime_update(world: &mut World) {
48    let current_tick = {
49        let tick_counter = match world.get_resource::<UpdateTicks>() {
50            Some(counter) => counter,
51            None => return,
52        };
53
54        // Increment update ticks and notify watchers of update tick.
55        tick_counter.increment_ticks()
56    };
57
58    if let Some(mut runtime) = world.remove_resource::<WASMTasksRuntime>() {
59        runtime.execute_main_thread_work(world, current_tick);
60        world.insert_resource(runtime);
61    }
62}
63
64type MainThreadCallback = Box<dyn FnOnce(MainThreadContext) + Send + 'static>;
65
66/// The Bevy [`Resource`] which allows for spawning new background tasks.
67#[derive(Resource)]
68pub struct WASMTasksRuntime(Box<WASMTasksRuntimeInner>);
69
70/// The inner fields are boxed to reduce the cost of the every-frame move out of and back into
71/// the world in [`tick_runtime_update`].
72struct WASMTasksRuntimeInner {
73    ticks: Arc<AtomicUsize>,
74    update_watch_rx: tokio::sync::watch::Receiver<()>,
75    update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
76    update_run_rx: tokio::sync::mpsc::UnboundedReceiver<MainThreadCallback>,
77}
78
79pub struct JoinHandle<T> {
80    handle: Option<RemoteHandle<T>>,
81}
82
83impl <T> JoinHandle<T> {
84    pub fn take(&mut self) -> Option<RemoteHandle<T>> {
85        self.handle.take()
86    }
87}
88
89impl<T> Drop for JoinHandle<T> {
90    /// To match Tokio behavior and make it easier to handle throwaway tasks,
91    /// if a JoinHandle is dropped without the inner RemoteHandle being taken,
92    /// we simply forget it so that it's able to continue to completion.
93    fn drop(&mut self) {
94        if let Some(handle) = self.take() {
95            handle.forget();
96        }
97    }
98}
99
100impl WASMTasksRuntime {
101    fn new(ticks: Arc<AtomicUsize>, update_watch_rx: tokio::sync::watch::Receiver<()>) -> Self {
102        let (update_run_tx, update_run_rx) = tokio::sync::mpsc::unbounded_channel();
103
104        Self(Box::new(WASMTasksRuntimeInner {
105            ticks,
106            update_watch_rx,
107            update_run_tx,
108            update_run_rx,
109        }))
110    }
111
112    /// Spawn a task which will run using WASM futures. The background task is provided a
113    /// [`TaskContext`] which allows it to do things like [sleep for a given number of main thread updates](TaskContext::sleep_updates)
114    /// or [invoke callbacks on the main Bevy thread](TaskContext::run_on_main_thread).
115    pub fn spawn_background_task<Task, Output, Spawnable>(&self, spawnable_task: Spawnable) -> JoinHandle<Output>
116    where
117        Task: Future<Output = Output> + 'static,
118        Spawnable: FnOnce(TaskContext) -> Task + 'static,
119    {
120        let inner = &self.0;
121        let context = TaskContext {
122            update_watch_rx: inner.update_watch_rx.clone(),
123            ticks: inner.ticks.clone(),
124            update_run_tx: inner.update_run_tx.clone(),
125        };
126        let future = spawnable_task(context);
127        let (future, handle) = future.remote_handle();
128        wasm_bindgen_futures::spawn_local(future);
129        JoinHandle {
130            handle: Some(handle),
131        }
132    }
133
134    /// Execute all of the requested runnables on the main thread.
135    pub(crate) fn execute_main_thread_work(&mut self, world: &mut World, current_tick: usize) {
136        while let Ok(runnable) = self.0.update_run_rx.try_recv() {
137            let context = MainThreadContext {
138                world,
139                current_tick,
140            };
141            runnable(context);
142        }
143    }
144}
145
146/// The context arguments which are available to main thread callbacks requested using
147/// [`run_on_main_thread`](TaskContext::run_on_main_thread).
148pub struct MainThreadContext<'a> {
149    /// A mutable reference to the main Bevy [World].
150    pub world: &'a mut World,
151    /// The current update tick in which the current main thread callback is executing.
152    pub current_tick: usize,
153}
154
155/// The context arguments which are available to background tasks spawned onto the
156/// [`WASMTasksRuntime`].
157#[derive(Clone)]
158pub struct TaskContext {
159    update_watch_rx: tokio::sync::watch::Receiver<()>,
160    update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
161    ticks: Arc<AtomicUsize>,
162}
163
164impl TaskContext {
165    /// Returns the current value of the ticket count from the main thread - how many updates
166    /// have occurred since the start of the program. Because the tick count is updated from the
167    /// main thread, the tick count may change any time after this function call returns.
168    pub fn current_tick(&self) -> usize {
169        self.ticks.load(Ordering::SeqCst)
170    }
171
172    /// Sleeps the background task until a given number of main thread updates have occurred. If
173    /// you instead want to sleep for a given length of wall-clock time, sleep using wasmtimer or similar
174    /// function.
175    pub async fn sleep_updates(&mut self, updates_to_sleep: usize) {
176        let target_tick = self
177            .ticks
178            .load(Ordering::SeqCst)
179            .wrapping_add(updates_to_sleep);
180        while self.ticks.load(Ordering::SeqCst) < target_tick {
181            if self.update_watch_rx.changed().await.is_err() {
182                return;
183            }
184        }
185    }
186
187    /// Invokes a synchronous callback on the main Bevy thread. The callback will have mutable access to the
188    /// main Bevy [`World`], allowing it to update any resources or entities that it wants. The callback can
189    /// report results back to the background thread by returning an output value, which will then be returned from
190    /// this async function once the callback runs.
191    pub async fn run_on_main_thread<Runnable, Output>(&mut self, runnable: Runnable) -> Output
192    where
193        Runnable: FnOnce(MainThreadContext) -> Output + Send + 'static,
194        Output: Send + 'static,
195    {
196        let (output_tx, output_rx) = tokio::sync::oneshot::channel();
197        if self.update_run_tx.send(Box::new(move |ctx| {
198            if output_tx.send(runnable(ctx)).is_err() {
199                panic!("Failed to sent output from operation run on main thread back to waiting task");
200            }
201        })).is_err() {
202            panic!("Failed to send operation to be run on main thread");
203        }
204        output_rx
205            .await
206            .expect("Failed to receive output from operation on main thread")
207    }
208}
209