use core::any::Any;
use core::future::Future;
use core::mem;
use core::pin::Pin;
use core::task::{Context, Poll};
use std::panic::{AssertUnwindSafe, catch_unwind};
use alloc::boxed::Box;
use alloc::vec::Vec;
pub(crate) type CaughtUnwind<T> = Result<T, Box<dyn Any + Send + 'static>>;
pub(crate) struct CatchUnwind<F: Future> {
inner: Option<Pin<Box<F>>>,
}
impl<F: Future> CatchUnwind<F> {
#[inline]
pub(crate) fn new(future: F) -> Self {
Self {
inner: Some(Box::pin(future)),
}
}
}
impl<F: Future> Future for CatchUnwind<F> {
type Output = CaughtUnwind<F::Output>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let Some(inner) = self.inner.as_mut() else {
return Poll::Pending;
};
let inner_pin = inner.as_mut();
match catch_unwind(AssertUnwindSafe(|| inner_pin.poll(cx))) {
Ok(Poll::Ready(out)) => {
self.inner = None;
Poll::Ready(Ok(out))
}
Ok(Poll::Pending) => Poll::Pending,
Err(payload) => {
self.inner = None;
Poll::Ready(Err(payload))
}
}
}
}
enum JoinSlot<F: Future> {
Pending(Pin<Box<F>>),
Done(F::Output),
}
pub(crate) struct JoinAll<F: Future> {
slots: Vec<JoinSlot<F>>,
remaining: usize,
}
impl<F: Future> JoinAll<F> {
#[inline]
pub(crate) fn new<I>(futures: I) -> Self
where
I: IntoIterator<Item = F>,
{
let slots: Vec<JoinSlot<F>> = futures
.into_iter()
.map(|f| JoinSlot::Pending(Box::pin(f)))
.collect();
let remaining = slots.len();
Self { slots, remaining }
}
}
impl<F: Future> Future for JoinAll<F>
where
F::Output: Unpin,
{
type Output = Vec<F::Output>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.get_mut();
for slot in &mut this.slots {
if let JoinSlot::Pending(ref mut fut) = *slot {
match fut.as_mut().poll(cx) {
Poll::Ready(value) => {
*slot = JoinSlot::Done(value);
this.remaining -= 1;
}
Poll::Pending => {}
}
}
}
if this.remaining > 0 {
return Poll::Pending;
}
let drained: Vec<F::Output> = mem::take(&mut this.slots)
.into_iter()
.map(|slot| match slot {
JoinSlot::Done(value) => value,
JoinSlot::Pending(_) => {
unreachable!("JoinAll: slot was not Done despite remaining=0")
}
})
.collect();
Poll::Ready(drained)
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
use super::{CatchUnwind, JoinAll};
use core::future::ready;
#[tokio::test]
async fn catch_unwind_passes_through_normal_completion() {
let value = CatchUnwind::new(async { 7_u32 }).await;
assert_eq!(value.unwrap(), 7);
}
#[tokio::test]
async fn catch_unwind_captures_panic_payload() {
let outcome = CatchUnwind::new(async {
panic!("boom");
})
.await;
let payload = outcome.unwrap_err();
assert_eq!(*payload.downcast_ref::<&'static str>().unwrap(), "boom");
}
#[tokio::test]
async fn join_all_completes_with_results_in_order() {
let futures = vec![ready(1), ready(2), ready(3)];
let results = JoinAll::new(futures).await;
assert_eq!(results, vec![1, 2, 3]);
}
#[tokio::test]
async fn join_all_with_zero_futures_is_immediately_ready() {
let futures: Vec<core::future::Ready<()>> = Vec::new();
let results = JoinAll::new(futures).await;
assert!(results.is_empty());
}
}