futures_ext/future/
on_cancel_with_data.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under both the MIT license found in the
5 * LICENSE-MIT file in the root directory of this source tree and the Apache
6 * License, Version 2.0 found in the LICENSE-APACHE file in the root directory
7 * of this source tree.
8 */
9
10use std::pin::Pin;
11
12use futures::future::Future;
13use futures::ready;
14use futures::task::Context;
15use futures::task::Poll;
16use pin_project::pin_project;
17use pin_project::pinned_drop;
18
19/// Trait to be implemented by futures that wish to provide additional data
20/// when they are canceled.
21pub trait CancelData {
22    /// The type of the data provided when the future is canceled.
23    type Data;
24
25    /// Provide cancellation data for this future.
26    fn cancel_data(&self) -> Self::Data;
27}
28
29/// Future combinator that executes the `on_cancel` closure if the inner future
30/// is canceled (dropped before completion).
31#[pin_project(PinnedDrop)]
32pub struct OnCancelWithData<Fut, OnCancelFn>
33where
34    Fut: Future + CancelData,
35    OnCancelFn: FnOnce(Fut::Data),
36{
37    #[pin]
38    inner: Fut,
39
40    on_cancel: Option<OnCancelFn>,
41}
42
43impl<Fut, OnCancelFn> OnCancelWithData<Fut, OnCancelFn>
44where
45    Fut: Future + CancelData,
46    OnCancelFn: FnOnce(Fut::Data),
47{
48    /// Construct an `OnCancelWithData` combinator that will run `on_cancel` if `inner`
49    /// is canceled.  Additional data will be extracted from `inner` and
50    /// passed to `on_cancel`.
51    pub fn new(inner: Fut, on_cancel: OnCancelFn) -> Self {
52        Self {
53            inner,
54            on_cancel: Some(on_cancel),
55        }
56    }
57}
58
59impl<Fut, OnCancelFn> Future for OnCancelWithData<Fut, OnCancelFn>
60where
61    Fut: Future + CancelData,
62    OnCancelFn: FnOnce(Fut::Data),
63{
64    type Output = Fut::Output;
65
66    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
67        let this = self.project();
68        let v = ready!(this.inner.poll(cx));
69        *this.on_cancel = None;
70        Poll::Ready(v)
71    }
72}
73
74#[pinned_drop]
75impl<Fut, OnCancelFn> PinnedDrop for OnCancelWithData<Fut, OnCancelFn>
76where
77    Fut: Future + CancelData,
78    OnCancelFn: FnOnce(Fut::Data),
79{
80    fn drop(self: Pin<&mut Self>) {
81        let this = self.project();
82        if let Some(on_cancel) = this.on_cancel.take() {
83            let data = this.inner.as_ref().get_ref().cancel_data();
84            on_cancel(data)
85        }
86    }
87}
88
89#[cfg(test)]
90mod test {
91    use std::sync::atomic::AtomicUsize;
92    use std::sync::atomic::Ordering;
93
94    use super::*;
95
96    struct WithCancelData {
97        result: usize,
98        data: usize,
99    }
100
101    impl Future for WithCancelData {
102        type Output = usize;
103
104        fn poll(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Self::Output> {
105            Poll::Ready(self.as_ref().get_ref().result)
106        }
107    }
108
109    impl CancelData for WithCancelData {
110        type Data = usize;
111
112        fn cancel_data(&self) -> Self::Data {
113            self.data
114        }
115    }
116
117    #[tokio::test]
118    async fn runs_when_canceled() {
119        let canceled = AtomicUsize::new(0);
120        let fut = WithCancelData {
121            result: 100,
122            data: 200,
123        };
124        let fut = OnCancelWithData::new(fut, |data| canceled.store(data, Ordering::Relaxed));
125        drop(fut);
126        assert_eq!(canceled.load(Ordering::Relaxed), 200);
127    }
128
129    #[tokio::test]
130    async fn doesnt_run_when_complete() {
131        let canceled = AtomicUsize::new(0);
132        let fut = WithCancelData {
133            result: 100,
134            data: 200,
135        };
136        let fut = OnCancelWithData::new(fut, |data| canceled.store(data, Ordering::Relaxed));
137        let val = fut.await;
138        assert_eq!(val, 100);
139        assert_eq!(canceled.load(Ordering::Relaxed), 0);
140    }
141}