async-rt 0.1.10

A small library designed to utilize async executors through an common API while extending features.
Documentation
use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
use std::future::Future;
use std::sync::Arc;
use tokio::runtime::Runtime;

/// Tokio executor
#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
pub struct TokioExecutor;

impl Executor for TokioExecutor {
    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + Send + 'static,
        F::Output: Send + 'static,
    {
        let handle = tokio::task::spawn(future);
        let inner = InnerJoinHandle::TokioHandle(handle);
        JoinHandle { inner }
    }
}

impl ExecutorBlocking for TokioExecutor {
    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
    where
        F: FnOnce() -> R + Send + 'static,
        R: Send + 'static,
    {
        let handle = tokio::task::spawn_blocking(f);
        let inner = InnerJoinHandle::TokioHandle(handle);
        JoinHandle { inner }
    }
}

/// Tokio executor with an explicit [`Runtime`]
#[derive(Clone, Debug)]
pub struct TokioRuntimeExecutor {
    runtime: Arc<Runtime>,
}

impl TokioRuntimeExecutor {
    /// Creates a tokio runtime with the current thread scheduler selected.
    pub fn with_single_thread() -> std::io::Result<Self> {
        let runtime = tokio::runtime::Builder::new_current_thread()
            .enable_all()
            .build()?;
        Ok(Self::with_runtime(runtime))
    }

    /// Creates a tokio runtime with multi-thread scheduler selected.
    pub fn with_multi_thread() -> std::io::Result<Self> {
        let runtime = tokio::runtime::Builder::new_multi_thread()
            .enable_all()
            .build()?;
        Ok(Self::with_runtime(runtime))
    }

    /// Create an executor with the supplied [`Runtime`].
    pub fn with_runtime(runtime: Runtime) -> Self {
        let runtime = Arc::new(runtime);
        Self { runtime }
    }
}

impl Executor for TokioRuntimeExecutor {
    fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
    where
        F: Future + Send + 'static,
        F::Output: Send + 'static,
    {
        let handle = self.runtime.spawn(future);
        let inner = InnerJoinHandle::TokioHandle(handle);
        JoinHandle { inner }
    }
}

impl ExecutorBlocking for TokioRuntimeExecutor {
    fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
    where
        F: FnOnce() -> R + Send + 'static,
        R: Send + 'static,
    {
        let handle = self.runtime.spawn_blocking(f);
        let inner = InnerJoinHandle::TokioHandle(handle);
        JoinHandle { inner }
    }
}

#[cfg(test)]
mod tests {
    use super::TokioExecutor;
    use crate::{Executor, ExecutorBlocking};
    use futures::channel::mpsc::{Receiver, UnboundedReceiver};

    #[tokio::test]
    async fn default_abortable_task() {
        let executor = TokioExecutor;

        async fn task(tx: futures::channel::oneshot::Sender<()>) {
            futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
            let _ = tx.send(());
            unreachable!();
        }

        let (tx, rx) = futures::channel::oneshot::channel::<()>();

        let handle = executor.spawn_abortable(task(tx));

        drop(handle);
        let result = rx.await;
        assert!(result.is_err());
    }

    #[tokio::test]
    async fn task_coroutine() {
        use futures::stream::StreamExt;
        let executor = TokioExecutor;

        enum Message {
            Send(String, futures::channel::oneshot::Sender<String>),
        }

        let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
            while let Some(msg) = rx.next().await {
                match msg {
                    Message::Send(msg, sender) => {
                        sender.send(msg).unwrap();
                    }
                }
            }
        });

        let (tx, rx) = futures::channel::oneshot::channel::<String>();
        let msg = Message::Send("Hello".into(), tx);

        task.send(msg).await.unwrap();
        let resp = rx.await.unwrap();
        assert_eq!(resp, "Hello");
    }

    #[tokio::test]
    async fn task_coroutine_with_context() {
        use futures::stream::StreamExt;
        let executor = TokioExecutor;

        #[derive(Default)]
        struct State {
            message: String,
        }

        enum Message {
            Set(String),
            Get(futures::channel::oneshot::Sender<String>),
        }

        let mut task = executor.spawn_coroutine_with_context(
            State::default(),
            |mut state, mut rx: Receiver<Message>| async move {
                while let Some(msg) = rx.next().await {
                    match msg {
                        Message::Set(msg) => {
                            state.message = msg;
                        }
                        Message::Get(resp) => {
                            resp.send(state.message.clone()).unwrap();
                        }
                    }
                }
            },
        );

        let msg = Message::Set("Hello".into());

        task.send(msg).await.unwrap();
        let (tx, rx) = futures::channel::oneshot::channel::<String>();
        let msg = Message::Get(tx);
        task.send(msg).await.unwrap();
        let resp = rx.await.unwrap();
        assert_eq!(resp, "Hello");
    }

    #[tokio::test]
    async fn task_unbounded_coroutine() {
        use futures::stream::StreamExt;
        let executor = TokioExecutor;

        enum Message {
            Send(String, futures::channel::oneshot::Sender<String>),
        }

        let mut task =
            executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
                while let Some(msg) = rx.next().await {
                    match msg {
                        Message::Send(msg, sender) => {
                            sender.send(msg).unwrap();
                        }
                    }
                }
            });

        let (tx, rx) = futures::channel::oneshot::channel::<String>();
        let msg = Message::Send("Hello".into(), tx);

        task.send(msg).unwrap();
        let resp = rx.await.unwrap();
        assert_eq!(resp, "Hello");
    }

    #[tokio::test]
    async fn task_unbounded_coroutine_with_context() {
        use futures::stream::StreamExt;
        let executor = TokioExecutor;

        #[derive(Default)]
        struct State {
            message: String,
        }

        enum Message {
            Set(String),
            Get(futures::channel::oneshot::Sender<String>),
        }

        let mut task = executor.spawn_unbounded_coroutine_with_context(
            State::default(),
            |mut state, mut rx: UnboundedReceiver<Message>| async move {
                while let Some(msg) = rx.next().await {
                    match msg {
                        Message::Set(msg) => {
                            state.message = msg;
                        }
                        Message::Get(resp) => {
                            resp.send(state.message.clone()).unwrap();
                        }
                    }
                }
            },
        );

        let msg = Message::Set("Hello".into());

        task.send(msg).unwrap();
        let (tx, rx) = futures::channel::oneshot::channel::<String>();
        let msg = Message::Get(tx);
        task.send(msg).unwrap();
        let resp = rx.await.unwrap();
        assert_eq!(resp, "Hello");
    }

    #[tokio::test]
    async fn blocking_task() {
        let executor = TokioExecutor;

        let task = executor.spawn_blocking(|| {
            std::thread::sleep(std::time::Duration::from_millis(100));
            "Hello"
        });
        let resp = task.await.unwrap();
        assert_eq!(resp, "Hello");
    }
}