async_cancellation_token/lib.rs
1//! # async-cancellation-token
2//!
3//! `async-cancellation-token` is a lightweight **single-threaded** Rust library that provides
4//! **cancellation tokens** for cooperative cancellation of asynchronous tasks and callbacks.
5//!
6//! This crate works in single-threaded async environments (e.g., `futures::executor::LocalPool`)
7//! and uses `Rc`, `Cell`, and `RefCell` internally. It is **not thread-safe**.
8//!
9//! ## Example
10//!
11//! ```rust
12//! use std::time::Duration;
13//! use async_cancellation_token::CancellationTokenSource;
14//! use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
15//! use futures_timer::Delay;
16//!
17//! let cts = CancellationTokenSource::new();
18//! let token = cts.token();
19//!
20//! let mut pool = LocalPool::new();
21//! let spawner = pool.spawner();
22//!
23//! spawner.spawn_local(async move {
24//! for i in 1..=5 {
25//! let delay = Delay::new(Duration::from_millis(100)).fuse();
26//! let cancelled = token.cancelled().fuse();
27//! pin_mut!(delay, cancelled);
28//!
29//! select! {
30//! _ = delay => println!("Step {i}"),
31//! _ = cancelled => {
32//! println!("Cancelled!");
33//! break;
34//! }
35//! }
36//! }
37//! }.map(|_| ())).unwrap();
38//!
39//! spawner.spawn_local(async move {
40//! Delay::new(Duration::from_millis(250)).await;
41//! cts.cancel();
42//! }.map(|_| ())).unwrap();
43//!
44//! pool.run();
45//! ```
46
47use std::{
48 cell::{Cell, RefCell},
49 error::Error,
50 fmt::Display,
51 future::Future,
52 pin::Pin,
53 rc::Rc,
54 task::{Context, Poll, Waker},
55};
56
57/// Inner shared state for `CancellationToken` and `CancellationTokenSource`.
58#[derive(Default)]
59struct Inner {
60 /// Whether the token has been cancelled.
61 cancelled: Cell<bool>,
62 /// List of wakers to wake when cancellation occurs.
63 wakers: RefCell<Vec<Waker>>,
64 /// List of callbacks to call when cancellation occurs.
65 callbacks: RefCell<Vec<Box<dyn FnOnce()>>>,
66}
67
68/// A source that can cancel associated `CancellationToken`s.
69///
70/// # Example
71///
72/// ```rust
73/// use async_cancellation_token::CancellationTokenSource;
74///
75/// let cts = CancellationTokenSource::new();
76/// let token = cts.token();
77///
78/// assert!(!cts.is_cancelled());
79/// cts.cancel();
80/// assert!(cts.is_cancelled());
81/// ```
82#[derive(Clone)]
83pub struct CancellationTokenSource {
84 inner: Rc<Inner>,
85}
86
87/// A token that can be checked for cancellation or awaited.
88///
89/// # Example
90///
91/// ```rust
92/// use async_cancellation_token::CancellationTokenSource;
93/// use futures::{FutureExt, executor::LocalPool, task::LocalSpawnExt};
94///
95/// let cts = CancellationTokenSource::new();
96/// let token = cts.token();
97///
98/// let mut pool = LocalPool::new();
99/// pool.spawner().spawn_local(async move {
100/// token.cancelled().await;
101/// println!("Cancelled!");
102/// }.map(|_| ())).unwrap();
103///
104/// cts.cancel();
105/// pool.run();
106/// ```
107#[derive(Clone)]
108pub struct CancellationToken {
109 inner: Rc<Inner>,
110}
111
112/// Error returned when a cancelled token is checked synchronously.
113#[derive(Copy, Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Hash)]
114pub struct Cancelled;
115
116impl Display for Cancelled {
117 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118 f.write_str("cancelled by CancellationTokenSource")
119 }
120}
121
122impl Error for Cancelled {}
123
124impl Default for CancellationTokenSource {
125 fn default() -> Self {
126 Self::new()
127 }
128}
129
130impl CancellationTokenSource {
131 /// Create a new `CancellationTokenSource`.
132 pub fn new() -> Self {
133 Self {
134 inner: Rc::new(Inner::default()),
135 }
136 }
137
138 /// Get a `CancellationToken` associated with this source.
139 pub fn token(&self) -> CancellationToken {
140 CancellationToken {
141 inner: self.inner.clone(),
142 }
143 }
144
145 /// Cancel all associated tokens.
146 ///
147 /// This triggers any registered callbacks and wakes all wakers.
148 pub fn cancel(&self) {
149 if !self.inner.cancelled.replace(true) {
150 // Call all registered callbacks
151 for cb in self.inner.callbacks.borrow_mut().drain(..) {
152 cb();
153 }
154
155 // Wake all tasks waiting for cancellation
156 for w in self.inner.wakers.borrow_mut().drain(..) {
157 w.wake();
158 }
159 }
160 }
161
162 /// Check if this source has been cancelled.
163 pub fn is_cancelled(&self) -> bool {
164 self.inner.cancelled.get()
165 }
166}
167
168impl CancellationToken {
169 /// Check if the token has been cancelled.
170 pub fn is_cancelled(&self) -> bool {
171 self.inner.cancelled.get()
172 }
173
174 /// Synchronously check cancellation and return `Err(Cancelled)` if cancelled.
175 pub fn check_cancelled(&self) -> Result<(), Cancelled> {
176 if self.is_cancelled() {
177 Err(Cancelled)
178 } else {
179 Ok(())
180 }
181 }
182
183 /// Returns a `Future` that completes when the token is cancelled.
184 ///
185 /// # Example
186 ///
187 /// ```rust
188 /// use async_cancellation_token::CancellationTokenSource;
189 /// use futures::{FutureExt, executor::LocalPool, task::LocalSpawnExt};
190 ///
191 /// let cts = CancellationTokenSource::new();
192 /// let token = cts.token();
193 ///
194 /// let mut pool = LocalPool::new();
195 /// pool.spawner().spawn_local(async move {
196 /// token.cancelled().await;
197 /// println!("Cancelled!");
198 /// }.map(|_| ())).unwrap();
199 ///
200 /// cts.cancel();
201 /// pool.run();
202 /// ```
203 pub fn cancelled(&self) -> CancelledFuture {
204 CancelledFuture {
205 token: self.clone(),
206 }
207 }
208
209 /// Register a callback to run when the token is cancelled.
210 ///
211 /// If the token is already cancelled, the callback is called immediately.
212 ///
213 /// # Example
214 ///
215 /// ```rust
216 /// use std::{cell::Cell, rc::Rc};
217 /// use async_cancellation_token::CancellationTokenSource;
218 ///
219 /// let cts = CancellationTokenSource::new();
220 /// let token = cts.token();
221 ///
222 /// let flag = Rc::new(Cell::new(false));
223 /// let flag_clone = Rc::clone(&flag);
224 ///
225 /// token.register(move || {
226 /// flag_clone.set(true);
227 /// });
228 ///
229 /// cts.cancel();
230 /// assert!(flag.get());
231 /// ```
232 pub fn register(&self, f: impl FnOnce() + 'static) {
233 if self.is_cancelled() {
234 f();
235 } else {
236 self.inner.callbacks.borrow_mut().push(Box::new(f));
237 }
238 }
239}
240
241/// Future that completes when a `CancellationToken` is cancelled.
242pub struct CancelledFuture {
243 token: CancellationToken,
244}
245
246impl Future for CancelledFuture {
247 type Output = ();
248
249 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250 if self.token.is_cancelled() {
251 Poll::Ready(())
252 } else {
253 let mut wakers = self.token.inner.wakers.borrow_mut();
254 if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
255 wakers.push(cx.waker().clone());
256 }
257 Poll::Pending
258 }
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 use std::time::Duration;
265
266 use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
267 use futures_timer::Delay;
268
269 use super::*;
270
271 #[test]
272 fn cancel_two_tasks() {
273 let cancelled_a = Rc::new(Cell::new(false));
274 let cancelled_b = Rc::new(Cell::new(false));
275
276 let task_a = |token: CancellationToken| {
277 let cancelled_a = Rc::clone(&cancelled_a);
278
279 async move {
280 println!("Task A started");
281
282 for i in 1..=10 {
283 let delay = Delay::new(Duration::from_millis(300)).fuse();
284 let cancelled = token.cancelled().fuse();
285
286 pin_mut!(delay, cancelled);
287
288 select! {
289 _ = delay => {
290 println!("Task A step {i}");
291 },
292 _ = cancelled => {
293 println!("Task A detected cancellation, cleaning up...");
294 // Cleanup && Dispose
295 cancelled_a.set(true);
296 break;
297 },
298 }
299 }
300
301 println!("Task A finished");
302 }
303 };
304
305 let task_b = |token: CancellationToken| {
306 let cancelled_b = Rc::clone(&cancelled_b);
307
308 async move {
309 println!("Task B started");
310
311 for i in 1..=10 {
312 Delay::new(Duration::from_millis(500)).await;
313
314 println!("Task B step {i}");
315 if token.check_cancelled().is_err() {
316 println!("Task B noticed cancellation after step {i}");
317 // Cleanup && Dispose
318 cancelled_b.set(true);
319 break;
320 }
321 }
322
323 println!("Task B finished");
324 }
325 };
326
327 let cts = CancellationTokenSource::new();
328
329 let mut pool = LocalPool::new();
330 let spawner = pool.spawner();
331
332 spawner
333 .spawn_local(task_a(cts.token()).map(|_| ()))
334 .unwrap();
335 spawner
336 .spawn_local(task_b(cts.token()).map(|_| ()))
337 .unwrap();
338
339 {
340 let cts = cts.clone();
341 spawner
342 .spawn_local(
343 async move {
344 Delay::new(Duration::from_secs(2)).await;
345 println!("Cancelling all tasks!");
346 cts.cancel();
347 }
348 .map(|_| ()),
349 )
350 .unwrap();
351 }
352
353 pool.run();
354
355 assert!(cts.is_cancelled());
356 assert!(cancelled_a.get());
357 assert!(cancelled_b.get());
358 }
359
360 #[test]
361 fn cancellation_register_callbacks() {
362 let cts = CancellationTokenSource::new();
363 let token = cts.token();
364
365 let flag1 = Rc::new(Cell::new(false));
366 let flag2 = Rc::new(Cell::new(false));
367
368 {
369 let flag1 = Rc::clone(&flag1);
370 token.register(move || {
371 flag1.set(true);
372 });
373 }
374
375 cts.cancel();
376 assert!(flag1.get());
377
378 {
379 let flag2 = Rc::clone(&flag2);
380 token.register(move || {
381 flag2.set(true);
382 });
383 }
384
385 assert!(flag2.get());
386 }
387}