use std::{
future::Future,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use futures::future::BoxFuture;
use parking_lot::Mutex;
use tokio::task::JoinHandle;
pub struct CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
done: bool,
inner: Option<BoxFuture<'static, F::Output>>,
receiver: Arc<Mutex<Option<JoinHandle<F::Output>>>>,
}
impl<F> Drop for CancellationSafeFuture<F>
where
F: Future + Send + 'static,
F::Output: Send,
{
fn drop(&mut self) {
if !self.done {
let mut receiver = self.receiver.lock();
assert!(receiver.is_none());
if Arc::strong_count(&self.receiver) > 1 {
let inner = self.inner.take().expect("Double-drop?");
let handle = tokio::task::spawn(inner);
*receiver = Some(handle);
}
}
}
}
impl<F> CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
pub fn new(fut: F, receiver: Arc<Mutex<Option<JoinHandle<F::Output>>>>) -> Self {
Self {
done: false,
inner: Some(Box::pin(fut)),
receiver,
}
}
}
impl<F> Future for CancellationSafeFuture<F>
where
F: Future + Send,
F::Output: Send,
{
type Output = F::Output;
fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
assert!(!self.done, "Polling future that already returned");
match self.inner.as_mut().expect("not dropped").as_mut().poll(cx) {
Poll::Ready(res) => {
self.done = true;
Poll::Ready(res)
}
Poll::Pending => Poll::Pending,
}
}
}
#[cfg(test)]
mod tests {
use std::{
sync::atomic::{AtomicBool, Ordering},
time::Duration,
};
use tokio::sync::Barrier;
use super::*;
#[tokio::test]
async fn test_happy_path() {
let done = Arc::new(AtomicBool::new(false));
let done_captured = Arc::clone(&done);
let receiver = Default::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.store(true, Ordering::SeqCst);
},
receiver,
);
fut.await;
assert!(done.load(Ordering::SeqCst));
}
#[tokio::test]
async fn test_cancel_future() {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);
let receiver = Default::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
Arc::clone(&receiver),
);
drop(fut);
tokio::time::timeout(Duration::from_secs(5), done.wait())
.await
.unwrap();
}
#[tokio::test]
async fn test_receiver_gone() {
let done = Arc::new(Barrier::new(2));
let done_captured = Arc::clone(&done);
let receiver = Default::default();
let fut = CancellationSafeFuture::new(
async move {
done_captured.wait().await;
},
receiver,
);
drop(fut);
assert_eq!(Arc::strong_count(&done), 1);
}
}