use std::{
borrow::Cow,
fmt::{self, Display},
future::Future,
pin::Pin,
task::{Context, Poll},
time::Duration,
};
use futures::{StreamExt, stream::FuturesUnordered};
use thiserror::Error;
use tokio::{
sync::mpsc,
task::{JoinError, JoinHandle},
};
use tracing::{Instrument, debug, error, info, warn};
use crate::notify_once::NotifyOnce;
#[derive(Debug, Error)]
pub enum Error {
#[error("Static task finished prematurely: {name}")]
PrematureFinish { name: Cow<'static, str> },
#[error("Some tasks failed to finish on time: {hung_tasks:?}")]
Hung { hung_tasks: Vec<String> },
}
pub async fn try_join_tasks_and_shutdown(
static_tasks: Vec<LxTask<()>>,
mut eph_tasks_rx: mpsc::Receiver<LxTask<()>>,
mut shutdown: NotifyOnce,
shutdown_timeout: Duration,
) -> Result<(), Error> {
if static_tasks.is_empty() {
shutdown.recv().await;
return Ok(());
}
let mut static_tasks = static_tasks
.into_iter()
.map(LxTask::logged)
.collect::<FuturesUnordered<_>>();
let mut ephemeral_tasks = FuturesUnordered::new();
let mut result = Ok(());
loop {
tokio::select! {
biased;
() = shutdown.recv() => break,
Some(task) = eph_tasks_rx.recv() => {
ephemeral_tasks.push(task.logged());
}
Some(_name) = ephemeral_tasks.next() => {}
Some(name) = static_tasks.next() => {
result = Err(Error::PrematureFinish { name });
break shutdown.send();
}
}
}
let mut all_tasks = static_tasks;
all_tasks.extend(ephemeral_tasks.into_iter());
let shutdown_timeout_fut = tokio::time::sleep(shutdown_timeout);
tokio::pin!(shutdown_timeout_fut);
while !all_tasks.is_empty() {
tokio::select! {
Some(_name) = all_tasks.next() => (),
() = &mut shutdown_timeout_fut => {
let hung_tasks = all_tasks
.iter()
.map(|task| task.name().to_owned())
.collect::<Vec<_>>();
return Err(Error::Hung { hung_tasks });
}
}
}
result
}
pub async fn join_tasks_and_shutdown(
name: &str,
static_tasks: Vec<LxTask<()>>,
eph_tasks_rx: mpsc::Receiver<LxTask<()>>,
shutdown: NotifyOnce,
max_shutdown_delta: Duration,
) {
let result = try_join_tasks_and_shutdown(
static_tasks,
eph_tasks_rx,
shutdown,
max_shutdown_delta,
)
.await;
match result {
Ok(()) => info!("{name} tasks finished."),
Err(e) => error!("{name} tasks errored: {e:#}"),
}
}
#[must_use]
pub struct MaybeLxTask<T>(pub Option<LxTask<T>>);
impl<T> MaybeLxTask<T> {
pub fn detach(self) {
if let Some(task) = self.0 {
task.detach();
}
}
}
#[must_use]
pub struct LxTask<T> {
task: JoinHandle<T>,
name: Cow<'static, str>,
}
pub struct LoggedLxTask<T>(LxTask<T>);
struct TaskOutputDisplay<'a> {
name: &'a str,
result: Result<(), &'a tokio::task::JoinError>,
}
impl<T> LxTask<T> {
pub fn from_tokio(
handle: JoinHandle<T>,
name: impl Into<Cow<'static, str>>,
) -> Self {
Self {
task: handle,
name: name.into(),
}
}
#[inline]
#[allow(clippy::disallowed_methods)]
pub fn spawn<F>(
name: impl Into<Cow<'static, str>>,
future: F,
) -> LxTask<F::Output>
where
F: Future<Output = T> + Send + 'static,
F::Output: Send + 'static,
{
let span = tracing::Span::current();
Self::spawn_with_span(name, span, future)
}
#[inline]
pub fn spawn_unlogged<F>(
name: impl Into<Cow<'static, str>>,
future: F,
) -> LxTask<F::Output>
where
F: Future<Output = T> + Send + 'static,
F::Output: Send + 'static,
{
let name = name.into();
let span = tracing::Span::current();
Self::spawn_unlogged_with_span(name, span, future)
}
#[inline]
#[allow(clippy::disallowed_methods)]
pub fn spawn_unlogged_with_span<F>(
name: impl Into<Cow<'static, str>>,
span: tracing::Span,
future: F,
) -> LxTask<F::Output>
where
F: Future<Output = T> + Send + 'static,
F::Output: Send + 'static,
{
let name = name.into();
Self {
task: tokio::spawn(future.instrument(span)),
name,
}
}
#[inline]
pub fn spawn_unnamed<F>(future: F) -> LxTask<F::Output>
where
F: Future<Output = T> + Send + 'static,
F::Output: Send + 'static,
{
Self::spawn_unlogged("<unnamed>", future)
}
#[inline]
#[allow(clippy::disallowed_methods)]
pub fn spawn_with_span<F>(
name: impl Into<Cow<'static, str>>,
span: tracing::Span,
future: F,
) -> LxTask<F::Output>
where
F: Future<Output = T> + Send + 'static,
F::Output: Send + 'static,
{
let name = name.into();
debug!("Spawning task: {name}");
Self {
task: tokio::spawn(future.instrument(span)),
name,
}
}
#[inline]
pub fn detach(self) {
std::mem::drop(self)
}
#[inline]
pub fn name(&self) -> &str {
&self.name
}
#[inline]
pub fn is_finished(&self) -> bool {
self.task.is_finished()
}
#[inline]
pub fn logged(self) -> LoggedLxTask<T> {
LoggedLxTask(self)
}
#[inline]
pub fn abort(&self) {
self.task.abort();
}
}
impl<T> Future for LxTask<T> {
type Output = Result<T, JoinError>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
let result = match Pin::new(&mut self.task).poll(cx) {
Poll::Ready(result) => result,
Poll::Pending => return Poll::Pending,
};
let result = match result {
Ok(val) => Ok(val),
Err(join_err) => {
{
let name = self.name();
eprintln!("FATAL TASK ERROR: {join_err:#} {name}");
tracing::error!(%name, "FATAL TASK ERROR: {join_err:#}");
}
match join_err.try_into_panic() {
Ok(panic_reason) => {
error!("Task '{name}' panicked!", name = self.name());
std::panic::resume_unwind(panic_reason)
}
Err(join_err) => Err(join_err),
}
}
};
Poll::Ready(result)
}
}
impl<T> LoggedLxTask<T> {
#[inline]
pub fn name(&self) -> &str {
self.0.name()
}
#[inline]
pub fn is_finished(&self) -> bool {
self.0.is_finished()
}
}
impl<T> Future for LoggedLxTask<T> {
type Output = Cow<'static, str>;
fn poll(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
) -> Poll<Self::Output> {
Pin::new(&mut self.0).poll(cx).map(|result| {
let mut log_error = false;
let mut log_warn = false;
match &result {
Ok(_) => (),
Err(e) if e.is_cancelled() => log_warn = true,
Err(e) if e.is_panic() => log_error = true,
_ => log_warn = true,
};
let msg = TaskOutputDisplay {
name: self.name(),
result: result.as_ref().map(|_| ()),
};
if log_error {
error!("{msg}")
} else if log_warn {
warn!("{msg}")
} else {
debug!("{msg}")
}
self.0.name.clone()
})
}
}
impl Display for TaskOutputDisplay<'_> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let join_label = match &self.result {
Ok(_) => "finished",
Err(e) if e.is_cancelled() => "cancelled",
Err(e) if e.is_panic() => "panicked",
_ => "(unknown join error)",
};
let name = self.name;
write!(f, "Task '{name}' {join_label}")?;
if let Err(e) = self.result {
write!(f, ": {e:#}")?;
}
Ok(())
}
}