use std::future::Future;
#[derive(Debug, thiserror::Error)]
pub enum JoinError {
#[error("task was cancelled")]
Cancelled,
#[error("task panicked")]
Panicked,
}
pub trait TaskProvider: Clone + Send + Sync + 'static {
type JoinHandle: Future<Output = Result<(), JoinError>> + Send + Sync + 'static;
fn spawn_task<F>(&self, name: &str, future: F) -> Self::JoinHandle
where
F: Future<Output = ()> + Send + 'static;
fn yield_now(&self) -> impl Future<Output = ()> + Send;
}
#[cfg(feature = "tokio-providers")]
#[derive(Clone, Debug)]
pub struct TokioTaskProvider;
#[cfg(feature = "tokio-providers")]
#[derive(Debug)]
pub struct TokioJoinHandle(tokio::task::JoinHandle<()>);
#[cfg(feature = "tokio-providers")]
impl Future for TokioJoinHandle {
type Output = Result<(), JoinError>;
fn poll(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Self::Output> {
use std::task::Poll;
match std::pin::Pin::new(&mut self.0).poll(cx) {
Poll::Ready(Ok(())) => Poll::Ready(Ok(())),
Poll::Ready(Err(e)) if e.is_cancelled() => Poll::Ready(Err(JoinError::Cancelled)),
Poll::Ready(Err(_)) => Poll::Ready(Err(JoinError::Panicked)),
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(feature = "tokio-providers")]
impl TaskProvider for TokioTaskProvider {
type JoinHandle = TokioJoinHandle;
fn spawn_task<F>(&self, name: &str, future: F) -> Self::JoinHandle
where
F: Future<Output = ()> + Send + 'static,
{
let task_name = name.to_string();
let inner = tokio::task::Builder::new()
.name(name)
.spawn(async move {
tracing::trace!("Task {} starting", task_name);
future.await;
tracing::trace!("Task {} completed", task_name);
})
.expect("Failed to spawn task");
TokioJoinHandle(inner)
}
async fn yield_now(&self) {
tokio::task::yield_now().await;
}
}