use async_trait::async_trait;
use executor_trait::{BlockingExecutor, Executor, FullExecutor, LocalExecutorError, Task};
use std::{
future::Future,
pin::Pin,
task::{Context, Poll},
};
use tokio::runtime::Handle;
#[cfg(feature = "tracing")]
use {tracing::Span, tracing_futures::Instrument};
#[derive(Debug, Clone)]
#[cfg_attr(not(feature = "tracing"), derive(Default))]
pub struct Tokio {
handle: Option<Handle>,
#[cfg(feature = "tracing")]
span: Span,
}
#[cfg(feature = "tracing")]
impl Default for Tokio {
fn default() -> Self {
Self {
handle: None,
span: Span::none(),
}
}
}
impl Tokio {
pub fn with_handle(mut self, handle: Handle) -> Self {
self.handle = Some(handle);
self
}
pub fn current() -> Self {
Self::default().with_handle(Handle::current())
}
pub fn handle(&self) -> Option<Handle> {
Handle::try_current().ok().or_else(|| self.handle.clone())
}
}
#[cfg(feature = "tracing")]
impl Tokio {
pub fn with_span(mut self, span: Span) -> Self {
self.span = span;
self
}
pub fn with_current_span(mut self) -> Self {
self.span = Span::current();
self
}
pub fn span(&self) -> Span {
self.span.clone()
}
}
struct TTask(tokio::task::JoinHandle<()>);
impl FullExecutor for Tokio {}
impl Executor for Tokio {
fn block_on(&self, f: Pin<Box<dyn Future<Output = ()>>>) {
#[cfg(feature = "tracing")]
let f = f.instrument(self.span.clone());
if let Some(handle) = self.handle() {
handle.block_on(f);
} else {
Handle::current().block_on(f);
}
}
fn spawn(&self, f: Pin<Box<dyn Future<Output = ()> + Send>>) -> Box<dyn Task> {
#[cfg(feature = "tracing")]
let f = f.instrument(self.span.clone());
Box::new(TTask(if let Some(handle) = self.handle() {
handle.spawn(f)
} else {
tokio::task::spawn(f)
}))
}
fn spawn_local(
&self,
f: Pin<Box<dyn Future<Output = ()>>>,
) -> Result<Box<dyn Task>, LocalExecutorError> {
Err(LocalExecutorError(f))
}
}
#[async_trait]
impl BlockingExecutor for Tokio {
async fn spawn_blocking(&self, f: Box<dyn FnOnce() + Send + 'static>) {
#[cfg(feature = "tracing")]
let f = {
let span = self.span.clone();
move || {
let _entered = span.enter();
f()
}
};
if let Some(handle) = self.handle() {
handle.spawn_blocking(f).await
} else {
tokio::task::spawn_blocking(f).await
}
.expect("blocking task failed");
}
}
#[async_trait(?Send)]
impl Task for TTask {
async fn cancel(self: Box<Self>) -> Option<()> {
self.0.abort();
self.0.await.ok()
}
}
impl Future for TTask {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.0).poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(res) => {
res.expect("task has been canceled");
Poll::Ready(())
}
}
}
}
mod sealed {
use std::ops::Deref;
pub struct Dummy;
impl Deref for super::Tokio {
type Target = Dummy;
fn deref(&self) -> &Self::Target {
&Dummy
}
}
}