bevy_tokio_tasks/
lib.rs

1use std::future::Future;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use bevy_app::{App, Plugin, Update};
6use bevy_ecs::schedule::{InternedScheduleLabel, ScheduleLabel};
7use bevy_ecs::{prelude::World, resource::Resource};
8
9use tokio::{runtime::Runtime, task::JoinHandle};
10
11/// A re-export of the tokio version used by this crate.
12pub use tokio;
13
14/// An internal struct keeping track of how many ticks have elapsed since the start of the program.
15#[derive(Resource)]
16struct UpdateTicks {
17    ticks: Arc<AtomicUsize>,
18    update_watch_tx: tokio::sync::watch::Sender<()>,
19}
20
21impl UpdateTicks {
22    fn increment_ticks(&self) -> usize {
23        let new_ticks = self.ticks.fetch_add(1, Ordering::SeqCst).wrapping_add(1);
24        self.update_watch_tx
25            .send(())
26            .expect("Failed to send update_watch channel message");
27        new_ticks
28    }
29}
30
31/// The Bevy [`Plugin`] which sets up the [`TokioTasksRuntime`] Bevy resource and registers
32/// the [`tick_runtime_update`] exclusive system.
33pub struct TokioTasksPlugin {
34    /// Callback which is used to create a Tokio runtime when the plugin is installed. The
35    /// default value for this field configures a multi-threaded [`Runtime`] with IO and timer
36    /// functionality enabled if building for non-wasm32 architectures. On wasm32 the current-thread
37    /// scheduler is used instead.
38    pub make_runtime: Box<dyn Fn() -> Runtime + Send + Sync + 'static>,
39    /// The [`ScheduleLabel`] during which the [`tick_runtime_update`] function will be executed.
40    /// The default value for this field is [`Update`].
41    pub schedule_label: InternedScheduleLabel,
42}
43
44impl Default for TokioTasksPlugin {
45    /// Configures the plugin to build a new Tokio [`Runtime`] with both IO and timer functionality
46    /// enabled. On the wasm32 architecture, the [`Runtime`] will be the current-thread runtime, on all other
47    /// architectures the [`Runtime`] will be the multi-thread runtime.
48    /// 
49    /// The default schedule label is [`Update`].
50    fn default() -> Self {
51        Self {
52            make_runtime: Box::new(|| {
53                #[cfg(not(target_arch = "wasm32"))]
54                let mut runtime = tokio::runtime::Builder::new_multi_thread();
55                #[cfg(target_arch = "wasm32")]
56                let mut runtime = tokio::runtime::Builder::new_current_thread();
57                runtime.enable_all();
58                runtime
59                    .build()
60                    .expect("Failed to create Tokio runtime for background tasks")
61            }),
62            schedule_label: Update.intern()
63        }
64    }
65}
66
67impl Plugin for TokioTasksPlugin {
68    fn build(&self, app: &mut App) {
69        let ticks = Arc::new(AtomicUsize::new(0));
70        let (update_watch_tx, update_watch_rx) = tokio::sync::watch::channel(());
71        let runtime = (self.make_runtime)();
72        app.insert_resource(UpdateTicks {
73            ticks: ticks.clone(),
74            update_watch_tx,
75        });
76        app.insert_resource(TokioTasksRuntime::new(ticks, runtime, update_watch_rx));
77        app.add_systems(self.schedule_label, tick_runtime_update);
78    }
79}
80
81/// The Bevy exclusive system which executes the main thread callbacks that background
82/// tasks requested using [`run_on_main_thread`](TaskContext::run_on_main_thread). You
83/// can control which Bevy schedule stage this system executes in by specifying a custom
84/// [`schedule_label`](TokioTasksPlugin::schedule_label) value.
85pub fn tick_runtime_update(world: &mut World) {
86    let current_tick = {
87        let tick_counter = match world.get_resource::<UpdateTicks>() {
88            Some(counter) => counter,
89            None => return,
90        };
91
92        // Increment update ticks and notify watchers of update tick.
93        tick_counter.increment_ticks()
94    };
95
96    if let Some(mut runtime) = world.remove_resource::<TokioTasksRuntime>() {
97        runtime.execute_main_thread_work(world, current_tick);
98        world.insert_resource(runtime);
99    }
100}
101
102type MainThreadCallback = Box<dyn FnOnce(MainThreadContext) + Send + 'static>;
103
104/// The Bevy [`Resource`] which stores the Tokio [`Runtime`] and allows for spawning new
105/// background tasks.
106#[derive(Resource)]
107pub struct TokioTasksRuntime(Box<TokioTasksRuntimeInner>);
108
109/// The inner fields are boxed to reduce the cost of the every-frame move out of and back into
110/// the world in [`tick_runtime_update`].
111struct TokioTasksRuntimeInner {
112    runtime: Runtime,
113    ticks: Arc<AtomicUsize>,
114    update_watch_rx: tokio::sync::watch::Receiver<()>,
115    update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
116    update_run_rx: tokio::sync::mpsc::UnboundedReceiver<MainThreadCallback>,
117}
118
119impl TokioTasksRuntime {
120    fn new(
121        ticks: Arc<AtomicUsize>,
122        runtime: Runtime,
123        update_watch_rx: tokio::sync::watch::Receiver<()>,
124    ) -> Self {
125        let (update_run_tx, update_run_rx) = tokio::sync::mpsc::unbounded_channel();
126
127        Self(Box::new(TokioTasksRuntimeInner {
128            runtime,
129            ticks,
130            update_watch_rx,
131            update_run_tx,
132            update_run_rx,
133        }))
134    }
135
136    /// Returns the Tokio [`Runtime`] on which background tasks are executed. You can specify
137    /// how this is created by providing a custom [`make_runtime`](TokioTasksPlugin::make_runtime).
138    pub fn runtime(&self) -> &Runtime {
139        &self.0.runtime
140    }
141
142    /// Spawn a task which will run on the background Tokio [`Runtime`] managed by this [`TokioTasksRuntime`]. The
143    /// background task is provided a [`TaskContext`] which allows it to do things like
144    /// [sleep for a given number of main thread updates](TaskContext::sleep_updates) or
145    /// [invoke callbacks on the main Bevy thread](TaskContext::run_on_main_thread).
146    pub fn spawn_background_task<Task, Output, Spawnable>(
147        &self,
148        spawnable_task: Spawnable,
149    ) -> JoinHandle<Output>
150    where
151        Task: Future<Output = Output> + Send + 'static,
152        Output: Send + 'static,
153        Spawnable: FnOnce(TaskContext) -> Task + Send + 'static,
154    {
155        let inner = &self.0;
156        let context = TaskContext {
157            update_watch_rx: inner.update_watch_rx.clone(),
158            ticks: inner.ticks.clone(),
159            update_run_tx: inner.update_run_tx.clone(),
160        };
161        let future = spawnable_task(context);
162        inner.runtime.spawn(future)
163    }
164
165    /// Execute all of the requested runnables on the main thread.
166    pub(crate) fn execute_main_thread_work(&mut self, world: &mut World, current_tick: usize) {
167        // Running this single future which yields once allows the runtime to process tasks
168        // if the runtime is a current_thread runtime. If its a multi-thread runtime then
169        // this isn't necessary but is harmless.
170        self.0.runtime.block_on(async {
171            tokio::task::yield_now().await;
172        });
173        while let Ok(runnable) = self.0.update_run_rx.try_recv() {
174            let context = MainThreadContext {
175                world,
176                current_tick,
177            };
178            runnable(context);
179        }
180    }
181}
182
183/// The context arguments which are available to main thread callbacks requested using
184/// [`run_on_main_thread`](TaskContext::run_on_main_thread).
185pub struct MainThreadContext<'a> {
186    /// A mutable reference to the main Bevy [World].
187    pub world: &'a mut World,
188    /// The current update tick in which the current main thread callback is executing.
189    pub current_tick: usize,
190}
191
192/// The context arguments which are available to background tasks spawned onto the
193/// [`TokioTasksRuntime`].
194#[derive(Clone)]
195pub struct TaskContext {
196    update_watch_rx: tokio::sync::watch::Receiver<()>,
197    update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
198    ticks: Arc<AtomicUsize>,
199}
200
201impl TaskContext {
202    /// Returns the current value of the ticket count from the main thread - how many updates
203    /// have occurred since the start of the program. Because the tick count is updated from the
204    /// main thread, the tick count may change any time after this function call returns.
205    pub fn current_tick(&self) -> usize {
206        self.ticks.load(Ordering::SeqCst)
207    }
208
209    /// Sleeps the background task until a given number of main thread updates have occurred. If
210    /// you instead want to sleep for a given length of wall-clock time, call the normal Tokio sleep
211    /// function.
212    pub async fn sleep_updates(&mut self, updates_to_sleep: usize) {
213        let target_tick = self
214            .ticks
215            .load(Ordering::SeqCst)
216            .wrapping_add(updates_to_sleep);
217        while self.ticks.load(Ordering::SeqCst) < target_tick {
218            if self.update_watch_rx.changed().await.is_err() {
219                return;
220            }
221        }
222    }
223
224    /// Invokes a synchronous callback on the main Bevy thread. The callback will have mutable access to the
225    /// main Bevy [`World`], allowing it to update any resources or entities that it wants. The callback can
226    /// report results back to the background thread by returning an output value, which will then be returned from
227    /// this async function once the callback runs.
228    pub async fn run_on_main_thread<Runnable, Output>(&mut self, runnable: Runnable) -> Output
229    where
230        Runnable: FnOnce(MainThreadContext) -> Output + Send + 'static,
231        Output: Send + 'static,
232    {
233        let (output_tx, output_rx) = tokio::sync::oneshot::channel();
234        if self.update_run_tx.send(Box::new(move |ctx| {
235            if output_tx.send(runnable(ctx)).is_err() {
236                panic!("Failed to sent output from operation run on main thread back to waiting task");
237            }
238        })).is_err() {
239            panic!("Failed to send operation to be run on main thread");
240        }
241        output_rx
242            .await
243            .expect("Failed to receive output from operation on main thread")
244    }
245}