#![allow(dead_code)]
use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::Arc;
use tokio::sync::Notify;
struct Inner {
cancelled: AtomicBool,
notify: Notify,
}
#[derive(Clone)]
pub struct CancellationFlag {
inner: Arc<Inner>,
}
impl std::fmt::Debug for CancellationFlag {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationFlag")
.field("cancelled", &self.is_cancelled())
.finish()
}
}
impl Default for CancellationFlag {
fn default() -> Self {
Self::new()
}
}
impl CancellationFlag {
pub fn new() -> Self {
Self {
inner: Arc::new(Inner {
cancelled: AtomicBool::new(false),
notify: Notify::new(),
}),
}
}
pub fn cancel(&self) {
let was = self.inner.cancelled.swap(true, Ordering::Release);
if !was {
self.inner.notify.notify_waiters();
}
}
pub fn is_cancelled(&self) -> bool {
self.inner.cancelled.load(Ordering::Acquire)
}
pub async fn cancelled(&self) {
if self.is_cancelled() {
return;
}
let notified = self.inner.notify.notified();
if self.is_cancelled() {
return;
}
notified.await;
}
}
pub struct CancelOnDrop {
flag: CancellationFlag,
}
impl CancelOnDrop {
pub fn new(flag: CancellationFlag) -> Self {
Self { flag }
}
pub fn flag(&self) -> &CancellationFlag {
&self.flag
}
}
impl Drop for CancelOnDrop {
fn drop(&mut self) {
self.flag.cancel();
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::time::Duration;
#[test]
fn flag_starts_not_cancelled() {
let f = CancellationFlag::new();
assert!(!f.is_cancelled());
}
#[test]
fn cancel_marks_flag_cancelled() {
let f = CancellationFlag::new();
f.cancel();
assert!(f.is_cancelled());
}
#[test]
fn cancel_is_idempotent() {
let f = CancellationFlag::new();
f.cancel();
f.cancel();
f.cancel();
assert!(f.is_cancelled());
}
#[test]
fn clones_share_state() {
let f = CancellationFlag::new();
let f2 = f.clone();
let f3 = f.clone();
assert!(!f2.is_cancelled());
assert!(!f3.is_cancelled());
f.cancel();
assert!(f2.is_cancelled());
assert!(f3.is_cancelled());
}
#[tokio::test]
async fn cancelled_future_resolves_on_cancel() {
let f = CancellationFlag::new();
let f2 = f.clone();
let h = tokio::spawn(async move {
f2.cancelled().await;
true
});
tokio::time::sleep(Duration::from_millis(5)).await;
assert!(!h.is_finished());
f.cancel();
let v = h.await.unwrap();
assert!(v);
}
#[tokio::test]
async fn cancelled_future_resolves_immediately_when_already_cancelled() {
let f = CancellationFlag::new();
f.cancel();
tokio::time::timeout(Duration::from_millis(50), f.cancelled())
.await
.expect("cancelled() future must resolve immediately");
}
#[tokio::test]
async fn multiple_cancelled_awaits_all_wake_on_cancel() {
let f = CancellationFlag::new();
let n = 4;
let mut handles = Vec::new();
for _ in 0..n {
let f2 = f.clone();
handles.push(tokio::spawn(async move {
f2.cancelled().await;
}));
}
tokio::time::sleep(Duration::from_millis(5)).await;
f.cancel();
for h in handles {
tokio::time::timeout(Duration::from_millis(100), h)
.await
.expect("await completes within budget")
.expect("join ok");
}
}
#[test]
fn cancel_on_drop_fires_on_scope_exit() {
let f = CancellationFlag::new();
{
let _guard = CancelOnDrop::new(f.clone());
assert!(!f.is_cancelled());
}
assert!(f.is_cancelled());
}
#[test]
fn cancel_on_drop_fires_on_panic() {
let f = CancellationFlag::new();
let f2 = f.clone();
let _ = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = CancelOnDrop::new(f2);
panic!("synthetic panic");
}));
assert!(f.is_cancelled());
}
#[test]
fn cancel_on_drop_explicit_drop_fires() {
let f = CancellationFlag::new();
let guard = CancelOnDrop::new(f.clone());
assert!(!f.is_cancelled());
drop(guard);
assert!(f.is_cancelled());
}
#[test]
fn cancel_on_drop_flag_borrowable_during_guard_lifetime() {
let f = CancellationFlag::new();
let guard = CancelOnDrop::new(f.clone());
assert!(!guard.flag().is_cancelled());
guard.flag().cancel();
assert!(f.is_cancelled());
}
#[tokio::test]
async fn cancel_on_drop_async_consumer_sees_cancellation() {
let f = CancellationFlag::new();
let f2 = f.clone();
let consumer = tokio::spawn(async move {
f2.cancelled().await;
"cancelled"
});
let producer = tokio::spawn(async move {
let _guard = CancelOnDrop::new(f);
tokio::time::sleep(Duration::from_millis(10)).await;
});
producer.await.unwrap();
let result =
tokio::time::timeout(Duration::from_millis(100), consumer)
.await
.expect("consumer completes after cancel")
.expect("join ok");
assert_eq!(result, "cancelled");
}
}