local-spawn-pool 0.1.0

Spawn `!Send` futures in a pool and await for all of them to finish. Standalone alternative to `tokio::task::LocalSet`.
Documentation
//! See [`LocalSpawnPool`] for documentation.

mod task;
pub use task::JoinHandle;
use task::Task;
mod tasks_to_add;
use tasks_to_add::TasksToAdd;

use std::cell::RefCell;
use std::future::Future;
use std::mem;
use std::pin::Pin;
use std::task::{Poll, Waker};

/// A pool of tasks to spawn futures and wait for them on a single thread.
///
/// It is inspired by and has almost the same functionality as [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html),
/// but this standalone crate allows you to avoid importing the whole [tokio crate](https://docs.rs/tokio) if you don't need it.
/// Unlike the [`tokio::task::LocalSet`](https://docs.rs/tokio/latest/tokio/task/struct.LocalSet.html), [`LocalSpawnPool`] doesn't
/// handle panics.
///
/// In some cases, it is necessary to run one or more futures that do not implement `Send` and thus are unsafe to send between
/// threads. In these cases, a [`LocalSpawnPool`] may be used to schedule one or more `!Send` futures to run together on the same
/// thread.
///
/// You can use the [`LocalSpawnPool::run_until`] function to run a future to completion on the [`LocalSpawnPool`], returning its
/// output (see [`LocalSpawnPool::run_until`] for more details). And you can use the [`LocalSpawnPool::spawn`] and [`spawn`]
/// functions to spawn futures on the [`LocalSpawnPool`]. To wait for all the spawned futures to complete, `await` the
/// [`LocalSpawnPool`] itself:
///
/// ## Awaiting the [`LocalSpawnPool`]
///
/// Example:
///
/// ```
/// use local_spawn_pool::LocalSpawnPool;
///
/// async fn run() {
///     let pool = LocalSpawnPool::new();
///     
///     pool.spawn(async {
///         // This future will be spawned inside `pool`
///         
///         local_spawn_pool::spawn(async {
///             // This future will be spawned inside `pool`
///             
///             local_spawn_pool::spawn(async {
///                 // This future will be spawned inside `pool`
///             });
///         });
///
///         local_spawn_pool::spawn(async {
///             // This future will be spawned inside `pool`
///         });
///     });
///
///     pool.await; // Will wait for all the futures inside the local_spawn_pool to complete
/// }
/// ```
///
/// Awaiting a [`LocalSpawnPool`] is `!Send`.
pub struct LocalSpawnPool(RefCell<Pin<Box<LocalSpawnPoolInner>>>);

#[cfg(not(test))]
impl Default for LocalSpawnPool {
    fn default() -> Self {
        Self::new()
    }
}

impl LocalSpawnPool {
    /// Returns a new [`LocalSpawnPool`].
    pub fn new(#[cfg(test)] name: &'static str) -> Self {
        Self(RefCell::new(Box::pin(LocalSpawnPoolInner::new(
            #[cfg(test)]
            name,
        ))))
    }

    /// Runs a future to completion on the [`LocalSpawnPool`], returning its output.
    ///
    /// This returns a future that runs the given future in a [`LocalSpawnPool`], allowing it to call [`spawn`] to spawn additional
    /// `!Send` futures. Any futures spawned on the [`LocalSpawnPool`] will be driven in the background until the future passed to
    /// `run_until` completes. When the future passed to `run_until` finishes, any futures which have not completed will remain
    /// on the [`LocalSpawnPool`], and will be driven on subsequent calls to `run_until` or when
    /// [awaiting the LocalSpawnPool](#awaiting-the-localspawnpool) itself.
    pub async fn run_until<F>(&self, future: F) -> F::Output
    where
        F: Future + 'static,
    {
        let join_handle = self.spawn(future);
        RunUntil::new(&self.0, join_handle).await
    }

    /// Spawns a `!Send` task onto the [`LocalSpawnPool`].
    ///
    /// This task is guaranteed to be run on the current thread.
    ///
    /// Unlike the free function [`spawn`], this method may be used to spawn local tasks when the [`LocalSpawnPool`] is not running.
    /// The provided future will start running once the [`LocalSpawnPool`] is next started, even if you don’t `await` the returned
    /// [`JoinHandle`].
    pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + 'static,
    {
        self.0.borrow_mut().spawn(future)
    }
}

/// See [Awaiting the LocalSpawnPool](#awaiting-the-localspawnpool).
impl Future for LocalSpawnPool {
    type Output = ();

    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        Future::poll(self.0.borrow_mut().as_mut(), cx)
    }
}

struct LocalSpawnPoolInner {
    #[cfg(test)]
    name: &'static str,
    tasks: Vec<Task>,
    waker: Option<Waker>,
}

impl LocalSpawnPoolInner {
    fn new(#[cfg(test)] name: &'static str) -> Self {
        Self {
            #[cfg(test)]
            name,
            tasks: Vec::new(),
            waker: None,
        }
    }

    fn spawn<F>(&mut self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + 'static,
    {
        let (task, join_handle) = task::create_task(future);
        self.tasks.push(task);

        if let Some(waker) = &self.waker {
            waker.wake_by_ref();
        }

        join_handle
    }
}

impl Future for LocalSpawnPoolInner {
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        self.waker = Some(cx.waker().clone());
        let tasks_snapshot = mem::take::<Vec<_>>(&mut self.tasks); // `tasks` is now empty

        if tasks_snapshot.is_empty() {
            Poll::Ready(())
        } else {
            let tasks_to_add = TasksToAdd::new();

            for mut task in tasks_snapshot {
                tasks_to_add::set_thread_local(
                    &tasks_to_add,
                    #[cfg(test)]
                    self.name,
                );

                if Future::poll(task.as_mut(), cx).is_pending() {
                    self.tasks.push(task);
                }
            }

            tasks_to_add::unset_thread_local();

            tasks_to_add.access_mut(|tasks_to_add_vec| {
                if !tasks_to_add_vec.is_empty() {
                    cx.waker().wake_by_ref();
                }

                self.tasks.append(tasks_to_add_vec);
            });

            if self.tasks.is_empty() {
                Poll::Ready(())
            } else {
                Poll::Pending
            }
        }
    }
}

/// Spawns a `!Send` future on the current [`LocalSpawnPool`].
///
/// The spawned future will run on the same thread that called [`spawn`].
///
/// The provided future will start running in the background immediately when [`spawn`] is called, even if you don’t `await` the
/// returned [`JoinHandle`].
#[track_caller]
pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
    F: Future + 'static,
{
    let (task, join_handle) = task::create_task(future);
    tasks_to_add::access_thread_local(|tasks_to_add| match tasks_to_add {
        #[cfg(not(test))]
        Some(tasks_to_add) => tasks_to_add.add(task),
        #[cfg(test)]
        Some((tasks_to_add, _)) => tasks_to_add.add(task),
        None => {
            panic!("`local_spawn_pool::spawn` was called outside the context of a `LocalSpawnPool`")
        }
    });
    join_handle
}

struct RunUntil<'a, T> {
    local_spawn_pool: Option<&'a RefCell<Pin<Box<LocalSpawnPoolInner>>>>,
    join_handle: Pin<Box<JoinHandle<T>>>,
}

impl<'a, T> RunUntil<'a, T> {
    fn new(
        local_spawn_pool: &'a RefCell<Pin<Box<LocalSpawnPoolInner>>>,
        join_handle: JoinHandle<T>,
    ) -> Self {
        RunUntil {
            local_spawn_pool: Some(local_spawn_pool),
            join_handle: Box::pin(join_handle),
        }
    }
}

impl<'a, T> Future for RunUntil<'a, T> {
    type Output = T;

    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        if let Some(local_spawn_pool) = self.local_spawn_pool {
            if let Poll::Ready(()) = Future::poll(local_spawn_pool.borrow_mut().as_mut(), cx) {
                self.local_spawn_pool = None;
            }
        }

        match Future::poll(self.join_handle.as_mut(), cx) {
            Poll::Ready(output) => {
                /*
                 * It's fine to unwrap, because `output` can be `None` only if the task:
                 * - was aborted via `JoinHandle::abort`, which is impossible because the this `JoinHandle` is never made
                 *   accessible to the outside
                 * - was aborted by the runtime, in which case this code would never be runned
                 */
                Poll::Ready(output.unwrap())
            }

            Poll::Pending => Poll::Pending,
        }
    }
}

#[cfg(test)]
#[tokio::test]
async fn test() {
    use std::rc::Rc;
    use std::time::Duration;
    use tokio::time;

    let results: Rc<RefCell<Vec<(u8, &'static str)>>> = Rc::new(RefCell::new(Vec::new()));

    #[track_caller]
    fn push_result(results: &Rc<RefCell<Vec<(u8, &'static str)>>>, result: u8) {
        results.borrow_mut().push((
            result,
            tasks_to_add::access_thread_local(|tasks_to_add_and_name| match tasks_to_add_and_name {
                Some(&(_, name)) => name,
                None => {
                    panic!("`spawn_pool_name()` was called outside the context of a `LocalSpawnPool`")
                }
            })
        ));
    }

    let local_spawn_pool_a = LocalSpawnPool::new("a");
    let output = local_spawn_pool_a
        .run_until({
            let results = Rc::clone(&results);
            async move {
                spawn({
                    let results = Rc::clone(&results);
                    async move {
                        time::sleep(Duration::from_millis(500)).await;
                        push_result(&results, 3);
                    }
                });

                spawn({
                    let results = Rc::clone(&results);
                    async move {
                        let local_spawn_pool_b = LocalSpawnPool::new("b");
                        local_spawn_pool_b.spawn({
                            let results = Rc::clone(&results);
                            async move {
                                let join_handle = spawn({
                                    let results = Rc::clone(&results);
                                    async move {
                                        time::sleep(Duration::from_millis(20)).await;
                                        push_result(&results, 1);
                                        "this is another output"
                                    }
                                });

                                assert_eq!(join_handle.await, Some("this is another output"));

                                spawn({
                                    let results = Rc::clone(&results);
                                    async move {
                                        time::sleep(Duration::from_millis(510)).await;
                                        push_result(&results, 4);
                                    }
                                });

                                let join_handle = spawn({
                                    let results = Rc::clone(&results);
                                    async move {
                                        time::sleep(Duration::from_millis(515)).await;
                                        push_result(&results, 100);
                                    }
                                });

                                join_handle.abort();
                                assert_eq!(join_handle.await, None);
                            }
                        });

                        time::sleep(Duration::from_millis(50)).await;
                        push_result(&results, 0);
                        local_spawn_pool_b.await;
                    }
                });

                spawn({
                    let results = Rc::clone(&results);
                    async move {
                        time::sleep(Duration::from_millis(150)).await;
                        push_result(&results, 2);
                    }
                });

                "this is the output"
            }
        })
        .await;
    assert_eq!(output, "this is the output");
    assert_eq!(&*results.borrow(), &[]);
    local_spawn_pool_a.await;
    assert_eq!(
        &*results.borrow(),
        &[(0, "a"), (1, "b"), (2, "a"), (3, "a"), (4, "b")]
    );
}