local-spawn-pool 0.1.0

Spawn `!Send` futures in a pool and await for all of them to finish. Standalone alternative to `tokio::task::LocalSet`.
Documentation
use futures::channel::oneshot;
use futures::channel::oneshot::{Receiver, Sender};
use std::boxed::Box;
use std::cell::{Cell, RefCell};
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::rc::Rc;
use std::task::Poll;

pub fn create_task<F>(future: F) -> (Task, JoinHandle<F::Output>)
where
    F: Future + 'static,
{
    let (output_tx, output_rx) = oneshot::channel::<F::Output>();
    let abort = Rc::new(Cell::new(false));

    (
        Task::from(GenericTask {
            future: Box::pin(future),
            output_tx: Some(output_tx),
            abort: Rc::clone(&abort),
        }),
        JoinHandle(RefCell::new(JoinHandleInner::Pending {
            output_rx: Box::pin(output_rx),
            abort,
        })),
    )
}

pub struct Task(Pin<Box<dyn Future<Output = ()>>>);

impl Deref for Task {
    type Target = Pin<Box<dyn Future<Output = ()>>>;

    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl DerefMut for Task {
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

#[cfg(test)]
impl Future for Task {
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        Future::poll(self.deref_mut().as_mut(), cx)
    }
}

impl<F> From<GenericTask<F>> for Task
where
    F: Future + 'static,
{
    fn from(generic_task: GenericTask<F>) -> Self {
        Self(Box::pin(generic_task))
    }
}

struct GenericTask<F>
where
    F: Future + 'static,
{
    future: Pin<Box<F>>,
    /// The only purpose of the `Option` is to be able to take ownership of the `Sender` in the `Future::poll` function.
    output_tx: Option<Sender<F::Output>>,
    abort: Rc<Cell<bool>>,
}

impl<F> Future for GenericTask<F>
where
    F: Future + 'static,
{
    type Output = ();

    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        if self.abort.get() {
            Poll::Ready(())
        } else {
            match Future::poll(self.future.as_mut(), cx) {
                Poll::Ready(value) => {
                    let _ = self.output_tx.take().unwrap().send(value);
                    Poll::Ready(())
                }

                Poll::Pending => Poll::Pending,
            }
        }
    }
}

/// An owned permission to join on a task (await its termination).
///
/// This can be thought of as the equivalent of [`std::thread::JoinHandle`] for a [`crate::LocalSpawnPool`] task rather than a thread. You
/// do not need to `.await` the [`JoinHandle`] to make the task execute — it will start running in the background immediately. When
/// awaiting the `JoinHandle<T>`, you will obtain an `Option<T>`, where `T` is the output of the spawned future associated with this
/// handle: it will be `None` if the task was aborted.
///
/// A [`JoinHandle`] detaches the associated task when it is dropped, which means that there is no longer any handle to the task,
/// and no way to `join` on it.
///
/// This `struct` is created by the [`crate::LocalSpawnPool::spawn`] and [`crate::spawn`] functions.
///
/// ## Cancel safety
///
/// The `JoinHandle<T>` type is cancel safe. If it is used as the event in a `tokio::select!` statement and some other branch
/// completes first, then it is guaranteed that the output of the task is not lost.
///
/// If a [`JoinHandle`] is dropped, then the task continues running in the background and its return value is lost.
pub struct JoinHandle<T>(RefCell<JoinHandleInner<T>>);

enum JoinHandleInner<T> {
    Pending {
        output_rx: Pin<Box<Receiver<T>>>,
        abort: Rc<Cell<bool>>,
    },
    Finished(
        /// The only purpose of the `Option` is to be able to take ownership of `T` in the `Future::poll` function.
        /// It should always always be `Some` before the `Future::poll` returns `Poll::Ready`.
        Option<T>,
    ),
    Aborted,
}

impl<T> JoinHandle<T> {
    fn poll(&self) {
        let mut inner = self.0.borrow_mut();

        if let JoinHandleInner::Pending {
            output_rx,
            abort: _,
        } = &mut *inner
        {
            match output_rx.try_recv() {
                Ok(Some(value)) => *inner = JoinHandleInner::Finished(Some(value)),
                Ok(None) => { /* Still pending */ }
                Err(_) => *inner = JoinHandleInner::Aborted,
            }
        }
    }

    /// Aborts the task associated with the handle.
    pub fn abort(&self) {
        let mut inner = self.0.borrow_mut();

        if let JoinHandleInner::Pending {
            output_rx: _,
            abort,
        } = &*inner
        {
            abort.set(true);
            *inner = JoinHandleInner::Aborted;
        }
    }

    /// Returns `true` if the task has finished executing. If the task was aborted before finishing execution, it returns `false`.
    pub fn is_finished(&self) -> bool {
        self.poll();
        matches!(&*self.0.borrow(), JoinHandleInner::Finished(_))
    }

    /// Returns `true` if the task has been aborted. If [`JoinHandle::abort`] was called after the task finished executing, it still
    /// returns `false`.
    pub fn is_aborted(&self) -> bool {
        self.poll();
        matches!(&*self.0.borrow(), JoinHandleInner::Aborted)
    }
}

impl<T> Future for JoinHandle<T> {
    type Output = Option<T>;

    fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Self::Output> {
        let mut inner = self.0.borrow_mut();

        match &mut *inner {
            JoinHandleInner::Pending {
                output_rx,
                abort: _,
            } => match Future::poll(output_rx.as_mut(), cx) {
                Poll::Ready(Ok(value)) => {
                    *inner = JoinHandleInner::Finished(None);
                    Poll::Ready(Some(value))
                }

                Poll::Ready(Err(_)) => {
                    *inner = JoinHandleInner::Aborted;
                    Poll::Ready(None)
                }

                Poll::Pending => Poll::Pending,
            },

            JoinHandleInner::Finished(value) => Poll::Ready(value.take()),
            JoinHandleInner::Aborted => Poll::Ready(None),
        }
    }
}

#[cfg(test)]
#[tokio::test]
async fn test() {
    use std::time::Duration;
    use tokio::task::LocalSet;
    use tokio::time;

    let local_set = LocalSet::new();

    local_set
        .run_until(async {
            let (task, join_handle) = create_task(async {
                time::sleep(Duration::from_millis(50)).await;
                "test"
            });
            tokio::task::spawn_local(task);
            assert!(!join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            assert_eq!(join_handle.await, Some("test"));

            //

            let (task, join_handle) = create_task(async {
                time::sleep(Duration::from_millis(50)).await;
                "test"
            });
            tokio::task::spawn_local(task);
            time::sleep(Duration::from_millis(100)).await;
            assert!(join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            join_handle.abort();
            assert!(join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            assert_eq!(join_handle.await, Some("test"));

            //

            let (task, join_handle) = create_task(async {
                time::sleep(Duration::from_millis(50)).await;
                "test"
            });
            tokio::task::spawn_local(task);
            assert!(!join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            join_handle.abort();
            assert!(!join_handle.is_finished());
            assert!(join_handle.is_aborted());
            assert_eq!(join_handle.await, None);

            //

            let (task, join_handle) = create_task(async {
                time::sleep(Duration::from_millis(500)).await;
                "test"
            });
            let tokio_join_handle = tokio::task::spawn_local(task);
            assert!(!join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            tokio_join_handle.abort();
            time::sleep(Duration::from_millis(100)).await;
            assert!(!join_handle.is_finished());
            assert!(join_handle.is_aborted());
            assert_eq!(join_handle.await, None);

            //

            let value = Rc::new(Cell::new(0i32));
            let (task, join_handle) = create_task({
                let value = Rc::clone(&value);
                async move {
                    time::sleep(Duration::from_millis(50)).await;
                    value.set(1);
                    "test"
                }
            });
            tokio::task::spawn_local(task);
            assert!(!join_handle.is_finished());
            assert!(!join_handle.is_aborted());
            drop(join_handle);
            assert_eq!(value.get(), 0);
            time::sleep(Duration::from_millis(100)).await;
            assert_eq!(value.get(), 1);
        })
        .await;

    local_set.await;
}