1use cfg_if::cfg_if;
2use futures::channel::mpsc;
3use futures_util::{FutureExt, SinkExt, StreamExt};
4use log::info;
5use std::cell::RefCell;
6use std::future::Future;
7use std::rc::Rc;
8
9type LocalPanicChannel = Rc<
10 RefCell<
11 Option<(
12 Option<RefCell<mpsc::UnboundedSender<Signal>>>,
13 Option<mpsc::UnboundedReceiver<Signal>>,
14 )>,
15 >,
16>;
17thread_local! {
18 static LOCAL_PANIC_CHANNEL: LocalPanicChannel = Rc::new(RefCell::new(None));
19}
20
21enum Signal {
22 Panic(String),
23 Exit,
24}
25
26#[macro_export]
46macro_rules! run_async {
47 ($($body:tt)*) => {{
48 datex_core::task::init_panic_notify();
49
50 tokio::task::LocalSet::new()
51 .run_until(async move {
52 let res = (async move { $($body)* }).await;
53 datex_core::task::close_panic_notify().await;
54 datex_core::task::unwind_local_spawn_panics().await;
55 res
56 }).await
57 }}
58}
59
60#[macro_export]
64macro_rules! run_async_thread {
65 ($($body:tt)*) => {{
66 thread::spawn(move || {
67 let runtime = tokio::runtime::Runtime::new().unwrap();
69
70 runtime.block_on(async {
72 run_async! {
73 $($body)*
74 }
75 });
76 })
77 }}
78}
79
80pub fn init_panic_notify() {
81 let (tx, rx) = mpsc::unbounded::<Signal>();
82 LOCAL_PANIC_CHANNEL
83 .try_with(|channel| {
84 let mut channel = channel.borrow_mut();
85 if channel.is_none() {
86 *channel = Some((Some(RefCell::new(tx)), Some(rx)));
87 } else {
88 panic!("Panic channel already initialized");
89 }
90 })
91 .expect("Failed to initialize panic channel");
92}
93
94#[allow(clippy::await_holding_refcell_ref)]
95pub async fn close_panic_notify() {
96 LOCAL_PANIC_CHANNEL
97 .with(|channel| {
98 let channel = channel.clone();
99 let mut channel = channel.borrow_mut();
100 if let Some((tx, _)) = &mut *channel {
101 tx.take()
102 } else {
103 panic!("Panic channel not initialized");
104 }
105 })
106 .expect("Failed to access panic channel")
107 .clone()
108 .borrow_mut()
109 .send(Signal::Exit)
110 .await
111 .expect("Failed to send exit signal");
112}
113
114pub async fn unwind_local_spawn_panics() {
115 let mut rx = LOCAL_PANIC_CHANNEL
116 .with(|channel| {
117 let channel = channel.clone();
118 let mut channel = channel.borrow_mut();
119 if let Some((_, rx)) = &mut *channel {
120 rx.take()
121 } else {
122 panic!("Panic channel not initialized");
123 }
124 })
125 .expect("Failed to access panic channel");
126 info!("Waiting for local spawn panics...");
127 if let Some(panic_msg) = rx.next().await {
128 match panic_msg {
129 Signal::Exit => {}
130 Signal::Panic(panic_msg) => {
131 panic!("Panic in local spawn: {panic_msg}");
132 }
133 }
134 }
135}
136
137#[allow(clippy::await_holding_refcell_ref)]
138async fn send_panic(panic: String) {
139 LOCAL_PANIC_CHANNEL
140 .try_with(|channel| {
141 let channel = channel.clone();
142 let channel = channel.borrow_mut();
143 if let Some((tx, _)) = &*channel {
144 tx.clone().expect("Panic channel not initialized")
145 } else {
146 panic!("Panic channel not initialized");
147 }
148 })
149 .expect("Failed to access panic channel")
150 .borrow_mut()
151 .send(Signal::Panic(panic))
152 .await
153 .expect("Failed to send panic");
154}
155
156pub fn spawn_with_panic_notify<F>(fut: F)
157where
158 F: Future<Output = ()> + 'static,
159{
160 spawn_local(async {
161 let result = std::panic::AssertUnwindSafe(fut).catch_unwind().await;
162 if let Err(err) = result {
163 let panic_msg = if let Some(s) = err.downcast_ref::<&str>() {
164 s.to_string()
165 } else if let Some(s) = err.downcast_ref::<String>() {
166 s.clone()
167 } else {
168 "Unknown panic type".to_string()
169 };
170 send_panic(panic_msg).await;
171 }
172 });
173}
174
175cfg_if! {
176 if #[cfg(feature = "tokio_runtime")] {
177 pub fn timeout<F>(duration: std::time::Duration, fut: F) -> tokio::time::Timeout<F::IntoFuture>
178 where
179 F: std::future::IntoFuture,
180 {
181 tokio::time::timeout(duration, fut)
182 }
183
184 pub fn spawn_local<F>(fut: F)-> tokio::task::JoinHandle<()>
185 where
186 F: std::future::Future<Output = ()> + 'static,
187 {
188 tokio::task::spawn_local(fut)
189 }
190 pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
191 where
192 F: Future<Output = ()> + Send + 'static,
193 {
194 tokio::spawn(fut)
195 }
196 pub fn spawn_blocking<F, R>(f: F) -> tokio::task::JoinHandle<R>
197 where
198 F: FnOnce() -> R + Send + 'static,
199 R: Send + 'static,
200 {
201 tokio::task::spawn_blocking(f)
202 }
203 pub async fn sleep(dur: std::time::Duration) {
204 tokio::time::sleep(dur).await;
205 }
206
207 } else if #[cfg(feature = "wasm_runtime")] {
208 use futures::future;
209
210 pub async fn timeout<F, T>(
211 duration: std::time::Duration,
212 fut: F,
213 ) -> Result<T, &'static str>
214 where
215 F: std::future::Future<Output = T>,
216 {
217 let timeout_fut = sleep(duration);
218 futures::pin_mut!(fut);
219 futures::pin_mut!(timeout_fut);
220
221 match future::select(fut, timeout_fut).await {
222 future::Either::Left((res, _)) => Ok(res),
223 future::Either::Right(_) => Err("timed out"),
224 }
225 }
226 pub async fn sleep(dur: std::time::Duration) {
227 gloo_timers::future::sleep(dur).await;
228 }
229
230 pub fn spawn_local<F>(fut: F)
231 where
232 F: std::future::Future<Output = ()> + 'static,
233 {
234 wasm_bindgen_futures::spawn_local(fut);
235 }
236 pub fn spawn<F>(fut: F)
237 where
238 F: std::future::Future<Output = ()> + 'static,
239 {
240 wasm_bindgen_futures::spawn_local(fut);
241 }
242 pub fn spawn_blocking<F>(_fut: F) -> !
243 where
244 F: std::future::Future + 'static,
245 {
246 panic!("`spawn_blocking` is not supported in the wasm runtime.");
247 }
248 } else {
249 compile_error!("Unsupported runtime. Please enable either 'tokio_runtime' or 'wasm_runtime' feature.");
250 }
251}