1use std::{
16 fmt::Debug,
17 mem::ManuallyDrop,
18 ops::{Deref, DerefMut},
19 sync::Arc,
20};
21
22use tokio::{
23 runtime::{Handle, Runtime},
24 task::JoinHandle,
25};
26
27use crate::error::{Error, ErrorKind, Result};
28
29pub struct BackgroundShutdownRuntime(ManuallyDrop<Runtime>);
33
34impl Debug for BackgroundShutdownRuntime {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 f.debug_tuple("BackgroundShutdownRuntime").finish()
37 }
38}
39
40impl Drop for BackgroundShutdownRuntime {
41 fn drop(&mut self) {
42 let runtime = unsafe { ManuallyDrop::take(&mut self.0) };
44
45 #[cfg(madsim)]
46 drop(runtime);
47 #[cfg(not(madsim))]
48 runtime.shutdown_background();
49 }
50}
51
52impl Deref for BackgroundShutdownRuntime {
53 type Target = Runtime;
54
55 fn deref(&self) -> &Self::Target {
56 &self.0
57 }
58}
59
60impl DerefMut for BackgroundShutdownRuntime {
61 fn deref_mut(&mut self) -> &mut Self::Target {
62 &mut self.0
63 }
64}
65
66impl From<Runtime> for BackgroundShutdownRuntime {
67 fn from(runtime: Runtime) -> Self {
68 Self(ManuallyDrop::new(runtime))
69 }
70}
71
72#[derive(Debug)]
74pub struct SpawnHandle<T> {
75 inner: JoinHandle<T>,
76}
77
78impl<T> std::future::Future for SpawnHandle<T> {
79 type Output = Result<T>;
80
81 fn poll(mut self: std::pin::Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> std::task::Poll<Self::Output> {
82 match std::pin::Pin::new(&mut self.inner).poll(cx) {
83 std::task::Poll::Ready(res) => match res {
84 Ok(v) => std::task::Poll::Ready(Ok(v)),
85 Err(e) => std::task::Poll::Ready(Err(Error::new(ErrorKind::Join, "tokio join error").with_source(e))),
86 },
87 std::task::Poll::Pending => std::task::Poll::Pending,
88 }
89 }
90}
91
92#[derive(Debug, Clone)]
94pub enum Spawner {
95 Runtime(Arc<BackgroundShutdownRuntime>),
97 Handle(Handle),
99}
100
101impl From<Runtime> for Spawner {
102 fn from(runtime: Runtime) -> Self {
103 Self::Runtime(Arc::new(runtime.into()))
104 }
105}
106
107impl From<Handle> for Spawner {
108 fn from(handle: Handle) -> Self {
109 Self::Handle(handle)
110 }
111}
112
113impl Spawner {
114 pub fn spawn<F>(&self, future: F) -> SpawnHandle<<F as std::future::Future>::Output>
116 where
117 F: std::future::Future + Send + 'static,
118 F::Output: Send + 'static,
119 {
120 let inner = match self {
121 Spawner::Runtime(rt) => rt.spawn(future),
122 Spawner::Handle(h) => h.spawn(future),
123 };
124 SpawnHandle { inner }
125 }
126
127 pub fn spawn_blocking<F, R>(&self, func: F) -> SpawnHandle<R>
129 where
130 F: FnOnce() -> R + Send + 'static,
131 R: Send + 'static,
132 {
133 let inner = match self {
134 Spawner::Runtime(rt) => rt.spawn_blocking(func),
135 Spawner::Handle(h) => h.spawn_blocking(func),
136 };
137 SpawnHandle { inner }
138 }
139
140 pub fn current() -> Self {
142 Spawner::Handle(Handle::current())
143 }
144}