bevy_async_ecs/
system.rs

1use crate::die;
2use crate::recv;
3use crate::util::remove_system;
4use crate::world::AsyncWorld;
5use bevy_ecs::prelude::*;
6use bevy_ecs::system::BoxedSystem;
7use bevy_ecs::system::SystemId;
8use bevy_platform::sync::Arc;
9use std::any::Any;
10use std::marker::PhantomData;
11
12type BoxedAnySend = Box<dyn Any + Send>;
13type SystemIdWithIO = SystemId<In<BoxedAnySend>, BoxedAnySend>;
14type BoxedSystemWithIO = BoxedSystem<In<BoxedAnySend>, BoxedAnySend>;
15
16/// Represents a registered `System` that can be run asynchronously.
17///
18/// Dropping an `AsyncSystem` will not unregister it. Use `AsyncSystem::unregister()`
19/// to clean up an `AsyncSystem` from the main bevy `World`.
20///
21/// The easiest way to get an `AsyncSystem` is with `AsyncWorld::register_system()`.
22#[derive(Debug, Clone)]
23pub struct AsyncSystem {
24	id: Arc<SystemId>,
25	world: AsyncWorld,
26}
27
28impl AsyncSystem {
29	pub(crate) async fn new(system: BoxedSystem, world: AsyncWorld) -> Self {
30		let (id_tx, id_rx) = async_channel::bounded(1);
31		world
32			.apply(move |world: &mut World| {
33				let id = world.register_boxed_system(system);
34				id_tx.try_send(id).unwrap_or_else(die);
35			})
36			.await;
37		let id = recv(id_rx).await;
38		let id = Arc::new(id);
39		Self { id, world }
40	}
41
42	/// Run the system.
43	pub async fn run(&self) {
44		let id = *self.id;
45		self.world
46			.apply(move |world: &mut World| {
47				world.run_system(id).unwrap_or_else(die);
48			})
49			.await;
50	}
51
52	/// Unregister the system.
53	///
54	/// If multiple clones of the AsyncSystem exist, a reference counter will be
55	/// decremented instead. The system will be unregistered when the counter
56	/// decrements to zero.
57	pub async fn unregister(self) {
58		let Self { id, world } = self;
59		if let Some(id) = Arc::into_inner(id) {
60			world.apply(remove_system(id)).await;
61		}
62	}
63}
64
65/// Represents a registered `System` that accepts input and returns output, and can be run
66/// asynchronously.
67///
68/// Dropping an `AsyncIOSystem` will not unregister it. Use `AsyncSystemIO::unregister()`
69/// to clean up an `AsyncSystemIO` from the main bevy `World`.
70///
71/// The easiest way to get an `AsyncIOSystem` is with `AsyncWorld::register_io_system()`.
72#[derive(Debug)]
73pub struct AsyncIOSystem<I: Send, O: Send> {
74	id: Arc<SystemIdWithIO>,
75	world: AsyncWorld,
76	_pd: PhantomData<fn(I) -> O>,
77}
78
79impl<I: Send, O: Send> Clone for AsyncIOSystem<I, O> {
80	fn clone(&self) -> Self {
81		Self {
82			id: Arc::clone(&self.id),
83			world: self.world.clone(),
84			_pd: PhantomData,
85		}
86	}
87}
88
89impl<I: Send + 'static, O: Send + 'static> AsyncIOSystem<I, O> {
90	pub(crate) async fn new<M>(
91		system: impl IntoSystem<In<I>, O, M> + Send,
92		world: AsyncWorld,
93	) -> Self {
94		fn unbox_input<I: Send + 'static>(In(boxed): In<BoxedAnySend>) -> I {
95			let concrete = boxed.downcast().unwrap_or_else(die);
96			*concrete
97		}
98
99		fn box_output<O: Send + 'static>(In(output): In<O>) -> BoxedAnySend {
100			Box::new(output)
101		}
102
103		let system = unbox_input::<I>.pipe(system).pipe(box_output);
104		let system: BoxedSystemWithIO = Box::new(IntoSystem::into_system(system));
105
106		let (id_tx, id_rx) = async_channel::bounded(1);
107		world
108			.apply(move |world: &mut World| {
109				let id = world.register_boxed_system(system);
110				id_tx.try_send(id).unwrap_or_else(die);
111			})
112			.await;
113
114		let id = recv(id_rx).await;
115		let id = Arc::new(id);
116
117		Self {
118			id,
119			world,
120			_pd: PhantomData,
121		}
122	}
123
124	/// Run the system.
125	pub async fn run(&self, input: I) -> O {
126		let (input_tx, input_rx) = async_channel::bounded(1);
127		let (output_tx, output_rx) = async_channel::bounded(1);
128
129		let input: BoxedAnySend = Box::new(input);
130		input_tx.send(input).await.unwrap_or_else(die);
131
132		let id = *self.id;
133		self.world
134			.apply(move |world: &mut World| {
135				let input = input_rx.try_recv().unwrap_or_else(die);
136				let output = world.run_system_with(id, input).unwrap_or_else(die);
137				output_tx.try_send(output).unwrap_or_else(die);
138			})
139			.await;
140
141		let boxed: BoxedAnySend = recv(output_rx).await;
142		let concrete = boxed.downcast().unwrap_or_else(die);
143		*concrete
144	}
145
146	/// Unregister the system.
147	///
148	/// If multiple clones of the AsyncIOSystem exist, a reference counter will be
149	/// decremented instead. The system will be unregistered when the counter
150	/// decrements to zero.
151	pub async fn unregister(self) {
152		let Self { id, world, _pd } = self;
153		if let Some(id) = Arc::into_inner(id) {
154			world.apply(remove_system(id)).await
155		}
156	}
157}
158
159#[cfg(test)]
160mod tests {
161	use crate::AsyncEcsPlugin;
162	use crate::world::AsyncWorld;
163	use bevy::ecs::system::RegisteredSystemError;
164	use bevy::prelude::*;
165	use bevy::tasks::AsyncComputeTaskPool;
166
167	#[derive(Component)]
168	struct Counter(u8);
169
170	impl Counter {
171		fn go_up(&mut self) {
172			self.0 += 1;
173		}
174	}
175
176	macro_rules! assert_counter {
177		($id:expr, $value:expr, $world:expr) => {
178			assert_eq!($value, $world.entity($id).get::<Counter>().unwrap().0);
179		};
180	}
181
182	fn increase_counter_all(mut query: Query<&mut Counter>) {
183		for mut counter in query.iter_mut() {
184			counter.go_up();
185		}
186	}
187
188	fn increase_counter(In(id): In<Entity>, mut query: Query<&mut Counter>) {
189		let mut counter = query.get_mut(id).unwrap();
190		counter.go_up();
191	}
192
193	fn get_counter_value(In(id): In<Entity>, query: Query<&Counter>) -> u8 {
194		query.get(id).unwrap().0
195	}
196
197	#[test]
198	fn smoke() {
199		let mut app = App::new();
200		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
201		let id = app.world_mut().spawn(Counter(0)).id();
202		assert_counter!(id, 0, app.world_mut());
203
204		let (barrier_tx, barrier_rx) = async_channel::bounded(1);
205		let async_world = AsyncWorld::from_world(app.world_mut());
206
207		AsyncComputeTaskPool::get()
208			.spawn(async move {
209				let increase_counter_all = async_world.register_system(increase_counter_all).await;
210				increase_counter_all.run().await;
211				barrier_tx.send(()).await.unwrap();
212			})
213			.detach();
214
215		loop {
216			match barrier_rx.try_recv() {
217				Ok(_) => break,
218				Err(_) => app.update(),
219			}
220		}
221		app.update();
222
223		assert_counter!(id, 1, app.world_mut());
224	}
225
226	#[test]
227	fn normal_unregister() {
228		let mut app = App::new();
229		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
230		let id = app.world_mut().spawn(Counter(0)).id();
231		assert_counter!(id, 0, app.world_mut());
232
233		let (sender, receiver) = async_channel::bounded(1);
234		let async_world = AsyncWorld::from_world(app.world_mut());
235
236		AsyncComputeTaskPool::get()
237			.spawn(async move {
238				let increase_counter_all = async_world.register_system(increase_counter_all).await;
239				let ica2 = increase_counter_all.clone();
240				increase_counter_all.unregister().await;
241
242				ica2.run().await;
243
244				let id = *ica2.id;
245				ica2.unregister().await;
246				sender.send(id).await.unwrap();
247			})
248			.detach();
249
250		let system_id = loop {
251			match receiver.try_recv() {
252				Ok(id) => break id,
253				Err(_) => app.update(),
254			}
255		};
256		app.update();
257
258		let err = app.world_mut().unregister_system(system_id);
259		assert_counter!(id, 1, app.world_mut());
260		assert!(matches!(
261			err,
262			Err(RegisteredSystemError::SystemIdNotRegistered(_))
263		));
264	}
265
266	#[test]
267	fn io() {
268		let mut app = App::new();
269		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
270		let id = app.world_mut().spawn(Counter(4)).id();
271		assert_counter!(id, 4, app.world_mut());
272
273		let (sender, receiver) = async_channel::bounded(1);
274		let async_world = AsyncWorld::from_world(app.world_mut());
275
276		AsyncComputeTaskPool::get()
277			.spawn(async move {
278				let increase_counter = async_world.register_io_system(increase_counter).await;
279				let get_counter_value = async_world.register_io_system(get_counter_value).await;
280
281				increase_counter.run(id).await;
282				let value = get_counter_value.run(id).await;
283				sender.send(value).await.unwrap();
284			})
285			.detach();
286
287		let value = loop {
288			match receiver.try_recv() {
289				Ok(value) => break value,
290				Err(_) => app.update(),
291			}
292		};
293		app.update();
294
295		assert_eq!(5, value);
296		assert_counter!(id, 5, app.world_mut());
297	}
298
299	#[test]
300	fn io_unregister() {
301		let mut app = App::new();
302		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
303		let id = app.world_mut().spawn(Counter(4)).id();
304		assert_counter!(id, 4, app.world_mut());
305
306		let (sender, receiver) = async_channel::bounded(1);
307		let async_world = AsyncWorld::from_world(app.world_mut());
308
309		AsyncComputeTaskPool::get()
310			.spawn(async move {
311				let increase_counter = async_world.register_io_system(increase_counter).await;
312				let get_counter_value = async_world.register_io_system(get_counter_value).await;
313
314				let gcv2 = get_counter_value.clone();
315				get_counter_value.unregister().await;
316
317				increase_counter.run(id).await;
318				let value = gcv2.run(id).await;
319				sender.send((value, *gcv2.id)).await.unwrap();
320				gcv2.unregister().await;
321			})
322			.detach();
323
324		let (value, system_id) = loop {
325			match receiver.try_recv() {
326				Ok(value) => break value,
327				Err(_) => app.update(),
328			}
329		};
330		app.update();
331
332		let err = app.world_mut().unregister_system(system_id);
333		assert_eq!(5, value);
334		assert_counter!(id, 5, app.world_mut());
335		assert!(matches!(
336			err,
337			Err(RegisteredSystemError::SystemIdNotRegistered(_))
338		));
339	}
340}