use futures::{future::BoxFuture, Future};
use crate::DeltaResult;
pub trait TaskExecutor: Send + Sync + 'static {
fn block_on<T>(&self, task: T) -> T::Output
where
T: Future + Send + 'static,
T::Output: Send + 'static;
fn spawn<F>(&self, task: F)
where
F: Future<Output = ()> + Send + 'static;
fn spawn_blocking<T, R>(&self, task: T) -> BoxFuture<'_, DeltaResult<R>>
where
T: FnOnce() -> R + Send + 'static,
R: Send + 'static;
}
#[cfg(any(feature = "tokio", test))]
pub mod tokio {
use super::TaskExecutor;
use futures::TryFutureExt;
use futures::{future::BoxFuture, Future};
use std::sync::mpsc::channel;
use tokio::runtime::RuntimeFlavor;
use crate::DeltaResult;
#[derive(Debug)]
pub struct TokioBackgroundExecutor {
sender: tokio::sync::mpsc::Sender<BoxFuture<'static, ()>>,
_thread: std::thread::JoinHandle<()>,
}
impl Default for TokioBackgroundExecutor {
fn default() -> Self {
Self::new()
}
}
impl TokioBackgroundExecutor {
pub fn new() -> Self {
let (sender, mut receiver) = tokio::sync::mpsc::channel::<BoxFuture<'_, ()>>(50);
let thread = std::thread::spawn(move || {
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.unwrap();
rt.block_on(async move {
while let Some(task) = receiver.recv().await {
tokio::task::spawn(task);
}
});
});
Self {
sender,
_thread: thread,
}
}
}
impl TokioBackgroundExecutor {
fn send_future(&self, fut: BoxFuture<'static, ()>) {
let mut fut = Some(fut);
loop {
match self.sender.try_send(fut.take().unwrap()) {
Ok(()) => break,
Err(tokio::sync::mpsc::error::TrySendError::Full(original)) => {
std::thread::yield_now();
fut.replace(original);
}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
panic!("TokioBackgroundExecutor channel closed")
}
};
}
}
}
impl TaskExecutor for TokioBackgroundExecutor {
fn block_on<T>(&self, task: T) -> T::Output
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let (sender, receiver) = channel::<T::Output>();
let fut = Box::pin(async move {
let task_output = task.await;
tokio::task::spawn_blocking(move || {
sender.send(task_output).ok();
})
.await
.unwrap();
});
self.send_future(fut);
receiver
.recv()
.expect("TokioBackgroundExecutor has crashed")
}
fn spawn<F>(&self, task: F)
where
F: Future<Output = ()> + Send + 'static,
{
self.send_future(Box::pin(task));
}
fn spawn_blocking<T, R>(&self, task: T) -> BoxFuture<'_, DeltaResult<R>>
where
T: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
Box::pin(tokio::task::spawn_blocking(task).map_err(crate::Error::join_failure))
}
}
#[derive(Debug)]
pub struct TokioMultiThreadExecutor {
handle: tokio::runtime::Handle,
}
impl TokioMultiThreadExecutor {
pub fn new(handle: tokio::runtime::Handle) -> Self {
assert_eq!(
handle.runtime_flavor(),
RuntimeFlavor::MultiThread,
"TokioExecutor must be created with a multi-threaded runtime"
);
Self { handle }
}
}
impl TaskExecutor for TokioMultiThreadExecutor {
fn block_on<T>(&self, task: T) -> T::Output
where
T: Future + Send + 'static,
T::Output: Send + 'static,
{
let (sender, receiver) = channel::<T::Output>();
let fut = Box::pin(async move {
let task_output = task.await;
tokio::task::spawn_blocking(move || {
sender.send(task_output).ok();
})
.await
.unwrap();
});
self.handle.spawn(fut);
receiver
.recv()
.expect("TokioMultiThreadExecutor has crashed")
}
fn spawn<F>(&self, task: F)
where
F: Future<Output = ()> + Send + 'static,
{
self.handle.spawn(task);
}
fn spawn_blocking<T, R>(&self, task: T) -> BoxFuture<'_, DeltaResult<R>>
where
T: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
Box::pin(tokio::task::spawn_blocking(task).map_err(crate::Error::join_failure))
}
}
#[cfg(test)]
mod test {
use super::*;
async fn test_executor(executor: impl TaskExecutor) {
let task = async {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
2 + 2
};
let result = executor.block_on(task);
assert_eq!(result, 4);
let (sender, receiver) = channel::<i32>();
executor.spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
sender.send(2 + 2).unwrap();
});
let result = receiver.recv().unwrap();
assert_eq!(result, 4);
}
#[tokio::test]
async fn test_tokio_background_executor() {
let executor = TokioBackgroundExecutor::new();
test_executor(executor).await;
}
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn test_tokio_multi_thread_executor() {
let executor = TokioMultiThreadExecutor::new(tokio::runtime::Handle::current());
test_executor(executor).await;
}
}
}