uasync 0.1.1

fast, safe, async executor
Documentation
use super::{scheduler::Scheduler, thread::Thread, waker::AtomicWaker};
use std::{
    any::Any,
    fmt,
    future::Future,
    io,
    mem::{drop, replace},
    panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
    pin::Pin,
    sync::atomic::{AtomicBool, AtomicU8, Ordering},
    sync::{Arc, Mutex},
    task::{Context, Poll, Wake, Waker},
    thread,
};

pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    let future = Box::pin(future);
    let thread = Thread::current();
    Task::spawn(&thread.scheduler, Some(&thread), future)
}

pub struct JoinHandle<T> {
    joinable: Option<Arc<dyn Joinable<T>>>,
}

impl<T> JoinHandle<T> {
    pub fn abort(&self) {
        if let Some(joinable) = self.joinable.as_ref() {
            if joinable.poll_abort() {
                joinable.clone().abort();
            }
        }
    }
}

impl<T: fmt::Debug> fmt::Debug for JoinHandle<T> {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        f.debug_struct("JoinHandle").finish()
    }
}

impl<T> Future for JoinHandle<T> {
    type Output = Result<T, JoinError>;

    fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
        let joinable = self
            .joinable
            .take()
            .expect("JoinHandle polled after completion");

        if let Poll::Ready(result) = joinable.poll_join(ctx) {
            return Poll::Ready(result);
        }

        self.joinable = Some(joinable);
        Poll::Pending
    }
}

pub struct JoinError {
    panic: Option<Box<dyn Any + Send + 'static>>,
}

impl std::error::Error for JoinError {}

impl From<JoinError> for io::Error {
    fn from(this: JoinError) -> io::Error {
        io::Error::new(
            io::ErrorKind::Other,
            match &this.panic {
                Some(_) => "task panicked",
                None => "task was cancelled",
            },
        )
    }
}

impl fmt::Display for JoinError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match &self.panic {
            Some(_) => write!(f, "panic"),
            None => write!(f, "cancelled"),
        }
    }
}

impl fmt::Debug for JoinError {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        match &self.panic {
            Some(_) => write!(f, "JoinError::Panic(_)"),
            None => write!(f, "JoinError::Cancelled"),
        }
    }
}

impl JoinError {
    pub fn is_cancelled(&self) -> bool {
        self.panic.is_none()
    }

    pub fn is_panic(&self) -> bool {
        self.panic.is_some()
    }

    pub fn try_into_panic(self) -> Result<Box<dyn Any + Send + 'static>, JoinError> {
        match self.panic {
            Some(panic) => Ok(panic),
            None => Err(self),
        }
    }

    pub fn into_panic(self) -> Box<dyn Any + Send + 'static> {
        self.try_into_panic().unwrap()
    }
}

pub async fn yield_now() {
    struct YieldFuture {
        yielded: bool,
    }

    impl Future for YieldFuture {
        type Output = ();

        fn poll(mut self: Pin<&mut Self>, ctx: &mut Context<'_>) -> Poll<Self::Output> {
            if self.yielded {
                return Poll::Ready(());
            }

            ctx.waker().wake_by_ref();
            self.yielded = true;
            Poll::Pending
        }
    }

    YieldFuture { yielded: false }.await
}

#[derive(Default)]
struct TaskState {
    state: AtomicU8,
}

impl TaskState {
    const IDLE: u8 = 0;
    const SCHEDULED: u8 = 1;
    const RUNNING: u8 = 2;
    const NOTIFIED: u8 = 3;
    const ABORTED: u8 = 4;

    fn transition_to_scheduled(&self) -> bool {
        self.state
            .fetch_update(Ordering::Release, Ordering::Relaxed, |state| match state {
                Self::RUNNING => Some(Self::NOTIFIED),
                Self::IDLE => Some(Self::SCHEDULED),
                _ => None,
            })
            .map(|state| state == Self::IDLE)
            .unwrap_or(false)
    }

    fn transition_to_running(&self) -> bool {
        match self.state.swap(Self::RUNNING, Ordering::Acquire) {
            Self::SCHEDULED => true,
            Self::ABORTED => false,
            _ => unreachable!(),
        }
    }

    fn transition_to_idle(&self) -> bool {
        self.state
            .fetch_update(Ordering::AcqRel, Ordering::Acquire, |state| match state {
                Self::RUNNING => Some(Self::IDLE),
                Self::NOTIFIED => Some(Self::SCHEDULED),
                Self::ABORTED => None,
                _ => unreachable!(),
            })
            .map(|state| state == Self::RUNNING)
            .unwrap_or(false)
    }

    fn transition_to_aborted(&self) -> bool {
        self.state.swap(Self::ABORTED, Ordering::Release) == Self::IDLE
    }
}

enum TaskData<F: Future> {
    Empty,
    Polling(Pin<Box<F>>, Waker),
    Ready(F::Output),
    Panic(Box<dyn Any + Send + 'static>),
    Aborted,
}

pub(super) struct Task<F: Future> {
    state: TaskState,
    waker: AtomicWaker,
    data: Mutex<TaskData<F>>,
    scheduler: Arc<Scheduler>,
}

impl<F: Future> Task<F> {
    pub(super) fn block_on(scheduler: &Arc<Scheduler>, mut future: Pin<Box<F>>) -> F::Output {
        struct Parker {
            thread: thread::Thread,
            unparked: AtomicBool,
        }

        impl Parker {
            fn new() -> Self {
                Self {
                    thread: thread::current(),
                    unparked: AtomicBool::new(false),
                }
            }

            fn park(&self) {
                while !self.unparked.swap(false, Ordering::Acquire) {
                    thread::park();
                }
            }
        }

        impl Wake for Parker {
            fn wake(self: Arc<Self>) {
                self.wake_by_ref()
            }

            fn wake_by_ref(self: &Arc<Self>) {
                if !self.unparked.swap(true, Ordering::Release) {
                    self.thread.unpark();
                }
            }
        }

        let result = {
            let parker = Arc::new(Parker::new());
            let waker = Waker::from(parker.clone());
            let mut ctx = Context::from_waker(&waker);

            let thread = Thread::enter(scheduler, None);
            if !Arc::ptr_eq(&thread.scheduler, scheduler) {
                unreachable!("nested block_on() is not supported");
            }

            scheduler.on_task_begin();
            catch_unwind(AssertUnwindSafe(|| loop {
                match future.as_mut().poll(&mut ctx) {
                    Poll::Ready(output) => break output,
                    Poll::Pending => parker.park(),
                }
            }))
        };

        scheduler.on_task_complete();
        match result {
            Ok(output) => output,
            Err(error) => resume_unwind(error),
        }
    }
}

impl<F> Task<F>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    pub(super) fn spawn(
        scheduler: &Arc<Scheduler>,
        thread: Option<&Thread>,
        future: Pin<Box<F>>,
    ) -> JoinHandle<F::Output> {
        let task = Arc::new(Task {
            state: TaskState::default(),
            waker: AtomicWaker::default(),
            data: Mutex::new(TaskData::Empty),
            scheduler: scheduler.clone(),
        });

        let waker = Waker::from(task.clone());
        *task.data.try_lock().unwrap() = TaskData::Polling(future, waker);

        scheduler.on_task_begin();
        assert!(task.state.transition_to_scheduled());
        scheduler.schedule(task.clone(), thread, false);

        JoinHandle {
            joinable: Some(task),
        }
    }
}

impl<F> Wake for Task<F>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    fn wake(self: Arc<Self>) {
        if self.state.transition_to_scheduled() {
            match Thread::try_current().ok().flatten() {
                Some(thread) => thread.scheduler.schedule(self, Some(&thread), false),
                None => self.scheduler.schedule(self.clone(), None, false),
            }
        }
    }

    fn wake_by_ref(self: &Arc<Self>) {
        if self.state.transition_to_scheduled() {
            let task = self.clone();
            match Thread::try_current().ok().flatten() {
                Some(thread) => thread.scheduler.schedule(task, Some(&thread), false),
                None => self.scheduler.schedule(task, None, false),
            }
        }
    }
}

pub(super) trait Runnable: Send + Sync {
    fn run(self: Arc<Self>, thread: &Thread);
}

impl<F> Runnable for Task<F>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    fn run(self: Arc<Self>, thread: &Thread) {
        let mut data = self.data.try_lock().unwrap();
        *data = match self.state.transition_to_running() {
            false => TaskData::Aborted,
            _ => {
                let poll_result = match &mut *data {
                    TaskData::Polling(future, waker) => {
                        let future = future.as_mut();
                        let mut ctx = Context::from_waker(waker);
                        catch_unwind(AssertUnwindSafe(|| future.poll(&mut ctx)))
                    }
                    _ => unreachable!(),
                };

                match poll_result {
                    Err(error) => TaskData::Panic(error),
                    Ok(Poll::Ready(output)) => TaskData::Ready(output),
                    Ok(Poll::Pending) => {
                        drop(data);
                        if self.state.transition_to_idle() {
                            return;
                        }

                        thread.scheduler.schedule(self, Some(thread), true);
                        return;
                    }
                }
            }
        };

        drop(data);
        self.waker.wake();
        thread.scheduler.on_task_complete();
    }
}

trait Joinable<T>: Send + Sync {
    fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<Result<T, JoinError>>;
    fn poll_abort(&self) -> bool;
    fn abort(self: Arc<Self>);
}

impl<F> Joinable<F::Output> for Task<F>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    fn poll_join(&self, ctx: &mut Context<'_>) -> Poll<Result<F::Output, JoinError>> {
        if self.waker.poll(ctx).is_pending() {
            return Poll::Pending;
        }

        match replace(&mut *self.data.try_lock().unwrap(), TaskData::Empty) {
            TaskData::Ready(output) => Poll::Ready(Ok(output)),
            TaskData::Aborted => Poll::Ready(Err(JoinError { panic: None })),
            TaskData::Panic(error) => Poll::Ready(Err(JoinError { panic: Some(error) })),
            _ => unreachable!(),
        }
    }

    fn poll_abort(&self) -> bool {
        self.state.transition_to_aborted()
    }

    fn abort(self: Arc<Self>) {
        self.wake()
    }
}