use std::{
pin::Pin,
task::{Context, Poll},
};
use futures_util::FutureExt;
use pin_project_lite::pin_project;
use synchrony::unsync::event::EventListener;
use crate::{
CancelToken,
future::Ext,
waker::{ExtWaker, with_ext},
};
pin_project! {
pub struct WithCancel<F: ?Sized> {
cancel: CancelToken,
#[pin]
future: F,
}
}
pin_project! {
pub struct WithCancelFailFast<F: ?Sized> {
listen: EventListener,
#[pin]
future: WithCancel<F>,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)]
pub struct Cancelled;
impl<F: ?Sized> WithCancel<F> {
pub fn new(future: F, cancel: CancelToken) -> Self
where
F: Sized,
{
Self { cancel, future }
}
}
impl<F> WithCancel<F> {
pub fn fail_fast(self) -> WithCancelFailFast<F> {
let listen = self.cancel.listen();
WithCancelFailFast {
listen,
future: self,
}
}
}
impl<F> WithCancelFailFast<F> {
pub fn fail_slow(self) -> WithCancel<F> {
self.future
}
}
impl std::fmt::Display for Cancelled {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Cancelled")
}
}
impl std::error::Error for Cancelled {}
impl<F: ?Sized> Future for WithCancel<F>
where
F: Future,
{
type Output = F::Output;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
with_ext(cx.waker(), |waker, ext: &Ext| {
let ext = ext.with_cancel(this.cancel);
ExtWaker::new(waker, &ext).poll(this.future)
})
}
}
impl<F: ?Sized> Future for WithCancelFailFast<F>
where
F: Future,
{
type Output = Result<F::Output, Cancelled>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let mut this = self.project();
if this.listen.poll_unpin(cx).is_ready() {
return Poll::Ready(Err(Cancelled));
}
this.future.poll_unpin(cx).map(Ok)
}
}