use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::task::{Context, Poll, Waker};
struct Inner {
cancelled: AtomicBool,
waiter_head: AtomicPtr<WaiterNode>,
child_head: AtomicPtr<ChildNode>,
}
struct WaiterNode {
waker: Waker,
next: *mut WaiterNode,
}
struct ChildNode {
inner: Arc<Inner>,
next: *mut ChildNode,
}
unsafe impl Send for WaiterNode {}
unsafe impl Send for ChildNode {}
impl Inner {
fn new() -> Arc<Self> {
Arc::new(Self {
cancelled: AtomicBool::new(false),
waiter_head: AtomicPtr::new(std::ptr::null_mut()),
child_head: AtomicPtr::new(std::ptr::null_mut()),
})
}
fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Acquire)
}
fn cancel(&self) {
self.cancelled.store(true, Ordering::Release);
let mut waiter = self
.waiter_head
.swap(std::ptr::null_mut(), Ordering::AcqRel);
while !waiter.is_null() {
let node = unsafe { Box::from_raw(waiter) };
waiter = node.next;
node.waker.wake();
}
let mut child = self.child_head.swap(std::ptr::null_mut(), Ordering::AcqRel);
while !child.is_null() {
let node = unsafe { Box::from_raw(child) };
child = node.next;
node.inner.cancel();
}
}
fn add_child(&self, child: &Arc<Inner>) {
let node = Box::into_raw(Box::new(ChildNode {
inner: child.clone(),
next: std::ptr::null_mut(),
}));
loop {
if self.is_cancelled() {
let node = unsafe { Box::from_raw(node) };
node.inner.cancel();
return;
}
let head = self.child_head.load(Ordering::Acquire);
unsafe { (*node).next = head };
if self
.child_head
.compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if self.is_cancelled() {
self.cancel();
}
return;
}
}
}
fn register(&self, waker: &Waker) -> bool {
if self.is_cancelled() {
return true;
}
let node = Box::into_raw(Box::new(WaiterNode {
waker: waker.clone(),
next: std::ptr::null_mut(),
}));
loop {
if self.is_cancelled() {
unsafe { drop(Box::from_raw(node)) };
return true;
}
let head = self.waiter_head.load(Ordering::Acquire);
unsafe { (*node).next = head };
if self
.waiter_head
.compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if self.is_cancelled() {
self.cancel(); return true;
}
return false;
}
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
let mut waiter = *self.waiter_head.get_mut();
while !waiter.is_null() {
let node = unsafe { Box::from_raw(waiter) };
waiter = node.next;
}
let mut child = *self.child_head.get_mut();
while !child.is_null() {
let node = unsafe { Box::from_raw(child) };
child = node.next;
}
}
}
#[derive(Clone)]
pub struct CancellationToken {
inner: Arc<Inner>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
inner: Inner::new(),
}
}
pub fn child(&self) -> Self {
let child = Self {
inner: Inner::new(),
};
self.inner.add_child(&child.inner);
child
}
pub fn cancel(&self) {
self.inner.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
pub fn drop_guard(self) -> DropGuard {
DropGuard { token: Some(self) }
}
pub fn cancelled(&self) -> Cancelled {
Cancelled {
inner: self.inner.clone(),
}
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationToken")
.field("cancelled", &self.is_cancelled())
.finish()
}
}
pub struct Cancelled {
inner: Arc<Inner>,
}
impl Future for Cancelled {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.inner.register(cx.waker()) {
Poll::Ready(())
} else {
Poll::Pending
}
}
}
pub struct DropGuard {
token: Option<CancellationToken>,
}
impl DropGuard {
pub fn disarm(mut self) -> CancellationToken {
self.token.take().expect("DropGuard already disarmed")
}
}
impl Drop for DropGuard {
fn drop(&mut self) {
if let Some(ref token) = self.token {
token.cancel();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{RawWaker, RawWakerVTable};
fn noop_waker() -> Waker {
fn noop(_: *const ()) {}
fn noop_clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
f.poll(&mut cx)
}
#[test]
fn not_cancelled_by_default() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancel_sets_flag() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn cancel_is_idempotent() {
let token = CancellationToken::new();
token.cancel();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn clone_shares_state() {
let token = CancellationToken::new();
let clone = token.clone();
token.cancel();
assert!(clone.is_cancelled());
}
#[test]
fn child_sees_parent_cancel() {
let parent = CancellationToken::new();
let child = parent.child();
assert!(!child.is_cancelled());
parent.cancel();
assert!(child.is_cancelled());
}
#[test]
fn grandchild_sees_ancestor_cancel() {
let root = CancellationToken::new();
let child = root.child();
let grandchild = child.child();
assert!(!grandchild.is_cancelled());
root.cancel();
assert!(grandchild.is_cancelled());
}
#[test]
fn child_cancel_does_not_affect_parent() {
let parent = CancellationToken::new();
let child = parent.child();
child.cancel();
assert!(child.is_cancelled());
assert!(!parent.is_cancelled());
}
#[test]
fn cancelled_future_ready_when_cancelled() {
let token = CancellationToken::new();
token.cancel();
let mut fut = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn cancelled_future_pending_then_ready() {
let token = CancellationToken::new();
let mut fut = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
token.cancel();
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn child_cancelled_future_from_parent() {
let parent = CancellationToken::new();
let child = parent.child();
let mut fut = std::pin::pin!(child.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
parent.cancel();
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn multiple_waiters() {
let token = CancellationToken::new();
let mut fut1 = std::pin::pin!(token.cancelled());
let mut fut2 = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut1.as_mut()), Poll::Pending));
assert!(matches!(poll_once(fut2.as_mut()), Poll::Pending));
token.cancel();
assert!(matches!(poll_once(fut1.as_mut()), Poll::Ready(())));
assert!(matches!(poll_once(fut2.as_mut()), Poll::Ready(())));
}
#[test]
fn cross_thread_cancel() {
let token = CancellationToken::new();
let clone = token.clone();
let handle = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
clone.cancel();
});
while !token.is_cancelled() {
std::hint::spin_loop();
}
handle.join().unwrap();
}
#[test]
fn drop_guard_cancels_on_drop() {
let token = CancellationToken::new();
let clone = token.clone();
{
let _guard = token.drop_guard();
assert!(!clone.is_cancelled());
}
assert!(clone.is_cancelled());
}
#[test]
fn drop_guard_disarm() {
let token = CancellationToken::new();
let clone = token.clone();
let guard = token.drop_guard();
let recovered = guard.disarm();
drop(recovered);
assert!(!clone.is_cancelled());
}
#[test]
fn drop_guard_on_panic() {
let token = CancellationToken::new();
let clone = token.clone();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = token.drop_guard();
panic!("simulated panic");
}));
assert!(result.is_err());
assert!(clone.is_cancelled());
}
#[test]
fn send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CancellationToken>();
assert_send_sync::<Cancelled>();
}
#[test]
fn drop_without_cancel_cleans_up() {
let token = CancellationToken::new();
let _child = token.child();
let mut fut = std::pin::pin!(token.cancelled());
let _ = poll_once(fut.as_mut()); }
#[test]
fn many_children() {
let parent = CancellationToken::new();
let children: Vec<_> = (0..100).map(|_| parent.child()).collect();
parent.cancel();
for child in &children {
assert!(child.is_cancelled());
}
}
#[test]
fn child_created_after_parent_cancelled() {
let parent = CancellationToken::new();
parent.cancel();
let child = parent.child();
assert!(child.is_cancelled());
}
}