async_cancellation_token/
lib.rs

1use std::{
2    cell::{Cell, RefCell},
3    future::Future,
4    pin::Pin,
5    rc::Rc,
6    task::{Context, Poll, Waker},
7};
8
9#[derive(Default)]
10struct Inner {
11    cancelled: Cell<bool>,
12    wakers: RefCell<Vec<Waker>>,
13    callbacks: RefCell<Vec<Box<dyn FnOnce()>>>,
14}
15
16#[derive(Clone)]
17pub struct CancellationTokenSource {
18    inner: Rc<Inner>,
19}
20
21#[derive(Clone)]
22pub struct CancellationToken {
23    inner: Rc<Inner>,
24}
25
26#[derive(Debug)]
27pub struct Cancelled;
28
29impl Default for CancellationTokenSource {
30    fn default() -> Self {
31        Self::new()
32    }
33}
34
35impl CancellationTokenSource {
36    pub fn new() -> Self {
37        Self {
38            inner: Rc::new(Inner::default()),
39        }
40    }
41
42    pub fn token(&self) -> CancellationToken {
43        CancellationToken {
44            inner: self.inner.clone(),
45        }
46    }
47
48    pub fn cancel(&self) {
49        if !self.inner.cancelled.replace(true) {
50            for cb in self.inner.callbacks.borrow_mut().drain(..) {
51                cb();
52            }
53
54            for w in self.inner.wakers.borrow_mut().drain(..) {
55                w.wake();
56            }
57        }
58    }
59
60    pub fn is_cancelled(&self) -> bool {
61        self.inner.cancelled.get()
62    }
63}
64
65impl CancellationToken {
66    pub fn is_cancelled(&self) -> bool {
67        self.inner.cancelled.get()
68    }
69
70    pub fn check_cancelled(&self) -> Result<(), Cancelled> {
71        if self.is_cancelled() {
72            Err(Cancelled)
73        } else {
74            Ok(())
75        }
76    }
77
78    pub fn cancelled(&self) -> CancelledFuture {
79        CancelledFuture {
80            token: self.clone(),
81        }
82    }
83
84    pub fn register(&self, f: impl FnOnce() + 'static) {
85        if self.is_cancelled() {
86            f();
87        } else {
88            self.inner.callbacks.borrow_mut().push(Box::new(f));
89        }
90    }
91}
92
93pub struct CancelledFuture {
94    token: CancellationToken,
95}
96
97impl Future for CancelledFuture {
98    type Output = ();
99
100    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
101        if self.token.is_cancelled() {
102            Poll::Ready(())
103        } else {
104            let mut wakers = self.token.inner.wakers.borrow_mut();
105            if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
106                wakers.push(cx.waker().clone());
107            }
108            Poll::Pending
109        }
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use std::time::Duration;
116
117    use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
118    use futures_timer::Delay;
119
120    use super::*;
121
122    #[test]
123    fn cancel_two_tasks() {
124        let cancelled_a = Rc::new(Cell::new(false));
125        let cancelled_b = Rc::new(Cell::new(false));
126
127        let task_a = |token: CancellationToken| {
128            let cancelled_a = Rc::clone(&cancelled_a);
129
130            async move {
131                println!("Task A started");
132
133                for i in 1..=10 {
134                    let delay = Delay::new(Duration::from_millis(300)).fuse();
135                    let cancelled = token.cancelled().fuse();
136
137                    pin_mut!(delay, cancelled);
138
139                    select! {
140                        _ = delay => {
141                            println!("Task A step {i}");
142                        },
143                        _ = cancelled => {
144                            println!("Task A detected cancellation, cleaning up...");
145                            // Cleanup && Dispose
146                            cancelled_a.set(true);
147                            break;
148                        },
149                    }
150                }
151
152                println!("Task A finished");
153            }
154        };
155
156        let task_b = |token: CancellationToken| {
157            let cancelled_b = Rc::clone(&cancelled_b);
158
159            async move {
160                println!("Task B started");
161
162                for i in 1..=10 {
163                    Delay::new(Duration::from_millis(500)).await;
164
165                    println!("Task B step {i}");
166                    if token.check_cancelled().is_err() {
167                        println!("Task B noticed cancellation after step {i}");
168                        // Cleanup && Dispose
169                        cancelled_b.set(true);
170                        break;
171                    }
172                }
173
174                println!("Task B finished");
175            }
176        };
177
178        let cts = CancellationTokenSource::new();
179
180        let mut pool = LocalPool::new();
181        let spawner = pool.spawner();
182
183        spawner
184            .spawn_local(task_a(cts.token()).map(|_| ()))
185            .unwrap();
186        spawner
187            .spawn_local(task_b(cts.token()).map(|_| ()))
188            .unwrap();
189
190        {
191            let cts = cts.clone();
192            spawner
193                .spawn_local(
194                    async move {
195                        Delay::new(Duration::from_secs(2)).await;
196                        println!("Cancelling all tasks!");
197                        cts.cancel();
198                    }
199                    .map(|_| ()),
200                )
201                .unwrap();
202        }
203
204        pool.run();
205
206        assert!(cts.is_cancelled());
207        assert!(cancelled_a.get());
208        assert!(cancelled_b.get());
209    }
210
211    #[test]
212    fn cancellation_register_callbacks() {
213        let cts = CancellationTokenSource::new();
214        let token = cts.token();
215
216        let flag1 = Rc::new(Cell::new(false));
217        let flag2 = Rc::new(Cell::new(false));
218
219        {
220            let flag1 = Rc::clone(&flag1);
221            token.register(move || {
222                flag1.set(true);
223            });
224        }
225
226        cts.cancel();
227        assert!(flag1.get());
228
229        {
230            let flag2 = Rc::clone(&flag2);
231            token.register(move || {
232                flag2.set(true);
233            });
234        }
235
236        assert!(flag2.get());
237    }
238}