use futures::future::FusedFuture;
use futures_intrusive::channel::shared::{state_broadcast_channel, StateReceiver, StateSender};
use futures_intrusive::sync::ManualResetEvent;
use pin_project::pin_project;
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use crate::with_token::WithToken;
use crate::{cancellation, Canceled, Cancellation, Join, Token};
fn make_join() -> (Arc<Join>, Arc<ManualResetEvent>) {
let event = Arc::new(ManualResetEvent::new(false));
(
Arc::new(Join {
event: event.clone(),
}),
event,
)
}
enum State<Fut>
where
Fut: Future,
{
Running {
cancel: Arc<StateReceiver<bool>>,
join: Arc<Join>,
},
Joining(Fut::Output),
Done,
}
#[must_use = "futures do nothing unless polled"]
#[pin_project]
pub struct ScopeFuture<Fut, Joiner>
where
Fut: Future,
{
state: State<Fut>,
#[pin]
cancellation: Cancellation,
#[pin]
inner: WithToken<Fut>,
#[pin]
joiner: Joiner,
cancel_sender: Option<StateSender<bool>>,
}
impl<Fut, Joiner> ScopeFuture<Fut, Joiner>
where
Fut: Future,
{
pub fn cancel(self: Pin<&mut Self>) {
if let Some(cancel_sender) = &self.cancel_sender {
let _ = cancel_sender.send(true);
}
}
pub fn force_cancel(self: Pin<&mut Self>) {
self.project().cancel_sender.take();
}
}
impl<Fut, Joiner> Future for ScopeFuture<Fut, Joiner>
where
Fut: Future,
Joiner: Future<Output = ()>,
{
type Output = Result<Fut::Output, Canceled>;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll<Self::Output> {
let this = self.as_mut().project();
match this.cancellation.poll(cx) {
Poll::Ready(Some(Canceled::Graceful)) => self.as_mut().cancel(),
Poll::Ready(Some(Canceled::Forced)) => {
let ret = match std::mem::replace(this.state, State::Done) {
State::Joining(v) => Poll::Ready(Ok(v)),
_ => Poll::Ready(Err(Canceled::Forced)),
};
self.force_cancel();
return ret;
}
_ => {}
}
let this = self.project();
match this.state {
State::Done => panic!("poll after completion or forced cancellation"),
State::Running { cancel, join } => {
let token = Token {
cancel: cancel.clone(),
join: Arc::downgrade(join),
};
match this.inner.poll(cx, token) {
Poll::Pending => return Poll::Pending,
Poll::Ready(v) => {
*this.state = State::Joining(v);
}
}
}
_ => {}
}
match this.joiner.poll(cx) {
Poll::Pending => Poll::Pending,
Poll::Ready(()) => match std::mem::replace(this.state, State::Done) {
State::Joining(v) => Poll::Ready(Ok(v)),
_ => unreachable!(),
},
}
}
}
impl<Fut, Joiner> FusedFuture for ScopeFuture<Fut, Joiner>
where
Fut: Future,
Joiner: Future<Output = ()>,
{
fn is_terminated(&self) -> bool {
match self.state {
State::Done => true,
_ => false,
}
}
}
pub fn scope<Fut>(
inner: Fut,
) -> ScopeFuture<impl Future<Output = Fut::Output>, impl Future<Output = ()>>
where
Fut: Future,
{
let (cancel_sender, receiver) = state_broadcast_channel();
let (join, event) = make_join();
let joiner = async move {
event.wait().await;
};
ScopeFuture {
state: State::Running {
cancel: Arc::new(receiver),
join,
},
cancellation: cancellation(),
inner: WithToken::new(inner),
joiner,
cancel_sender: Some(cancel_sender),
}
}