use crate::error::{Error, Result};
use crate::task_map::TaskMap;
use std::collections::HashMap;
use std::time::Duration;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument, instrument};
pub type BoxedError = Box<dyn std::error::Error + Send + Sync>;
type TaskOutcome<T> = (&'static str, std::result::Result<T, BoxedError>);
pub type PartialResults<T> = HashMap<&'static str, std::result::Result<T, BoxedError>>;
pub struct Executor<T> {
tasks: TaskMap<T>,
cancellation: Option<CancellationToken>,
timeout: Option<Duration>,
}
impl<T> Executor<T>
where
T: Send + 'static,
{
#[must_use]
pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
self.cancellation = Some(token);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
#[must_use]
pub fn with_partial_results(self) -> PartialExecutor<T> {
PartialExecutor {
tasks: self.tasks,
cancellation: self.cancellation,
timeout: self.timeout,
}
}
}
impl<T> std::future::IntoFuture for Executor<T>
where
T: Send + 'static,
{
type Output = Result<HashMap<&'static str, T>>;
type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { run_fail_fast(self).await })
}
}
pub struct PartialExecutor<T> {
tasks: TaskMap<T>,
cancellation: Option<CancellationToken>,
timeout: Option<Duration>,
}
impl<T> PartialExecutor<T>
where
T: Send + 'static,
{
#[must_use]
pub fn with_cancellation(mut self, token: CancellationToken) -> Self {
self.cancellation = Some(token);
self
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = Some(timeout);
self
}
}
impl<T> std::future::IntoFuture for PartialExecutor<T>
where
T: Send + 'static,
{
type Output = Result<PartialResults<T>>;
type IntoFuture = std::pin::Pin<Box<dyn std::future::Future<Output = Self::Output> + Send>>;
fn into_future(self) -> Self::IntoFuture {
Box::pin(async move { run_partial(self).await })
}
}
fn spawn_tasks<T>(tasks: TaskMap<T>, token: &CancellationToken) -> JoinSet<TaskOutcome<T>>
where
T: Send + 'static,
{
let mut set: JoinSet<TaskOutcome<T>> = JoinSet::new();
for (name, task_fn) in tasks.tasks {
let child_token = token.clone();
let span = tracing::info_span!("concurrent.task", task.name = name);
set.spawn(
async move {
let result = task_fn(child_token).await;
(name, result)
}
.instrument(span),
);
}
set
}
#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
async fn run_fail_fast<T>(executor: Executor<T>) -> Result<HashMap<&'static str, T>>
where
T: Send + 'static,
{
let token = executor.cancellation.unwrap_or_default();
let mut set = spawn_tasks(executor.tasks, &token);
let mut results: HashMap<&'static str, T> = HashMap::new();
let timeout = executor.timeout;
loop {
let outcome = next_outcome(&mut set, &token, timeout).await?;
match outcome {
None => break,
Some((name, Ok(v))) => {
results.insert(name, v);
}
Some((name, Err(e))) => {
token.cancel();
set.shutdown().await;
return Err(Error::TaskFailed { name, source: e });
}
}
if token.is_cancelled() && set.is_empty() {
return Err(Error::Cancelled);
}
}
Ok(results)
}
#[instrument(skip(executor), fields(task_count = executor.tasks.len()))]
async fn run_partial<T>(executor: PartialExecutor<T>) -> Result<PartialResults<T>>
where
T: Send + 'static,
{
let token = executor.cancellation.unwrap_or_default();
let mut set = spawn_tasks(executor.tasks, &token);
let mut results: PartialResults<T> = HashMap::new();
let timeout = executor.timeout;
loop {
let outcome = next_outcome(&mut set, &token, timeout).await?;
match outcome {
None => break,
Some((name, result)) => {
results.insert(name, result);
}
}
}
Ok(results)
}
async fn next_outcome<T>(
set: &mut JoinSet<TaskOutcome<T>>,
token: &CancellationToken,
timeout: Option<Duration>,
) -> Result<Option<TaskOutcome<T>>>
where
T: Send + 'static,
{
let next = async { set.join_next().await };
let raw = if let Some(d) = timeout {
if let Ok(v) = tokio::time::timeout(d, next).await {
v
} else {
token.cancel();
set.shutdown().await;
return Err(Error::Timeout);
}
} else {
next.await
};
match raw {
None => Ok(None),
Some(Ok(outcome)) => Ok(Some(outcome)),
Some(Err(e)) => Err(Error::Join(e)),
}
}
#[must_use]
pub fn execute_concurrently<T>(tasks: TaskMap<T>) -> Executor<T> {
Executor {
tasks,
cancellation: None,
timeout: None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use pretty_assertions::assert_eq;
#[tokio::test]
async fn empty_map_resolves_to_empty_results() {
let m: TaskMap<u32> = TaskMap::new();
let r = execute_concurrently(m).await.unwrap();
assert!(r.is_empty());
}
#[tokio::test]
async fn two_tasks_complete() {
let m: TaskMap<u32> = TaskMap::new()
.insert("a", |_| async { Ok::<_, std::io::Error>(1) })
.insert("b", |_| async { Ok::<_, std::io::Error>(2) });
let r = execute_concurrently(m).await.unwrap();
assert_eq!(r["a"], 1);
assert_eq!(r["b"], 2);
}
#[tokio::test]
async fn failing_task_returns_task_failed_error() {
let m: TaskMap<u32> = TaskMap::new()
.insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
.insert("bad", |_| async {
Err::<u32, std::io::Error>(std::io::Error::other("boom"))
});
let err = execute_concurrently(m).await.unwrap_err();
match err {
Error::TaskFailed { name, .. } => assert_eq!(name, "bad"),
other => panic!("expected TaskFailed, got {other:?}"),
}
}
#[tokio::test]
async fn timeout_returns_timeout_error() {
let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok::<_, std::io::Error>(1)
});
let err = execute_concurrently(m)
.with_timeout(Duration::from_millis(50))
.await
.unwrap_err();
assert!(matches!(err, Error::Timeout));
}
#[tokio::test]
async fn external_cancellation_causes_cancelled_error() {
let token = CancellationToken::new();
let inner = token.clone();
let m: TaskMap<u32> = TaskMap::new().insert("waiter", move |ct| async move {
ct.cancelled().await;
Err::<u32, std::io::Error>(std::io::Error::other("cancelled"))
});
let handle =
tokio::spawn(async move { execute_concurrently(m).with_cancellation(token).await });
tokio::time::sleep(Duration::from_millis(20)).await;
inner.cancel();
let err = handle.await.unwrap().unwrap_err();
assert!(matches!(err, Error::TaskFailed { .. } | Error::Cancelled));
}
#[tokio::test]
async fn partial_results_returns_per_task_results() {
let m: TaskMap<u32> = TaskMap::new()
.insert("ok", |_| async { Ok::<_, std::io::Error>(1) })
.insert("bad", |_| async {
Err::<u32, std::io::Error>(std::io::Error::other("boom"))
})
.insert("also_ok", |_| async { Ok::<_, std::io::Error>(2) });
let r = execute_concurrently(m)
.with_partial_results()
.await
.unwrap();
assert_eq!(r.len(), 3);
assert!(r["ok"].is_ok());
assert!(r["bad"].is_err());
assert!(r["also_ok"].is_ok());
}
#[tokio::test]
async fn partial_timeout_still_propagates() {
let m: TaskMap<u32> = TaskMap::new().insert("slow", |_| async {
tokio::time::sleep(Duration::from_secs(10)).await;
Ok::<_, std::io::Error>(1)
});
let err = execute_concurrently(m)
.with_partial_results()
.with_timeout(Duration::from_millis(20))
.await
.unwrap_err();
assert!(matches!(err, Error::Timeout));
}
}