bevy_async_ecs/
system.rs

1use crate::util::remove_system;
2use crate::world::AsyncWorld;
3use crate::{die, recv};
4use bevy_ecs::prelude::*;
5use bevy_ecs::system::{BoxedSystem, SystemId};
6use bevy_platform::sync::Arc;
7use std::any::Any;
8use std::marker::PhantomData;
9
10type BoxedAnySend = Box<dyn Any + Send>;
11type SystemIdWithIO = SystemId<In<BoxedAnySend>, BoxedAnySend>;
12type BoxedSystemWithIO = BoxedSystem<In<BoxedAnySend>, BoxedAnySend>;
13
14/// Represents a registered `System` that can be run asynchronously.
15///
16/// Dropping an `AsyncSystem` will not unregister it. Use `AsyncSystem::unregister()`
17/// to clean up an `AsyncSystem` from the main bevy `World`.
18///
19/// The easiest way to get an `AsyncSystem` is with `AsyncWorld::register_system()`.
20#[derive(Debug, Clone)]
21pub struct AsyncSystem {
22	id: Arc<SystemId>,
23	world: AsyncWorld,
24}
25
26impl AsyncSystem {
27	pub(crate) async fn new(system: BoxedSystem, world: AsyncWorld) -> Self {
28		let (id_tx, id_rx) = async_channel::bounded(1);
29		world
30			.apply(move |world: &mut World| {
31				let id = world.register_boxed_system(system);
32				id_tx.try_send(id).unwrap_or_else(die);
33			})
34			.await;
35		let id = recv(id_rx).await;
36		let id = Arc::new(id);
37		Self { id, world }
38	}
39
40	/// Run the system.
41	pub async fn run(&self) {
42		let id = *self.id;
43		self.world
44			.apply(move |world: &mut World| {
45				world.run_system(id).unwrap_or_else(die);
46			})
47			.await;
48	}
49
50	/// Unregister the system.
51	///
52	/// If multiple clones of the AsyncSystem exist, a reference counter will be
53	/// decremented instead. The system will be unregistered when the counter
54	/// decrements to zero.
55	pub async fn unregister(self) {
56		let Self { id, world } = self;
57		if let Some(id) = Arc::into_inner(id) {
58			world.apply(remove_system(id)).await;
59		}
60	}
61}
62
63/// Represents a registered `System` that accepts input and returns output, and can be run
64/// asynchronously.
65///
66/// Dropping an `AsyncIOSystem` will not unregister it. Use `AsyncSystemIO::unregister()`
67/// to clean up an `AsyncSystemIO` from the main bevy `World`.
68///
69/// The easiest way to get an `AsyncIOSystem` is with `AsyncWorld::register_io_system()`.
70#[derive(Debug)]
71pub struct AsyncIOSystem<I: Send, O: Send> {
72	id: Arc<SystemIdWithIO>,
73	world: AsyncWorld,
74	_pd: PhantomData<fn(I) -> O>,
75}
76
77impl<I: Send, O: Send> Clone for AsyncIOSystem<I, O> {
78	fn clone(&self) -> Self {
79		Self {
80			id: Arc::clone(&self.id),
81			world: self.world.clone(),
82			_pd: PhantomData,
83		}
84	}
85}
86
87impl<I: Send + 'static, O: Send + 'static> AsyncIOSystem<I, O> {
88	pub(crate) async fn new<M>(
89		system: impl IntoSystem<In<I>, O, M> + Send,
90		world: AsyncWorld,
91	) -> Self {
92		fn unbox_input<I: Send + 'static>(In(boxed): In<BoxedAnySend>) -> I {
93			let concrete = boxed.downcast().unwrap_or_else(die);
94			*concrete
95		}
96
97		fn box_output<O: Send + 'static>(In(output): In<O>) -> BoxedAnySend {
98			Box::new(output)
99		}
100
101		let system = unbox_input.pipe(system).pipe(box_output);
102		let system: BoxedSystemWithIO = Box::new(IntoSystem::into_system(system));
103
104		let (id_tx, id_rx) = async_channel::bounded(1);
105		world
106			.apply(move |world: &mut World| {
107				let id = world.register_boxed_system(system);
108				id_tx.try_send(id).unwrap_or_else(die);
109			})
110			.await;
111
112		let id = recv(id_rx).await;
113		let id = Arc::new(id);
114
115		Self {
116			id,
117			world,
118			_pd: PhantomData,
119		}
120	}
121
122	/// Run the system.
123	pub async fn run(&self, input: I) -> O {
124		let (input_tx, input_rx) = async_channel::bounded(1);
125		let (output_tx, output_rx) = async_channel::bounded(1);
126
127		let input: BoxedAnySend = Box::new(input);
128		input_tx.send(input).await.unwrap_or_else(die);
129
130		let id = *self.id;
131		self.world
132			.apply(move |world: &mut World| {
133				let input = input_rx.try_recv().unwrap_or_else(die);
134				let output = world.run_system_with(id, input).unwrap_or_else(die);
135				output_tx.try_send(output).unwrap_or_else(die);
136			})
137			.await;
138
139		let boxed: BoxedAnySend = recv(output_rx).await;
140		let concrete = boxed.downcast().unwrap_or_else(die);
141		*concrete
142	}
143
144	/// Unregister the system.
145	///
146	/// If multiple clones of the AsyncIOSystem exist, a reference counter will be
147	/// decremented instead. The system will be unregistered when the counter
148	/// decrements to zero.
149	pub async fn unregister(self) {
150		let Self { id, world, _pd } = self;
151		if let Some(id) = Arc::into_inner(id) {
152			world.apply(remove_system(id)).await
153		}
154	}
155}
156
157#[cfg(test)]
158mod tests {
159	use crate::world::AsyncWorld;
160	use crate::AsyncEcsPlugin;
161	use bevy::ecs::system::RegisteredSystemError;
162	use bevy::prelude::*;
163	use bevy::tasks::AsyncComputeTaskPool;
164
165	#[derive(Component)]
166	struct Counter(u8);
167
168	impl Counter {
169		fn go_up(&mut self) {
170			self.0 += 1;
171		}
172	}
173
174	macro_rules! assert_counter {
175		($id:expr, $value:expr, $world:expr) => {
176			assert_eq!($value, $world.entity($id).get::<Counter>().unwrap().0);
177		};
178	}
179
180	fn increase_counter_all(mut query: Query<&mut Counter>) {
181		for mut counter in query.iter_mut() {
182			counter.go_up();
183		}
184	}
185
186	fn increase_counter(In(id): In<Entity>, mut query: Query<&mut Counter>) {
187		let mut counter = query.get_mut(id).unwrap();
188		counter.go_up();
189	}
190
191	fn get_counter_value(In(id): In<Entity>, query: Query<&Counter>) -> u8 {
192		query.get(id).unwrap().0
193	}
194
195	#[test]
196	fn smoke() {
197		let mut app = App::new();
198		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
199		let id = app.world_mut().spawn(Counter(0)).id();
200		assert_counter!(id, 0, app.world_mut());
201
202		let (barrier_tx, barrier_rx) = async_channel::bounded(1);
203		let async_world = AsyncWorld::from_world(app.world_mut());
204
205		AsyncComputeTaskPool::get()
206			.spawn(async move {
207				let increase_counter_all = async_world.register_system(increase_counter_all).await;
208				increase_counter_all.run().await;
209				barrier_tx.send(()).await.unwrap();
210			})
211			.detach();
212
213		loop {
214			match barrier_rx.try_recv() {
215				Ok(_) => break,
216				Err(_) => app.update(),
217			}
218		}
219		app.update();
220
221		assert_counter!(id, 1, app.world_mut());
222	}
223
224	#[test]
225	fn normal_unregister() {
226		let mut app = App::new();
227		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
228		let id = app.world_mut().spawn(Counter(0)).id();
229		assert_counter!(id, 0, app.world_mut());
230
231		let (sender, receiver) = async_channel::bounded(1);
232		let async_world = AsyncWorld::from_world(app.world_mut());
233
234		AsyncComputeTaskPool::get()
235			.spawn(async move {
236				let increase_counter_all = async_world.register_system(increase_counter_all).await;
237				let ica2 = increase_counter_all.clone();
238				increase_counter_all.unregister().await;
239
240				ica2.run().await;
241
242				let id = *ica2.id;
243				ica2.unregister().await;
244				sender.send(id).await.unwrap();
245			})
246			.detach();
247
248		let system_id = loop {
249			match receiver.try_recv() {
250				Ok(id) => break id,
251				Err(_) => app.update(),
252			}
253		};
254		app.update();
255
256		let err = app.world_mut().unregister_system(system_id);
257		assert_counter!(id, 1, app.world_mut());
258		assert!(matches!(
259			err,
260			Err(RegisteredSystemError::SystemIdNotRegistered(_))
261		));
262	}
263
264	#[test]
265	fn io() {
266		let mut app = App::new();
267		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
268		let id = app.world_mut().spawn(Counter(4)).id();
269		assert_counter!(id, 4, app.world_mut());
270
271		let (sender, receiver) = async_channel::bounded(1);
272		let async_world = AsyncWorld::from_world(app.world_mut());
273
274		AsyncComputeTaskPool::get()
275			.spawn(async move {
276				let increase_counter = async_world.register_io_system(increase_counter).await;
277				let get_counter_value = async_world.register_io_system(get_counter_value).await;
278
279				increase_counter.run(id).await;
280				let value = get_counter_value.run(id).await;
281				sender.send(value).await.unwrap();
282			})
283			.detach();
284
285		let value = loop {
286			match receiver.try_recv() {
287				Ok(value) => break value,
288				Err(_) => app.update(),
289			}
290		};
291		app.update();
292
293		assert_eq!(5, value);
294		assert_counter!(id, 5, app.world_mut());
295	}
296
297	#[test]
298	fn io_unregister() {
299		let mut app = App::new();
300		app.add_plugins((MinimalPlugins, AsyncEcsPlugin));
301		let id = app.world_mut().spawn(Counter(4)).id();
302		assert_counter!(id, 4, app.world_mut());
303
304		let (sender, receiver) = async_channel::bounded(1);
305		let async_world = AsyncWorld::from_world(app.world_mut());
306
307		AsyncComputeTaskPool::get()
308			.spawn(async move {
309				let increase_counter = async_world.register_io_system(increase_counter).await;
310				let get_counter_value = async_world.register_io_system(get_counter_value).await;
311
312				let gcv2 = get_counter_value.clone();
313				get_counter_value.unregister().await;
314
315				increase_counter.run(id).await;
316				let value = gcv2.run(id).await;
317				sender.send((value, *gcv2.id)).await.unwrap();
318				gcv2.unregister().await;
319			})
320			.detach();
321
322		let (value, system_id) = loop {
323			match receiver.try_recv() {
324				Ok(value) => break value,
325				Err(_) => app.update(),
326			}
327		};
328		app.update();
329
330		let err = app.world_mut().unregister_system(system_id);
331		assert_eq!(5, value);
332		assert_counter!(id, 5, app.world_mut());
333		assert!(matches!(
334			err,
335			Err(RegisteredSystemError::SystemIdNotRegistered(_))
336		));
337	}
338}