1use std::cell::Cell;
2use std::pin::Pin;
3use std::sync::atomic::{self, AtomicUsize};
4use std::sync::{Arc, Mutex};
5use std::task::{Context, Poll};
6
7use futures::Future;
8use futures::task::{FutureObj, LocalFutureObj, LocalSpawn, LocalSpawnExt, Spawn, SpawnError};
9use pin_project::{pin_project, pinned_drop};
10
11use super::relay_pad::{RelayPad, TaskDequeueErr};
12
13trait Respawn<'sc>: 'sc {
14 fn respawn(&self, pad: Arc<RelayPad<'sc>>, manager: Arc<SpawnManager>, root: bool);
15}
16
17#[derive(Clone)]
18pub struct GlobalRespawn<Sp>(Sp);
19
20impl<'sc, Sp: Spawn + Clone + Send + 'sc> Respawn<'sc> for GlobalRespawn<Sp> {
21 fn respawn(&self, pad: Arc<RelayPad<'sc>>, manager: Arc<SpawnManager>, root: bool) {
22 let fut = unsafe { RelayFuture::new_global_raw(pad, self.clone(), root, manager) };
23 if let Some(fut) = fut {
24 self.0.spawn_obj(fut).ok();
25 }
26 }
27}
28
29#[derive(Clone)]
30pub struct LocalRespawn<Sp>(Sp);
31
32impl<'sc, Sp: LocalSpawn + Clone + 'sc> Respawn<'sc> for LocalRespawn<Sp> {
33 fn respawn(&self, pad: Arc<RelayPad<'sc>>, manager: Arc<SpawnManager>, root: bool) {
34 let fut = unsafe { RelayFuture::new_local_raw(pad, self.clone(), root, manager) };
35 if let Some(fut) = fut {
36 self.0.spawn_local(fut).ok();
37 }
38 }
39}
40
41#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
42pub struct RelayFutureId {
43 spawn_id: usize,
44 instance: usize,
45}
46
47#[derive(Debug)]
48struct SpawnManager {
49 non_working: AtomicUsize,
50 all: AtomicUsize,
51 id: usize,
52 next_instance: AtomicUsize,
53}
54
55impl SpawnManager {
56 fn new(id: usize) -> Self {
57 Self {
58 non_working: AtomicUsize::new(0),
59 all: AtomicUsize::new(0),
60 id,
61 next_instance: AtomicUsize::new(0),
62 }
63 }
64
65 fn register(&self) -> RelayFutureId {
66 self.non_working.fetch_add(1, atomic::Ordering::Relaxed);
67 self.all.fetch_add(1, atomic::Ordering::Relaxed);
68
69 RelayFutureId {
70 spawn_id: self.id,
71 instance: self.next_instance.fetch_add(1, atomic::Ordering::Relaxed),
72 }
73 }
74
75 fn unregister(&self) {
76 self.non_working.fetch_sub(1, atomic::Ordering::Relaxed);
77 self.all.fetch_sub(1, atomic::Ordering::Relaxed);
78 }
79
80 fn start_polling(&self) -> RespawnCounterPollingGuard<'_> {
81 let non_working = self.non_working.fetch_sub(1, atomic::Ordering::Relaxed) - 1;
82 let should_respawn = non_working < 5;
83 RespawnCounterPollingGuard(self, should_respawn)
84 }
85}
86
87#[derive(Debug)]
88struct RespawnCounterPollingGuard<'c>(&'c SpawnManager, bool);
89
90impl<'c> RespawnCounterPollingGuard<'c> {
91 fn should_respawn(&self) -> bool {
92 self.1
93 }
94}
95
96impl<'c> Drop for RespawnCounterPollingGuard<'c> {
97 fn drop(&mut self) {
98 self.0.non_working.fetch_add(1, atomic::Ordering::Relaxed);
99 }
100}
101
102#[derive(Debug)]
103struct Unpinned<'sc, Sp: 'sc> {
104 pad: Arc<RelayPad<'sc>>,
105 panicked: Cell<bool>,
106 root: bool,
107 spawn: Sp,
108 manager: Arc<SpawnManager>,
109}
110
111impl<'sc, Sp: Respawn<'sc>> Unpinned<'sc, Sp> {
112 fn respawn(&self, root: bool) {
113 self.spawn.respawn(self.pad.clone(), self.manager.clone(), root);
115 }
116}
117
118#[derive(Debug)]
119struct ActiveFuture<'sc> {
120 future: Option<FutureObj<'sc, ()>>,
121 destroyed: bool,
122}
123
124impl<'sc> ActiveFuture<'sc> {
125 fn new() -> Self {
126 Self {
127 future: None,
128 destroyed: false,
129 }
130 }
131}
132
133#[pin_project]
134#[derive(Debug)]
135pub struct RelayFutureInner<'sc> {
136 #[pin]
137 active: Mutex<ActiveFuture<'sc>>,
138 id: RelayFutureId,
139}
140
141impl<'sc> RelayFutureInner<'sc> {
142 pub fn destroy(&self, pad: &RelayPad<'sc>, rescue_future: bool) {
143 let mut guard = self.active.lock().unwrap_or_else(|err| err.into_inner());
144 if !guard.destroyed {
145 let fut = guard.future.take();
148 guard.destroyed = true;
149 pad.unregister_relay_future(self.id, fut.filter(|_| rescue_future));
150 }
151 debug_assert!(guard.future.is_none());
152 }
153
154 pub fn id(&self) -> RelayFutureId {
155 self.id
156 }
157}
158
159impl<'sc> RelayFutureInner<'sc> {}
160
161#[pin_project(PinnedDrop)]
162#[derive(Debug)]
163struct RelayFuture<'sc, Sp> {
164 #[pin]
165 inner: Arc<RelayFutureInner<'sc>>,
166 unpinned: Unpinned<'sc, Sp>,
167}
168
169impl<'sc, Sp> RelayFuture<'sc, Sp> {
170 fn new(pad: Arc<RelayPad<'sc>>, spawn: Sp, root: bool, manager: Arc<SpawnManager>) -> Option<Self> {
171 let id = manager.register();
172 let inner = pad.register_relay_future(RelayFutureInner {
173 active: Mutex::new(ActiveFuture::new()),
174 id,
175 })?;
176
177 Some(Self {
178 inner,
179 unpinned: Unpinned {
180 pad,
181 panicked: Cell::new(false),
182 root,
183 spawn,
184 manager,
185 },
186 })
187 }
188}
189
190#[pinned_drop]
191impl<'sc, Sp> PinnedDrop for RelayFuture<'sc, Sp> {
192 #[allow(clippy::needless_lifetimes)]
193 fn drop(self: Pin<&mut Self>) {
194 let this = self.project();
195 let unpinned = this.unpinned;
196 this.inner.destroy(&unpinned.pad, !unpinned.panicked.get());
197
198 unpinned.manager.unregister();
199 }
200}
201
202impl<'sc, Sp: Respawn<'sc>> Future for RelayFuture<'sc, Sp> {
203 type Output = ();
204
205 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<()> {
206 let this = self.as_mut().project();
207 let unpinned = this.unpinned;
208 let mut finished_tasks = 0;
210 let this_id = this.inner.id;
211 let future_cell = &mut this.inner.active.lock().unwrap().future;
212 loop {
213 if let Some(fut) = future_cell {
214 if let Some(mut poll_guard) = unpinned.pad.start_future_polling(this_id) {
216 let respawn_guard = unpinned.manager.start_polling();
218 if respawn_guard.should_respawn() {
219 unpinned.respawn(false);
220 }
221
222 struct Bomb<'l, 'sc, Sp: Respawn<'sc>>(&'l Unpinned<'sc, Sp>, bool);
223 impl<'l, 'sc, Sp: Respawn<'sc>> Drop for Bomb<'l, 'sc, Sp> {
224 fn drop(&mut self) {
225 if self.1 {
226 self.0.respawn(true);
228 self.0.panicked.set(true);
229 }
230 }
231 }
232
233 let mut bomb = Bomb(unpinned, unpinned.root);
234 let poll_result = Pin::new(fut).poll(cx);
235 bomb.1 = false;
236
237 match poll_result {
238 Poll::Ready(()) => {
239 future_cell.take();
240 finished_tasks += 1;
241 if finished_tasks > 5 {
242 unpinned.respawn(unpinned.root);
243 return Poll::Ready(());
244 }
245 continue;
246 }
247 Poll::Pending => {
248 poll_guard.will_poll_again();
249 unpinned.respawn(false);
250 return Poll::Pending;
251 }
252 }
253 } else {
254 return Poll::Ready(());
256 }
257 } else {
258 match unpinned.pad.dequeue_task(unpinned.root.then_some(cx)) {
260 Ok(task) => *future_cell = Some(task),
261 Err(TaskDequeueErr::WaitingForTasks) => return Poll::Pending,
262 Err(TaskDequeueErr::NoTasks) => return Poll::Ready(()),
263 Err(TaskDequeueErr::Destroy) => return Poll::Ready(()),
264 };
265
266 }
268 }
269 }
270}
271
272impl<'sc, Sp: 'sc> RelayFuture<'sc, Sp> {
273 unsafe fn new_global_raw(
274 pad: Arc<RelayPad<'sc>>,
275 spawn: Sp,
276 root: bool,
277 manager: Arc<SpawnManager>,
278 ) -> Option<FutureObj<'static, ()>>
279 where
280 Sp: Respawn<'sc> + Send,
281 {
282 let fut = Self::new(pad, spawn, root, manager)?;
283 let fut_obj = FutureObj::new(Box::new(fut));
284 let static_fut = unsafe { std::mem::transmute::<FutureObj<'sc, ()>, FutureObj<'static, ()>>(fut_obj) };
298 Some(static_fut)
299 }
300
301 unsafe fn new_local_raw(
302 pad: Arc<RelayPad<'sc>>,
303 spawn: Sp,
304 root: bool,
305 manager: Arc<SpawnManager>,
306 ) -> Option<LocalFutureObj<'static, ()>>
307 where
308 Sp: Respawn<'sc>,
309 {
310 let fut = Self::new(pad, spawn, root, manager)?;
311 let fut_obj = LocalFutureObj::new(Box::new(fut));
312 let static_fut =
314 unsafe { std::mem::transmute::<LocalFutureObj<'sc, ()>, LocalFutureObj<'static, ()>>(fut_obj) };
315 Some(static_fut)
316 }
317}
318
319pub fn spawn_on_global<'sc, Sp: Spawn + Clone + Send + 'sc>(
320 pad: Arc<RelayPad<'sc>>,
321 spawn: Sp,
322 spawn_id: usize,
323) -> Result<(), SpawnError> {
324 let fut = unsafe {
325 RelayFuture::new_global_raw(
326 pad,
327 GlobalRespawn(spawn.clone()),
328 true,
329 Arc::new(SpawnManager::new(spawn_id)),
330 )
331 };
332 if let Some(fut) = fut {
333 spawn.spawn_obj(fut)
334 } else {
335 Err(SpawnError::shutdown())
336 }
337}
338
339pub fn spawn_on_local<'sc, Sp: LocalSpawn + Clone + 'sc>(
340 pad: Arc<RelayPad<'sc>>,
341 spawn: Sp,
342 spawn_id: usize,
343) -> Result<(), SpawnError> {
344 let fut = unsafe {
345 RelayFuture::new_local_raw(
346 pad,
347 LocalRespawn(spawn.clone()),
348 true,
349 Arc::new(SpawnManager::new(spawn_id)),
350 )
351 };
352 if let Some(fut) = fut {
353 spawn.spawn_local_obj(fut)
354 } else {
355 Err(SpawnError::shutdown())
356 }
357}