datex_core/
task.rs

1use cfg_if::cfg_if;
2use futures::channel::mpsc;
3use futures_util::{FutureExt, SinkExt, StreamExt};
4use log::info;
5use std::cell::RefCell;
6use std::future::Future;
7use std::rc::Rc;
8
9thread_local! {
10    static LOCAL_PANIC_CHANNEL: Rc<RefCell<
11        Option<(
12            Option<RefCell<mpsc::UnboundedSender<Signal>>>,
13            Option<mpsc::UnboundedReceiver<Signal>>
14        )>
15    >> = Rc::new(RefCell::new(None));
16}
17
18enum Signal {
19    Panic(String),
20    Exit,
21}
22
23/// Creates an async execution context in which `spawn_local` or `spawn_with_panic_notify` can be used.
24/// When a panic occurs in a background task spawned with `spawn_with_panic_notify`, the panic will
25/// be propagated to the main task and the execution will be stopped.
26///
27/// Example usage:
28/// ```rust
29/// use datex_core::run_async;
30/// use datex_core::task::spawn_with_panic_notify;
31///
32/// async fn example() {
33///     run_async! {
34///         tokio::time::sleep(std::time::Duration::from_secs(1)).await;
35///         spawn_with_panic_notify(async {
36///             // Simulate a panic
37///             panic!("This is a test panic");
38///        });
39///     }
40/// }
41/// ```
42#[macro_export]
43macro_rules! run_async {
44    ($($body:tt)*) => {{
45        datex_core::task::init_panic_notify();
46
47        tokio::task::LocalSet::new()
48            .run_until(async move {
49                let res = (async move { $($body)* }).await;
50                datex_core::task::close_panic_notify().await;
51                datex_core::task::unwind_local_spawn_panics().await;
52                res
53            }).await
54    }}
55}
56
57/// Spawns a thread that runs an async block using the Tokio runtime.
58/// The behavior is similar to `run_async! {}`, with the only difference being that
59/// it runs in a separate thread.
60#[macro_export]
61macro_rules! run_async_thread {
62    ($($body:tt)*) => {{
63        thread::spawn(move || {
64            // tokio runtime setup
65            let runtime = tokio::runtime::Runtime::new().unwrap();
66
67            // Run an async block using the runtime
68            runtime.block_on(async {
69                run_async! {
70                    $($body)*
71                }
72            });
73        })
74    }}
75}
76
77pub fn init_panic_notify() {
78    let (tx, rx) = mpsc::unbounded::<Signal>();
79    LOCAL_PANIC_CHANNEL
80        .try_with(|channel| {
81            let mut channel = channel.borrow_mut();
82            if channel.is_none() {
83                *channel = Some((Some(RefCell::new(tx)), Some(rx)));
84            } else {
85                panic!("Panic channel already initialized");
86            }
87        })
88        .expect("Failed to initialize panic channel");
89}
90
91pub async fn close_panic_notify() {
92    LOCAL_PANIC_CHANNEL
93        .with(|channel| {
94            let channel = channel.clone();
95            let mut channel = channel.borrow_mut();
96            if let Some((tx, _)) = &mut *channel {
97                tx.take()
98            } else {
99                panic!("Panic channel not initialized");
100            }
101        })
102        .expect("Failed to access panic channel")
103        .clone()
104        .borrow_mut()
105        .send(Signal::Exit)
106        .await
107        .expect("Failed to send exit signal");
108}
109
110pub async fn unwind_local_spawn_panics() {
111    let mut rx = LOCAL_PANIC_CHANNEL
112        .with(|channel| {
113            let channel = channel.clone();
114            let mut channel = channel.borrow_mut();
115            if let Some((_, rx)) = &mut *channel {
116                rx.take()
117            } else {
118                panic!("Panic channel not initialized");
119            }
120        })
121        .expect("Failed to access panic channel");
122    info!("Waiting for local spawn panics...");
123    if let Some(panic_msg) = rx.next().await {
124        match panic_msg {
125            Signal::Exit => {}
126            Signal::Panic(panic_msg) => {
127                panic!("Panic in local spawn: {panic_msg}");
128            }
129        }
130    }
131}
132async fn send_panic(panic: String) {
133    LOCAL_PANIC_CHANNEL
134        .try_with(|channel| {
135            let channel = channel.clone();
136            let channel = channel.borrow_mut();
137            if let Some((tx, _)) = &*channel {
138                tx.clone().expect("Panic channel not initialized")
139            } else {
140                panic!("Panic channel not initialized");
141            }
142        })
143        .expect("Failed to access panic channel")
144        .borrow_mut()
145        .send(Signal::Panic(panic))
146        .await
147        .expect("Failed to send panic");
148}
149
150pub fn spawn_with_panic_notify<F>(fut: F)
151where
152    F: Future<Output = ()> + 'static,
153{
154    spawn_local(async {
155        let result = std::panic::AssertUnwindSafe(fut).catch_unwind().await;
156        if let Err(err) = result {
157            let panic_msg = if let Some(s) = err.downcast_ref::<&str>() {
158                s.to_string()
159            } else if let Some(s) = err.downcast_ref::<String>() {
160                s.clone()
161            } else {
162                "Unknown panic type".to_string()
163            };
164            send_panic(panic_msg).await;
165        }
166    });
167}
168
169cfg_if! {
170    if #[cfg(feature = "tokio_runtime")] {
171        pub fn timeout<F>(duration: std::time::Duration, fut: F) -> tokio::time::Timeout<F::IntoFuture>
172        where
173            F: std::future::IntoFuture,
174        {
175            tokio::time::timeout(duration, fut)
176        }
177
178        pub fn spawn_local<F>(fut: F)-> tokio::task::JoinHandle<()>
179        where
180            F: std::future::Future<Output = ()> + 'static,
181        {
182            tokio::task::spawn_local(fut)
183        }
184        pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
185        where
186            F: Future<Output = ()> + Send + 'static,
187        {
188            tokio::spawn(fut)
189        }
190        pub fn spawn_blocking<F, R>(f: F) -> tokio::task::JoinHandle<R>
191        where
192            F: FnOnce() -> R + Send + 'static,
193            R: Send + 'static,
194        {
195            tokio::task::spawn_blocking(f)
196        }
197        pub async fn sleep(dur: std::time::Duration) {
198            tokio::time::sleep(dur).await;
199        }
200
201    } else if #[cfg(feature = "wasm_runtime")] {
202        use futures::future;
203
204        pub async fn timeout<F, T>(
205            duration: std::time::Duration,
206            fut: F,
207        ) -> Result<T, &'static str>
208        where
209            F: std::future::Future<Output = T>,
210        {
211            let timeout_fut = sleep(duration);
212            futures::pin_mut!(fut);
213            futures::pin_mut!(timeout_fut);
214
215            match future::select(fut, timeout_fut).await {
216                future::Either::Left((res, _)) => Ok(res),
217                future::Either::Right(_) => Err("timed out"),
218            }
219        }
220        pub async fn sleep(dur: std::time::Duration) {
221            gloo_timers::future::sleep(dur).await;
222        }
223
224        pub fn spawn_local<F>(fut: F)
225        where
226            F: std::future::Future<Output = ()> + 'static,
227        {
228            wasm_bindgen_futures::spawn_local(fut);
229        }
230        pub fn spawn<F>(fut: F)
231        where
232            F: std::future::Future<Output = ()> + 'static,
233        {
234            wasm_bindgen_futures::spawn_local(fut);
235        }
236        pub fn spawn_blocking<F>(_fut: F) -> !
237        where
238            F: std::future::Future + 'static,
239        {
240            panic!("`spawn_blocking` is not supported in the wasm runtime.");
241        }
242    } else {
243        compile_error!("Unsupported runtime. Please enable either 'tokio_runtime' or 'wasm_runtime' feature.");
244    }
245}