use std::{fmt, sync::Arc};
use crate::utils::try_extract_panic_message;
use eyre::Context as _;
use futures::{FutureExt as _, future::BoxFuture};
use tokio::sync::Barrier;
use super::{StopReceiver, named_future::NamedFuture};
use crate::task::{Task, TaskKind};
pub(crate) type NamedBoxFuture<T> = NamedFuture<BoxFuture<'static, T>>;
#[derive(Default)]
pub(super) struct Runnables {
pub(super) tasks: Vec<Box<dyn Task>>,
pub(super) shutdown_hooks: Vec<NamedBoxFuture<eyre::Result<()>>>,
}
impl fmt::Debug for Runnables {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Runnables")
.field("tasks", &self.tasks)
.field("shutdown_hooks", &self.shutdown_hooks)
.finish()
}
}
pub(super) struct TaskReprs {
pub(super) tasks: Vec<NamedBoxFuture<eyre::Result<()>>>,
pub(super) shutdown_hooks: Vec<NamedBoxFuture<eyre::Result<()>>>,
}
impl fmt::Debug for TaskReprs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TaskReprs")
.field("long_running_tasks", &self.tasks.len())
.field("shutdown_hooks", &self.shutdown_hooks.len())
.finish()
}
}
impl Runnables {
pub(super) fn is_empty(&self) -> bool {
self.tasks.is_empty()
}
pub(super) fn task_barrier(&self) -> Arc<Barrier> {
let barrier_size = self
.tasks
.iter()
.filter(|t| {
matches!(
t.kind(),
TaskKind::Precondition | TaskKind::OneshotTask | TaskKind::Task
)
})
.count();
Arc::new(Barrier::new(barrier_size))
}
pub(super) fn prepare_tasks(
&mut self,
task_barrier: Arc<Barrier>,
stop_receiver: StopReceiver,
) -> TaskReprs {
let mut long_running_tasks = Vec::new();
let mut oneshot_tasks = Vec::new();
for task in std::mem::take(&mut self.tasks) {
let name = task.id();
let kind = task.kind();
let stop_receiver = stop_receiver.clone();
let task_barrier = task_barrier.clone();
let task_future: BoxFuture<'static, _> =
Box::pin(task.run_internal(stop_receiver, task_barrier));
let named_future = NamedFuture::new(task_future, name);
if kind.is_oneshot() {
oneshot_tasks.push(named_future);
} else {
long_running_tasks.push(named_future);
}
}
let only_oneshot_tasks = long_running_tasks.is_empty();
let oneshot_runner_system_task =
oneshot_runner_task(oneshot_tasks, stop_receiver, only_oneshot_tasks);
long_running_tasks.push(oneshot_runner_system_task);
TaskReprs {
tasks: long_running_tasks,
shutdown_hooks: std::mem::take(&mut self.shutdown_hooks),
}
}
}
fn oneshot_runner_task(
oneshot_tasks: Vec<NamedBoxFuture<eyre::Result<()>>>,
mut stop_receiver: StopReceiver,
only_oneshot_tasks: bool,
) -> NamedBoxFuture<eyre::Result<()>> {
let future = async move {
let oneshot_tasks = oneshot_tasks.into_iter().map(|fut| async move {
let handle = tokio::runtime::Handle::current();
let name = fut.id().to_string();
match handle.spawn(fut).await {
Ok(Ok(())) => Ok(()),
Ok(Err(err)) => Err(err).with_context(|| format!("Oneshot task {name} failed")),
Err(panic_err) => {
let panic_msg = try_extract_panic_message(panic_err);
Err(eyre::format_err!(
"Oneshot task {name} panicked: {panic_msg}"
))
}
}
});
match futures::future::try_join_all(oneshot_tasks).await {
Err(err) => Err(err),
Ok(_) if only_oneshot_tasks => {
Ok(())
}
Ok(_) => {
stop_receiver.0.changed().await.ok();
Ok(())
}
}
};
NamedBoxFuture::new(future.boxed(), "oneshot_runner".into())
}