commonware_runtime/
utils.rsuse crate::Error;
#[cfg(test)]
use crate::{Runner, Spawner};
#[cfg(test)]
use futures::stream::{FuturesUnordered, StreamExt};
use futures::{
channel::oneshot,
stream::{AbortHandle, Abortable},
FutureExt,
};
use prometheus_client::metrics::gauge::Gauge;
use std::{
any::Any,
future::Future,
panic::{resume_unwind, AssertUnwindSafe},
pin::Pin,
sync::{Arc, Once},
task::{Context, Poll},
};
use tracing::error;
pub async fn reschedule() {
struct Reschedule {
yielded: bool,
}
impl Future for Reschedule {
type Output = ();
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.yielded {
Poll::Ready(())
} else {
self.yielded = true;
cx.waker().wake_by_ref();
Poll::Pending
}
}
}
Reschedule { yielded: false }.await
}
fn extract_panic_message(err: &(dyn Any + Send)) -> String {
if let Some(s) = err.downcast_ref::<&str>() {
s.to_string()
} else if let Some(s) = err.downcast_ref::<String>() {
s.clone()
} else {
format!("{:?}", err)
}
}
pub struct Handle<T>
where
T: Send + 'static,
{
aborter: AbortHandle,
receiver: oneshot::Receiver<Result<T, Error>>,
running: Gauge,
once: Arc<Once>,
}
impl<T> Handle<T>
where
T: Send + 'static,
{
pub(crate) fn init<F>(
f: F,
running: Gauge,
catch_panic: bool,
) -> (impl Future<Output = ()>, Self)
where
F: Future<Output = T> + Send + 'static,
{
running.inc();
let once = Arc::new(Once::new());
let (sender, receiver) = oneshot::channel();
let (aborter, abort_registration) = AbortHandle::new_pair();
let wrapped = {
let once = once.clone();
let running = running.clone();
async move {
let result = AssertUnwindSafe(f).catch_unwind().await;
once.call_once(|| {
running.dec();
});
let result = match result {
Ok(result) => Ok(result),
Err(err) => {
if !catch_panic {
resume_unwind(err);
}
let err = extract_panic_message(&*err);
error!(?err, "task panicked");
Err(Error::Exited)
}
};
let _ = sender.send(result);
}
};
let abortable = Abortable::new(wrapped, abort_registration);
(
abortable.map(|_| ()),
Self {
aborter,
receiver,
running,
once,
},
)
}
pub fn abort(&self) {
self.aborter.abort();
self.once.call_once(|| {
self.running.dec();
});
}
}
impl<T> Future for Handle<T>
where
T: Send + 'static,
{
type Output = Result<T, Error>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
Pin::new(&mut self.receiver)
.poll(cx)
.map(|res| res.map_err(|_| Error::Closed).and_then(|r| r))
}
}
#[cfg(test)]
async fn task(i: usize) -> usize {
for _ in 0..5 {
reschedule().await;
}
i
}
#[cfg(test)]
pub fn run_tasks(tasks: usize, runner: impl Runner, context: impl Spawner) -> Vec<usize> {
runner.start(async move {
let mut handles = FuturesUnordered::new();
for i in 0..tasks - 1 {
handles.push(context.spawn("test", task(i)));
}
handles.push(context.spawn("test", task(tasks - 1)));
let mut outputs = Vec::new();
while let Some(result) = handles.next().await {
outputs.push(result.unwrap());
}
assert_eq!(outputs.len(), tasks);
outputs
})
}