use std::{
pin::Pin,
sync::{Arc, Mutex},
task::{Context, Poll, Waker},
};
use futures::Future;
use pin_project::pin_project;
#[pin_project]
pub(crate) struct Cancel<F> {
inner: Arc<Mutex<Inner>>,
#[pin]
fut: F,
}
#[derive(Clone, Copy, Debug)]
enum Status {
Pending,
Finished,
Cancelled,
}
struct Inner {
status: Status,
waker: Option<Waker>,
}
#[derive(Clone)]
pub(crate) struct CancelHandle {
inner: Arc<Mutex<Inner>>,
}
impl<F> Cancel<F> {
pub(crate) fn new(fut: F) -> (CancelHandle, Cancel<F>) {
let inner = Arc::new(Mutex::new(Inner {
status: Status::Pending,
waker: None,
}));
let handle = CancelHandle {
inner: inner.clone(),
};
let future = Cancel { inner, fut };
(handle, future)
}
}
impl CancelHandle {
pub(crate) fn cancel(&self) -> Result<(), CannotCancel> {
let mut inner = self.inner.lock().expect("poisoned lock");
match inner.status {
Status::Pending => inner.status = Status::Cancelled,
Status::Finished => return Err(CannotCancel::Finished),
Status::Cancelled => return Err(CannotCancel::Cancelled),
}
if let Some(waker) = inner.waker.take() {
drop(inner); waker.wake();
}
Ok(())
}
}
#[derive(thiserror::Error, Clone, Debug)]
#[error("Future was cancelled")]
pub(crate) struct Cancelled;
#[derive(thiserror::Error, Clone, Debug)]
pub(crate) enum CannotCancel {
#[error("Already cancelled")]
Cancelled,
#[error("Already finished")]
Finished,
}
impl<F: Future> Future for Cancel<F> {
type Output = Result<F::Output, Cancelled>;
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
let this = self.project();
let mut inner = this.inner.lock().expect("lock poisoned");
match inner.status {
Status::Pending => {}
Status::Finished => {
}
Status::Cancelled => return Poll::Ready(Err(Cancelled)),
}
match this.fut.poll(cx) {
Poll::Ready(val) => {
inner.status = Status::Finished;
Poll::Ready(Ok(val))
}
Poll::Pending => {
if let Some(existing_waker) = &mut inner.waker {
existing_waker.clone_from(cx.waker());
} else {
inner.waker = Some(cx.waker().clone());
}
Poll::Pending
}
}
}
}
#[cfg(test)]
mod test {
#![allow(clippy::bool_assert_comparison)]
#![allow(clippy::clone_on_copy)]
#![allow(clippy::dbg_macro)]
#![allow(clippy::mixed_attributes_style)]
#![allow(clippy::print_stderr)]
#![allow(clippy::print_stdout)]
#![allow(clippy::single_char_pattern)]
#![allow(clippy::unwrap_used)]
#![allow(clippy::unchecked_time_subtraction)]
#![allow(clippy::useless_vec)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::string_slice)]
use std::{future, time::Duration};
use super::*;
use futures::{FutureExt as _, StreamExt as _, stream::FuturesUnordered};
use futures_await_test::async_test;
use oneshot_fused_workaround as oneshot;
use tor_basic_utils::RngExt;
use tor_rtcompat::SleepProvider as _;
#[async_test]
async fn not_cancelled() {
let f = futures::future::ready("hello");
let (_h, f) = Cancel::new(f);
assert_eq!(f.await.unwrap(), "hello");
}
#[async_test]
async fn cancelled() {
let f = futures::future::pending::<()>();
let (h, f) = Cancel::new(f);
let (r, ()) = futures::join!(f, async {
h.cancel().unwrap();
});
assert!(matches!(r, Err(Cancelled)));
let (_tx, rx) = oneshot::channel::<()>();
let (h, f) = Cancel::new(rx);
let (r, ()) = futures::join!(f, async {
h.cancel().unwrap();
});
assert!(matches!(r, Err(Cancelled)));
}
#[test]
fn cancelled_or_not() {
tor_rtmock::MockRuntime::test_with_various(|rt| async move {
#[allow(deprecated)] let rt = tor_rtmock::MockSleepRuntime::new(rt);
const N_TRIES: usize = 1024;
const SLEEP_CEIL: Duration = Duration::from_millis(1);
let work_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
let cancel_succeeded = Arc::new(Mutex::new([None; N_TRIES]));
let mut futs = FuturesUnordered::new();
for idx in 0..N_TRIES {
let work_succeeded = Arc::clone(&work_succeeded);
let cancel_succeeded = Arc::clone(&cancel_succeeded);
let rt1 = rt.clone();
let rt2 = rt.clone();
let t1 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
let t2 = rand::rng().gen_range_infallible(..=SLEEP_CEIL);
let work = future::ready(());
let (handle, work) = Cancel::new(work);
let f1 = async move {
rt1.sleep(t1).await;
let r = handle.cancel();
cancel_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
};
let f2 = async move {
rt2.sleep(t2).await;
let r = work.await;
work_succeeded.lock().unwrap()[idx] = Some(r.is_ok());
};
futs.push(f1.boxed());
futs.push(f2.boxed());
}
rt.wait_for(async { while let Some(()) = futs.next().await {} })
.await;
for idx in 0..N_TRIES {
let ws = work_succeeded.lock().unwrap()[idx];
let cs = cancel_succeeded.lock().unwrap()[idx];
match (ws, cs) {
(Some(true), Some(false)) => {}
(Some(false), Some(true)) => {}
_ => panic!("incorrect values {:?}", (idx, ws, cs)),
}
}
});
}
}