async_cancellation_token/
lib.rs

1//! # async-cancellation-token
2//!
3//! `async-cancellation-token` is a lightweight **single-threaded** Rust library that provides
4//! **cancellation tokens** for cooperative cancellation of asynchronous tasks and callbacks.
5//!
6//! This crate works in single-threaded async environments (e.g., `futures::executor::LocalPool`)
7//! and uses `Rc`, `Cell`, and `RefCell` internally. It is **not thread-safe**.
8//!
9//! ## Example
10//!
11//! ```rust
12//! use std::time::Duration;
13//! use async_cancellation_token::CancellationTokenSource;
14//! use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
15//! use futures_timer::Delay;
16//!
17//! let cts = CancellationTokenSource::new();
18//! let token = cts.token();
19//!
20//! let mut pool = LocalPool::new();
21//! let spawner = pool.spawner();
22//!
23//! spawner.spawn_local(async move {
24//!     for i in 1..=5 {
25//!         let delay = Delay::new(Duration::from_millis(100)).fuse();
26//!         let cancelled = token.cancelled().fuse();
27//!         pin_mut!(delay, cancelled);
28//!
29//!         select! {
30//!             _ = delay => println!("Step {i}"),
31//!             _ = cancelled => {
32//!                 println!("Cancelled!");
33//!                 break;
34//!             }
35//!         }
36//!     }
37//! }.map(|_| ())).unwrap();
38//!
39//! spawner.spawn_local(async move {
40//!     Delay::new(Duration::from_millis(250)).await;
41//!     cts.cancel();
42//! }.map(|_| ())).unwrap();
43//!
44//! pool.run();
45//! ```
46
47use std::{
48    cell::{Cell, RefCell},
49    error::Error,
50    fmt::Display,
51    future::Future,
52    pin::Pin,
53    rc::Rc,
54    task::{Context, Poll, Waker},
55};
56
57/// Inner shared state for `CancellationToken` and `CancellationTokenSource`.
58#[derive(Default)]
59struct Inner {
60    /// Whether the token has been cancelled.
61    cancelled: Cell<bool>,
62    /// List of wakers to wake when cancellation occurs.
63    wakers: RefCell<Vec<Waker>>,
64    /// List of callbacks to call when cancellation occurs.
65    callbacks: RefCell<Vec<Box<dyn FnOnce()>>>,
66}
67
68/// A source that can cancel associated `CancellationToken`s.
69///
70/// # Example
71///
72/// ```rust
73/// use async_cancellation_token::CancellationTokenSource;
74///
75/// let cts = CancellationTokenSource::new();
76/// let token = cts.token();
77///
78/// assert!(!cts.is_cancelled());
79/// cts.cancel();
80/// assert!(cts.is_cancelled());
81/// ```
82#[derive(Clone)]
83pub struct CancellationTokenSource {
84    inner: Rc<Inner>,
85}
86
87/// A token that can be checked for cancellation or awaited.
88///
89/// # Example
90///
91/// ```rust
92/// use async_cancellation_token::CancellationTokenSource;
93/// use futures::{FutureExt, executor::LocalPool, task::LocalSpawnExt};
94///
95/// let cts = CancellationTokenSource::new();
96/// let token = cts.token();
97///
98/// let mut pool = LocalPool::new();
99/// pool.spawner().spawn_local(async move {
100///     token.cancelled().await;
101///     println!("Cancelled!");
102/// }.map(|_| ())).unwrap();
103///
104/// cts.cancel();
105/// pool.run();
106/// ```
107#[derive(Clone)]
108pub struct CancellationToken {
109    inner: Rc<Inner>,
110}
111
112/// Error returned when a cancelled token is checked synchronously.
113#[derive(Copy, Clone, Debug, Default, Eq, Ord, PartialEq, PartialOrd, Hash)]
114pub struct Cancelled;
115
116impl Display for Cancelled {
117    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
118        f.write_str("cancelled by CancellationTokenSource")
119    }
120}
121
122impl Error for Cancelled {}
123
124impl Default for CancellationTokenSource {
125    fn default() -> Self {
126        Self::new()
127    }
128}
129
130impl CancellationTokenSource {
131    /// Create a new `CancellationTokenSource`.
132    pub fn new() -> Self {
133        Self {
134            inner: Rc::new(Inner::default()),
135        }
136    }
137
138    /// Get a `CancellationToken` associated with this source.
139    pub fn token(&self) -> CancellationToken {
140        CancellationToken {
141            inner: self.inner.clone(),
142        }
143    }
144
145    /// Cancel all associated tokens.
146    ///
147    /// This triggers any registered callbacks and wakes all wakers.
148    pub fn cancel(&self) {
149        if !self.inner.cancelled.replace(true) {
150            // Call all registered callbacks
151            for cb in self.inner.callbacks.borrow_mut().drain(..) {
152                cb();
153            }
154
155            // Wake all tasks waiting for cancellation
156            for w in self.inner.wakers.borrow_mut().drain(..) {
157                w.wake();
158            }
159        }
160    }
161
162    /// Check if this source has been cancelled.
163    pub fn is_cancelled(&self) -> bool {
164        self.inner.cancelled.get()
165    }
166}
167
168impl CancellationToken {
169    /// Check if the token has been cancelled.
170    pub fn is_cancelled(&self) -> bool {
171        self.inner.cancelled.get()
172    }
173
174    /// Synchronously check cancellation and return `Err(Cancelled)` if cancelled.
175    pub fn check_cancelled(&self) -> Result<(), Cancelled> {
176        if self.is_cancelled() {
177            Err(Cancelled)
178        } else {
179            Ok(())
180        }
181    }
182
183    /// Returns a `Future` that completes when the token is cancelled.
184    ///
185    /// # Example
186    ///
187    /// ```rust
188    /// use async_cancellation_token::CancellationTokenSource;
189    /// use futures::{FutureExt, executor::LocalPool, task::LocalSpawnExt};
190    ///
191    /// let cts = CancellationTokenSource::new();
192    /// let token = cts.token();
193    ///
194    /// let mut pool = LocalPool::new();
195    /// pool.spawner().spawn_local(async move {
196    ///     token.cancelled().await;
197    ///     println!("Cancelled!");
198    /// }.map(|_| ())).unwrap();
199    ///
200    /// cts.cancel();
201    /// pool.run();
202    /// ```
203    pub fn cancelled(&self) -> CancelledFuture {
204        CancelledFuture {
205            token: self.clone(),
206        }
207    }
208
209    /// Register a callback to run when the token is cancelled.
210    ///
211    /// If the token is already cancelled, the callback is called immediately.
212    ///
213    /// # Example
214    ///
215    /// ```rust
216    /// use std::{cell::Cell, rc::Rc};
217    /// use async_cancellation_token::CancellationTokenSource;
218    ///
219    /// let cts = CancellationTokenSource::new();
220    /// let token = cts.token();
221    ///
222    /// let flag = Rc::new(Cell::new(false));
223    /// let flag_clone = Rc::clone(&flag);
224    ///
225    /// token.register(move || {
226    ///     flag_clone.set(true);
227    /// });
228    ///
229    /// cts.cancel();
230    /// assert!(flag.get());
231    /// ```
232    pub fn register(&self, f: impl FnOnce() + 'static) {
233        if self.is_cancelled() {
234            f();
235        } else {
236            self.inner.callbacks.borrow_mut().push(Box::new(f));
237        }
238    }
239}
240
241/// Future that completes when a `CancellationToken` is cancelled.
242pub struct CancelledFuture {
243    token: CancellationToken,
244}
245
246impl Future for CancelledFuture {
247    type Output = ();
248
249    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
250        if self.token.is_cancelled() {
251            Poll::Ready(())
252        } else {
253            let mut wakers = self.token.inner.wakers.borrow_mut();
254            if !wakers.iter().any(|w| w.will_wake(cx.waker())) {
255                wakers.push(cx.waker().clone());
256            }
257            Poll::Pending
258        }
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use std::time::Duration;
265
266    use futures::{FutureExt, executor::LocalPool, pin_mut, select, task::LocalSpawnExt};
267    use futures_timer::Delay;
268
269    use super::*;
270
271    #[test]
272    fn cancel_two_tasks() {
273        let cancelled_a = Rc::new(Cell::new(false));
274        let cancelled_b = Rc::new(Cell::new(false));
275
276        let task_a = |token: CancellationToken| {
277            let cancelled_a = Rc::clone(&cancelled_a);
278
279            async move {
280                println!("Task A started");
281
282                for i in 1..=10 {
283                    let delay = Delay::new(Duration::from_millis(300)).fuse();
284                    let cancelled = token.cancelled().fuse();
285
286                    pin_mut!(delay, cancelled);
287
288                    select! {
289                        _ = delay => {
290                            println!("Task A step {i}");
291                        },
292                        _ = cancelled => {
293                            println!("Task A detected cancellation, cleaning up...");
294                            // Cleanup && Dispose
295                            cancelled_a.set(true);
296                            break;
297                        },
298                    }
299                }
300
301                println!("Task A finished");
302            }
303        };
304
305        let task_b = |token: CancellationToken| {
306            let cancelled_b = Rc::clone(&cancelled_b);
307
308            async move {
309                println!("Task B started");
310
311                for i in 1..=10 {
312                    Delay::new(Duration::from_millis(500)).await;
313
314                    println!("Task B step {i}");
315                    if token.check_cancelled().is_err() {
316                        println!("Task B noticed cancellation after step {i}");
317                        // Cleanup && Dispose
318                        cancelled_b.set(true);
319                        break;
320                    }
321                }
322
323                println!("Task B finished");
324            }
325        };
326
327        let cts = CancellationTokenSource::new();
328
329        let mut pool = LocalPool::new();
330        let spawner = pool.spawner();
331
332        spawner
333            .spawn_local(task_a(cts.token()).map(|_| ()))
334            .unwrap();
335        spawner
336            .spawn_local(task_b(cts.token()).map(|_| ()))
337            .unwrap();
338
339        {
340            let cts = cts.clone();
341            spawner
342                .spawn_local(
343                    async move {
344                        Delay::new(Duration::from_secs(2)).await;
345                        println!("Cancelling all tasks!");
346                        cts.cancel();
347                    }
348                    .map(|_| ()),
349                )
350                .unwrap();
351        }
352
353        pool.run();
354
355        assert!(cts.is_cancelled());
356        assert!(cancelled_a.get());
357        assert!(cancelled_b.get());
358    }
359
360    #[test]
361    fn cancellation_register_callbacks() {
362        let cts = CancellationTokenSource::new();
363        let token = cts.token();
364
365        let flag1 = Rc::new(Cell::new(false));
366        let flag2 = Rc::new(Cell::new(false));
367
368        {
369            let flag1 = Rc::clone(&flag1);
370            token.register(move || {
371                flag1.set(true);
372            });
373        }
374
375        cts.cancel();
376        assert!(flag1.get());
377
378        {
379            let flag2 = Rc::clone(&flag2);
380            token.register(move || {
381                flag2.set(true);
382            });
383        }
384
385        assert!(flag2.get());
386    }
387}