use crate::{Executor, ExecutorBlocking, InnerJoinHandle, JoinHandle};
use std::future::Future;
use std::sync::Arc;
use tokio::runtime::Runtime;
#[derive(Clone, Copy, Debug, PartialOrd, PartialEq, Eq)]
pub struct TokioExecutor;
impl Executor for TokioExecutor {
fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let handle = tokio::task::spawn(future);
let inner = InnerJoinHandle::TokioHandle(handle);
JoinHandle { inner }
}
}
impl ExecutorBlocking for TokioExecutor {
fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let handle = tokio::task::spawn_blocking(f);
let inner = InnerJoinHandle::TokioHandle(handle);
JoinHandle { inner }
}
}
#[derive(Clone, Debug)]
pub struct TokioRuntimeExecutor {
runtime: Arc<Runtime>,
}
impl TokioRuntimeExecutor {
pub fn with_single_thread() -> std::io::Result<Self> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()?;
Ok(Self::with_runtime(runtime))
}
pub fn with_multi_thread() -> std::io::Result<Self> {
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
Ok(Self::with_runtime(runtime))
}
pub fn with_runtime(runtime: Runtime) -> Self {
let runtime = Arc::new(runtime);
Self { runtime }
}
}
impl Executor for TokioRuntimeExecutor {
fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
let handle = self.runtime.spawn(future);
let inner = InnerJoinHandle::TokioHandle(handle);
JoinHandle { inner }
}
}
impl ExecutorBlocking for TokioRuntimeExecutor {
fn spawn_blocking<F, R>(&self, f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
let handle = self.runtime.spawn_blocking(f);
let inner = InnerJoinHandle::TokioHandle(handle);
JoinHandle { inner }
}
}
#[cfg(test)]
mod tests {
use super::TokioExecutor;
use crate::{Executor, ExecutorBlocking};
use futures::channel::mpsc::{Receiver, UnboundedReceiver};
#[tokio::test]
async fn default_abortable_task() {
let executor = TokioExecutor;
async fn task(tx: futures::channel::oneshot::Sender<()>) {
futures_timer::Delay::new(std::time::Duration::from_secs(5)).await;
let _ = tx.send(());
unreachable!();
}
let (tx, rx) = futures::channel::oneshot::channel::<()>();
let handle = executor.spawn_abortable(task(tx));
drop(handle);
let result = rx.await;
assert!(result.is_err());
}
#[tokio::test]
async fn task_coroutine() {
use futures::stream::StreamExt;
let executor = TokioExecutor;
enum Message {
Send(String, futures::channel::oneshot::Sender<String>),
}
let mut task = executor.spawn_coroutine(|mut rx: Receiver<Message>| async move {
while let Some(msg) = rx.next().await {
match msg {
Message::Send(msg, sender) => {
sender.send(msg).unwrap();
}
}
}
});
let (tx, rx) = futures::channel::oneshot::channel::<String>();
let msg = Message::Send("Hello".into(), tx);
task.send(msg).await.unwrap();
let resp = rx.await.unwrap();
assert_eq!(resp, "Hello");
}
#[tokio::test]
async fn task_coroutine_with_context() {
use futures::stream::StreamExt;
let executor = TokioExecutor;
#[derive(Default)]
struct State {
message: String,
}
enum Message {
Set(String),
Get(futures::channel::oneshot::Sender<String>),
}
let mut task = executor.spawn_coroutine_with_context(
State::default(),
|mut state, mut rx: Receiver<Message>| async move {
while let Some(msg) = rx.next().await {
match msg {
Message::Set(msg) => {
state.message = msg;
}
Message::Get(resp) => {
resp.send(state.message.clone()).unwrap();
}
}
}
},
);
let msg = Message::Set("Hello".into());
task.send(msg).await.unwrap();
let (tx, rx) = futures::channel::oneshot::channel::<String>();
let msg = Message::Get(tx);
task.send(msg).await.unwrap();
let resp = rx.await.unwrap();
assert_eq!(resp, "Hello");
}
#[tokio::test]
async fn task_unbounded_coroutine() {
use futures::stream::StreamExt;
let executor = TokioExecutor;
enum Message {
Send(String, futures::channel::oneshot::Sender<String>),
}
let mut task =
executor.spawn_unbounded_coroutine(|mut rx: UnboundedReceiver<Message>| async move {
while let Some(msg) = rx.next().await {
match msg {
Message::Send(msg, sender) => {
sender.send(msg).unwrap();
}
}
}
});
let (tx, rx) = futures::channel::oneshot::channel::<String>();
let msg = Message::Send("Hello".into(), tx);
task.send(msg).unwrap();
let resp = rx.await.unwrap();
assert_eq!(resp, "Hello");
}
#[tokio::test]
async fn task_unbounded_coroutine_with_context() {
use futures::stream::StreamExt;
let executor = TokioExecutor;
#[derive(Default)]
struct State {
message: String,
}
enum Message {
Set(String),
Get(futures::channel::oneshot::Sender<String>),
}
let mut task = executor.spawn_unbounded_coroutine_with_context(
State::default(),
|mut state, mut rx: UnboundedReceiver<Message>| async move {
while let Some(msg) = rx.next().await {
match msg {
Message::Set(msg) => {
state.message = msg;
}
Message::Get(resp) => {
resp.send(state.message.clone()).unwrap();
}
}
}
},
);
let msg = Message::Set("Hello".into());
task.send(msg).unwrap();
let (tx, rx) = futures::channel::oneshot::channel::<String>();
let msg = Message::Get(tx);
task.send(msg).unwrap();
let resp = rx.await.unwrap();
assert_eq!(resp, "Hello");
}
#[tokio::test]
async fn blocking_task() {
let executor = TokioExecutor;
let task = executor.spawn_blocking(|| {
std::thread::sleep(std::time::Duration::from_millis(100));
"Hello"
});
let resp = task.await.unwrap();
assert_eq!(resp, "Hello");
}
}