smol_cancellation_token/
lib.rs

1//! An asynchronously awaitable `CancellationToken`.
2//! The token allows to signal a cancellation request to one or more tasks.
3pub(crate) mod guard;
4mod tree_node;
5mod util;
6#[cfg(test)]
7mod test;
8
9use crate::util::MaybeDangling;
10use core::future::Future;
11use core::pin::Pin;
12use core::task::{Context, Poll};
13use std::sync::Arc;
14
15use event_listener::EventListener;
16use guard::DropGuard;
17use pin_project_lite::pin_project;
18
19/// A token which can be used to signal a cancellation request to one or more
20/// tasks.
21///
22/// Tasks can call [`CancellationToken::cancelled()`] in order to
23/// obtain a Future which will be resolved when cancellation is requested.
24///
25/// Cancellation can be requested through the [`CancellationToken::cancel`] method.
26///
27/// # Examples
28///
29/// ```no_run
30/// use tokio::select;
31/// use tokio_util::sync::CancellationToken;
32///
33/// #[tokio::main]
34/// async fn main() {
35///     let token = CancellationToken::new();
36///     let cloned_token = token.clone();
37///
38///     let join_handle = tokio::spawn(async move {
39///         // Wait for either cancellation or a very long time
40///         select! {
41///             _ = cloned_token.cancelled() => {
42///                 // The token was cancelled
43///                 5
44///             }
45///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
46///                 99
47///             }
48///         }
49///     });
50///
51///     tokio::spawn(async move {
52///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
53///         token.cancel();
54///     });
55///
56///     assert_eq!(5, join_handle.await.unwrap());
57/// }
58/// ```
59pub struct CancellationToken {
60    inner: Arc<tree_node::TreeNode>,
61}
62
63impl std::panic::UnwindSafe for CancellationToken {}
64impl std::panic::RefUnwindSafe for CancellationToken {}
65
66pin_project! {
67    /// A Future that is resolved once the corresponding [`CancellationToken`]
68    /// is cancelled.
69    #[must_use = "futures do nothing unless polled"]
70    pub struct WaitForCancellationFuture<'a> {
71        cancellation_token: &'a CancellationToken,
72        #[pin]
73        future: EventListener<()>,
74    }
75}
76
77pin_project! {
78    /// A Future that is resolved once the corresponding [`CancellationToken`]
79    /// is cancelled.
80    ///
81    /// This is the counterpart to [`WaitForCancellationFuture`] that takes
82    /// [`CancellationToken`] by value instead of using a reference.
83    #[must_use = "futures do nothing unless polled"]
84    pub struct WaitForCancellationFutureOwned {
85        // This field internally has a reference to the cancellation token, but camouflages
86        // the relationship with `'static`. To avoid Undefined Behavior, we must ensure
87        // that the reference is only used while the cancellation token is still alive. To
88        // do that, we ensure that the future is the first field, so that it is dropped
89        // before the cancellation token.
90        //
91        // We use `MaybeDanglingFuture` here because without it, the compiler could assert
92        // the reference inside `future` to be valid even after the destructor of that
93        // field runs. (Specifically, when the `WaitForCancellationFutureOwned` is passed
94        // as an argument to a function, the reference can be asserted to be valid for the
95        // rest of that function.) To avoid that, we use `MaybeDangling` which tells the
96        // compiler that the reference stored inside it might not be valid.
97        //
98        // See <https://users.rust-lang.org/t/unsafe-code-review-semi-owning-weak-rwlock-t-guard/95706>
99        // for more info.
100        #[pin]
101        future: MaybeDangling<EventListener<()>>,
102        cancellation_token: CancellationToken,
103    }
104}
105
106// ===== impl CancellationToken =====
107
108impl core::fmt::Debug for CancellationToken {
109    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
110        f.debug_struct("CancellationToken")
111            .field("is_cancelled", &self.is_cancelled())
112            .finish()
113    }
114}
115
116impl Clone for CancellationToken {
117    /// Creates a clone of the `CancellationToken` which will get cancelled
118    /// whenever the current token gets cancelled, and vice versa.
119    fn clone(&self) -> Self {
120        tree_node::increase_handle_refcount(&self.inner);
121        CancellationToken {
122            inner: self.inner.clone(),
123        }
124    }
125}
126
127impl Drop for CancellationToken {
128    fn drop(&mut self) {
129        tree_node::decrease_handle_refcount(&self.inner);
130    }
131}
132
133impl Default for CancellationToken {
134    fn default() -> CancellationToken {
135        CancellationToken::new()
136    }
137}
138
139impl CancellationToken {
140    /// Creates a new `CancellationToken` in the non-cancelled state.
141    pub fn new() -> CancellationToken {
142        CancellationToken {
143            inner: Arc::new(tree_node::TreeNode::new()),
144        }
145    }
146
147    /// Creates a `CancellationToken` which will get cancelled whenever the
148    /// current token gets cancelled. Unlike a cloned `CancellationToken`,
149    /// cancelling a child token does not cancel the parent token.
150    ///
151    /// If the current token is already cancelled, the child token will get
152    /// returned in cancelled state.
153    ///
154    /// # Examples
155    ///
156    /// ```no_run
157    /// use tokio::select;
158    /// use tokio_util::sync::CancellationToken;
159    ///
160    /// #[tokio::main]
161    /// async fn main() {
162    ///     let token = CancellationToken::new();
163    ///     let child_token = token.child_token();
164    ///
165    ///     let join_handle = tokio::spawn(async move {
166    ///         // Wait for either cancellation or a very long time
167    ///         select! {
168    ///             _ = child_token.cancelled() => {
169    ///                 // The token was cancelled
170    ///                 5
171    ///             }
172    ///             _ = tokio::time::sleep(std::time::Duration::from_secs(9999)) => {
173    ///                 99
174    ///             }
175    ///         }
176    ///     });
177    ///
178    ///     tokio::spawn(async move {
179    ///         tokio::time::sleep(std::time::Duration::from_millis(10)).await;
180    ///         token.cancel();
181    ///     });
182    ///
183    ///     assert_eq!(5, join_handle.await.unwrap());
184    /// }
185    /// ```
186    pub fn child_token(&self) -> CancellationToken {
187        CancellationToken {
188            inner: tree_node::child_node(&self.inner),
189        }
190    }
191
192    /// Cancel the [`CancellationToken`] and all child tokens which had been
193    /// derived from it.
194    ///
195    /// This will wake up all tasks which are waiting for cancellation.
196    ///
197    /// Be aware that cancellation is not an atomic operation. It is possible
198    /// for another thread running in parallel with a call to `cancel` to first
199    /// receive `true` from `is_cancelled` on one child node, and then receive
200    /// `false` from `is_cancelled` on another child node. However, once the
201    /// call to `cancel` returns, all child nodes have been fully cancelled.
202    pub fn cancel(&self) {
203        tree_node::cancel(&self.inner);
204    }
205
206    /// Returns `true` if the `CancellationToken` is cancelled.
207    pub fn is_cancelled(&self) -> bool {
208        tree_node::is_cancelled(&self.inner)
209    }
210
211    /// Returns a `Future` that gets fulfilled when cancellation is requested.
212    ///
213    /// The future will complete immediately if the token is already cancelled
214    /// when this method is called.
215    ///
216    /// # Cancel safety
217    ///
218    /// This method is cancel safe.
219    pub fn cancelled(&self) -> WaitForCancellationFuture<'_> {
220        WaitForCancellationFuture {
221            cancellation_token: self,
222            future: self.inner.notified(),
223        }
224    }
225
226    /// Returns a `Future` that gets fulfilled when cancellation is requested.
227    ///
228    /// The future will complete immediately if the token is already cancelled
229    /// when this method is called.
230    ///
231    /// The function takes self by value and returns a future that owns the
232    /// token.
233    ///
234    /// # Cancel safety
235    ///
236    /// This method is cancel safe.
237    pub fn cancelled_owned(self) -> WaitForCancellationFutureOwned {
238        WaitForCancellationFutureOwned::new(self)
239    }
240
241    /// Creates a `DropGuard` for this token.
242    ///
243    /// Returned guard will cancel this token (and all its children) on drop
244    /// unless disarmed.
245    pub fn drop_guard(self) -> DropGuard {
246        DropGuard { inner: Some(self) }
247    }
248
249    /// Runs a future to completion and returns its result wrapped inside of an `Option`
250    /// unless the `CancellationToken` is cancelled. In that case the function returns
251    /// `None` and the future gets dropped.
252    ///
253    /// # Cancel safety
254    ///
255    /// This method is only cancel safe if `fut` is cancel safe.
256    pub async fn run_until_cancelled<F>(&self, fut: F) -> Option<F::Output>
257    where
258        F: Future,
259    {
260        pin_project! {
261            /// A Future that is resolved once the corresponding [`CancellationToken`]
262            /// is cancelled or a given Future gets resolved. It is biased towards the
263            /// Future completion.
264            #[must_use = "futures do nothing unless polled"]
265            struct RunUntilCancelledFuture<'a, F: Future> {
266                #[pin]
267                cancellation: WaitForCancellationFuture<'a>,
268                #[pin]
269                future: F,
270            }
271        }
272
273        impl<'a, F: Future> Future for RunUntilCancelledFuture<'a, F> {
274            type Output = Option<F::Output>;
275
276            fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
277                let this = self.project();
278                if let Poll::Ready(res) = this.future.poll(cx) {
279                    Poll::Ready(Some(res))
280                } else if this.cancellation.poll(cx).is_ready() {
281                    Poll::Ready(None)
282                } else {
283                    Poll::Pending
284                }
285            }
286        }
287
288        RunUntilCancelledFuture {
289            cancellation: self.cancelled(),
290            future: fut,
291        }
292        .await
293    }
294
295    /// Runs a future to completion and returns its result wrapped inside of an `Option`
296    /// unless the `CancellationToken` is cancelled. In that case the function returns
297    /// `None` and the future gets dropped.
298    ///
299    /// The function takes self by value and returns a future that owns the token.
300    ///
301    /// # Cancel safety
302    ///
303    /// This method is only cancel safe if `fut` is cancel safe.
304    pub async fn run_until_cancelled_owned<F>(self, fut: F) -> Option<F::Output>
305    where
306        F: Future,
307    {
308        self.run_until_cancelled(fut).await
309    }
310}
311
312// ===== impl WaitForCancellationFuture =====
313
314impl<'a> core::fmt::Debug for WaitForCancellationFuture<'a> {
315    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
316        f.debug_struct("WaitForCancellationFuture").finish()
317    }
318}
319
320impl<'a> Future for WaitForCancellationFuture<'a> {
321    type Output = ();
322
323    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
324        let mut this = self.project();
325        loop {
326            if this.cancellation_token.is_cancelled() {
327                return Poll::Ready(());
328            }
329
330            // No wakeups can be lost here because there is always a call to
331            // `is_cancelled` between the creation of the future and the call to
332            // `poll`, and the code that sets the cancelled flag does so before
333            // waking the `Notified`.
334            if this.future.as_mut().poll(cx).is_pending() {
335                return Poll::Pending;
336            }
337
338            this.future.set(this.cancellation_token.inner.notified());
339        }
340    }
341}
342
343// ===== impl WaitForCancellationFutureOwned =====
344
345impl core::fmt::Debug for WaitForCancellationFutureOwned {
346    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
347        f.debug_struct("WaitForCancellationFutureOwned").finish()
348    }
349}
350
351impl WaitForCancellationFutureOwned {
352    fn new(cancellation_token: CancellationToken) -> Self {
353        WaitForCancellationFutureOwned {
354            // cancellation_token holds a heap allocation and is guaranteed to have a
355            // stable deref, thus it would be ok to move the cancellation_token while
356            // the future holds a reference to it.
357            //
358            // # Safety
359            //
360            // cancellation_token is dropped after future due to the field ordering.
361            future: MaybeDangling::new(unsafe { Self::new_future(&cancellation_token) }),
362            cancellation_token,
363        }
364    }
365
366    /// # Safety
367    /// The returned future must be destroyed before the cancellation token is
368    /// destroyed.
369    unsafe fn new_future(cancellation_token: &CancellationToken) -> EventListener<()> {
370        let inner_ptr = Arc::as_ptr(&cancellation_token.inner);
371        // SAFETY: The `Arc::as_ptr` method guarantees that `inner_ptr` remains
372        // valid until the strong count of the Arc drops to zero, and the caller
373        // guarantees that they will drop the future before that happens.
374        unsafe { (*inner_ptr).notified() }
375    }
376}
377
378impl Future for WaitForCancellationFutureOwned {
379    type Output = ();
380
381    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
382        let mut this = self.project();
383
384        loop {
385            if this.cancellation_token.is_cancelled() {
386                return Poll::Ready(());
387            }
388
389            // No wakeups can be lost here because there is always a call to
390            // `is_cancelled` between the creation of the future and the call to
391            // `poll`, and the code that sets the cancelled flag does so before
392            // waking the `Notified`.
393            if this.future.as_mut().poll(cx).is_pending() {
394                return Poll::Pending;
395            }
396
397            // # Safety
398            //
399            // cancellation_token is dropped after future due to the field ordering.
400            this.future.set(MaybeDangling::new(unsafe {
401                Self::new_future(this.cancellation_token)
402            }));
403        }
404    }
405}