1use crate::util::remove_system;
2use crate::world::AsyncWorld;
3use crate::{die, recv};
4use bevy_ecs::prelude::*;
5use bevy_ecs::system::{BoxedSystem, SystemId};
6use bevy_platform::sync::Arc;
7use std::any::Any;
8use std::marker::PhantomData;
9
10type BoxedAnySend = Box<dyn Any + Send>;
11type SystemIdWithIO = SystemId<In<BoxedAnySend>, BoxedAnySend>;
12type BoxedSystemWithIO = BoxedSystem<In<BoxedAnySend>, BoxedAnySend>;
13
14#[derive(Debug, Clone)]
21pub struct AsyncSystem {
22 id: Arc<SystemId>,
23 world: AsyncWorld,
24}
25
26impl AsyncSystem {
27 pub(crate) async fn new(system: BoxedSystem, world: AsyncWorld) -> Self {
28 let (id_tx, id_rx) = async_channel::bounded(1);
29 world
30 .apply(move |world: &mut World| {
31 let id = world.register_boxed_system(system);
32 id_tx.try_send(id).unwrap_or_else(die);
33 })
34 .await;
35 let id = recv(id_rx).await;
36 let id = Arc::new(id);
37 Self { id, world }
38 }
39
40 pub async fn run(&self) {
42 let id = *self.id;
43 self.world
44 .apply(move |world: &mut World| {
45 world.run_system(id).unwrap_or_else(die);
46 })
47 .await;
48 }
49
50 pub async fn unregister(self) {
56 let Self { id, world } = self;
57 if let Some(id) = Arc::into_inner(id) {
58 world.apply(remove_system(id)).await;
59 }
60 }
61}
62
63#[derive(Debug)]
71pub struct AsyncIOSystem<I: Send, O: Send> {
72 id: Arc<SystemIdWithIO>,
73 world: AsyncWorld,
74 _pd: PhantomData<fn(I) -> O>,
75}
76
77impl<I: Send, O: Send> Clone for AsyncIOSystem<I, O> {
78 fn clone(&self) -> Self {
79 Self {
80 id: Arc::clone(&self.id),
81 world: self.world.clone(),
82 _pd: PhantomData,
83 }
84 }
85}
86
87impl<I: Send + 'static, O: Send + 'static> AsyncIOSystem<I, O> {
88 pub(crate) async fn new<M>(
89 system: impl IntoSystem<In<I>, O, M> + Send,
90 world: AsyncWorld,
91 ) -> Self {
92 fn unbox_input<I: Send + 'static>(In(boxed): In<BoxedAnySend>) -> I {
93 let concrete = boxed.downcast().unwrap_or_else(die);
94 *concrete
95 }
96
97 fn box_output<O: Send + 'static>(In(output): In<O>) -> BoxedAnySend {
98 Box::new(output)
99 }
100
101 let system = unbox_input.pipe(system).pipe(box_output);
102 let system: BoxedSystemWithIO = Box::new(IntoSystem::into_system(system));
103
104 let (id_tx, id_rx) = async_channel::bounded(1);
105 world
106 .apply(move |world: &mut World| {
107 let id = world.register_boxed_system(system);
108 id_tx.try_send(id).unwrap_or_else(die);
109 })
110 .await;
111
112 let id = recv(id_rx).await;
113 let id = Arc::new(id);
114
115 Self {
116 id,
117 world,
118 _pd: PhantomData,
119 }
120 }
121
122 pub async fn run(&self, input: I) -> O {
124 let (input_tx, input_rx) = async_channel::bounded(1);
125 let (output_tx, output_rx) = async_channel::bounded(1);
126
127 let input: BoxedAnySend = Box::new(input);
128 input_tx.send(input).await.unwrap_or_else(die);
129
130 let id = *self.id;
131 self.world
132 .apply(move |world: &mut World| {
133 let input = input_rx.try_recv().unwrap_or_else(die);
134 let output = world.run_system_with(id, input).unwrap_or_else(die);
135 output_tx.try_send(output).unwrap_or_else(die);
136 })
137 .await;
138
139 let boxed: BoxedAnySend = recv(output_rx).await;
140 let concrete = boxed.downcast().unwrap_or_else(die);
141 *concrete
142 }
143
144 pub async fn unregister(self) {
150 let Self { id, world, _pd } = self;
151 if let Some(id) = Arc::into_inner(id) {
152 world.apply(remove_system(id)).await
153 }
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use crate::world::AsyncWorld;
160 use crate::AsyncEcsPlugin;
161 use bevy::ecs::system::RegisteredSystemError;
162 use bevy::prelude::*;
163 use bevy::tasks::AsyncComputeTaskPool;
164
165 #[derive(Component)]
166 struct Counter(u8);
167
168 impl Counter {
169 fn go_up(&mut self) {
170 self.0 += 1;
171 }
172 }
173
174 macro_rules! assert_counter {
175 ($id:expr, $value:expr, $world:expr) => {
176 assert_eq!($value, $world.entity($id).get::<Counter>().unwrap().0);
177 };
178 }
179
180 fn increase_counter_all(mut query: Query<&mut Counter>) {
181 for mut counter in query.iter_mut() {
182 counter.go_up();
183 }
184 }
185
186 fn increase_counter(In(id): In<Entity>, mut query: Query<&mut Counter>) {
187 let mut counter = query.get_mut(id).unwrap();
188 counter.go_up();
189 }
190
191 fn get_counter_value(In(id): In<Entity>, query: Query<&Counter>) -> u8 {
192 query.get(id).unwrap().0
193 }
194
195 #[test]
196 fn smoke() {
197 let mut app = App::new();
198 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
199 let id = app.world_mut().spawn(Counter(0)).id();
200 assert_counter!(id, 0, app.world_mut());
201
202 let (barrier_tx, barrier_rx) = async_channel::bounded(1);
203 let async_world = AsyncWorld::from_world(app.world_mut());
204
205 AsyncComputeTaskPool::get()
206 .spawn(async move {
207 let increase_counter_all = async_world.register_system(increase_counter_all).await;
208 increase_counter_all.run().await;
209 barrier_tx.send(()).await.unwrap();
210 })
211 .detach();
212
213 loop {
214 match barrier_rx.try_recv() {
215 Ok(_) => break,
216 Err(_) => app.update(),
217 }
218 }
219 app.update();
220
221 assert_counter!(id, 1, app.world_mut());
222 }
223
224 #[test]
225 fn normal_unregister() {
226 let mut app = App::new();
227 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
228 let id = app.world_mut().spawn(Counter(0)).id();
229 assert_counter!(id, 0, app.world_mut());
230
231 let (sender, receiver) = async_channel::bounded(1);
232 let async_world = AsyncWorld::from_world(app.world_mut());
233
234 AsyncComputeTaskPool::get()
235 .spawn(async move {
236 let increase_counter_all = async_world.register_system(increase_counter_all).await;
237 let ica2 = increase_counter_all.clone();
238 increase_counter_all.unregister().await;
239
240 ica2.run().await;
241
242 let id = *ica2.id;
243 ica2.unregister().await;
244 sender.send(id).await.unwrap();
245 })
246 .detach();
247
248 let system_id = loop {
249 match receiver.try_recv() {
250 Ok(id) => break id,
251 Err(_) => app.update(),
252 }
253 };
254 app.update();
255
256 let err = app.world_mut().unregister_system(system_id);
257 assert_counter!(id, 1, app.world_mut());
258 assert!(matches!(
259 err,
260 Err(RegisteredSystemError::SystemIdNotRegistered(_))
261 ));
262 }
263
264 #[test]
265 fn io() {
266 let mut app = App::new();
267 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
268 let id = app.world_mut().spawn(Counter(4)).id();
269 assert_counter!(id, 4, app.world_mut());
270
271 let (sender, receiver) = async_channel::bounded(1);
272 let async_world = AsyncWorld::from_world(app.world_mut());
273
274 AsyncComputeTaskPool::get()
275 .spawn(async move {
276 let increase_counter = async_world.register_io_system(increase_counter).await;
277 let get_counter_value = async_world.register_io_system(get_counter_value).await;
278
279 increase_counter.run(id).await;
280 let value = get_counter_value.run(id).await;
281 sender.send(value).await.unwrap();
282 })
283 .detach();
284
285 let value = loop {
286 match receiver.try_recv() {
287 Ok(value) => break value,
288 Err(_) => app.update(),
289 }
290 };
291 app.update();
292
293 assert_eq!(5, value);
294 assert_counter!(id, 5, app.world_mut());
295 }
296
297 #[test]
298 fn io_unregister() {
299 let mut app = App::new();
300 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
301 let id = app.world_mut().spawn(Counter(4)).id();
302 assert_counter!(id, 4, app.world_mut());
303
304 let (sender, receiver) = async_channel::bounded(1);
305 let async_world = AsyncWorld::from_world(app.world_mut());
306
307 AsyncComputeTaskPool::get()
308 .spawn(async move {
309 let increase_counter = async_world.register_io_system(increase_counter).await;
310 let get_counter_value = async_world.register_io_system(get_counter_value).await;
311
312 let gcv2 = get_counter_value.clone();
313 get_counter_value.unregister().await;
314
315 increase_counter.run(id).await;
316 let value = gcv2.run(id).await;
317 sender.send((value, *gcv2.id)).await.unwrap();
318 gcv2.unregister().await;
319 })
320 .detach();
321
322 let (value, system_id) = loop {
323 match receiver.try_recv() {
324 Ok(value) => break value,
325 Err(_) => app.update(),
326 }
327 };
328 app.update();
329
330 let err = app.world_mut().unregister_system(system_id);
331 assert_eq!(5, value);
332 assert_counter!(id, 5, app.world_mut());
333 assert!(matches!(
334 err,
335 Err(RegisteredSystemError::SystemIdNotRegistered(_))
336 ));
337 }
338}