const NO_EXCEPTION: u8 = 0;
const THROWING: u8 = 1;
const THROWN: u8 = 2;
const MOVED: u8 = 3;
pub struct ExceptionContext<E> {
status: core::sync::atomic::AtomicU8,
exception: core::cell::UnsafeCell<core::mem::MaybeUninit<E>>,
}
unsafe impl<E: Send> Send for ExceptionContext<E> {}
unsafe impl<E: Send> Sync for ExceptionContext<E> {}
impl<E> Drop for ExceptionContext<E> {
fn drop(&mut self) {
if *self.status.get_mut() == THROWN {
let e = unsafe { self.exception.get().read().assume_init() };
drop(e)
}
}
}
impl<E> Default for ExceptionContext<E> {
fn default() -> Self {
Self {
status: 0.into(),
exception: core::cell::UnsafeCell::new(core::mem::MaybeUninit::uninit()),
}
}
}
impl<E> ExceptionContext<E> {
pub fn new() -> Self {
core::default::Default::default()
}
pub async fn throw(&self, exception: E) -> ! {
if self
.status
.compare_exchange(
NO_EXCEPTION,
THROWING,
core::sync::atomic::Ordering::Relaxed,
core::sync::atomic::Ordering::Relaxed,
)
.is_err()
{
panic!("`throw` calls more than once")
}
unsafe { (&mut *self.exception.get()).write(exception) };
self.status
.store(THROWN, core::sync::atomic::Ordering::Release);
core::future::pending().await
}
pub fn catching<'a, Fu: core::future::Future, F: Fn(&'a Self) -> Fu>(
&'a self,
f: F,
) -> Catching<'a, E, Fu> {
Catching {
ctx: self,
future: f(self),
}
}
fn try_take_exception(&self) -> Option<E> {
if self
.status
.compare_exchange(
THROWN,
MOVED,
core::sync::atomic::Ordering::Acquire,
core::sync::atomic::Ordering::Relaxed,
)
.is_ok()
{
Some(unsafe { self.exception.get().read().assume_init() })
} else {
None
}
}
}
pin_project_lite::pin_project! {
pub struct Catching<'a, E, F> {
ctx: &'a ExceptionContext<E>,
#[pin]
future: F,
}
}
impl<'a, E, F: core::future::Future> core::future::Future for Catching<'a, E, F> {
type Output = Result<F::Output, E>;
fn poll(
self: core::pin::Pin<&mut Self>,
cx: &mut core::task::Context<'_>,
) -> core::task::Poll<Self::Output> {
let this = self.project();
let p = this.future.poll(cx);
if let Some(exception) = this.ctx.try_take_exception() {
core::task::Poll::Ready(Err(exception))
} else {
p.map(Ok)
}
}
}