datex_core/
task.rs

1use cfg_if::cfg_if;
2use futures::channel::mpsc;
3use futures_util::{FutureExt, SinkExt, StreamExt};
4use log::info;
5use std::cell::RefCell;
6use std::future::Future;
7use std::rc::Rc;
8
9type LocalPanicChannel = Rc<
10    RefCell<
11        Option<(
12            Option<RefCell<mpsc::UnboundedSender<Signal>>>,
13            Option<mpsc::UnboundedReceiver<Signal>>,
14        )>,
15    >,
16>;
17thread_local! {
18    static LOCAL_PANIC_CHANNEL: LocalPanicChannel = Rc::new(RefCell::new(None));
19}
20
21enum Signal {
22    Panic(String),
23    Exit,
24}
25
26/// Creates an async execution context in which `spawn_local` or `spawn_with_panic_notify` can be used.
27/// When a panic occurs in a background task spawned with `spawn_with_panic_notify`, the panic will
28/// be propagated to the main task and the execution will be stopped.
29///
30/// Example usage:
31/// ```rust
32/// use datex_core::run_async;
33/// use datex_core::task::spawn_with_panic_notify;
34///
35/// async fn example() {
36///     run_async! {
37///         tokio::time::sleep(std::time::Duration::from_secs(1)).await;
38///         spawn_with_panic_notify(async {
39///             // Simulate a panic
40///             panic!("This is a test panic");
41///        });
42///     }
43/// }
44/// ```
45#[macro_export]
46macro_rules! run_async {
47    ($($body:tt)*) => {{
48        datex_core::task::init_panic_notify();
49
50        tokio::task::LocalSet::new()
51            .run_until(async move {
52                let res = (async move { $($body)* }).await;
53                datex_core::task::close_panic_notify().await;
54                datex_core::task::unwind_local_spawn_panics().await;
55                res
56            }).await
57    }}
58}
59
60/// Spawns a thread that runs an async block using the Tokio runtime.
61/// The behavior is similar to `run_async! {}`, with the only difference being that
62/// it runs in a separate thread.
63#[macro_export]
64macro_rules! run_async_thread {
65    ($($body:tt)*) => {{
66        thread::spawn(move || {
67            // tokio runtime setup
68            let runtime = tokio::runtime::Runtime::new().unwrap();
69
70            // Run an async block using the runtime
71            runtime.block_on(async {
72                run_async! {
73                    $($body)*
74                }
75            });
76        })
77    }}
78}
79
80pub fn init_panic_notify() {
81    let (tx, rx) = mpsc::unbounded::<Signal>();
82    LOCAL_PANIC_CHANNEL
83        .try_with(|channel| {
84            let mut channel = channel.borrow_mut();
85            if channel.is_none() {
86                *channel = Some((Some(RefCell::new(tx)), Some(rx)));
87            } else {
88                panic!("Panic channel already initialized");
89            }
90        })
91        .expect("Failed to initialize panic channel");
92}
93
94#[allow(clippy::await_holding_refcell_ref)]
95pub async fn close_panic_notify() {
96    LOCAL_PANIC_CHANNEL
97        .with(|channel| {
98            let channel = channel.clone();
99            let mut channel = channel.borrow_mut();
100            if let Some((tx, _)) = &mut *channel {
101                tx.take()
102            } else {
103                panic!("Panic channel not initialized");
104            }
105        })
106        .expect("Failed to access panic channel")
107        .clone()
108        .borrow_mut()
109        .send(Signal::Exit)
110        .await
111        .expect("Failed to send exit signal");
112}
113
114pub async fn unwind_local_spawn_panics() {
115    let mut rx = LOCAL_PANIC_CHANNEL
116        .with(|channel| {
117            let channel = channel.clone();
118            let mut channel = channel.borrow_mut();
119            if let Some((_, rx)) = &mut *channel {
120                rx.take()
121            } else {
122                panic!("Panic channel not initialized");
123            }
124        })
125        .expect("Failed to access panic channel");
126    info!("Waiting for local spawn panics...");
127    if let Some(panic_msg) = rx.next().await {
128        match panic_msg {
129            Signal::Exit => {}
130            Signal::Panic(panic_msg) => {
131                panic!("Panic in local spawn: {panic_msg}");
132            }
133        }
134    }
135}
136
137#[allow(clippy::await_holding_refcell_ref)]
138async fn send_panic(panic: String) {
139    LOCAL_PANIC_CHANNEL
140        .try_with(|channel| {
141            let channel = channel.clone();
142            let channel = channel.borrow_mut();
143            if let Some((tx, _)) = &*channel {
144                tx.clone().expect("Panic channel not initialized")
145            } else {
146                panic!("Panic channel not initialized");
147            }
148        })
149        .expect("Failed to access panic channel")
150        .borrow_mut()
151        .send(Signal::Panic(panic))
152        .await
153        .expect("Failed to send panic");
154}
155
156pub fn spawn_with_panic_notify<F>(fut: F)
157where
158    F: Future<Output = ()> + 'static,
159{
160    spawn_local(async {
161        let result = std::panic::AssertUnwindSafe(fut).catch_unwind().await;
162        if let Err(err) = result {
163            let panic_msg = if let Some(s) = err.downcast_ref::<&str>() {
164                s.to_string()
165            } else if let Some(s) = err.downcast_ref::<String>() {
166                s.clone()
167            } else {
168                "Unknown panic type".to_string()
169            };
170            send_panic(panic_msg).await;
171        }
172    });
173}
174
175cfg_if! {
176    if #[cfg(feature = "tokio_runtime")] {
177        pub fn timeout<F>(duration: std::time::Duration, fut: F) -> tokio::time::Timeout<F::IntoFuture>
178        where
179            F: std::future::IntoFuture,
180        {
181            tokio::time::timeout(duration, fut)
182        }
183
184        pub fn spawn_local<F>(fut: F)-> tokio::task::JoinHandle<()>
185        where
186            F: std::future::Future<Output = ()> + 'static,
187        {
188            tokio::task::spawn_local(fut)
189        }
190        pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
191        where
192            F: Future<Output = ()> + Send + 'static,
193        {
194            tokio::spawn(fut)
195        }
196        pub fn spawn_blocking<F, R>(f: F) -> tokio::task::JoinHandle<R>
197        where
198            F: FnOnce() -> R + Send + 'static,
199            R: Send + 'static,
200        {
201            tokio::task::spawn_blocking(f)
202        }
203        pub async fn sleep(dur: std::time::Duration) {
204            tokio::time::sleep(dur).await;
205        }
206
207    } else if #[cfg(feature = "wasm_runtime")] {
208        use futures::future;
209
210        pub async fn timeout<F, T>(
211            duration: std::time::Duration,
212            fut: F,
213        ) -> Result<T, &'static str>
214        where
215            F: std::future::Future<Output = T>,
216        {
217            let timeout_fut = sleep(duration);
218            futures::pin_mut!(fut);
219            futures::pin_mut!(timeout_fut);
220
221            match future::select(fut, timeout_fut).await {
222                future::Either::Left((res, _)) => Ok(res),
223                future::Either::Right(_) => Err("timed out"),
224            }
225        }
226        pub async fn sleep(dur: std::time::Duration) {
227            gloo_timers::future::sleep(dur).await;
228        }
229
230        pub fn spawn_local<F>(fut: F)
231        where
232            F: std::future::Future<Output = ()> + 'static,
233        {
234            wasm_bindgen_futures::spawn_local(fut);
235        }
236        pub fn spawn<F>(fut: F)
237        where
238            F: std::future::Future<Output = ()> + 'static,
239        {
240            wasm_bindgen_futures::spawn_local(fut);
241        }
242        pub fn spawn_blocking<F>(_fut: F) -> !
243        where
244            F: std::future::Future + 'static,
245        {
246            panic!("`spawn_blocking` is not supported in the wasm runtime.");
247        }
248    } else {
249        compile_error!("Unsupported runtime. Please enable either 'tokio_runtime' or 'wasm_runtime' feature.");
250    }
251}