exocore_core/futures/
owned_spawn.rs1use std::{
2 pin::Pin,
3 task::{Context, Poll},
4};
5
6use futures::{
7 channel::{oneshot, oneshot::Canceled},
8 prelude::*,
9 FutureExt,
10};
11
12use super::spawn_future;
13
14pub fn owned_spawn<F, O>(fut: F) -> OwnedSpawn<O>
18where
19 F: Future<Output = O> + 'static + Send,
20 O: Send + 'static,
21{
22 let (wrapped_future, spawn) = owned_future(fut);
23 spawn_future(wrapped_future);
24 spawn
25}
26
27pub fn owned_future<F, O>(fut: F) -> (impl Future<Output = ()> + 'static + Send, OwnedSpawn<O>)
31where
32 F: Future<Output = O> + 'static + Send,
33 O: Send + 'static,
34{
35 let (owner_drop_sender, owner_drop_receiver) = oneshot::channel();
36 let (spawned_drop_sender, spawned_drop_receiver) = oneshot::channel();
37
38 let wrapped = async move {
39 let spawned_drop_sender = spawned_drop_sender;
40
41 futures::select! {
42 _ = owner_drop_receiver.fuse() => {
43 },
45 result = fut.fuse() => {
46 let _ = spawned_drop_sender.send(result);
47 },
48 };
49 };
50
51 let spawn = OwnedSpawn {
52 _owner_drop_sender: owner_drop_sender,
53 spawned_drop_receiver,
54 };
55
56 (wrapped, spawn)
57}
58
59pub struct OwnedSpawn<O>
61where
62 O: Send + 'static,
63{
64 _owner_drop_sender: oneshot::Sender<()>,
65 spawned_drop_receiver: oneshot::Receiver<O>,
66}
67
68impl<O> Future for OwnedSpawn<O>
69where
70 O: Send + 'static,
71{
72 type Output = Result<O, Canceled>;
73
74 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
75 self.spawned_drop_receiver.poll_unpin(cx)
76 }
77}
78
79pub struct OwnedSpawnSet<O>
85where
86 O: Send + 'static,
87{
88 spawns: Vec<OwnedSpawn<O>>,
89}
90
91impl<O> OwnedSpawnSet<O>
92where
93 O: Send + 'static,
94{
95 pub fn new() -> OwnedSpawnSet<O> {
96 OwnedSpawnSet { spawns: Vec::new() }
97 }
98
99 pub fn spawn<F>(&mut self, fut: F)
100 where
101 F: Future<Output = O> + 'static + Send,
102 {
103 let spawn = owned_spawn(fut);
104 self.spawns.push(spawn);
105 }
106
107 pub async fn cleanup(self) -> OwnedSpawnSet<O> {
110 let remaining_spawns = OwnedSpawnCleaner(self.spawns).await;
111 OwnedSpawnSet {
112 spawns: remaining_spawns,
113 }
114 }
115
116 pub fn len(&self) -> usize {
117 self.spawns.len()
118 }
119
120 pub fn is_empty(&self) -> bool {
121 self.spawns.is_empty()
122 }
123}
124
125impl<O> Default for OwnedSpawnSet<O>
126where
127 O: Send + 'static,
128{
129 fn default() -> Self {
130 OwnedSpawnSet::new()
131 }
132}
133
134struct OwnedSpawnCleaner<O>(Vec<OwnedSpawn<O>>)
135where
136 O: Send + 'static;
137
138impl<O> Future for OwnedSpawnCleaner<O>
139where
140 O: Send + 'static,
141{
142 type Output = Vec<OwnedSpawn<O>>;
143
144 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
145 if self.0.is_empty() {
146 return Poll::Ready(Vec::new());
147 }
148
149 let mut current_spawns = Vec::new();
150 std::mem::swap(&mut self.0, &mut current_spawns);
151
152 let mut remaining_spawns = Vec::new();
153 for mut spawn in current_spawns {
154 let polled = spawn.poll_unpin(cx);
155 if polled.is_pending() {
156 remaining_spawns.push(spawn);
157 }
158 }
159
160 Poll::Ready(remaining_spawns)
161 }
162}
163
164#[cfg(test)]
165mod tests {
166 use std::{
167 sync::{
168 atomic::{AtomicBool, Ordering},
169 Arc,
170 },
171 time::Duration,
172 };
173
174 use super::{super::sleep, *};
175
176 #[tokio::test]
177 async fn propagate_spawned_result() -> anyhow::Result<()> {
178 let spawned = owned_spawn(async move { 1 + 1 });
179 assert_eq!(2, spawned.await?);
180
181 Ok::<(), anyhow::Error>(())
182 }
183
184 #[tokio::test]
185 async fn owner_drop_cancels_spawned() -> anyhow::Result<()> {
186 let dropper = Dropper::default();
187 let dropped = dropper.dropped.clone();
188
189 let spawned = owned_spawn(async move {
190 sleep(Duration::from_secs(3600)).await;
191 drop(dropper);
192 Ok::<(), ()>(())
193 });
194
195 sleep(Duration::from_millis(100)).await;
196
197 assert!(!dropped.load(Ordering::SeqCst));
198
199 drop(spawned);
200
201 sleep(Duration::from_millis(100)).await;
202 assert!(dropped.load(Ordering::SeqCst));
203
204 Ok::<(), anyhow::Error>(())
205 }
206
207 #[tokio::test]
208 async fn spawn_set_cleanup() -> anyhow::Result<()> {
209 let mut set = OwnedSpawnSet::<i32>::new();
210
211 set = set.cleanup().await;
212
213 set.spawn(async { 1 + 1 });
214 assert_eq!(1, set.spawns.len());
215
216 sleep(Duration::from_millis(100)).await;
217 set = set.cleanup().await;
218 assert_eq!(0, set.spawns.len());
219
220 let dropper = Dropper::default();
221 let dropped = dropper.dropped.clone();
222 set.spawn(async move {
223 sleep(Duration::from_secs(3600)).await;
224 drop(dropper);
225 1 + 1
226 });
227
228 set = set.cleanup().await;
229 assert_eq!(1, set.spawns.len());
230
231 drop(set);
232
233 sleep(Duration::from_millis(100)).await;
234 assert!(dropped.load(Ordering::SeqCst));
235
236 Ok::<(), anyhow::Error>(())
237 }
238
239 struct Dropper {
240 dropped: Arc<AtomicBool>,
241 }
242
243 impl Default for Dropper {
244 fn default() -> Dropper {
245 Dropper {
246 dropped: Arc::new(AtomicBool::new(false)),
247 }
248 }
249 }
250
251 impl Drop for Dropper {
252 fn drop(&mut self) {
253 self.dropped.store(true, Ordering::SeqCst)
254 }
255 }
256}