use crate::{utils::extract_panic_message, Error};
use futures::{
channel::oneshot,
stream::{AbortHandle, Abortable},
FutureExt as _,
};
use prometheus_client::metrics::gauge::Gauge;
use std::{
future::Future,
panic::{catch_unwind, resume_unwind, AssertUnwindSafe},
pin::Pin,
sync::{Arc, Mutex, Once},
task::{Context, Poll},
};
use tracing::error;
pub struct Handle<T>
where
T: Send + 'static,
{
aborter: Option<AbortHandle>,
receiver: oneshot::Receiver<Result<T, Error>>,
running: Gauge,
once: Arc<Once>,
}
impl<T> Handle<T>
where
T: Send + 'static,
{
pub(crate) fn init_future<F>(
f: F,
running: Gauge,
catch_panic: bool,
children: Arc<Mutex<Vec<AbortHandle>>>,
) -> (impl Future<Output = ()>, Self)
where
F: Future<Output = T> + Send + 'static,
{
running.inc();
let once = Arc::new(Once::new());
let (sender, receiver) = oneshot::channel();
let (aborter, abort_registration) = AbortHandle::new_pair();
let wrapped = {
let once = once.clone();
let running = running.clone();
async move {
let result = AssertUnwindSafe(f).catch_unwind().await;
once.call_once(|| {
running.dec();
});
let result = match result {
Ok(result) => Ok(result),
Err(err) => {
if !catch_panic {
resume_unwind(err);
}
let err = extract_panic_message(&*err);
error!(?err, "task panicked");
Err(Error::Exited)
}
};
let _ = sender.send(result);
}
};
let abortable = Abortable::new(wrapped, abort_registration);
(
abortable.map(move |_| {
for handle in children.lock().unwrap().drain(..) {
handle.abort();
}
}),
Self {
aborter: Some(aborter),
receiver,
running,
once,
},
)
}
pub(crate) fn init_blocking<F>(f: F, running: Gauge, catch_panic: bool) -> (impl FnOnce(), Self)
where
F: FnOnce() -> T + Send + 'static,
{
running.inc();
let once = Arc::new(Once::new());
let (sender, receiver) = oneshot::channel();
let f = {
let once = once.clone();
let running = running.clone();
move || {
let result = catch_unwind(AssertUnwindSafe(f));
once.call_once(|| {
running.dec();
});
let result = match result {
Ok(value) => Ok(value),
Err(err) => {
if !catch_panic {
resume_unwind(err);
}
let err = extract_panic_message(&*err);
error!(?err, "task panicked");
Err(Error::Exited)
}
};
let _ = sender.send(result);
}
};
(
f,
Self {
aborter: None,
receiver,
running,
once,
},
)
}
pub fn abort(&self) {
let Some(aborter) = &self.aborter else {
return;
};
aborter.abort();
self.once.call_once(|| {
self.running.dec();
});
}
pub(crate) fn abort_handle(&self) -> Option<AbortHandle> {
self.aborter.clone()
}
}
impl<T> Future for Handle<T>
where
T: Send + 'static,
{
type Output = Result<T, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
match Pin::new(&mut self.receiver).poll(cx) {
Poll::Ready(Ok(Ok(value))) => {
self.once.call_once(|| {
self.running.dec();
});
Poll::Ready(Ok(value))
}
Poll::Ready(Ok(Err(err))) => {
self.once.call_once(|| {
self.running.dec();
});
Poll::Ready(Err(err))
}
Poll::Ready(Err(_)) => {
self.once.call_once(|| {
self.running.dec();
});
Poll::Ready(Err(Error::Closed))
}
Poll::Pending => Poll::Pending,
}
}
}