1use std::future::Future;
35use std::pin::Pin;
36use std::sync::Arc;
37use std::sync::atomic::{AtomicBool, AtomicPtr, Ordering};
38use std::task::{Context, Poll, Waker};
39
40struct Inner {
45 cancelled: AtomicBool,
46 waiter_head: AtomicPtr<WaiterNode>,
49 child_head: AtomicPtr<ChildNode>,
52}
53
54struct WaiterNode {
55 waker: Waker,
56 next: *mut WaiterNode,
57}
58
59struct ChildNode {
60 inner: Arc<Inner>,
61 next: *mut ChildNode,
62}
63
64unsafe impl Send for WaiterNode {}
68unsafe impl Send for ChildNode {}
69
70impl Inner {
71 fn new() -> Arc<Self> {
72 Arc::new(Self {
73 cancelled: AtomicBool::new(false),
74 waiter_head: AtomicPtr::new(std::ptr::null_mut()),
75 child_head: AtomicPtr::new(std::ptr::null_mut()),
76 })
77 }
78
79 fn is_cancelled(&self) -> bool {
81 self.cancelled.load(Ordering::Acquire)
82 }
83
84 fn cancel(&self) {
92 self.cancelled.store(true, Ordering::Release);
95
96 let mut waiter = self
98 .waiter_head
99 .swap(std::ptr::null_mut(), Ordering::AcqRel);
100 while !waiter.is_null() {
101 let node = unsafe { Box::from_raw(waiter) };
103 waiter = node.next;
104 node.waker.wake();
105 }
106
107 let mut child = self.child_head.swap(std::ptr::null_mut(), Ordering::AcqRel);
109 while !child.is_null() {
110 let node = unsafe { Box::from_raw(child) };
111 child = node.next;
112 node.inner.cancel();
113 }
114 }
115
116 fn add_child(&self, child: &Arc<Inner>) {
118 let node = Box::into_raw(Box::new(ChildNode {
119 inner: child.clone(),
120 next: std::ptr::null_mut(),
121 }));
122
123 loop {
125 if self.is_cancelled() {
127 let node = unsafe { Box::from_raw(node) };
129 node.inner.cancel();
130 return;
131 }
132
133 let head = self.child_head.load(Ordering::Acquire);
134 unsafe { (*node).next = head };
135 if self
136 .child_head
137 .compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
138 .is_ok()
139 {
140 if self.is_cancelled() {
144 self.cancel();
146 }
147 return;
148 }
149 }
150 }
151
152 fn register(&self, waker: &Waker) -> bool {
154 if self.is_cancelled() {
155 return true;
156 }
157
158 let node = Box::into_raw(Box::new(WaiterNode {
159 waker: waker.clone(),
160 next: std::ptr::null_mut(),
161 }));
162
163 loop {
165 if self.is_cancelled() {
166 unsafe { drop(Box::from_raw(node)) };
168 return true;
169 }
170
171 let head = self.waiter_head.load(Ordering::Acquire);
172 unsafe { (*node).next = head };
173 if self
174 .waiter_head
175 .compare_exchange_weak(head, node, Ordering::AcqRel, Ordering::Relaxed)
176 .is_ok()
177 {
178 if self.is_cancelled() {
180 self.cancel(); return true;
182 }
183 return false;
184 }
185 }
186 }
187}
188
189impl Drop for Inner {
190 fn drop(&mut self) {
191 let mut waiter = *self.waiter_head.get_mut();
194 while !waiter.is_null() {
195 let node = unsafe { Box::from_raw(waiter) };
196 waiter = node.next;
197 }
198 let mut child = *self.child_head.get_mut();
199 while !child.is_null() {
200 let node = unsafe { Box::from_raw(child) };
201 child = node.next;
202 }
203 }
204}
205
206#[derive(Clone)]
228pub struct CancellationToken {
229 inner: Arc<Inner>,
230}
231
232impl CancellationToken {
233 pub fn new() -> Self {
235 Self {
236 inner: Inner::new(),
237 }
238 }
239
240 pub fn child(&self) -> Self {
244 let child = Self {
245 inner: Inner::new(),
246 };
247 self.inner.add_child(&child.inner);
248 child
249 }
250
251 pub fn cancel(&self) {
254 self.inner.cancel();
255 }
256
257 pub fn is_cancelled(&self) -> bool {
261 self.inner.is_cancelled()
262 }
263
264 pub fn drop_guard(self) -> DropGuard {
268 DropGuard { token: Some(self) }
269 }
270
271 pub fn cancelled(&self) -> Cancelled {
273 Cancelled {
274 inner: self.inner.clone(),
275 last_waker: None,
276 }
277 }
278}
279
280impl Default for CancellationToken {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286impl std::fmt::Debug for CancellationToken {
287 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
288 f.debug_struct("CancellationToken")
289 .field("cancelled", &self.is_cancelled())
290 .finish()
291 }
292}
293
294pub struct Cancelled {
307 inner: Arc<Inner>,
308 last_waker: Option<Waker>,
309}
310
311impl Future for Cancelled {
312 type Output = ();
313
314 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
315 if self.inner.is_cancelled() {
316 return Poll::Ready(());
317 }
318
319 let needs_register = self
322 .last_waker
323 .as_ref()
324 .is_none_or(|prev| !prev.will_wake(cx.waker()));
325
326 if needs_register {
327 if self.inner.register(cx.waker()) {
328 return Poll::Ready(());
329 }
330 self.last_waker = Some(cx.waker().clone());
331 }
332
333 Poll::Pending
334 }
335}
336
337pub struct DropGuard {
346 token: Option<CancellationToken>,
347}
348
349impl DropGuard {
350 pub fn disarm(mut self) -> CancellationToken {
353 self.token.take().expect("DropGuard already disarmed")
354 }
355}
356
357impl Drop for DropGuard {
358 fn drop(&mut self) {
359 if let Some(ref token) = self.token {
360 token.cancel();
361 }
362 }
363}
364
365#[cfg(test)]
370mod tests {
371 use super::*;
372 use std::task::{RawWaker, RawWakerVTable};
373
374 fn noop_waker() -> Waker {
375 fn noop(_: *const ()) {}
376 fn noop_clone(p: *const ()) -> RawWaker {
377 RawWaker::new(p, &VTABLE)
378 }
379 const VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
380 unsafe { Waker::from_raw(RawWaker::new(std::ptr::null(), &VTABLE)) }
381 }
382
383 fn poll_once<F: Future>(f: Pin<&mut F>) -> Poll<F::Output> {
384 let waker = noop_waker();
385 let mut cx = Context::from_waker(&waker);
386 f.poll(&mut cx)
387 }
388
389 #[test]
390 fn not_cancelled_by_default() {
391 let token = CancellationToken::new();
392 assert!(!token.is_cancelled());
393 }
394
395 #[test]
396 fn cancel_sets_flag() {
397 let token = CancellationToken::new();
398 token.cancel();
399 assert!(token.is_cancelled());
400 }
401
402 #[test]
403 fn cancel_is_idempotent() {
404 let token = CancellationToken::new();
405 token.cancel();
406 token.cancel();
407 assert!(token.is_cancelled());
408 }
409
410 #[test]
411 fn clone_shares_state() {
412 let token = CancellationToken::new();
413 let clone = token.clone();
414 token.cancel();
415 assert!(clone.is_cancelled());
416 }
417
418 #[test]
419 fn child_sees_parent_cancel() {
420 let parent = CancellationToken::new();
421 let child = parent.child();
422 assert!(!child.is_cancelled());
423 parent.cancel();
424 assert!(child.is_cancelled());
425 }
426
427 #[test]
428 fn grandchild_sees_ancestor_cancel() {
429 let root = CancellationToken::new();
430 let child = root.child();
431 let grandchild = child.child();
432 assert!(!grandchild.is_cancelled());
433 root.cancel();
434 assert!(grandchild.is_cancelled());
435 }
436
437 #[test]
438 fn child_cancel_does_not_affect_parent() {
439 let parent = CancellationToken::new();
440 let child = parent.child();
441 child.cancel();
442 assert!(child.is_cancelled());
443 assert!(!parent.is_cancelled());
444 }
445
446 #[test]
447 fn cancelled_future_ready_when_cancelled() {
448 let token = CancellationToken::new();
449 token.cancel();
450
451 let mut fut = std::pin::pin!(token.cancelled());
452 assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
453 }
454
455 #[test]
456 fn cancelled_future_pending_then_ready() {
457 let token = CancellationToken::new();
458
459 let mut fut = std::pin::pin!(token.cancelled());
460 assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
461
462 token.cancel();
463 assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
465 }
466
467 #[test]
468 fn child_cancelled_future_from_parent() {
469 let parent = CancellationToken::new();
470 let child = parent.child();
471
472 let mut fut = std::pin::pin!(child.cancelled());
473 assert!(matches!(poll_once(fut.as_mut()), Poll::Pending));
474
475 parent.cancel();
476 assert!(matches!(poll_once(fut.as_mut()), Poll::Ready(())));
477 }
478
479 #[test]
480 fn multiple_waiters() {
481 let token = CancellationToken::new();
482
483 let mut fut1 = std::pin::pin!(token.cancelled());
484 let mut fut2 = std::pin::pin!(token.cancelled());
485
486 assert!(matches!(poll_once(fut1.as_mut()), Poll::Pending));
487 assert!(matches!(poll_once(fut2.as_mut()), Poll::Pending));
488
489 token.cancel();
490
491 assert!(matches!(poll_once(fut1.as_mut()), Poll::Ready(())));
492 assert!(matches!(poll_once(fut2.as_mut()), Poll::Ready(())));
493 }
494
495 #[test]
496 fn cross_thread_cancel() {
497 let token = CancellationToken::new();
498 let clone = token.clone();
499
500 let handle = std::thread::spawn(move || {
501 std::thread::sleep(std::time::Duration::from_millis(10));
502 clone.cancel();
503 });
504
505 while !token.is_cancelled() {
506 std::hint::spin_loop();
507 }
508
509 handle.join().unwrap();
510 }
511
512 #[test]
513 fn drop_guard_cancels_on_drop() {
514 let token = CancellationToken::new();
515 let clone = token.clone();
516 {
517 let _guard = token.drop_guard();
518 assert!(!clone.is_cancelled());
519 }
520 assert!(clone.is_cancelled());
521 }
522
523 #[test]
524 fn drop_guard_disarm() {
525 let token = CancellationToken::new();
526 let clone = token.clone();
527 let guard = token.drop_guard();
528 let recovered = guard.disarm();
529 drop(recovered);
530 assert!(!clone.is_cancelled());
531 }
532
533 #[test]
534 fn drop_guard_on_panic() {
535 let token = CancellationToken::new();
536 let clone = token.clone();
537
538 let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
539 let _guard = token.drop_guard();
540 panic!("simulated panic");
541 }));
542
543 assert!(result.is_err());
544 assert!(clone.is_cancelled());
545 }
546
547 #[test]
548 fn send_sync() {
549 fn assert_send_sync<T: Send + Sync>() {}
550 assert_send_sync::<CancellationToken>();
551 assert_send_sync::<Cancelled>();
552 }
553
554 #[test]
555 fn drop_without_cancel_cleans_up() {
556 let token = CancellationToken::new();
558 let _child = token.child();
559 let mut fut = std::pin::pin!(token.cancelled());
560 let _ = poll_once(fut.as_mut()); }
563
564 #[test]
565 fn many_children() {
566 let parent = CancellationToken::new();
567 let children: Vec<_> = (0..100).map(|_| parent.child()).collect();
568
569 parent.cancel();
570 for child in &children {
571 assert!(child.is_cancelled());
572 }
573 }
574
575 #[test]
576 fn child_created_after_parent_cancelled() {
577 let parent = CancellationToken::new();
578 parent.cancel();
579 let child = parent.child();
580 assert!(child.is_cancelled());
581 }
582}