exocore_core/futures/
owned_spawn.rs

1use std::{
2    pin::Pin,
3    task::{Context, Poll},
4};
5
6use futures::{
7    channel::{oneshot, oneshot::Canceled},
8    prelude::*,
9    FutureExt,
10};
11
12use super::spawn_future;
13
14/// Spawns a future on current executor that can be cancelled by dropping the
15/// `OwnedSpawn` handle. It is also possible to get the result of the spawned
16/// future by awaiting on the handle.
17pub fn owned_spawn<F, O>(fut: F) -> OwnedSpawn<O>
18where
19    F: Future<Output = O> + 'static + Send,
20    O: Send + 'static,
21{
22    let (wrapped_future, spawn) = owned_future(fut);
23    spawn_future(wrapped_future);
24    spawn
25}
26
27/// Wraps a future that can be cancelled by dropping the `OwnedSpawn` handle.
28/// It is also possible to get the result of the spawned future by awaiting on
29/// the handle.
30pub fn owned_future<F, O>(fut: F) -> (impl Future<Output = ()> + 'static + Send, OwnedSpawn<O>)
31where
32    F: Future<Output = O> + 'static + Send,
33    O: Send + 'static,
34{
35    let (owner_drop_sender, owner_drop_receiver) = oneshot::channel();
36    let (spawned_drop_sender, spawned_drop_receiver) = oneshot::channel();
37
38    let wrapped = async move {
39        let spawned_drop_sender = spawned_drop_sender;
40
41        futures::select! {
42            _ = owner_drop_receiver.fuse() => {
43                // owner got dropped, doing nothing
44            },
45            result = fut.fuse() => {
46                let _ = spawned_drop_sender.send(result);
47            },
48        };
49    };
50
51    let spawn = OwnedSpawn {
52        _owner_drop_sender: owner_drop_sender,
53        spawned_drop_receiver,
54    };
55
56    (wrapped, spawn)
57}
58
59/// Result of `owned_spawn` or `owned_future` function.
60pub struct OwnedSpawn<O>
61where
62    O: Send + 'static,
63{
64    _owner_drop_sender: oneshot::Sender<()>,
65    spawned_drop_receiver: oneshot::Receiver<O>,
66}
67
68impl<O> Future for OwnedSpawn<O>
69where
70    O: Send + 'static,
71{
72    type Output = Result<O, Canceled>;
73
74    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75        self.spawned_drop_receiver.poll_unpin(cx)
76    }
77}
78
79/// Collection of `OwnedSpawn` that allow keeping ownership over spawned futures
80/// and manage their completion.
81///
82/// Caution: The `cleanup` method needs to be called in order to cleanup
83/// completed spawns.
84pub struct OwnedSpawnSet<O>
85where
86    O: Send + 'static,
87{
88    spawns: Vec<OwnedSpawn<O>>,
89}
90
91impl<O> OwnedSpawnSet<O>
92where
93    O: Send + 'static,
94{
95    pub fn new() -> OwnedSpawnSet<O> {
96        OwnedSpawnSet { spawns: Vec::new() }
97    }
98
99    pub fn spawn<F>(&mut self, fut: F)
100    where
101        F: Future<Output = O> + 'static + Send,
102    {
103        let spawn = owned_spawn(fut);
104        self.spawns.push(spawn);
105    }
106
107    /// Cleans up the completed spawns and return a new set with remaining
108    /// spawns.
109    pub async fn cleanup(self) -> OwnedSpawnSet<O> {
110        let remaining_spawns = OwnedSpawnCleaner(self.spawns).await;
111        OwnedSpawnSet {
112            spawns: remaining_spawns,
113        }
114    }
115
116    pub fn len(&self) -> usize {
117        self.spawns.len()
118    }
119
120    pub fn is_empty(&self) -> bool {
121        self.spawns.is_empty()
122    }
123}
124
125impl<O> Default for OwnedSpawnSet<O>
126where
127    O: Send + 'static,
128{
129    fn default() -> Self {
130        OwnedSpawnSet::new()
131    }
132}
133
134struct OwnedSpawnCleaner<O>(Vec<OwnedSpawn<O>>)
135where
136    O: Send + 'static;
137
138impl<O> Future for OwnedSpawnCleaner<O>
139where
140    O: Send + 'static,
141{
142    type Output = Vec<OwnedSpawn<O>>;
143
144    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145        if self.0.is_empty() {
146            return Poll::Ready(Vec::new());
147        }
148
149        let mut current_spawns = Vec::new();
150        std::mem::swap(&mut self.0, &mut current_spawns);
151
152        let mut remaining_spawns = Vec::new();
153        for mut spawn in current_spawns {
154            let polled = spawn.poll_unpin(cx);
155            if polled.is_pending() {
156                remaining_spawns.push(spawn);
157            }
158        }
159
160        Poll::Ready(remaining_spawns)
161    }
162}
163
164#[cfg(test)]
165mod tests {
166    use std::{
167        sync::{
168            atomic::{AtomicBool, Ordering},
169            Arc,
170        },
171        time::Duration,
172    };
173
174    use super::{super::sleep, *};
175
176    #[tokio::test]
177    async fn propagate_spawned_result() -> anyhow::Result<()> {
178        let spawned = owned_spawn(async move { 1 + 1 });
179        assert_eq!(2, spawned.await?);
180
181        Ok::<(), anyhow::Error>(())
182    }
183
184    #[tokio::test]
185    async fn owner_drop_cancels_spawned() -> anyhow::Result<()> {
186        let dropper = Dropper::default();
187        let dropped = dropper.dropped.clone();
188
189        let spawned = owned_spawn(async move {
190            sleep(Duration::from_secs(3600)).await;
191            drop(dropper);
192            Ok::<(), ()>(())
193        });
194
195        sleep(Duration::from_millis(100)).await;
196
197        assert!(!dropped.load(Ordering::SeqCst));
198
199        drop(spawned);
200
201        sleep(Duration::from_millis(100)).await;
202        assert!(dropped.load(Ordering::SeqCst));
203
204        Ok::<(), anyhow::Error>(())
205    }
206
207    #[tokio::test]
208    async fn spawn_set_cleanup() -> anyhow::Result<()> {
209        let mut set = OwnedSpawnSet::<i32>::new();
210
211        set = set.cleanup().await;
212
213        set.spawn(async { 1 + 1 });
214        assert_eq!(1, set.spawns.len());
215
216        sleep(Duration::from_millis(100)).await;
217        set = set.cleanup().await;
218        assert_eq!(0, set.spawns.len());
219
220        let dropper = Dropper::default();
221        let dropped = dropper.dropped.clone();
222        set.spawn(async move {
223            sleep(Duration::from_secs(3600)).await;
224            drop(dropper);
225            1 + 1
226        });
227
228        set = set.cleanup().await;
229        assert_eq!(1, set.spawns.len());
230
231        drop(set);
232
233        sleep(Duration::from_millis(100)).await;
234        assert!(dropped.load(Ordering::SeqCst));
235
236        Ok::<(), anyhow::Error>(())
237    }
238
239    struct Dropper {
240        dropped: Arc<AtomicBool>,
241    }
242
243    impl Default for Dropper {
244        fn default() -> Dropper {
245            Dropper {
246                dropped: Arc::new(AtomicBool::new(false)),
247            }
248        }
249    }
250
251    impl Drop for Dropper {
252        fn drop(&mut self) {
253            self.dropped.store(true, Ordering::SeqCst)
254        }
255    }
256}