1use std::future::Future;
2use std::sync::atomic::{AtomicUsize, Ordering};
3use std::sync::Arc;
4
5use bevy_app::{App, Plugin, Update};
6use bevy_ecs::schedule::{InternedScheduleLabel, ScheduleLabel};
7use bevy_ecs::{prelude::World, resource::Resource};
8
9use tokio::{runtime::Runtime, task::JoinHandle};
10
11pub use tokio;
13
14#[derive(Resource)]
16struct UpdateTicks {
17 ticks: Arc<AtomicUsize>,
18 update_watch_tx: tokio::sync::watch::Sender<()>,
19}
20
21impl UpdateTicks {
22 fn increment_ticks(&self) -> usize {
23 let new_ticks = self.ticks.fetch_add(1, Ordering::SeqCst).wrapping_add(1);
24 self.update_watch_tx
25 .send(())
26 .expect("Failed to send update_watch channel message");
27 new_ticks
28 }
29}
30
31pub struct TokioTasksPlugin {
34 pub make_runtime: Box<dyn Fn() -> Runtime + Send + Sync + 'static>,
39 pub schedule_label: InternedScheduleLabel,
42}
43
44impl Default for TokioTasksPlugin {
45 fn default() -> Self {
51 Self {
52 make_runtime: Box::new(|| {
53 #[cfg(not(target_arch = "wasm32"))]
54 let mut runtime = tokio::runtime::Builder::new_multi_thread();
55 #[cfg(target_arch = "wasm32")]
56 let mut runtime = tokio::runtime::Builder::new_current_thread();
57 runtime.enable_all();
58 runtime
59 .build()
60 .expect("Failed to create Tokio runtime for background tasks")
61 }),
62 schedule_label: Update.intern()
63 }
64 }
65}
66
67impl Plugin for TokioTasksPlugin {
68 fn build(&self, app: &mut App) {
69 let ticks = Arc::new(AtomicUsize::new(0));
70 let (update_watch_tx, update_watch_rx) = tokio::sync::watch::channel(());
71 let runtime = (self.make_runtime)();
72 app.insert_resource(UpdateTicks {
73 ticks: ticks.clone(),
74 update_watch_tx,
75 });
76 app.insert_resource(TokioTasksRuntime::new(ticks, runtime, update_watch_rx));
77 app.add_systems(self.schedule_label, tick_runtime_update);
78 }
79}
80
81pub fn tick_runtime_update(world: &mut World) {
86 let current_tick = {
87 let tick_counter = match world.get_resource::<UpdateTicks>() {
88 Some(counter) => counter,
89 None => return,
90 };
91
92 tick_counter.increment_ticks()
94 };
95
96 if let Some(mut runtime) = world.remove_resource::<TokioTasksRuntime>() {
97 runtime.execute_main_thread_work(world, current_tick);
98 world.insert_resource(runtime);
99 }
100}
101
102type MainThreadCallback = Box<dyn FnOnce(MainThreadContext) + Send + 'static>;
103
104#[derive(Resource)]
107pub struct TokioTasksRuntime(Box<TokioTasksRuntimeInner>);
108
109struct TokioTasksRuntimeInner {
112 runtime: Runtime,
113 ticks: Arc<AtomicUsize>,
114 update_watch_rx: tokio::sync::watch::Receiver<()>,
115 update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
116 update_run_rx: tokio::sync::mpsc::UnboundedReceiver<MainThreadCallback>,
117}
118
119impl TokioTasksRuntime {
120 fn new(
121 ticks: Arc<AtomicUsize>,
122 runtime: Runtime,
123 update_watch_rx: tokio::sync::watch::Receiver<()>,
124 ) -> Self {
125 let (update_run_tx, update_run_rx) = tokio::sync::mpsc::unbounded_channel();
126
127 Self(Box::new(TokioTasksRuntimeInner {
128 runtime,
129 ticks,
130 update_watch_rx,
131 update_run_tx,
132 update_run_rx,
133 }))
134 }
135
136 pub fn runtime(&self) -> &Runtime {
139 &self.0.runtime
140 }
141
142 pub fn spawn_background_task<Task, Output, Spawnable>(
147 &self,
148 spawnable_task: Spawnable,
149 ) -> JoinHandle<Output>
150 where
151 Task: Future<Output = Output> + Send + 'static,
152 Output: Send + 'static,
153 Spawnable: FnOnce(TaskContext) -> Task + Send + 'static,
154 {
155 let inner = &self.0;
156 let context = TaskContext {
157 update_watch_rx: inner.update_watch_rx.clone(),
158 ticks: inner.ticks.clone(),
159 update_run_tx: inner.update_run_tx.clone(),
160 };
161 let future = spawnable_task(context);
162 inner.runtime.spawn(future)
163 }
164
165 pub(crate) fn execute_main_thread_work(&mut self, world: &mut World, current_tick: usize) {
167 self.0.runtime.block_on(async {
171 tokio::task::yield_now().await;
172 });
173 while let Ok(runnable) = self.0.update_run_rx.try_recv() {
174 let context = MainThreadContext {
175 world,
176 current_tick,
177 };
178 runnable(context);
179 }
180 }
181}
182
183pub struct MainThreadContext<'a> {
186 pub world: &'a mut World,
188 pub current_tick: usize,
190}
191
192#[derive(Clone)]
195pub struct TaskContext {
196 update_watch_rx: tokio::sync::watch::Receiver<()>,
197 update_run_tx: tokio::sync::mpsc::UnboundedSender<MainThreadCallback>,
198 ticks: Arc<AtomicUsize>,
199}
200
201impl TaskContext {
202 pub fn current_tick(&self) -> usize {
206 self.ticks.load(Ordering::SeqCst)
207 }
208
209 pub async fn sleep_updates(&mut self, updates_to_sleep: usize) {
213 let target_tick = self
214 .ticks
215 .load(Ordering::SeqCst)
216 .wrapping_add(updates_to_sleep);
217 while self.ticks.load(Ordering::SeqCst) < target_tick {
218 if self.update_watch_rx.changed().await.is_err() {
219 return;
220 }
221 }
222 }
223
224 pub async fn run_on_main_thread<Runnable, Output>(&mut self, runnable: Runnable) -> Output
229 where
230 Runnable: FnOnce(MainThreadContext) -> Output + Send + 'static,
231 Output: Send + 'static,
232 {
233 let (output_tx, output_rx) = tokio::sync::oneshot::channel();
234 if self.update_run_tx.send(Box::new(move |ctx| {
235 if output_tx.send(runnable(ctx)).is_err() {
236 panic!("Failed to sent output from operation run on main thread back to waiting task");
237 }
238 })).is_err() {
239 panic!("Failed to send operation to be run on main thread");
240 }
241 output_rx
242 .await
243 .expect("Failed to receive output from operation on main thread")
244 }
245}