Skip to main content

nexus_async_rt/
cancel.rs

1//! Cancellation tokens for cooperative task shutdown.
2//!
3//! Adapted from tokio-util's `CancellationToken` design, built for
4//! the nexus-async-rt runtime. `Clone + Send + Sync`. Hierarchical —
5//! cancelling a parent cancels all children.
6//!
7//! Lock-free: `is_cancelled()` is a single atomic load. Registration
8//! and cancellation use atomic Treiber stacks (CAS on head). No mutex.
9//!
10//! Any holder can cancel or await cancellation — no separate sender/
11//! receiver roles. This allows any task in a group to trigger shutdown.
12//!
13//! ```ignore
14//! use nexus_async_rt::CancellationToken;
15//!
16//! let token = CancellationToken::new();
17//!
18//! // Any clone can cancel or await:
19//! let t = token.clone();
20//! spawn_boxed(async move {
21//!     match do_work().await {
22//!         Ok(()) => t.cancelled().await,  // wait
23//!         Err(_) => t.cancel(),           // or trigger
24//!     }
25//! });
26//!
27//! // Hierarchical:
28//! let child = token.child();  // cancelled when parent is
29//!
30//! // Drop guard — cancels on scope exit:
31//! let _guard = token.drop_guard();
32//! ```
33
34use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
38use std::task::{Context, Poll, Waker};
39
40// =============================================================================
41// Inner state — lock-free via atomic Treiber stacks
42// =============================================================================
43
44struct Inner {
45    cancelled: AtomicBool,
46    /// Head of the waiter Treiber stack. Each node is a heap-allocated
47    /// `WaiterNode`. Push via CAS, drain-all via swap-to-null on cancel.
48    waiter_head: AtomicPtr<WaiterNode>,
49    /// Head of the child Treiber stack. Each node is a heap-allocated
50    /// `ChildNode`. Same push/drain pattern.
51    child_head: AtomicPtr<ChildNode>,
52}
53
54struct WaiterNode {
55    waker: Waker,
56    next: *mut WaiterNode,
57}
58
59struct ChildNode {
60    inner: Arc<Inner>,
61    next: *mut ChildNode,
62}
63
64// SAFETY: WaiterNode/ChildNode are only accessed via atomic stack
65// operations (push from any thread, drain from cancelling thread).
66// The Waker inside is Send+Sync. Arc<Inner> is Send+Sync.
67unsafe impl Send for WaiterNode {}
68unsafe impl Send for ChildNode {}
69
70impl Inner {
71    fn new() -> Arc<Self> {
72        Arc::new(Self {
73            cancelled: AtomicBool::new(false),
74            waiter_head: AtomicPtr::new(std::ptr::null_mut()),
75            child_head: AtomicPtr::new(std::ptr::null_mut()),
76        })
77    }
78
79    /// O(1) — single atomic load.
80    fn is_cancelled(&self) -> bool {
81        self.cancelled.load(Ordering::Acquire)
82    }
83
84    /// Cancel: set flag, drain and wake all waiters, drain and cancel all children.
85    ///
86    /// Idempotent — safe to call multiple times. The flag swap is a no-op
87    /// if already true. The list drains are also idempotent (swap to null
88    /// on an already-null list is a no-op). This is important because
89    /// register()/add_child() call cancel() to catch nodes pushed during
90    /// a race window.
91    fn cancel(&self) {
92        // Set the flag. If it was already set, we still drain below to
93        // catch nodes pushed between a prior cancel()'s drain and now.
94        self.cancelled.store(true, Ordering::Release);
95
96        // Drain waiters — swap head to null, walk the list.
97        let mut waiter = self
98            .waiter_head
99            .swap(std::ptr::null_mut(), Ordering::AcqRel);
100        while !waiter.is_null() {
101            // SAFETY: node was allocated by register() via Box::into_raw.
102            let node = unsafe { Box::from_raw(waiter) };
103            waiter = node.next;
104            node.waker.wake();
105        }
106
107        // Drain children — swap head to null, cancel each.
108        let mut child = self.child_head.swap(std::ptr::null_mut(), Ordering::AcqRel);
109        while !child.is_null() {
110            let node = unsafe { Box::from_raw(child) };
111            child = node.next;
112            node.inner.cancel();
113        }
114    }
115
116    /// Register a child. If already cancelled, cancels the child immediately.
117    fn add_child(&self, child: &Arc<Inner>) {
118        let node = Box::into_raw(Box::new(ChildNode {
119            inner: child.clone(),
120            next: std::ptr::null_mut(),
121        }));
122
123        // CAS push onto the child stack.
124        loop {
125            // Check cancelled before pushing — avoid leaking the node.
126            if self.is_cancelled() {
127                // SAFETY: we just allocated this node.
128                let node = unsafe { Box::from_raw(node) };
129                node.inner.cancel();
130                return;
131            }
132
133            let head = self.child_head.load(Ordering::Acquire);
134            unsafe { (*node).next = head };
135            if self
136                .child_head
137                .compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
138                .is_ok()
139            {
140                // Successfully pushed. But check if cancelled between
141                // our load and the CAS — if so, the cancel() call may
142                // have already drained and missed our node.
143                if self.is_cancelled() {
144                    // Re-cancel to catch our node (idempotent).
145                    self.cancel();
146                }
147                return;
148            }
149        }
150    }
151
152    /// Register a waker. Returns true if already cancelled.
153    fn register(&self, waker: &Waker) -> bool {
154        if self.is_cancelled() {
155            return true;
156        }
157
158        let node = Box::into_raw(Box::new(WaiterNode {
159            waker: waker.clone(),
160            next: std::ptr::null_mut(),
161        }));
162
163        // CAS push onto the waiter stack.
164        loop {
165            if self.is_cancelled() {
166                // SAFETY: we just allocated this node.
167                unsafe { drop(Box::from_raw(node)) };
168                return true;
169            }
170
171            let head = self.waiter_head.load(Ordering::Acquire);
172            unsafe { (*node).next = head };
173            if self
174                .waiter_head
175                .compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
176                .is_ok()
177            {
178                // Check for race: cancelled between load and CAS.
179                if self.is_cancelled() {
180                    self.cancel(); // idempotent — drains our node
181                    return true;
182                }
183                return false;
184            }
185        }
186    }
187}
188
189impl Drop for Inner {
190    fn drop(&mut self) {
191        // Clean up any remaining nodes (shouldn't happen normally,
192        // but guards against leaks if tokens are dropped without cancel).
193        let mut waiter = *self.waiter_head.get_mut();
194        while !waiter.is_null() {
195            let node = unsafe { Box::from_raw(waiter) };
196            waiter = node.next;
197        }
198        let mut child = *self.child_head.get_mut();
199        while !child.is_null() {
200            let node = unsafe { Box::from_raw(child) };
201            child = node.next;
202        }
203    }
204}
205
206// =============================================================================
207// CancellationToken
208// =============================================================================
209
210/// A token for cooperative cancellation.
211///
212/// `Clone + Send + Sync`. Cloning shares the same cancellation state.
213/// Use [`child()`](CancellationToken::child) for hierarchical cancellation.
214///
215/// # Example
216///
217/// ```ignore
218/// let token = CancellationToken::new();
219///
220/// spawn_boxed(async move {
221///     token.cancelled().await;
222///     println!("shutting down");
223/// });
224///
225/// token.cancel();
226/// ```
227#[derive(Clone)]
228pub struct CancellationToken {
229    inner: Arc<Inner>,
230}
231
232impl CancellationToken {
233    /// Create a new cancellation token.
234    pub fn new() -> Self {
235        Self {
236            inner: Inner::new(),
237        }
238    }
239
240    /// Create a child token. Cancelling this token (or any ancestor)
241    /// also cancels the child and wakes its waiters. Cancelling the
242    /// child does NOT cancel the parent.
243    pub fn child(&self) -> Self {
244        let child = Self {
245            inner: Inner::new(),
246        };
247        self.inner.add_child(&child.inner);
248        child
249    }
250
251    /// Cancel this token. All futures awaiting [`cancelled()`](Self::cancelled)
252    /// will resolve. Child tokens are also cancelled.
253    pub fn cancel(&self) {
254        self.inner.cancel();
255    }
256
257    /// Whether this token has been cancelled.
258    /// O(1) — single atomic load. Parent cancellation propagates
259    /// eagerly (sets the child's flag), so no chain traversal needed.
260    pub fn is_cancelled(&self) -> bool {
261        self.inner.is_cancelled()
262    }
263
264    /// Returns a guard that cancels this token when dropped.
265    ///
266    /// Useful for ensuring cancellation on scope exit or panic.
267    pub fn drop_guard(self) -> DropGuard {
268        DropGuard { token: Some(self) }
269    }
270
271    /// Returns a future that resolves when this token is cancelled.
272    pub fn cancelled(&self) -> Cancelled {
273        Cancelled {
274            inner: self.inner.clone(),
275            last_waker: None,
276        }
277    }
278}
279
280impl Default for CancellationToken {
281    fn default() -> Self {
282        Self::new()
283    }
284}
285
286impl std::fmt::Debug for CancellationToken {
287    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288        f.debug_struct("CancellationToken")
289            .field("cancelled", &self.is_cancelled())
290            .finish()
291    }
292}
293
294// =============================================================================
295// Cancelled future
296// =============================================================================
297
298/// Future that resolves when a [`CancellationToken`] is cancelled.
299///
300/// Created by [`CancellationToken::cancelled()`].
301///
302/// Tracks the last registered waker and re-registers if the waker changes
303/// (e.g., the future is moved between tasks via `select!` or `Timeout`).
304/// Each registration allocates a `WaiterNode` on the heap. Prior nodes
305/// are cleaned up when `cancel()` drains the Treiber stack.
306pub struct Cancelled {
307    inner: Arc<Inner>,
308    last_waker: Option<Waker>,
309}
310
311impl Future for Cancelled {
312    type Output = ();
313
314    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
315        if self.inner.is_cancelled() {
316            return Poll::Ready(());
317        }
318
319        // Register (or re-register) if this is the first poll or
320        // the waker has changed since last registration.
321        let needs_register = self
322            .last_waker
323            .as_ref()
324            .is_none_or(|prev| !prev.will_wake(cx.waker()));
325
326        if needs_register {
327            if self.inner.register(cx.waker()) {
328                return Poll::Ready(());
329            }
330            self.last_waker = Some(cx.waker().clone());
331        }
332
333        Poll::Pending
334    }
335}
336
337// =============================================================================
338// DropGuard
339// =============================================================================
340
341/// A guard that cancels a [`CancellationToken`] when dropped.
342///
343/// Created by [`CancellationToken::drop_guard()`]. Call
344/// [`disarm()`](DropGuard::disarm) to prevent cancellation on drop.
345pub struct DropGuard {
346    token: Option<CancellationToken>,
347}
348
349impl DropGuard {
350    /// Disarm the guard — the token will NOT be cancelled on drop.
351    /// Returns the token.
352    pub fn disarm(mut self) -> CancellationToken {
353        self.token.take().expect("DropGuard already disarmed")
354    }
355}
356
357impl Drop for DropGuard {
358    fn drop(&mut self) {
359        if let Some(ref token) = self.token {
360            token.cancel();
361        }
362    }
363}
364
365// =============================================================================
366// Tests
367// =============================================================================
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372    use std::task::{RawWaker, RawWakerVTable};
373
374    fn noop_waker() -> Waker {
375        fn noop(_: *const ()) {}
376        fn noop_clone(p: *const ()) -> RawWaker {
377            RawWaker::new(p, &VTABLE)
378        }
379        const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
380        unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
381    }
382
383    fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
384        let waker = noop_waker();
385        let mut cx = Context::from_waker(&waker);
386        f.poll(&mut cx)
387    }
388
389    #[test]
390    fn not_cancelled_by_default() {
391        let token = CancellationToken::new();
392        assert!(!token.is_cancelled());
393    }
394
395    #[test]
396    fn cancel_sets_flag() {
397        let token = CancellationToken::new();
398        token.cancel();
399        assert!(token.is_cancelled());
400    }
401
402    #[test]
403    fn cancel_is_idempotent() {
404        let token = CancellationToken::new();
405        token.cancel();
406        token.cancel();
407        assert!(token.is_cancelled());
408    }
409
410    #[test]
411    fn clone_shares_state() {
412        let token = CancellationToken::new();
413        let clone = token.clone();
414        token.cancel();
415        assert!(clone.is_cancelled());
416    }
417
418    #[test]
419    fn child_sees_parent_cancel() {
420        let parent = CancellationToken::new();
421        let child = parent.child();
422        assert!(!child.is_cancelled());
423        parent.cancel();
424        assert!(child.is_cancelled());
425    }
426
427    #[test]
428    fn grandchild_sees_ancestor_cancel() {
429        let root = CancellationToken::new();
430        let child = root.child();
431        let grandchild = child.child();
432        assert!(!grandchild.is_cancelled());
433        root.cancel();
434        assert!(grandchild.is_cancelled());
435    }
436
437    #[test]
438    fn child_cancel_does_not_affect_parent() {
439        let parent = CancellationToken::new();
440        let child = parent.child();
441        child.cancel();
442        assert!(child.is_cancelled());
443        assert!(!parent.is_cancelled());
444    }
445
446    #[test]
447    fn cancelled_future_ready_when_cancelled() {
448        let token = CancellationToken::new();
449        token.cancel();
450
451        let mut fut = std::pin::pin!(token.cancelled());
452        assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
453    }
454
455    #[test]
456    fn cancelled_future_pending_then_ready() {
457        let token = CancellationToken::new();
458
459        let mut fut = std::pin::pin!(token.cancelled());
460        assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
461
462        token.cancel();
463        // Re-poll — now ready.
464        assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
465    }
466
467    #[test]
468    fn child_cancelled_future_from_parent() {
469        let parent = CancellationToken::new();
470        let child = parent.child();
471
472        let mut fut = std::pin::pin!(child.cancelled());
473        assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
474
475        parent.cancel();
476        assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
477    }
478
479    #[test]
480    fn multiple_waiters() {
481        let token = CancellationToken::new();
482
483        let mut fut1 = std::pin::pin!(token.cancelled());
484        let mut fut2 = std::pin::pin!(token.cancelled());
485
486        assert!(matches!(poll_once(fut1.as_mut()), Poll::Pending));
487        assert!(matches!(poll_once(fut2.as_mut()), Poll::Pending));
488
489        token.cancel();
490
491        assert!(matches!(poll_once(fut1.as_mut()), Poll::Ready(())));
492        assert!(matches!(poll_once(fut2.as_mut()), Poll::Ready(())));
493    }
494
495    #[test]
496    fn cross_thread_cancel() {
497        let token = CancellationToken::new();
498        let clone = token.clone();
499
500        let handle = std::thread::spawn(move || {
501            std::thread::sleep(std::time::Duration::from_millis(10));
502            clone.cancel();
503        });
504
505        while !token.is_cancelled() {
506            std::hint::spin_loop();
507        }
508
509        handle.join().unwrap();
510    }
511
512    #[test]
513    fn drop_guard_cancels_on_drop() {
514        let token = CancellationToken::new();
515        let clone = token.clone();
516        {
517            let _guard = token.drop_guard();
518            assert!(!clone.is_cancelled());
519        }
520        assert!(clone.is_cancelled());
521    }
522
523    #[test]
524    fn drop_guard_disarm() {
525        let token = CancellationToken::new();
526        let clone = token.clone();
527        let guard = token.drop_guard();
528        let recovered = guard.disarm();
529        drop(recovered);
530        assert!(!clone.is_cancelled());
531    }
532
533    #[test]
534    fn drop_guard_on_panic() {
535        let token = CancellationToken::new();
536        let clone = token.clone();
537
538        let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
539            let _guard = token.drop_guard();
540            panic!("simulated panic");
541        }));
542
543        assert!(result.is_err());
544        assert!(clone.is_cancelled());
545    }
546
547    #[test]
548    fn send_sync() {
549        fn assert_send_sync<T: Send + Sync>() {}
550        assert_send_sync::<CancellationToken>();
551        assert_send_sync::<Cancelled>();
552    }
553
554    #[test]
555    fn drop_without_cancel_cleans_up() {
556        // Tokens dropped without cancellation — nodes should be freed.
557        let token = CancellationToken::new();
558        let _child = token.child();
559        let mut fut = std::pin::pin!(token.cancelled());
560        let _ = poll_once(fut.as_mut()); // register a waiter
561        // Everything dropped — no leak (tested under miri if available).
562    }
563
564    #[test]
565    fn many_children() {
566        let parent = CancellationToken::new();
567        let children: Vec<_> = (0..100).map(|_| parent.child()).collect();
568
569        parent.cancel();
570        for child in &children {
571            assert!(child.is_cancelled());
572        }
573    }
574
575    #[test]
576    fn child_created_after_parent_cancelled() {
577        let parent = CancellationToken::new();
578        parent.cancel();
579        let child = parent.child();
580        assert!(child.is_cancelled());
581    }
582}