Skip to main content

veilid_tools/
spawn.rs

1use super::*;
2
3cfg_if! {
4    if #[cfg(feature="rt-wasm-bindgen")] {
5        use async_executors::{Bindgen, LocalSpawnHandleExt, SpawnHandleExt};
6
7        cfg_if! {
8            if #[cfg(feature="debug-locks-detect")] {
9                use std::task::{Context, Poll, Wake, Waker};
10                use std::sync::{atomic::AtomicU64, LazyLock};
11                use send_wrapper::SendWrapper;
12
13                static ACTIVE_TASK_ID: LazyLock<SendWrapper<AtomicU64>> = LazyLock::new(|| SendWrapper::new(AtomicU64::new(0)));
14                static NEXT_TASK_ID: LazyLock<SendWrapper<AtomicU64>> = LazyLock::new(|| SendWrapper::new(AtomicU64::new(0)));
15
16                #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
17                pub struct AsyncTaskId(u64);
18                impl AsyncTaskId {
19                    #[must_use]
20                    pub fn this() -> AsyncTaskId {
21                        AsyncTaskId(ACTIVE_TASK_ID.load(Ordering::Relaxed))
22                    }
23                }
24
25                // Wrapper for waker that propagates a task id
26                struct AllocTaskIdWakerWrapper {
27                    inner_waker: Waker,
28                    task_id: u64,
29                }
30                impl<'a> Wake for AllocTaskIdWakerWrapper {
31                    fn wake(self: Arc<Self>) {
32                        ACTIVE_TASK_ID.store(self.task_id, Ordering::Relaxed);
33                        self.inner_waker.wake_by_ref();
34                    }
35                }
36
37                // Wrapper that adds a task id to the context of a future that is the start of a spawned task
38                struct AllocTaskIdFuture<Fut: Future> {
39                    inner: Fut,
40                    task_id: u64,
41                }
42                impl<Fut: Future> From<Fut> for AllocTaskIdFuture<Fut> {
43                    fn from(inner: Fut) -> Self {
44                        let task_id = NEXT_TASK_ID.fetch_add(1, Ordering::Relaxed);
45                        Self {
46                            inner,
47                            task_id,
48                        }
49                    }
50                }
51                impl<Fut: Future> Future for AllocTaskIdFuture<Fut> {
52                    type Output = Fut::Output;
53                    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
54                        // Poll inner future
55                        let task_id = self.task_id;
56                        let inner_fut = unsafe { self.map_unchecked_mut(|s| &mut s.inner) };
57                        let wrapped_waker = Arc::new(AllocTaskIdWakerWrapper {
58                            inner_waker: cx.waker().clone(),
59                            task_id,
60                        }).into();
61                        let mut wrapped_cx = Context::from_waker(&wrapped_waker);
62                        inner_fut.poll(&mut wrapped_cx)
63                    }
64                }
65            }
66        }
67
68        pub fn spawn<Out>(_name: &str, future: impl Future<Output = Out> + Send + 'static) -> MustJoinHandle<Out>
69        where
70            Out: Send + 'static,
71        {
72
73            #[cfg(feature="debug-locks-detect")]
74            let future = AllocTaskIdFuture::from(future);
75
76            MustJoinHandle::new(
77                Bindgen
78                    .spawn_handle(future)
79                    .expect_or_log("wasm-bindgen-futures spawn_handle_local should never error out"),
80            )
81        }
82
83        pub fn spawn_local<Out>(_name: &str, future: impl Future<Output = Out> + 'static) -> MustJoinHandle<Out>
84        where
85            Out: 'static,
86        {
87
88            #[cfg(feature="debug-locks-detect")]
89            let future = AllocTaskIdFuture::from(future);
90
91            MustJoinHandle::new(
92                Bindgen
93                    .spawn_handle_local(future)
94                    .expect_or_log("wasm-bindgen-futures spawn_handle_local should never error out"),
95            )
96        }
97
98        pub fn spawn_detached<Out>(_name: &str, future: impl Future<Output = Out> + Send + 'static)
99        where
100            Out: Send + 'static,
101        {
102            #[cfg(feature="debug-locks-detect")]
103            let future = AllocTaskIdFuture::from(future);
104
105            Bindgen
106                .spawn_handle_local(future)
107                .expect_or_log("wasm-bindgen-futures spawn_handle_local should never error out")
108                .detach()
109        }
110        pub fn spawn_detached_local<Out>(_name: &str, future: impl Future<Output = Out> + 'static)
111        where
112            Out: 'static,
113        {
114            #[cfg(feature="debug-locks-detect")]
115            let future = AllocTaskIdFuture::from(future);
116
117            Bindgen
118                .spawn_handle_local(future)
119                .expect_or_log("wasm-bindgen-futures spawn_handle_local should never error out")
120                .detach()
121        }
122
123    } else {
124
125        pub fn spawn<Out>(name: &str, future: impl Future<Output = Out> + Send + 'static) -> MustJoinHandle<Out>
126        where
127            Out: Send + 'static,
128        {
129            cfg_if! {
130                if #[cfg(feature="rt-async-std")] {
131                    MustJoinHandle::new(async_std::task::Builder::new().name(name.to_string()).spawn(future).unwrap_or_log())
132                } else if #[cfg(all(tokio_unstable, feature="rt-tokio", feature="tracing"))] {
133                    MustJoinHandle::new(tokio::task::Builder::new().name(name).spawn(future).unwrap_or_log())
134                } else if #[cfg(feature="rt-tokio")] {
135                    let _name = name;
136                    MustJoinHandle::new(tokio::task::spawn(future))
137                }
138            }
139        }
140
141        pub fn spawn_local<Out>(name: &str, future: impl Future<Output = Out> + 'static) -> MustJoinHandle<Out>
142        where
143            Out: 'static,
144        {
145            cfg_if! {
146                if #[cfg(feature="rt-async-std")] {
147                    MustJoinHandle::new(async_std::task::Builder::new().name(name.to_string()).local(future).unwrap_or_log())
148                } else if #[cfg(all(tokio_unstable, feature="rt-tokio", feature="tracing"))] {
149                    MustJoinHandle::new(tokio::task::Builder::new().name(name).spawn_local(future).unwrap_or_log())
150                } else if #[cfg(feature="rt-tokio")] {
151                    let _name = name;
152                    MustJoinHandle::new(tokio::task::spawn_local(future))
153                }
154            }
155        }
156
157        pub fn spawn_detached<Out>(name: &str, future: impl Future<Output = Out> + Send + 'static)
158        where
159            Out: Send + 'static,
160        {
161            cfg_if! {
162                if #[cfg(feature="rt-async-std")] {
163                    drop(async_std::task::Builder::new().name(name.to_string()).spawn(future).unwrap_or_log());
164                } else if #[cfg(all(tokio_unstable, feature="rt-tokio", feature="tracing"))] {
165                    drop(tokio::task::Builder::new().name(name).spawn(future).unwrap_or_log());
166                } else if #[cfg(feature="rt-tokio")] {
167                    let _name = name;
168                    drop(tokio::task::spawn(future))
169                }
170            }
171        }
172
173        pub fn spawn_detached_local<Out>(name: &str,future: impl Future<Output = Out> + 'static)
174        where
175            Out: 'static,
176        {
177            cfg_if! {
178                if #[cfg(feature="rt-async-std")] {
179                    drop(async_std::task::Builder::new().name(name.to_string()).local(future).unwrap_or_log());
180                } else if #[cfg(all(tokio_unstable, feature="rt-tokio", feature="tracing"))] {
181                    drop(tokio::task::Builder::new().name(name).spawn_local(future).unwrap_or_log());
182                } else if #[cfg(feature="rt-tokio")] {
183                    let _name = name;
184                    drop(tokio::task::spawn_local(future))
185                }
186            }
187        }
188
189        #[allow(unused_variables)]
190        pub async fn blocking_wrapper<F, R>(name: &str, blocking_task: F, err_result: R) -> R
191        where
192            F: FnOnce() -> R + Send + 'static,
193            R: Send + 'static,
194        {
195            // run blocking stuff in blocking thread
196            cfg_if! {
197                if #[cfg(feature="rt-async-std")] {
198                    let _name = name;
199                    // async_std::task::Builder blocking doesn't work like spawn_blocking()
200                    async_std::task::spawn_blocking(blocking_task).await
201                } else if #[cfg(all(tokio_unstable, feature="rt-tokio", feature="tracing"))] {
202                    tokio::task::Builder::new().name(name).spawn_blocking(blocking_task).unwrap_or_log().await.unwrap_or(err_result)
203                } else if #[cfg(feature="rt-tokio")] {
204                    let _name = name;
205                    tokio::task::spawn_blocking(blocking_task).await.unwrap_or(err_result)
206                } else {
207                    #[compile_error("must use an executor")]
208                }
209            }
210        }
211
212        cfg_if! {
213            if #[cfg(feature="rt-tokio")] {
214                #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
215                pub struct AsyncTaskId(tokio::task::Id);
216                impl AsyncTaskId {
217                    #[must_use]
218                    pub fn this() -> AsyncTaskId {
219                        AsyncTaskId(tokio::task::id())
220                    }
221                }
222            } else if #[cfg(feature="rt-async-std")] {
223                #[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
224                pub struct AsyncTaskId(async_std::task::TaskId);
225                impl AsyncTaskId {
226                    #[must_use]
227                    pub fn this() -> AsyncTaskId {
228                        AsyncTaskId(async_std::task::current().id())
229                    }
230                }
231            } else {
232                #[compile_error("must use an executor")]
233            }
234        }
235    }
236}