use std::collections::HashMap;
use std::future::Future;
use anyhow::Context;
use tokio::task::{Id, JoinError, JoinSet};
use crate::shutdown::CancellationToken;
pub struct Tasks {
handles: JoinSet<anyhow::Result<()>>,
names: HashMap<Id, String>,
}
impl Default for Tasks {
fn default() -> Self {
Self {
handles: JoinSet::new(),
names: HashMap::new(),
}
}
}
impl Tasks {
pub fn new() -> Self {
Self::default()
}
pub fn spawn(
&mut self,
name: impl Into<String>,
task: impl Future<Output = anyhow::Result<()>> + Send + 'static,
) -> Id {
let id = self.handles.spawn(task).id();
self.names.insert(id, name.into());
id
}
pub fn spawn_infallible(
&mut self,
name: impl Into<String>,
task: impl Future<Output = ()> + Send + 'static,
) -> Id {
self.spawn(name, async move {
task.await;
Ok(())
})
}
pub async fn join_next(&mut self) -> Option<(String, Result<anyhow::Result<()>, JoinError>)> {
let result = self.handles.join_next_with_id().await?;
let id = match &result {
Ok((id, _)) => *id,
Err(err) => err.id(),
};
let name = self.names.remove(&id).unwrap_or_else(|| "unknown".to_string());
let result = result.map(|(_, output)| output);
Some((name, result))
}
pub fn is_empty(&self) -> bool {
self.handles.is_empty()
}
pub fn len(&self) -> usize {
self.handles.len()
}
pub async fn join_next_as_error(&mut self) -> anyhow::Result<()> {
let Some((task, result)) = self.join_next().await else {
anyhow::bail!("task set is empty");
};
Self::unexpected_completion(&task, result)
}
pub async fn join_next_or_cancelled(&mut self, token: CancellationToken) -> anyhow::Result<()> {
while !token.is_cancelled() {
tokio::select! {
biased;
() = token.cancelled() => break,
result = self.join_next() => {
let Some((task, result)) = result else {
anyhow::bail!("task set is empty");
};
Self::unexpected_completion(&task, result)?;
},
}
}
while let Some((task, result)) = self.join_next().await {
Self::shutdown_completion(&task, result)?;
}
Ok(())
}
fn unexpected_completion(
task: &str,
result: Result<anyhow::Result<()>, JoinError>,
) -> anyhow::Result<()> {
match result {
Ok(Ok(())) => anyhow::bail!("task {task} completed unexpectedly"),
Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed")),
Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
}
}
fn shutdown_completion(
task: &str,
result: Result<anyhow::Result<()>, JoinError>,
) -> anyhow::Result<()> {
match result {
Ok(Ok(())) => Ok(()),
Ok(Err(err)) => Err(err).with_context(|| format!("task {task} failed during shutdown")),
Err(err) if err.is_cancelled() => Ok(()),
Err(err) => Err(err).with_context(|| format!("task {task} failed to join")),
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use super::*;
#[tokio::test]
async fn join_next_or_cancelled_accepts_clean_task_completion_after_cancellation() {
let token = crate::shutdown::CancellationToken::new();
let mut tasks = Tasks::new();
tasks.spawn("worker", {
let token = token.clone();
async move {
token.cancelled().await;
Ok(())
}
});
token.cancel();
tasks
.join_next_or_cancelled(token)
.await
.expect("clean shutdown should not be treated as an error");
}
#[tokio::test]
async fn join_next_or_cancelled_treats_task_completion_before_cancellation_as_error() {
let token = crate::shutdown::CancellationToken::new();
let mut tasks = Tasks::new();
tasks.spawn("worker", async { Ok(()) });
let err = tasks
.join_next_or_cancelled(token)
.await
.expect_err("unexpected task completion should fail before shutdown");
assert_eq!(err.to_string(), "task worker completed unexpectedly");
}
#[tokio::test]
async fn join_next_or_cancelled_waits_for_all_tasks_to_complete_after_cancellation() {
let token = crate::shutdown::CancellationToken::new();
let mut tasks = Tasks::new();
tasks.spawn("worker-a", {
let token = token.clone();
async move {
token.cancelled().await;
Ok(())
}
});
tasks.spawn("worker-b", {
let token = token.clone();
async move {
token.cancelled().await;
tokio::time::sleep(Duration::from_millis(10)).await;
Ok(())
}
});
token.cancel();
tasks
.join_next_or_cancelled(token)
.await
.expect("shutdown should wait for all clean task exits");
assert!(tasks.is_empty());
}
}