local_spawn_pool/
lib.rs

1//! See [`LocalSpawnPool`] for documentation.
2
3mod task;
4pub use task::JoinHandle;
5use task::Task;
6mod tasks_to_add;
7use tasks_to_add::TasksToAdd;
8
9use std::cell::RefCell;
10use std::future::Future;
11use std::mem;
12use std::pin::Pin;
13use std::task::{Poll, Waker};
14
15/// A pool of tasks to spawn futures and wait for them on a single thread.
16///
17/// It is inspired by and has almost the same functionality as [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html),
18/// but this standalone crate allows you to avoid importing the whole [tokio crate](https://docs.rs/tokio) if you don't need it.
19/// Unlike the [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html), [`LocalSpawnPool`] doesn't
20/// handle panics.
21///
22/// In some cases, it is necessary to run one or more futures that do not implement `Send` and thus are unsafe to send between
23/// threads. In these cases, a [`LocalSpawnPool`] may be used to schedule one or more `!Send` futures to run together on the same
24/// thread.
25///
26/// You can use the [`LocalSpawnPool::run_until`] function to run a future to completion on the [`LocalSpawnPool`], returning its
27/// output (see [`LocalSpawnPool::run_until`] for more details). And you can use the [`LocalSpawnPool::spawn`] and [`spawn`]
28/// functions to spawn futures on the [`LocalSpawnPool`]. To wait for all the spawned futures to complete, `await` the
29/// [`LocalSpawnPool`] itself:
30///
31/// ## Awaiting the [`LocalSpawnPool`]
32///
33/// Example:
34///
35/// ```
36/// use local_spawn_pool::LocalSpawnPool;
37///
38/// async fn run() {
39///     let pool = LocalSpawnPool::new();
40///     
41///     pool.spawn(async {
42///         // This future will be spawned inside `pool`
43///         
44///         local_spawn_pool::spawn(async {
45///             // This future will be spawned inside `pool`
46///             
47///             local_spawn_pool::spawn(async {
48///                 // This future will be spawned inside `pool`
49///             });
50///         });
51///
52///         local_spawn_pool::spawn(async {
53///             // This future will be spawned inside `pool`
54///         });
55///     });
56///
57///     pool.await; // Will wait for all the futures inside the local_spawn_pool to complete
58/// }
59/// ```
60///
61/// Awaiting a [`LocalSpawnPool`] is `!Send`.
62pub struct LocalSpawnPool(RefCell<Pin<Box<LocalSpawnPoolInner>>>);
63
64#[cfg(not(test))]
65impl Default for LocalSpawnPool {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl LocalSpawnPool {
72    /// Returns a new [`LocalSpawnPool`].
73    pub fn new(#[cfg(test)] name: &'static str) -> Self {
74        Self(RefCell::new(Box::pin(LocalSpawnPoolInner::new(
75            #[cfg(test)]
76            name,
77        ))))
78    }
79
80    /// Runs a future to completion on the [`LocalSpawnPool`], returning its output.
81    ///
82    /// This returns a future that runs the given future in a [`LocalSpawnPool`], allowing it to call [`spawn`] to spawn additional
83    /// `!Send` futures. Any futures spawned on the [`LocalSpawnPool`] will be driven in the background until the future passed to
84    /// `run_until` completes. When the future passed to `run_until` finishes, any futures which have not completed will remain
85    /// on the [`LocalSpawnPool`], and will be driven on subsequent calls to `run_until` or when
86    /// [awaiting the LocalSpawnPool](#awaiting-the-localspawnpool) itself.
87    pub async fn run_until<F>(&self, future: F) -> F::Output
88    where
89        F: Future + 'static,
90    {
91        let join_handle = self.spawn(future);
92        RunUntil::new(&self.0, join_handle).await
93    }
94
95    /// Spawns a `!Send` task onto the [`LocalSpawnPool`].
96    ///
97    /// This task is guaranteed to be run on the current thread.
98    ///
99    /// Unlike the free function [`spawn`], this method may be used to spawn local tasks when the [`LocalSpawnPool`] is not running.
100    /// The provided future will start running once the [`LocalSpawnPool`] is next started, even if you don’t `await` the returned
101    /// [`JoinHandle`].
102    pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
103    where
104        F: Future + 'static,
105    {
106        self.0.borrow_mut().spawn(future)
107    }
108}
109
110/// See [Awaiting the LocalSpawnPool](#awaiting-the-localspawnpool).
111impl Future for LocalSpawnPool {
112    type Output = ();
113
114    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
115        Future::poll(self.0.borrow_mut().as_mut(), cx)
116    }
117}
118
119struct LocalSpawnPoolInner {
120    #[cfg(test)]
121    name: &'static str,
122    tasks: Vec<Task>,
123    waker: Option<Waker>,
124}
125
126impl LocalSpawnPoolInner {
127    fn new(#[cfg(test)] name: &'static str) -> Self {
128        Self {
129            #[cfg(test)]
130            name,
131            tasks: Vec::new(),
132            waker: None,
133        }
134    }
135
136    fn spawn<F>(&mut self, future: F) -> JoinHandle<F::Output>
137    where
138        F: Future + 'static,
139    {
140        let (task, join_handle) = task::create_task(future);
141        self.tasks.push(task);
142
143        if let Some(waker) = &self.waker {
144            waker.wake_by_ref();
145        }
146
147        join_handle
148    }
149}
150
151impl Future for LocalSpawnPoolInner {
152    type Output = ();
153
154    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
155        self.waker = Some(cx.waker().clone());
156        let tasks_snapshot = mem::take::<Vec<_>>(&mut self.tasks); // `tasks` is now empty
157
158        if tasks_snapshot.is_empty() {
159            Poll::Ready(())
160        } else {
161            let tasks_to_add = TasksToAdd::new();
162
163            for mut task in tasks_snapshot {
164                tasks_to_add::set_thread_local(
165                    &tasks_to_add,
166                    #[cfg(test)]
167                    self.name,
168                );
169
170                if Future::poll(task.as_mut(), cx).is_pending() {
171                    self.tasks.push(task);
172                }
173            }
174
175            tasks_to_add::unset_thread_local();
176
177            tasks_to_add.access_mut(|tasks_to_add_vec| {
178                if !tasks_to_add_vec.is_empty() {
179                    cx.waker().wake_by_ref();
180                }
181
182                self.tasks.append(tasks_to_add_vec);
183            });
184
185            if self.tasks.is_empty() {
186                Poll::Ready(())
187            } else {
188                Poll::Pending
189            }
190        }
191    }
192}
193
194/// Spawns a `!Send` future on the current [`LocalSpawnPool`].
195///
196/// The spawned future will run on the same thread that called [`spawn`].
197///
198/// The provided future will start running in the background immediately when [`spawn`] is called, even if you don’t `await` the
199/// returned [`JoinHandle`].
200#[track_caller]
201pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
202where
203    F: Future + 'static,
204{
205    let (task, join_handle) = task::create_task(future);
206    tasks_to_add::access_thread_local(|tasks_to_add| match tasks_to_add {
207        #[cfg(not(test))]
208        Some(tasks_to_add) => tasks_to_add.add(task),
209        #[cfg(test)]
210        Some((tasks_to_add, _)) => tasks_to_add.add(task),
211        None => {
212            panic!("`local_spawn_pool::spawn` was called outside the context of a `LocalSpawnPool`")
213        }
214    });
215    join_handle
216}
217
218struct RunUntil<'a, T> {
219    local_spawn_pool: Option<&'a RefCell<Pin<Box<LocalSpawnPoolInner>>>>,
220    join_handle: Pin<Box<JoinHandle<T>>>,
221}
222
223impl<'a, T> RunUntil<'a, T> {
224    fn new(
225        local_spawn_pool: &'a RefCell<Pin<Box<LocalSpawnPoolInner>>>,
226        join_handle: JoinHandle<T>,
227    ) -> Self {
228        RunUntil {
229            local_spawn_pool: Some(local_spawn_pool),
230            join_handle: Box::pin(join_handle),
231        }
232    }
233}
234
235impl<'a, T> Future for RunUntil<'a, T> {
236    type Output = T;
237
238    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
239        if let Some(local_spawn_pool) = self.local_spawn_pool {
240            if let Poll::Ready(()) = Future::poll(local_spawn_pool.borrow_mut().as_mut(), cx) {
241                self.local_spawn_pool = None;
242            }
243        }
244
245        match Future::poll(self.join_handle.as_mut(), cx) {
246            Poll::Ready(output) => {
247                /*
248                 * It's fine to unwrap, because `output` can be `None` only if the task:
249                 * - was aborted via `JoinHandle::abort`, which is impossible because the this `JoinHandle` is never made
250                 *   accessible to the outside
251                 * - was aborted by the runtime, in which case this code would never be runned
252                 */
253                Poll::Ready(output.unwrap())
254            }
255
256            Poll::Pending => Poll::Pending,
257        }
258    }
259}
260
261#[cfg(test)]
262#[tokio::test]
263async fn test() {
264    use std::rc::Rc;
265    use std::time::Duration;
266    use tokio::time;
267
268    let results: Rc<RefCell<Vec<(u8, &'static str)>>> = Rc::new(RefCell::new(Vec::new()));
269
270    #[track_caller]
271    fn push_result(results: &Rc<RefCell<Vec<(u8, &'static str)>>>, result: u8) {
272        results.borrow_mut().push((
273            result,
274            tasks_to_add::access_thread_local(|tasks_to_add_and_name| match tasks_to_add_and_name {
275                Some(&(_, name)) => name,
276                None => {
277                    panic!("`spawn_pool_name()` was called outside the context of a `LocalSpawnPool`")
278                }
279            })
280        ));
281    }
282
283    let local_spawn_pool_a = LocalSpawnPool::new("a");
284    let output = local_spawn_pool_a
285        .run_until({
286            let results = Rc::clone(&results);
287            async move {
288                spawn({
289                    let results = Rc::clone(&results);
290                    async move {
291                        time::sleep(Duration::from_millis(500)).await;
292                        push_result(&results, 3);
293                    }
294                });
295
296                spawn({
297                    let results = Rc::clone(&results);
298                    async move {
299                        let local_spawn_pool_b = LocalSpawnPool::new("b");
300                        local_spawn_pool_b.spawn({
301                            let results = Rc::clone(&results);
302                            async move {
303                                let join_handle = spawn({
304                                    let results = Rc::clone(&results);
305                                    async move {
306                                        time::sleep(Duration::from_millis(20)).await;
307                                        push_result(&results, 1);
308                                        "this is another output"
309                                    }
310                                });
311
312                                assert_eq!(join_handle.await, Some("this is another output"));
313
314                                spawn({
315                                    let results = Rc::clone(&results);
316                                    async move {
317                                        time::sleep(Duration::from_millis(510)).await;
318                                        push_result(&results, 4);
319                                    }
320                                });
321
322                                let join_handle = spawn({
323                                    let results = Rc::clone(&results);
324                                    async move {
325                                        time::sleep(Duration::from_millis(515)).await;
326                                        push_result(&results, 100);
327                                    }
328                                });
329
330                                join_handle.abort();
331                                assert_eq!(join_handle.await, None);
332                            }
333                        });
334
335                        time::sleep(Duration::from_millis(50)).await;
336                        push_result(&results, 0);
337                        local_spawn_pool_b.await;
338                    }
339                });
340
341                spawn({
342                    let results = Rc::clone(&results);
343                    async move {
344                        time::sleep(Duration::from_millis(150)).await;
345                        push_result(&results, 2);
346                    }
347                });
348
349                "this is the output"
350            }
351        })
352        .await;
353    assert_eq!(output, "this is the output");
354    assert_eq!(&*results.borrow(), &[]);
355    local_spawn_pool_a.await;
356    assert_eq!(
357        &*results.borrow(),
358        &[(0, "a"), (1, "b"), (2, "a"), (3, "a"), (4, "b")]
359    );
360}