use std::future::Future;
use std::pin::Pin;
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use std::task::{Context, Poll, Waker};
struct CancelInner {
cancelled: AtomicBool,
wakers: Mutex<Vec<Waker>>,
}
#[derive(Clone)]
pub struct CancellationToken(Arc<CancelInner>);
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
impl CancellationToken {
pub fn new() -> Self {
Self(Arc::new(CancelInner {
cancelled: AtomicBool::new(false),
wakers: Mutex::new(Vec::new()),
}))
}
pub fn cancel(&self) {
if !self.0.cancelled.swap(true, Ordering::SeqCst) {
let mut wakers = self.0.wakers.lock().expect("cancel waker lock poisoned");
for waker in wakers.drain(..) {
waker.wake();
}
}
}
pub fn is_cancelled(&self) -> bool {
self.0.cancelled.load(Ordering::SeqCst)
}
pub fn cancelled(&self) -> Cancelled<'_> {
Cancelled { token: self }
}
}
impl std::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationToken")
.field("cancelled", &self.is_cancelled())
.finish()
}
}
pub struct Cancelled<'a> {
token: &'a CancellationToken,
}
impl Future for Cancelled<'_> {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.token.is_cancelled() {
return Poll::Ready(());
}
let mut wakers = self
.token
.0
.wakers
.lock()
.expect("cancel waker lock poisoned");
if self.token.is_cancelled() {
return Poll::Ready(());
}
if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
wakers.push(cx.waker().clone());
}
Poll::Pending
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn cancel_flips_flag() {
let t = CancellationToken::new();
assert!(!t.is_cancelled());
t.cancel();
assert!(t.is_cancelled());
}
#[test]
fn cancel_is_idempotent() {
let t = CancellationToken::new();
t.cancel();
t.cancel(); assert!(t.is_cancelled());
}
#[test]
fn clones_share_state() {
let t = CancellationToken::new();
let t2 = t.clone();
t.cancel();
assert!(t2.is_cancelled());
}
#[tokio::test]
async fn cancelled_future_resolves_after_cancel() {
let t = CancellationToken::new();
let t2 = t.clone();
tokio::spawn(async move {
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
t2.cancel();
});
t.cancelled().await;
assert!(t.is_cancelled());
}
#[tokio::test]
async fn cancelled_returns_immediately_if_already_cancelled() {
let t = CancellationToken::new();
t.cancel();
t.cancelled().await;
}
}