futures_ext/future/
on_cancel.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11
12use futures::future::Future;
13use futures::ready;
14use futures::task::Context;
15use futures::task::Poll;
16use pin_project::pin_project;
17use pin_project::pinned_drop;
18
19/// Future combinator that executes the `on_cancel` closure if the inner future
20/// is canceled (dropped before completion).
21#[pin_project(PinnedDrop)]
22pub struct OnCancel<Fut, OnCancelFn>
23where
24    Fut: Future,
25    OnCancelFn: FnOnce(),
26{
27    #[pin]
28    inner: Fut,
29
30    on_cancel: Option<OnCancelFn>,
31}
32
33impl<Fut, OnCancelFn> OnCancel<Fut, OnCancelFn>
34where
35    Fut: Future,
36    OnCancelFn: FnOnce(),
37{
38    /// Construct an `OnCancel` combinator that will run `on_cancel` if `inner`
39    /// is canceled.
40    pub fn new(inner: Fut, on_cancel: OnCancelFn) -> Self {
41        Self {
42            inner,
43            on_cancel: Some(on_cancel),
44        }
45    }
46}
47
48impl<Fut, OnCancelFn> Future for OnCancel<Fut, OnCancelFn>
49where
50    Fut: Future,
51    OnCancelFn: FnOnce(),
52{
53    type Output = Fut::Output;
54
55    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
56        let this = self.project();
57        let v = ready!(this.inner.poll(cx));
58        *this.on_cancel = None;
59        Poll::Ready(v)
60    }
61}
62
63#[pinned_drop]
64impl<Fut, OnCancelFn> PinnedDrop for OnCancel<Fut, OnCancelFn>
65where
66    Fut: Future,
67    OnCancelFn: FnOnce(),
68{
69    fn drop(self: Pin<&mut Self>) {
70        let this = self.project();
71        if let Some(on_cancel) = this.on_cancel.take() {
72            on_cancel()
73        }
74    }
75}
76
77#[cfg(test)]
78mod test {
79    use std::sync::atomic::AtomicBool;
80    use std::sync::atomic::Ordering;
81
82    use super::*;
83
84    #[tokio::test]
85    async fn runs_when_canceled() {
86        let canceled = AtomicBool::new(false);
87        let fut = OnCancel::new(async {}, || canceled.store(true, Ordering::Relaxed));
88        drop(fut);
89        assert!(canceled.load(Ordering::Relaxed));
90    }
91
92    #[tokio::test]
93    async fn doesnt_run_when_complete() {
94        let canceled = AtomicBool::new(false);
95        let fut = OnCancel::new(async {}, || canceled.store(true, Ordering::Relaxed));
96        fut.await;
97        assert!(!canceled.load(Ordering::Relaxed));
98    }
99}