graphile_worker_runtime 0.1.3

Async runtime compatibility package for graphile_worker
Documentation
use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};

use futures::future::{Abortable, Aborted};
use thiserror::Error;

#[derive(Debug, Error)]
pub enum JoinError {
    #[error("task was aborted")]
    Aborted,
    #[error("task failed: {0}")]
    Failed(String),
}

#[derive(Clone)]
pub struct AbortHandle(futures::future::AbortHandle);

impl AbortHandle {
    pub fn abort(&self) {
        self.0.abort();
    }
}

pub struct JoinHandle<T> {
    abort_handle: AbortHandle,
    inner: JoinHandleInner<T>,
}

enum JoinHandleInner<T> {
    #[cfg(feature = "runtime-tokio")]
    Tokio(tokio::task::JoinHandle<Result<T, Aborted>>),
    #[cfg(feature = "runtime-async-std")]
    AsyncStd(async_std::task::JoinHandle<Result<T, JoinError>>),
}

impl<T> JoinHandle<T> {
    pub fn abort_handle(&self) -> AbortHandle {
        self.abort_handle.clone()
    }
}

impl<T: Send + 'static> Future for JoinHandle<T> {
    type Output = Result<T, JoinError>;

    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
        match &mut self.inner {
            #[cfg(feature = "runtime-tokio")]
            JoinHandleInner::Tokio(handle) => {
                Pin::new(handle).poll(cx).map(|result| match result {
                    Ok(Ok(value)) => Ok(value),
                    Ok(Err(_)) => Err(JoinError::Aborted),
                    Err(error) => Err(JoinError::Failed(error.to_string())),
                })
            }
            #[cfg(feature = "runtime-async-std")]
            JoinHandleInner::AsyncStd(handle) => Pin::new(handle).poll(cx),
        }
    }
}

pub fn spawn<F>(future: F) -> JoinHandle<F::Output>
where
    F: Future + Send + 'static,
    F::Output: Send + 'static,
{
    let (abort_handle, abort_registration) = futures::future::AbortHandle::new_pair();
    let future = Abortable::new(future, abort_registration);
    let abort_handle = AbortHandle(abort_handle);

    #[cfg(all(feature = "runtime-tokio", not(feature = "runtime-async-std")))]
    {
        let handle =
            tokio::runtime::Handle::try_current().unwrap_or_else(|_| crate::missing_runtime());
        JoinHandle {
            abort_handle,
            inner: JoinHandleInner::Tokio(handle.spawn(future)),
        }
    }

    #[cfg(all(feature = "runtime-tokio", feature = "runtime-async-std"))]
    if let Ok(handle) = tokio::runtime::Handle::try_current() {
        return JoinHandle {
            abort_handle,
            inner: JoinHandleInner::Tokio(handle.spawn(future)),
        };
    }

    #[cfg(feature = "runtime-async-std")]
    {
        async_std_spawn(future, abort_handle)
    }

    #[cfg(not(any(feature = "runtime-tokio", feature = "runtime-async-std")))]
    {
        crate::missing_runtime()
    }
}

#[cfg(feature = "runtime-async-std")]
fn async_std_spawn<F, T>(future: F, abort_handle: AbortHandle) -> JoinHandle<T>
where
    F: Future<Output = Result<T, Aborted>> + Send + 'static,
    T: Send + 'static,
{
    let future = async move {
        match futures::FutureExt::catch_unwind(std::panic::AssertUnwindSafe(future)).await {
            Ok(Ok(value)) => Ok(value),
            Ok(Err(_)) => Err(JoinError::Aborted),
            Err(payload) => Err(JoinError::Failed(panic_payload_to_string(payload))),
        }
    };

    JoinHandle {
        abort_handle,
        inner: JoinHandleInner::AsyncStd(async_std::task::spawn(future)),
    }
}

#[cfg(feature = "runtime-async-std")]
fn panic_payload_to_string(payload: Box<dyn std::any::Any + Send>) -> String {
    if let Some(message) = payload.downcast_ref::<&str>() {
        return (*message).to_string();
    }

    if let Some(message) = payload.downcast_ref::<String>() {
        return message.clone();
    }

    "task panicked".to_string()
}