local_spawn_pool/lib.rs
1//! See [`LocalSpawnPool`] for documentation.
2
3mod task;
4pub use task::JoinHandle;
5use task::Task;
6mod tasks_to_add;
7use tasks_to_add::TasksToAdd;
8
9use std::cell::RefCell;
10use std::future::Future;
11use std::mem;
12use std::pin::Pin;
13use std::task::{Poll, Waker};
14
15/// A pool of tasks to spawn futures and wait for them on a single thread.
16///
17/// It is inspired by and has almost the same functionality as [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html),
18/// but this standalone crate allows you to avoid importing the whole [tokio crate](https://docs.rs/tokio) if you don't need it.
19/// Unlike the [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html), [`LocalSpawnPool`] doesn't
20/// handle panics.
21///
22/// In some cases, it is necessary to run one or more futures that do not implement `Send` and thus are unsafe to send between
23/// threads. In these cases, a [`LocalSpawnPool`] may be used to schedule one or more `!Send` futures to run together on the same
24/// thread.
25///
26/// You can use the [`LocalSpawnPool::run_until`] function to run a future to completion on the [`LocalSpawnPool`], returning its
27/// output (see [`LocalSpawnPool::run_until`] for more details). And you can use the [`LocalSpawnPool::spawn`] and [`spawn`]
28/// functions to spawn futures on the [`LocalSpawnPool`]. To wait for all the spawned futures to complete, `await` the
29/// [`LocalSpawnPool`] itself:
30///
31/// ## Awaiting the [`LocalSpawnPool`]
32///
33/// Example:
34///
35/// ```
36/// use local_spawn_pool::LocalSpawnPool;
37///
38/// async fn run() {
39/// let pool = LocalSpawnPool::new();
40///
41/// pool.spawn(async {
42/// // This future will be spawned inside `pool`
43///
44/// local_spawn_pool::spawn(async {
45/// // This future will be spawned inside `pool`
46///
47/// local_spawn_pool::spawn(async {
48/// // This future will be spawned inside `pool`
49/// });
50/// });
51///
52/// local_spawn_pool::spawn(async {
53/// // This future will be spawned inside `pool`
54/// });
55/// });
56///
57/// pool.await; // Will wait for all the futures inside the local_spawn_pool to complete
58/// }
59/// ```
60///
61/// Awaiting a [`LocalSpawnPool`] is `!Send`.
62pub struct LocalSpawnPool(RefCell<Pin<Box<LocalSpawnPoolInner>>>);
63
64#[cfg(not(test))]
65impl Default for LocalSpawnPool {
66 fn default() -> Self {
67 Self::new()
68 }
69}
70
71impl LocalSpawnPool {
72 /// Returns a new [`LocalSpawnPool`].
73 pub fn new(#[cfg(test)] name: &'static str) -> Self {
74 Self(RefCell::new(Box::pin(LocalSpawnPoolInner::new(
75 #[cfg(test)]
76 name,
77 ))))
78 }
79
80 /// Runs a future to completion on the [`LocalSpawnPool`], returning its output.
81 ///
82 /// This returns a future that runs the given future in a [`LocalSpawnPool`], allowing it to call [`spawn`] to spawn additional
83 /// `!Send` futures. Any futures spawned on the [`LocalSpawnPool`] will be driven in the background until the future passed to
84 /// `run_until` completes. When the future passed to `run_until` finishes, any futures which have not completed will remain
85 /// on the [`LocalSpawnPool`], and will be driven on subsequent calls to `run_until` or when
86 /// [awaiting the LocalSpawnPool](#awaiting-the-localspawnpool) itself.
87 pub async fn run_until<F>(&self, future: F) -> F::Output
88 where
89 F: Future + 'static,
90 {
91 let join_handle = self.spawn(future);
92 RunUntil::new(&self.0, join_handle).await
93 }
94
95 /// Spawns a `!Send` task onto the [`LocalSpawnPool`].
96 ///
97 /// This task is guaranteed to be run on the current thread.
98 ///
99 /// Unlike the free function [`spawn`], this method may be used to spawn local tasks when the [`LocalSpawnPool`] is not running.
100 /// The provided future will start running once the [`LocalSpawnPool`] is next started, even if you don’t `await` the returned
101 /// [`JoinHandle`].
102 pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
103 where
104 F: Future + 'static,
105 {
106 self.0.borrow_mut().spawn(future)
107 }
108}
109
110/// See [Awaiting the LocalSpawnPool](#awaiting-the-localspawnpool).
111impl Future for LocalSpawnPool {
112 type Output = ();
113
114 fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
115 Future::poll(self.0.borrow_mut().as_mut(), cx)
116 }
117}
118
119struct LocalSpawnPoolInner {
120 #[cfg(test)]
121 name: &'static str,
122 tasks: Vec<Task>,
123 waker: Option<Waker>,
124}
125
126impl LocalSpawnPoolInner {
127 fn new(#[cfg(test)] name: &'static str) -> Self {
128 Self {
129 #[cfg(test)]
130 name,
131 tasks: Vec::new(),
132 waker: None,
133 }
134 }
135
136 fn spawn<F>(&mut self, future: F) -> JoinHandle<F::Output>
137 where
138 F: Future + 'static,
139 {
140 let (task, join_handle) = task::create_task(future);
141 self.tasks.push(task);
142
143 if let Some(waker) = &self.waker {
144 waker.wake_by_ref();
145 }
146
147 join_handle
148 }
149}
150
151impl Future for LocalSpawnPoolInner {
152 type Output = ();
153
154 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
155 self.waker = Some(cx.waker().clone());
156 let tasks_snapshot = mem::take::<Vec<_>>(&mut self.tasks); // `tasks` is now empty
157
158 if tasks_snapshot.is_empty() {
159 Poll::Ready(())
160 } else {
161 let tasks_to_add = TasksToAdd::new();
162
163 for mut task in tasks_snapshot {
164 tasks_to_add::set_thread_local(
165 &tasks_to_add,
166 #[cfg(test)]
167 self.name,
168 );
169
170 if Future::poll(task.as_mut(), cx).is_pending() {
171 self.tasks.push(task);
172 }
173 }
174
175 tasks_to_add::unset_thread_local();
176
177 tasks_to_add.access_mut(|tasks_to_add_vec| {
178 if !tasks_to_add_vec.is_empty() {
179 cx.waker().wake_by_ref();
180 }
181
182 self.tasks.append(tasks_to_add_vec);
183 });
184
185 if self.tasks.is_empty() {
186 Poll::Ready(())
187 } else {
188 Poll::Pending
189 }
190 }
191 }
192}
193
194/// Spawns a `!Send` future on the current [`LocalSpawnPool`].
195///
196/// The spawned future will run on the same thread that called [`spawn`].
197///
198/// The provided future will start running in the background immediately when [`spawn`] is called, even if you don’t `await` the
199/// returned [`JoinHandle`].
200#[track_caller]
201pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
202where
203 F: Future + 'static,
204{
205 let (task, join_handle) = task::create_task(future);
206 tasks_to_add::access_thread_local(|tasks_to_add| match tasks_to_add {
207 #[cfg(not(test))]
208 Some(tasks_to_add) => tasks_to_add.add(task),
209 #[cfg(test)]
210 Some((tasks_to_add, _)) => tasks_to_add.add(task),
211 None => {
212 panic!("`local_spawn_pool::spawn` was called outside the context of a `LocalSpawnPool`")
213 }
214 });
215 join_handle
216}
217
218struct RunUntil<'a, T> {
219 local_spawn_pool: Option<&'a RefCell<Pin<Box<LocalSpawnPoolInner>>>>,
220 join_handle: Pin<Box<JoinHandle<T>>>,
221}
222
223impl<'a, T> RunUntil<'a, T> {
224 fn new(
225 local_spawn_pool: &'a RefCell<Pin<Box<LocalSpawnPoolInner>>>,
226 join_handle: JoinHandle<T>,
227 ) -> Self {
228 RunUntil {
229 local_spawn_pool: Some(local_spawn_pool),
230 join_handle: Box::pin(join_handle),
231 }
232 }
233}
234
235impl<'a, T> Future for RunUntil<'a, T> {
236 type Output = T;
237
238 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
239 if let Some(local_spawn_pool) = self.local_spawn_pool {
240 if let Poll::Ready(()) = Future::poll(local_spawn_pool.borrow_mut().as_mut(), cx) {
241 self.local_spawn_pool = None;
242 }
243 }
244
245 match Future::poll(self.join_handle.as_mut(), cx) {
246 Poll::Ready(output) => {
247 /*
248 * It's fine to unwrap, because `output` can be `None` only if the task:
249 * - was aborted via `JoinHandle::abort`, which is impossible because the this `JoinHandle` is never made
250 * accessible to the outside
251 * - was aborted by the runtime, in which case this code would never be runned
252 */
253 Poll::Ready(output.unwrap())
254 }
255
256 Poll::Pending => Poll::Pending,
257 }
258 }
259}
260
261#[cfg(test)]
262#[tokio::test]
263async fn test() {
264 use std::rc::Rc;
265 use std::time::Duration;
266 use tokio::time;
267
268 let results: Rc<RefCell<Vec<(u8, &'static str)>>> = Rc::new(RefCell::new(Vec::new()));
269
270 #[track_caller]
271 fn push_result(results: &Rc<RefCell<Vec<(u8, &'static str)>>>, result: u8) {
272 results.borrow_mut().push((
273 result,
274 tasks_to_add::access_thread_local(|tasks_to_add_and_name| match tasks_to_add_and_name {
275 Some(&(_, name)) => name,
276 None => {
277 panic!("`spawn_pool_name()` was called outside the context of a `LocalSpawnPool`")
278 }
279 })
280 ));
281 }
282
283 let local_spawn_pool_a = LocalSpawnPool::new("a");
284 let output = local_spawn_pool_a
285 .run_until({
286 let results = Rc::clone(&results);
287 async move {
288 spawn({
289 let results = Rc::clone(&results);
290 async move {
291 time::sleep(Duration::from_millis(500)).await;
292 push_result(&results, 3);
293 }
294 });
295
296 spawn({
297 let results = Rc::clone(&results);
298 async move {
299 let local_spawn_pool_b = LocalSpawnPool::new("b");
300 local_spawn_pool_b.spawn({
301 let results = Rc::clone(&results);
302 async move {
303 let join_handle = spawn({
304 let results = Rc::clone(&results);
305 async move {
306 time::sleep(Duration::from_millis(20)).await;
307 push_result(&results, 1);
308 "this is another output"
309 }
310 });
311
312 assert_eq!(join_handle.await, Some("this is another output"));
313
314 spawn({
315 let results = Rc::clone(&results);
316 async move {
317 time::sleep(Duration::from_millis(510)).await;
318 push_result(&results, 4);
319 }
320 });
321
322 let join_handle = spawn({
323 let results = Rc::clone(&results);
324 async move {
325 time::sleep(Duration::from_millis(515)).await;
326 push_result(&results, 100);
327 }
328 });
329
330 join_handle.abort();
331 assert_eq!(join_handle.await, None);
332 }
333 });
334
335 time::sleep(Duration::from_millis(50)).await;
336 push_result(&results, 0);
337 local_spawn_pool_b.await;
338 }
339 });
340
341 spawn({
342 let results = Rc::clone(&results);
343 async move {
344 time::sleep(Duration::from_millis(150)).await;
345 push_result(&results, 2);
346 }
347 });
348
349 "this is the output"
350 }
351 })
352 .await;
353 assert_eq!(output, "this is the output");
354 assert_eq!(&*results.borrow(), &[]);
355 local_spawn_pool_a.await;
356 assert_eq!(
357 &*results.borrow(),
358 &[(0, "a"), (1, "b"), (2, "a"), (3, "a"), (4, "b")]
359 );
360}