1#![forbid(unsafe_code)]
2
3use std::future::Future;
4use std::pin::Pin;
5use std::sync::{
6 Arc, Mutex,
7 atomic::{AtomicBool, Ordering},
8};
9use std::task::{Context, Poll};
10
11#[derive(Clone, Debug, Default)]
12pub struct CancellationToken {
13 inner: Arc<Inner>,
14}
15
16#[derive(Debug)]
17struct Inner {
18 cancelled: AtomicBool,
19 wakers: Mutex<Vec<std::task::Waker>>,
20}
21
22impl Default for Inner {
23 fn default() -> Self {
24 Self {
25 cancelled: AtomicBool::new(false),
26 wakers: Mutex::new(Vec::new()),
27 }
28 }
29}
30
31impl CancellationToken {
32 pub fn new() -> Self {
33 Self::default()
34 }
35
36 pub fn cancel(&self) {
37 self.inner.cancelled.store(true, Ordering::SeqCst);
38 let mut wakers = match self.inner.wakers.lock() {
39 Ok(guard) => guard,
40 Err(poisoned) => poisoned.into_inner(),
41 };
42 let wakers = std::mem::take(&mut *wakers);
43 for w in wakers {
44 w.wake();
45 }
46 }
47
48 pub fn is_cancelled(&self) -> bool {
49 self.inner.cancelled.load(Ordering::SeqCst)
50 }
51
52 pub fn cancelled(&self) -> Cancelled {
53 Cancelled {
54 token: self.clone(),
55 }
56 }
57}
58
59pub struct Cancelled {
60 token: CancellationToken,
61}
62
63impl Future for Cancelled {
64 type Output = ();
65
66 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
67 if self.token.is_cancelled() {
68 return Poll::Ready(());
69 }
70 let waker = cx.waker().clone();
71 let mut wakers = match self.token.inner.wakers.lock() {
72 Ok(guard) => guard,
73 Err(poisoned) => poisoned.into_inner(),
74 };
75 if !wakers.iter().any(|w| w.will_wake(&waker)) {
76 wakers.push(waker);
77 }
78 if self.token.is_cancelled() {
79 Poll::Ready(())
80 } else {
81 Poll::Pending
82 }
83 }
84}