1use crate::die;
2use crate::recv;
3use crate::util::remove_system;
4use crate::world::AsyncWorld;
5use bevy_ecs::prelude::*;
6use bevy_ecs::system::BoxedSystem;
7use bevy_ecs::system::SystemId;
8use bevy_platform::sync::Arc;
9use std::any::Any;
10use std::marker::PhantomData;
11
12type BoxedAnySend = Box<dyn Any + Send>;
13type SystemIdWithIO = SystemId<In<BoxedAnySend>, BoxedAnySend>;
14type BoxedSystemWithIO = BoxedSystem<In<BoxedAnySend>, BoxedAnySend>;
15
16#[derive(Debug, Clone)]
23pub struct AsyncSystem {
24 id: Arc<SystemId>,
25 world: AsyncWorld,
26}
27
28impl AsyncSystem {
29 pub(crate) async fn new(system: BoxedSystem, world: AsyncWorld) -> Self {
30 let (id_tx, id_rx) = async_channel::bounded(1);
31 world
32 .apply(move |world: &mut World| {
33 let id = world.register_boxed_system(system);
34 id_tx.try_send(id).unwrap_or_else(die);
35 })
36 .await;
37 let id = recv(id_rx).await;
38 let id = Arc::new(id);
39 Self { id, world }
40 }
41
42 pub async fn run(&self) {
44 let id = *self.id;
45 self.world
46 .apply(move |world: &mut World| {
47 world.run_system(id).unwrap_or_else(die);
48 })
49 .await;
50 }
51
52 pub async fn unregister(self) {
58 let Self { id, world } = self;
59 if let Some(id) = Arc::into_inner(id) {
60 world.apply(remove_system(id)).await;
61 }
62 }
63}
64
65#[derive(Debug)]
73pub struct AsyncIOSystem<I: Send, O: Send> {
74 id: Arc<SystemIdWithIO>,
75 world: AsyncWorld,
76 _pd: PhantomData<fn(I) -> O>,
77}
78
79impl<I: Send, O: Send> Clone for AsyncIOSystem<I, O> {
80 fn clone(&self) -> Self {
81 Self {
82 id: Arc::clone(&self.id),
83 world: self.world.clone(),
84 _pd: PhantomData,
85 }
86 }
87}
88
89impl<I: Send + 'static, O: Send + 'static> AsyncIOSystem<I, O> {
90 pub(crate) async fn new<M>(
91 system: impl IntoSystem<In<I>, O, M> + Send,
92 world: AsyncWorld,
93 ) -> Self {
94 fn unbox_input<I: Send + 'static>(In(boxed): In<BoxedAnySend>) -> I {
95 let concrete = boxed.downcast().unwrap_or_else(die);
96 *concrete
97 }
98
99 fn box_output<O: Send + 'static>(In(output): In<O>) -> BoxedAnySend {
100 Box::new(output)
101 }
102
103 let system = unbox_input::<I>.pipe(system).pipe(box_output);
104 let system: BoxedSystemWithIO = Box::new(IntoSystem::into_system(system));
105
106 let (id_tx, id_rx) = async_channel::bounded(1);
107 world
108 .apply(move |world: &mut World| {
109 let id = world.register_boxed_system(system);
110 id_tx.try_send(id).unwrap_or_else(die);
111 })
112 .await;
113
114 let id = recv(id_rx).await;
115 let id = Arc::new(id);
116
117 Self {
118 id,
119 world,
120 _pd: PhantomData,
121 }
122 }
123
124 pub async fn run(&self, input: I) -> O {
126 let (input_tx, input_rx) = async_channel::bounded(1);
127 let (output_tx, output_rx) = async_channel::bounded(1);
128
129 let input: BoxedAnySend = Box::new(input);
130 input_tx.send(input).await.unwrap_or_else(die);
131
132 let id = *self.id;
133 self.world
134 .apply(move |world: &mut World| {
135 let input = input_rx.try_recv().unwrap_or_else(die);
136 let output = world.run_system_with(id, input).unwrap_or_else(die);
137 output_tx.try_send(output).unwrap_or_else(die);
138 })
139 .await;
140
141 let boxed: BoxedAnySend = recv(output_rx).await;
142 let concrete = boxed.downcast().unwrap_or_else(die);
143 *concrete
144 }
145
146 pub async fn unregister(self) {
152 let Self { id, world, _pd } = self;
153 if let Some(id) = Arc::into_inner(id) {
154 world.apply(remove_system(id)).await
155 }
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 use crate::AsyncEcsPlugin;
162 use crate::world::AsyncWorld;
163 use bevy::ecs::system::RegisteredSystemError;
164 use bevy::prelude::*;
165 use bevy::tasks::AsyncComputeTaskPool;
166
167 #[derive(Component)]
168 struct Counter(u8);
169
170 impl Counter {
171 fn go_up(&mut self) {
172 self.0 += 1;
173 }
174 }
175
176 macro_rules! assert_counter {
177 ($id:expr, $value:expr, $world:expr) => {
178 assert_eq!($value, $world.entity($id).get::<Counter>().unwrap().0);
179 };
180 }
181
182 fn increase_counter_all(mut query: Query<&mut Counter>) {
183 for mut counter in query.iter_mut() {
184 counter.go_up();
185 }
186 }
187
188 fn increase_counter(In(id): In<Entity>, mut query: Query<&mut Counter>) {
189 let mut counter = query.get_mut(id).unwrap();
190 counter.go_up();
191 }
192
193 fn get_counter_value(In(id): In<Entity>, query: Query<&Counter>) -> u8 {
194 query.get(id).unwrap().0
195 }
196
197 #[test]
198 fn smoke() {
199 let mut app = App::new();
200 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
201 let id = app.world_mut().spawn(Counter(0)).id();
202 assert_counter!(id, 0, app.world_mut());
203
204 let (barrier_tx, barrier_rx) = async_channel::bounded(1);
205 let async_world = AsyncWorld::from_world(app.world_mut());
206
207 AsyncComputeTaskPool::get()
208 .spawn(async move {
209 let increase_counter_all = async_world.register_system(increase_counter_all).await;
210 increase_counter_all.run().await;
211 barrier_tx.send(()).await.unwrap();
212 })
213 .detach();
214
215 loop {
216 match barrier_rx.try_recv() {
217 Ok(_) => break,
218 Err(_) => app.update(),
219 }
220 }
221 app.update();
222
223 assert_counter!(id, 1, app.world_mut());
224 }
225
226 #[test]
227 fn normal_unregister() {
228 let mut app = App::new();
229 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
230 let id = app.world_mut().spawn(Counter(0)).id();
231 assert_counter!(id, 0, app.world_mut());
232
233 let (sender, receiver) = async_channel::bounded(1);
234 let async_world = AsyncWorld::from_world(app.world_mut());
235
236 AsyncComputeTaskPool::get()
237 .spawn(async move {
238 let increase_counter_all = async_world.register_system(increase_counter_all).await;
239 let ica2 = increase_counter_all.clone();
240 increase_counter_all.unregister().await;
241
242 ica2.run().await;
243
244 let id = *ica2.id;
245 ica2.unregister().await;
246 sender.send(id).await.unwrap();
247 })
248 .detach();
249
250 let system_id = loop {
251 match receiver.try_recv() {
252 Ok(id) => break id,
253 Err(_) => app.update(),
254 }
255 };
256 app.update();
257
258 let err = app.world_mut().unregister_system(system_id);
259 assert_counter!(id, 1, app.world_mut());
260 assert!(matches!(
261 err,
262 Err(RegisteredSystemError::SystemIdNotRegistered(_))
263 ));
264 }
265
266 #[test]
267 fn io() {
268 let mut app = App::new();
269 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
270 let id = app.world_mut().spawn(Counter(4)).id();
271 assert_counter!(id, 4, app.world_mut());
272
273 let (sender, receiver) = async_channel::bounded(1);
274 let async_world = AsyncWorld::from_world(app.world_mut());
275
276 AsyncComputeTaskPool::get()
277 .spawn(async move {
278 let increase_counter = async_world.register_io_system(increase_counter).await;
279 let get_counter_value = async_world.register_io_system(get_counter_value).await;
280
281 increase_counter.run(id).await;
282 let value = get_counter_value.run(id).await;
283 sender.send(value).await.unwrap();
284 })
285 .detach();
286
287 let value = loop {
288 match receiver.try_recv() {
289 Ok(value) => break value,
290 Err(_) => app.update(),
291 }
292 };
293 app.update();
294
295 assert_eq!(5, value);
296 assert_counter!(id, 5, app.world_mut());
297 }
298
299 #[test]
300 fn io_unregister() {
301 let mut app = App::new();
302 app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
303 let id = app.world_mut().spawn(Counter(4)).id();
304 assert_counter!(id, 4, app.world_mut());
305
306 let (sender, receiver) = async_channel::bounded(1);
307 let async_world = AsyncWorld::from_world(app.world_mut());
308
309 AsyncComputeTaskPool::get()
310 .spawn(async move {
311 let increase_counter = async_world.register_io_system(increase_counter).await;
312 let get_counter_value = async_world.register_io_system(get_counter_value).await;
313
314 let gcv2 = get_counter_value.clone();
315 get_counter_value.unregister().await;
316
317 increase_counter.run(id).await;
318 let value = gcv2.run(id).await;
319 sender.send((value, *gcv2.id)).await.unwrap();
320 gcv2.unregister().await;
321 })
322 .detach();
323
324 let (value, system_id) = loop {
325 match receiver.try_recv() {
326 Ok(value) => break value,
327 Err(_) => app.update(),
328 }
329 };
330 app.update();
331
332 let err = app.world_mut().unregister_system(system_id);
333 assert_eq!(5, value);
334 assert_counter!(id, 5, app.world_mut());
335 assert!(matches!(
336 err,
337 Err(RegisteredSystemError::SystemIdNotRegistered(_))
338 ));
339 }
340}