1pub use tokio::{
14 runtime::{Handle as TokioHandle, Runtime as TokioRuntime},
15 sync::{
16 mpsc::{channel, Receiver, Sender},
17 Mutex, RwLock,
18 },
19 task::JoinHandle as TokioJoinHandle,
20};
21
22use std::{
23 future::Future,
24 pin::Pin,
25 sync::OnceLock,
26 task::{Context, Poll},
27};
28
29static RUNTIME: OnceLock<GlobalRuntime> = OnceLock::new();
30
31struct GlobalRuntime {
32 runtime: Option<Runtime>,
33 handle: RuntimeHandle,
34}
35
36impl GlobalRuntime {
37 fn handle(&self) -> RuntimeHandle {
38 if let Some(r) = &self.runtime {
39 r.handle()
40 } else {
41 self.handle.clone()
42 }
43 }
44
45 #[track_caller]
46 fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
47 where
48 F: Future + Send + 'static,
49 F::Output: Send + 'static,
50 {
51 if let Some(r) = &self.runtime {
52 r.spawn(task)
53 } else {
54 self.handle.spawn(task)
55 }
56 }
57
58 #[track_caller]
59 pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
60 where
61 F: FnOnce() -> R + Send + 'static,
62 R: Send + 'static,
63 {
64 if let Some(r) = &self.runtime {
65 r.spawn_blocking(func)
66 } else {
67 self.handle.spawn_blocking(func)
68 }
69 }
70
71 #[track_caller]
72 fn block_on<F: Future>(&self, task: F) -> F::Output {
73 if let Some(r) = &self.runtime {
74 r.block_on(task)
75 } else {
76 self.handle.block_on(task)
77 }
78 }
79}
80
81pub enum Runtime {
83 Tokio(TokioRuntime),
85}
86
87impl Runtime {
88 pub fn inner(&self) -> &TokioRuntime {
90 let Self::Tokio(r) = self;
91 r
92 }
93
94 pub fn handle(&self) -> RuntimeHandle {
96 match self {
97 Self::Tokio(r) => RuntimeHandle::Tokio(r.handle().clone()),
98 }
99 }
100
101 #[track_caller]
102 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
104 where
105 F: Future + Send + 'static,
106 F::Output: Send + 'static,
107 {
108 match self {
109 Self::Tokio(r) => {
110 let _guard = r.enter();
111 JoinHandle::Tokio(tokio::spawn(task))
112 }
113 }
114 }
115
116 #[track_caller]
117 pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
119 where
120 F: FnOnce() -> R + Send + 'static,
121 R: Send + 'static,
122 {
123 match self {
124 Self::Tokio(r) => JoinHandle::Tokio(r.spawn_blocking(func)),
125 }
126 }
127
128 #[track_caller]
129 pub fn block_on<F: Future>(&self, task: F) -> F::Output {
131 match self {
132 Self::Tokio(r) => r.block_on(task),
133 }
134 }
135}
136
137#[derive(Debug)]
139pub enum JoinHandle<T> {
140 Tokio(TokioJoinHandle<T>),
142}
143
144impl<T> JoinHandle<T> {
145 pub fn inner(&self) -> &TokioJoinHandle<T> {
147 let Self::Tokio(t) = self;
148 t
149 }
150
151 pub fn abort(&self) {
157 match self {
158 Self::Tokio(t) => t.abort(),
159 }
160 }
161}
162
163impl<T> Future for JoinHandle<T> {
164 type Output = crate::Result<T>;
165 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
166 match self.get_mut() {
167 Self::Tokio(t) => Pin::new(t).poll(cx).map_err(Into::into),
168 }
169 }
170}
171
172#[derive(Clone)]
174pub enum RuntimeHandle {
175 Tokio(TokioHandle),
177}
178
179impl RuntimeHandle {
180 pub fn inner(&self) -> &TokioHandle {
182 let Self::Tokio(h) = self;
183 h
184 }
185
186 #[track_caller]
187 pub fn spawn_blocking<F, R>(&self, func: F) -> JoinHandle<R>
189 where
190 F: FnOnce() -> R + Send + 'static,
191 R: Send + 'static,
192 {
193 match self {
194 Self::Tokio(h) => JoinHandle::Tokio(h.spawn_blocking(func)),
195 }
196 }
197
198 #[track_caller]
199 pub fn spawn<F>(&self, task: F) -> JoinHandle<F::Output>
201 where
202 F: Future + Send + 'static,
203 F::Output: Send + 'static,
204 {
205 match self {
206 Self::Tokio(h) => {
207 let _guard = h.enter();
208 JoinHandle::Tokio(tokio::spawn(task))
209 }
210 }
211 }
212
213 #[track_caller]
214 pub fn block_on<F: Future>(&self, task: F) -> F::Output {
216 match self {
217 Self::Tokio(h) => h.block_on(task),
218 }
219 }
220}
221
222fn default_runtime() -> GlobalRuntime {
223 let runtime = Runtime::Tokio(TokioRuntime::new().unwrap());
224 let handle = runtime.handle();
225 GlobalRuntime {
226 runtime: Some(runtime),
227 handle,
228 }
229}
230
231pub fn set(handle: TokioHandle) {
256 RUNTIME
257 .set(GlobalRuntime {
258 runtime: None,
259 handle: RuntimeHandle::Tokio(handle),
260 })
261 .unwrap_or_else(|_| panic!("runtime already initialized"))
262}
263
264pub fn handle() -> RuntimeHandle {
266 let runtime = RUNTIME.get_or_init(default_runtime);
267 runtime.handle()
268}
269
270#[track_caller]
271pub fn block_on<F: Future>(task: F) -> F::Output {
273 let runtime = RUNTIME.get_or_init(default_runtime);
274 runtime.block_on(task)
275}
276
277#[track_caller]
278pub fn spawn<F>(task: F) -> JoinHandle<F::Output>
280where
281 F: Future + Send + 'static,
282 F::Output: Send + 'static,
283{
284 let runtime = RUNTIME.get_or_init(default_runtime);
285 runtime.spawn(task)
286}
287
288#[track_caller]
289pub fn spawn_blocking<F, R>(func: F) -> JoinHandle<R>
291where
292 F: FnOnce() -> R + Send + 'static,
293 R: Send + 'static,
294{
295 let runtime = RUNTIME.get_or_init(default_runtime);
296 runtime.spawn_blocking(func)
297}
298
299#[track_caller]
300#[allow(dead_code)]
301pub(crate) fn safe_block_on<F>(task: F) -> F::Output
302where
303 F: Future + Send + 'static,
304 F::Output: Send + 'static,
305{
306 if let Ok(handle) = tokio::runtime::Handle::try_current() {
307 let (tx, rx) = std::sync::mpsc::sync_channel(1);
308 let handle_ = handle.clone();
309 handle.spawn_blocking(move || {
310 tx.send(handle_.block_on(task)).unwrap();
311 });
312 rx.recv().unwrap()
313 } else {
314 block_on(task)
315 }
316}
317
318#[cfg(test)]
319mod tests {
320 use super::*;
321
322 #[tokio::test]
323 async fn runtime_spawn() {
324 let join = spawn(async { 5 });
325 assert_eq!(join.await.unwrap(), 5);
326 }
327
328 #[test]
329 fn runtime_block_on() {
330 assert_eq!(block_on(async { 0 }), 0);
331 }
332
333 #[tokio::test]
334 async fn handle_spawn() {
335 let handle = handle();
336 let join = handle.spawn(async { 5 });
337 assert_eq!(join.await.unwrap(), 5);
338 }
339
340 #[test]
341 fn handle_block_on() {
342 let handle = handle();
343 assert_eq!(handle.block_on(async { 0 }), 0);
344 }
345
346 #[tokio::test]
347 async fn handle_abort() {
348 let handle = handle();
349 let join = handle.spawn(async {
350 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
352 5
353 });
354 join.abort();
355 if let crate::Error::JoinError(raw_error) = join.await.unwrap_err() {
356 assert!(raw_error.is_cancelled());
357 } else {
358 panic!("Abort did not result in the expected `JoinError`");
359 }
360 }
361}