use crate::task::AtomicWaker;
use futures_core::future::Future;
use futures_core::task::{Context, Poll};
use pin_utils::unsafe_pinned;
use core::fmt;
use core::pin::Pin;
use core::sync::atomic::{AtomicBool, Ordering};
use alloc::sync::Arc;
#[derive(Debug, Clone)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Abortable<Fut> {
future: Fut,
inner: Arc<AbortInner>,
}
impl<Fut: Unpin> Unpin for Abortable<Fut> {}
impl<Fut> Abortable<Fut> where Fut: Future {
unsafe_pinned!(future: Fut);
pub fn new(future: Fut, reg: AbortRegistration) -> Self {
Abortable {
future,
inner: reg.inner,
}
}
}
#[derive(Debug)]
pub struct AbortRegistration {
inner: Arc<AbortInner>,
}
#[derive(Debug, Clone)]
pub struct AbortHandle {
inner: Arc<AbortInner>,
}
impl AbortHandle {
pub fn new_pair() -> (Self, AbortRegistration) {
let inner = Arc::new(AbortInner {
waker: AtomicWaker::new(),
cancel: AtomicBool::new(false),
});
(
AbortHandle {
inner: inner.clone(),
},
AbortRegistration {
inner,
},
)
}
}
#[derive(Debug)]
struct AbortInner {
waker: AtomicWaker,
cancel: AtomicBool,
}
pub fn abortable<Fut>(future: Fut) -> (Abortable<Fut>, AbortHandle)
where Fut: Future
{
let (handle, reg) = AbortHandle::new_pair();
(
Abortable::new(future, reg),
handle,
)
}
#[derive(Copy, Clone, Debug, Eq, PartialEq)]
pub struct Aborted;
impl fmt::Display for Aborted {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "`Abortable` future has been aborted")
}
}
#[cfg(feature = "std")]
impl std::error::Error for Aborted {}
impl<Fut> Future for Abortable<Fut> where Fut: Future {
type Output = Result<Fut::Output, Aborted>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted))
}
if let Poll::Ready(x) = self.as_mut().future().poll(cx) {
return Poll::Ready(Ok(x))
}
self.inner.waker.register(cx.waker());
if self.inner.cancel.load(Ordering::Relaxed) {
return Poll::Ready(Err(Aborted))
}
Poll::Pending
}
}
impl AbortHandle {
pub fn abort(&self) {
self.inner.cancel.store(true, Ordering::Relaxed);
self.inner.waker.wake();
}
}