async_cancellation_token/
lib.rs1use std::{
2 cell::{Cell, RefCell},
3 future::Future,
4 pin::Pin,
5 rc::Rc,
6 task::{Context, Poll, Waker},
7};
8
9#[derive(Default)]
10struct Inner {
11 cancelled: Cell<bool>,
12 wakers: RefCell<Vec<Waker>>,
13 callbacks: RefCell<Vec<Box<dyn FnOnce()>>>,
14}
15
16#[derive(Clone)]
17pub struct CancellationTokenSource {
18 inner: Rc<Inner>,
19}
20
21#[derive(Clone)]
22pub struct CancellationToken {
23 inner: Rc<Inner>,
24}
25
26#[derive(Debug)]
27pub struct Cancelled;
28
29impl Default for CancellationTokenSource {
30 fn default() -> Self {
31 Self::new()
32 }
33}
34
35impl CancellationTokenSource {
36 pub fn new() -> Self {
37 Self {
38 inner: Rc::new(Inner::default()),
39 }
40 }
41
42 pub fn token(&self) -> CancellationToken {
43 CancellationToken {
44 inner: self.inner.clone(),
45 }
46 }
47
48 pub fn cancel(&self) {
49 if !self.inner.cancelled.replace(true) {
50 for cb in self.inner.callbacks.borrow_mut().drain(..) {
51 cb();
52 }
53
54 for w in self.inner.wakers.borrow_mut().drain(..) {
55 w.wake();
56 }
57 }
58 }
59
60 pub fn is_cancelled(&self) -> bool {
61 self.inner.cancelled.get()
62 }
63}
64
65impl CancellationToken {
66 pub fn is_cancelled(&self) -> bool {
67 self.inner.cancelled.get()
68 }
69
70 pub fn check_cancelled(&self) -> Result<(), Cancelled> {
71 if self.is_cancelled() {
72 Err(Cancelled)
73 } else {
74 Ok(())
75 }
76 }
77
78 pub fn cancelled(&self) -> CancelledFuture {
79 CancelledFuture {
80 token: self.clone(),
81 }
82 }
83
84 pub fn register(&self, f: impl FnOnce() + 'static) {
85 if self.is_cancelled() {
86 f();
87 } else {
88 self.inner.callbacks.borrow_mut().push(Box::new(f));
89 }
90 }
91}
92
93pub struct CancelledFuture {
94 token: CancellationToken,
95}
96
97impl Future for CancelledFuture {
98 type Output = ();
99
100 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
101 if self.token.is_cancelled() {
102 Poll::Ready(())
103 } else {
104 let mut wakers = self.token.inner.wakers.borrow_mut();
105 if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
106 wakers.push(cx.waker().clone());
107 }
108 Poll::Pending
109 }
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use std::time::Duration;
116
117 use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
118 use futures_timer::Delay;
119
120 use super::*;
121
122 #[test]
123 fn cancel_two_tasks() {
124 let cancelled_a = Rc::new(Cell::new(false));
125 let cancelled_b = Rc::new(Cell::new(false));
126
127 let task_a = |token: CancellationToken| {
128 let cancelled_a = Rc::clone(&cancelled_a);
129
130 async move {
131 println!("Task A started");
132
133 for i in 1..=10 {
134 let delay = Delay::new(Duration::from_millis(300)).fuse();
135 let cancelled = token.cancelled().fuse();
136
137 pin_mut!(delay, cancelled);
138
139 select! {
140 _ = delay => {
141 println!("Task A step {i}");
142 },
143 _ = cancelled => {
144 println!("Task A detected cancellation, cleaning up...");
145 cancelled_a.set(true);
147 break;
148 },
149 }
150 }
151
152 println!("Task A finished");
153 }
154 };
155
156 let task_b = |token: CancellationToken| {
157 let cancelled_b = Rc::clone(&cancelled_b);
158
159 async move {
160 println!("Task B started");
161
162 for i in 1..=10 {
163 Delay::new(Duration::from_millis(500)).await;
164
165 println!("Task B step {i}");
166 if token.check_cancelled().is_err() {
167 println!("Task B noticed cancellation after step {i}");
168 cancelled_b.set(true);
170 break;
171 }
172 }
173
174 println!("Task B finished");
175 }
176 };
177
178 let cts = CancellationTokenSource::new();
179
180 let mut pool = LocalPool::new();
181 let spawner = pool.spawner();
182
183 spawner
184 .spawn_local(task_a(cts.token()).map(|_| ()))
185 .unwrap();
186 spawner
187 .spawn_local(task_b(cts.token()).map(|_| ()))
188 .unwrap();
189
190 {
191 let cts = cts.clone();
192 spawner
193 .spawn_local(
194 async move {
195 Delay::new(Duration::from_secs(2)).await;
196 println!("Cancelling all tasks!");
197 cts.cancel();
198 }
199 .map(|_| ()),
200 )
201 .unwrap();
202 }
203
204 pool.run();
205
206 assert!(cts.is_cancelled());
207 assert!(cancelled_a.get());
208 assert!(cancelled_b.get());
209 }
210
211 #[test]
212 fn cancellation_register_callbacks() {
213 let cts = CancellationTokenSource::new();
214 let token = cts.token();
215
216 let flag1 = Rc::new(Cell::new(false));
217 let flag2 = Rc::new(Cell::new(false));
218
219 {
220 let flag1 = Rc::clone(&flag1);
221 token.register(move || {
222 flag1.set(true);
223 });
224 }
225
226 cts.cancel();
227 assert!(flag1.get());
228
229 {
230 let flag2 = Rc::clone(&flag2);
231 token.register(move || {
232 flag2.set(true);
233 });
234 }
235
236 assert!(flag2.get());
237 }
238}