use crate::{supervision::Tree, utils::extract_panic_message, Error};
use commonware_utils::{
channel::oneshot,
sync::{Mutex, Once},
};
use futures::{
future::{select, Either},
pin_mut,
stream::{AbortHandle, Abortable, Aborted},
FutureExt as _,
};
use prometheus_client::metrics::gauge::Gauge;
use std::{
any::Any,
future::Future,
panic::{resume_unwind, AssertUnwindSafe},
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tracing::error;
pub struct Handle<T>
where
T: Send + 'static,
{
abort_handle: Option<AbortHandle>,
receiver: oneshot::Receiver<Result<T, Error>>,
metric: MetricHandle,
}
impl<T> Handle<T>
where
T: Send + 'static,
{
#[inline(always)]
pub(crate) fn init<F>(
f: F,
metric: MetricHandle,
panicker: Panicker,
tree: Arc<Tree>,
) -> (impl Future<Output = ()>, Self)
where
F: Future<Output = T> + Send + 'static,
{
let (sender, receiver) = oneshot::channel();
let (abort_handle, abort_registration) = AbortHandle::new_pair();
let metric_handle = metric.clone();
let task = async move {
let result =
Abortable::new(AssertUnwindSafe(f).catch_unwind(), abort_registration).await;
match result {
Ok(Ok(result)) => {
let _ = sender.send(Ok(result));
}
Ok(Err(panic)) => {
panicker.notify(panic);
let _ = sender.send(Err(Error::Exited));
}
Err(Aborted) => {}
}
tree.abort();
metric_handle.finish();
};
(
task,
Self {
abort_handle: Some(abort_handle),
receiver,
metric,
},
)
}
pub(crate) fn closed(metric: MetricHandle) -> Self {
metric.finish();
let (sender, receiver) = oneshot::channel();
drop(sender);
Self {
abort_handle: None,
receiver,
metric,
}
}
pub fn abort(&self) {
let Some(abort_handle) = &self.abort_handle else {
return;
};
abort_handle.abort();
self.metric.finish();
}
pub(crate) fn aborter(&self) -> Option<Aborter> {
self.abort_handle
.clone()
.map(|inner| Aborter::new(inner, self.metric.clone()))
}
}
impl<T> Future for Handle<T>
where
T: Send + 'static,
{
type Output = Result<T, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.receiver)
.poll(cx)
.map(|result| result.unwrap_or_else(|_| Err(Error::Closed)))
}
}
#[derive(Clone)]
pub(crate) struct MetricHandle {
gauge: Gauge,
finished: Arc<Once>,
}
impl MetricHandle {
pub(crate) fn new(gauge: Gauge) -> Self {
gauge.inc();
Self {
gauge,
finished: Arc::new(Once::new()),
}
}
pub(crate) fn finish(&self) {
let gauge = self.gauge.clone();
self.finished.call_once(move || {
gauge.dec();
});
}
}
pub type Panic = Box<dyn Any + Send + 'static>;
#[derive(Clone)]
pub(crate) struct Panicker {
catch: bool,
sender: Arc<Mutex<Option<oneshot::Sender<Panic>>>>,
}
impl Panicker {
pub(crate) fn new(catch: bool) -> (Self, Panicked) {
let (sender, receiver) = oneshot::channel();
let panicker = Self {
catch,
sender: Arc::new(Mutex::new(Some(sender))),
};
let panicked = Panicked { receiver };
(panicker, panicked)
}
#[commonware_macros::stability(ALPHA)]
pub(crate) const fn catch(&self) -> bool {
self.catch
}
pub(crate) fn notify(&self, panic: Box<dyn Any + Send + 'static>) {
let err = extract_panic_message(&*panic);
error!(?err, "task panicked");
if self.catch {
return;
}
let mut sender = self.sender.lock();
let Some(sender) = sender.take() else {
return;
};
let _ = sender.send(panic);
}
}
pub(crate) struct Panicked {
receiver: oneshot::Receiver<Panic>,
}
impl Panicked {
pub(crate) async fn interrupt<Fut>(self, task: Fut) -> Fut::Output
where
Fut: Future,
{
let panicked = self.receiver;
pin_mut!(panicked);
pin_mut!(task);
match select(panicked, task).await {
Either::Left((panic, task)) => match panic {
Ok(panic) => {
resume_unwind(panic);
}
Err(_) => task.await,
},
Either::Right((output, _)) => {
output
}
}
}
}
pub(crate) struct Aborter {
inner: AbortHandle,
metric: MetricHandle,
}
impl Aborter {
pub(crate) const fn new(inner: AbortHandle, metric: MetricHandle) -> Self {
Self { inner, metric }
}
pub(crate) fn abort(self) {
self.inner.abort();
self.metric.finish();
}
}
#[cfg(test)]
mod tests {
use crate::{deterministic, Metrics, Runner, Spawner};
use futures::future;
const METRIC_PREFIX: &str = "runtime_tasks_running{";
fn running_tasks_for_label(metrics: &str, label: &str) -> Option<u64> {
let label_fragment = format!("name=\"{label}\"");
metrics.lines().find_map(|line| {
if line.starts_with(METRIC_PREFIX) && line.contains(&label_fragment) {
line.rsplit_once(' ')
.and_then(|(_, value)| value.trim().parse::<u64>().ok())
} else {
None
}
})
}
#[test]
fn tasks_running_decreased_after_completion() {
const LABEL: &str = "tasks_running_after_completion";
let runner = deterministic::Runner::default();
runner.start(|context| async move {
let context = context.with_label(LABEL);
let handle = context.clone().spawn(|_| async move { "done" });
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"expected tasks_running gauge to be 1 before completion: {metrics}",
);
let output = handle.await.expect("task failed");
assert_eq!(output, "done");
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(0),
"expected tasks_running gauge to return to 0 after completion: {metrics}",
);
});
}
#[test]
fn tasks_running_unchanged_when_handle_dropped() {
const LABEL: &str = "tasks_running_unchanged";
let runner = deterministic::Runner::default();
runner.start(|context| async move {
let context = context.with_label(LABEL);
let handle = context.clone().spawn(|_| async move {
future::pending::<()>().await;
});
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"expected tasks_running gauge to be 1 before dropping handle: {metrics}",
);
drop(handle);
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"dropping handle should not finish metrics: {metrics}",
);
});
}
#[test]
fn tasks_running_decreased_immediately_on_abort_via_handle() {
const LABEL: &str = "tasks_running_abort_via_handle";
let runner = deterministic::Runner::default();
runner.start(|context| async move {
let context = context.with_label(LABEL);
let handle = context.clone().spawn(|_| async move {
future::pending::<()>().await;
});
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"expected tasks_running gauge to be 1 before abort: {metrics}",
);
handle.abort();
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(0),
"expected tasks_running gauge to return to 0 after abort: {metrics}",
);
});
}
#[test]
fn tasks_running_decreased_after_blocking_completion() {
const LABEL: &str = "tasks_running_after_blocking_completion";
let runner = deterministic::Runner::default();
runner.start(|context| async move {
let context = context.with_label(LABEL);
let blocking_handle = context.clone().shared(true).spawn(|_| async move {
42
});
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"expected tasks_running gauge to be 1 while blocking task runs: {metrics}",
);
let result = blocking_handle.await.expect("blocking task failed");
assert_eq!(result, 42);
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(0),
"expected tasks_running gauge to return to 0 after blocking task completes: {metrics}",
);
});
}
#[test]
fn tasks_running_decreased_immediately_on_abort_via_aborter() {
const LABEL: &str = "tasks_running_abort_via_aborter";
let runner = deterministic::Runner::default();
runner.start(|context| async move {
let context = context.with_label(LABEL);
let handle = context.clone().spawn(|_| async move {
future::pending::<()>().await;
});
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(1),
"expected tasks_running gauge to be 1 before abort: {metrics}",
);
let aborter = handle.aborter().unwrap();
aborter.abort();
let metrics = context.encode();
assert_eq!(
running_tasks_for_label(&metrics, LABEL),
Some(0),
"expected tasks_running gauge to return to 0 after abort: {metrics}",
);
});
}
}