use std::cell::UnsafeCell;
use std::future::Future;
use std::marker::PhantomPinned;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
use std::task::{Context, Poll, Waker};
#[inline]
fn spin_lock(lock: &AtomicBool) {
if lock
.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return; }
spin_lock_slow(lock);
}
#[cold]
#[inline(never)]
fn spin_lock_slow(lock: &AtomicBool) {
let mut spins: u32 = 0;
loop {
if lock
.compare_exchange_weak(false, true, Ordering::AcqRel, Ordering::Acquire)
.is_ok()
{
return;
}
if spins < 6 {
for _ in 0..(1u32 << spins) {
std::hint::spin_loop();
}
spins += 1;
} else {
std::thread::yield_now();
}
}
}
#[inline]
fn spin_unlock(lock: &AtomicBool) {
lock.store(false, Ordering::Release);
}
struct SpinGuard<'a> {
lock: &'a AtomicBool,
}
impl<'a> SpinGuard<'a> {
#[inline]
fn new(lock: &'a AtomicBool) -> Self {
spin_lock(lock);
Self { lock }
}
}
impl Drop for SpinGuard<'_> {
#[inline]
fn drop(&mut self) {
spin_unlock(self.lock);
}
}
struct Inner {
cancelled: AtomicBool,
list_lock: AtomicBool,
head: UnsafeCell<*mut WaiterNode>,
child_head: AtomicPtr<ChildNode>,
#[cfg(test)]
race_yield: AtomicBool,
}
struct WaiterNode {
next: UnsafeCell<*mut WaiterNode>,
prev: UnsafeCell<*mut WaiterNode>,
waker: UnsafeCell<Option<Waker>>,
in_list: AtomicBool,
}
impl WaiterNode {
const fn new() -> Self {
Self {
next: UnsafeCell::new(std::ptr::null_mut()),
prev: UnsafeCell::new(std::ptr::null_mut()),
waker: UnsafeCell::new(None),
in_list: AtomicBool::new(false),
}
}
}
unsafe impl Send for WaiterNode {}
unsafe impl Sync for WaiterNode {}
struct ChildNode {
inner: Arc<Inner>,
next: *mut ChildNode,
}
unsafe impl Send for ChildNode {}
unsafe impl Send for Inner {}
unsafe impl Sync for Inner {}
impl Inner {
fn new() -> Arc<Self> {
Arc::new(Self {
cancelled: AtomicBool::new(false),
list_lock: AtomicBool::new(false),
head: UnsafeCell::new(std::ptr::null_mut()),
child_head: AtomicPtr::new(std::ptr::null_mut()),
#[cfg(test)]
race_yield: AtomicBool::new(false),
})
}
fn is_cancelled(&self) -> bool {
self.cancelled.load(Ordering::Acquire)
}
fn cancel(&self) {
self.cancelled.store(true, Ordering::Release);
let mut wakers: Vec<Waker> = Vec::new();
{
let _guard = SpinGuard::new(&self.list_lock);
let mut cur = unsafe { *self.head.get() };
unsafe { *self.head.get() = std::ptr::null_mut() };
while !cur.is_null() {
let next = unsafe { *(*cur).next.get() };
let waker = unsafe { (*(*cur).waker.get()).take() };
unsafe { (*cur).in_list.store(false, Ordering::Release) };
#[cfg(test)]
if self.race_yield.load(Ordering::Relaxed) {
std::thread::yield_now();
}
cur = next;
if let Some(w) = waker {
wakers.push(w);
}
}
}
for w in wakers {
w.wake();
}
let mut child = self.child_head.swap(std::ptr::null_mut(), Ordering::AcqRel);
while !child.is_null() {
let node = unsafe { Box::from_raw(child) };
child = node.next;
node.inner.cancel();
}
}
fn add_child(&self, child: &Arc<Inner>) {
let node = Box::into_raw(Box::new(ChildNode {
inner: child.clone(),
next: std::ptr::null_mut(),
}));
loop {
if self.is_cancelled() {
let node = unsafe { Box::from_raw(node) };
node.inner.cancel();
return;
}
let head = self.child_head.load(Ordering::Acquire);
unsafe { (*node).next = head };
if self
.child_head
.compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
.is_ok()
{
if self.is_cancelled() {
self.cancel();
}
return;
}
}
}
}
impl Drop for Inner {
fn drop(&mut self) {
#[cfg(debug_assertions)]
{
let head = unsafe { *self.head.get() };
debug_assert!(
head.is_null(),
"Inner::Drop with waiter list non-empty — Cancelled futures \
must outlive their Inner via Arc<Inner>; if you see this, \
the list-discipline invariant has been violated"
);
}
let mut child = *self.child_head.get_mut();
while !child.is_null() {
let node = unsafe { Box::from_raw(child) };
child = node.next;
}
}
}
#[derive(Clone)]
pub struct CancellationToken {
inner: Arc<Inner>,
}
impl CancellationToken {
pub fn new() -> Self {
Self {
inner: Inner::new(),
}
}
pub fn child(&self) -> Self {
let child = Self {
inner: Inner::new(),
};
self.inner.add_child(&child.inner);
child
}
pub fn cancel(&self) {
self.inner.cancel();
}
pub fn is_cancelled(&self) -> bool {
self.inner.is_cancelled()
}
pub fn drop_guard(self) -> DropGuard {
DropGuard { token: Some(self) }
}
pub fn cancelled(&self) -> Cancelled {
Cancelled {
inner: self.inner.clone(),
node: WaiterNode::new(),
_pin: PhantomPinned,
}
}
#[cfg(test)]
pub(crate) fn enable_race_yield(&self) {
self.inner.race_yield.store(true, Ordering::Relaxed);
}
}
impl Default for CancellationToken {
fn default() -> Self {
Self::new()
}
}
impl std::fmt::Debug for CancellationToken {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CancellationToken")
.field("cancelled", &self.is_cancelled())
.finish()
}
}
pub struct Cancelled {
inner: Arc<Inner>,
node: WaiterNode,
_pin: PhantomPinned,
}
impl Future for Cancelled {
type Output = ();
fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
if self.inner.is_cancelled() {
return Poll::Ready(());
}
let this = unsafe { self.get_unchecked_mut() };
let node = &this.node;
if !node.in_list.load(Ordering::Acquire) {
{
let _guard = SpinGuard::new(&this.inner.list_lock);
unsafe { *node.waker.get() = Some(cx.waker().clone()) };
unsafe {
let head_slot = this.inner.head.get();
let old_head = *head_slot;
let node_ptr = std::ptr::from_ref(node).cast_mut();
*node.next.get() = old_head;
*node.prev.get() = std::ptr::null_mut();
if !old_head.is_null() {
*(*old_head).prev.get() = node_ptr;
}
*head_slot = node_ptr;
}
node.in_list.store(true, Ordering::Release);
}
if this.inner.is_cancelled() {
return Poll::Ready(());
}
return Poll::Pending;
}
let _guard = SpinGuard::new(&this.inner.list_lock);
if !node.in_list.load(Ordering::Relaxed) {
return Poll::Ready(());
}
let needs_update = unsafe {
(*node.waker.get())
.as_ref()
.is_none_or(|w| !w.will_wake(cx.waker()))
};
if needs_update {
unsafe { *node.waker.get() = Some(cx.waker().clone()) };
}
Poll::Pending
}
}
impl Drop for Cancelled {
fn drop(&mut self) {
if !self.node.in_list.load(Ordering::Acquire) {
return;
}
let _guard = SpinGuard::new(&self.inner.list_lock);
if self.node.in_list.load(Ordering::Relaxed) {
unsafe {
let prev = *self.node.prev.get();
let next = *self.node.next.get();
if prev.is_null() {
*self.inner.head.get() = next;
} else {
*(*prev).next.get() = next;
}
if !next.is_null() {
*(*next).prev.get() = prev;
}
*self.node.next.get() = std::ptr::null_mut();
*self.node.prev.get() = std::ptr::null_mut();
let _ = (*self.node.waker.get()).take();
}
self.node.in_list.store(false, Ordering::Release);
}
}
}
pub struct DropGuard {
token: Option<CancellationToken>,
}
impl DropGuard {
pub fn disarm(mut self) -> CancellationToken {
self.token.take().expect("DropGuard already disarmed")
}
}
impl Drop for DropGuard {
fn drop(&mut self) {
if let Some(ref token) = self.token {
token.cancel();
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::task::{RawWaker, RawWakerVTable};
fn noop_waker() -> Waker {
fn noop(_: *const ()) {}
fn noop_clone(p: *const ()) -> RawWaker {
RawWaker::new(p, &VTABLE)
}
const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
}
fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
let waker = noop_waker();
let mut cx = Context::from_waker(&waker);
f.poll(&mut cx)
}
#[test]
fn not_cancelled_by_default() {
let token = CancellationToken::new();
assert!(!token.is_cancelled());
}
#[test]
fn cancel_sets_flag() {
let token = CancellationToken::new();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn cancel_is_idempotent() {
let token = CancellationToken::new();
token.cancel();
token.cancel();
assert!(token.is_cancelled());
}
#[test]
fn clone_shares_state() {
let token = CancellationToken::new();
let clone = token.clone();
token.cancel();
assert!(clone.is_cancelled());
}
#[test]
fn child_sees_parent_cancel() {
let parent = CancellationToken::new();
let child = parent.child();
assert!(!child.is_cancelled());
parent.cancel();
assert!(child.is_cancelled());
}
#[test]
fn grandchild_sees_ancestor_cancel() {
let root = CancellationToken::new();
let child = root.child();
let grandchild = child.child();
assert!(!grandchild.is_cancelled());
root.cancel();
assert!(grandchild.is_cancelled());
}
#[test]
fn child_cancel_does_not_affect_parent() {
let parent = CancellationToken::new();
let child = parent.child();
child.cancel();
assert!(child.is_cancelled());
assert!(!parent.is_cancelled());
}
#[test]
fn cancelled_future_ready_when_cancelled() {
let token = CancellationToken::new();
token.cancel();
let mut fut = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn cancelled_future_pending_then_ready() {
let token = CancellationToken::new();
let mut fut = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
token.cancel();
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn child_cancelled_future_from_parent() {
let parent = CancellationToken::new();
let child = parent.child();
let mut fut = std::pin::pin!(child.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
parent.cancel();
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn multiple_waiters() {
let token = CancellationToken::new();
let mut fut1 = std::pin::pin!(token.cancelled());
let mut fut2 = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut1.as_mut()), Poll::Pending));
assert!(matches!(poll_once(fut2.as_mut()), Poll::Pending));
token.cancel();
assert!(matches!(poll_once(fut1.as_mut()), Poll::Ready(())));
assert!(matches!(poll_once(fut2.as_mut()), Poll::Ready(())));
}
#[test]
fn cross_thread_cancel() {
let token = CancellationToken::new();
let clone = token.clone();
let handle = std::thread::spawn(move || {
std::thread::sleep(std::time::Duration::from_millis(10));
clone.cancel();
});
while !token.is_cancelled() {
std::hint::spin_loop();
}
handle.join().unwrap();
}
#[test]
fn drop_guard_cancels_on_drop() {
let token = CancellationToken::new();
let clone = token.clone();
{
let _guard = token.drop_guard();
assert!(!clone.is_cancelled());
}
assert!(clone.is_cancelled());
}
#[test]
fn drop_guard_disarm() {
let token = CancellationToken::new();
let clone = token.clone();
let guard = token.drop_guard();
let recovered = guard.disarm();
drop(recovered);
assert!(!clone.is_cancelled());
}
#[test]
fn drop_guard_on_panic() {
let token = CancellationToken::new();
let clone = token.clone();
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let _guard = token.drop_guard();
panic!("simulated panic");
}));
assert!(result.is_err());
assert!(clone.is_cancelled());
}
#[test]
fn send_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<CancellationToken>();
assert_send_sync::<Cancelled>();
}
#[test]
fn drop_without_cancel_cleans_up() {
let token = CancellationToken::new();
let _child = token.child();
let mut fut = std::pin::pin!(token.cancelled());
let _ = poll_once(fut.as_mut()); }
#[test]
fn many_children() {
let parent = CancellationToken::new();
let children: Vec<_> = (0..100).map(|_| parent.child()).collect();
parent.cancel();
for child in &children {
assert!(child.is_cancelled());
}
}
#[test]
fn child_created_after_parent_cancelled() {
let parent = CancellationToken::new();
parent.cancel();
let child = parent.child();
assert!(child.is_cancelled());
}
#[test]
fn poll_after_cancel_drained_uses_in_list_false_path() {
let token = CancellationToken::new();
let mut fut = std::pin::pin!(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
token.cancel();
assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
}
#[test]
fn cancel_drain_race_regression() {
use std::sync::Arc;
use std::sync::atomic::AtomicBool;
#[cfg(miri)]
const TRIALS: usize = 20;
#[cfg(not(miri))]
const TRIALS: usize = 200;
for _ in 0..TRIALS {
let token = CancellationToken::new();
token.enable_race_yield();
let registered = Arc::new(AtomicBool::new(false));
let drop_thread = {
let token = token.clone();
let registered = registered.clone();
std::thread::spawn(move || {
let mut fut = Box::pin(token.cancelled());
assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
registered.store(true, Ordering::Release);
while !token.is_cancelled() {
std::hint::spin_loop();
}
drop(fut);
})
};
let cancel_thread = {
let token = token.clone();
let registered = registered.clone();
std::thread::spawn(move || {
while !registered.load(Ordering::Acquire) {
std::hint::spin_loop();
}
token.cancel();
})
};
drop_thread.join().unwrap();
cancel_thread.join().unwrap();
}
}
}