datex_core/
task.rs

1use crate::stdlib::string::String;
2use crate::stdlib::string::ToString;
3use cfg_if::cfg_if;
4use core::cell::RefCell;
5use core::clone::Clone;
6use core::future::Future;
7use core::prelude::rust_2024::*;
8use futures_util::{FutureExt, SinkExt, StreamExt};
9use log::info;
10
11type LocalPanicChannel = Option<(
12    Option<RefCell<UnboundedSender<Signal>>>,
13    Option<UnboundedReceiver<Signal>>,
14)>;
15
16#[cfg_attr(not(feature = "embassy_runtime"), thread_local)]
17static mut LOCAL_PANIC_CHANNEL: LocalPanicChannel = None;
18
19enum Signal {
20    Panic(String),
21    Exit,
22}
23
24/// Creates an async execution context in which `spawn_local` or `spawn_with_panic_notify` can be used.
25/// When a panic occurs in a background task spawned with `spawn_with_panic_notify`, the panic will
26/// be propagated to the main task and the execution will be stopped.
27///
28/// Example usage:
29/// ```rust
30/// use datex_core::run_async;
31/// use datex_core::task::{spawn_with_panic_notify_default};
32///
33/// async fn example() {
34///     run_async! {
35///         tokio::time::sleep(core::time::Duration::from_secs(1)).await;
36///         spawn_with_panic_notify_default(async {
37///             // Simulate a panic
38///             core::panic!("This is a test panic");
39///        });
40///     }
41/// }
42/// ```
43#[macro_export]
44macro_rules! run_async {
45    ($($body:tt)*) => {{
46        datex_core::task::init_panic_notify();
47
48        tokio::task::LocalSet::new()
49            .run_until(async move {
50                let res = (async move { $($body)* }).await;
51                datex_core::task::close_panic_notify().await;
52                datex_core::task::unwind_local_spawn_panics().await;
53                res
54            }).await
55    }}
56}
57
58/// Spawns a thread that runs an async block using the Tokio runtime.
59/// The behavior is similar to `run_async! {}`, with the only difference being that
60/// it runs in a separate thread.
61#[macro_export]
62macro_rules! run_async_thread {
63    ($($body:tt)*) => {{
64        thread::spawn(move || {
65            // tokio runtime setup
66            let runtime = tokio::runtime::Runtime::new().unwrap();
67
68            // Run an async block using the runtime
69            runtime.block_on(async {
70                run_async! {
71                    $($body)*
72                }
73            });
74        })
75    }}
76}
77
78pub fn init_panic_notify() {
79    let (tx, rx) = create_unbounded_channel::<Signal>();
80    unsafe {
81        let channel = &mut LOCAL_PANIC_CHANNEL;
82        if channel.is_none() {
83            *channel = Some((Some(RefCell::new(tx)), Some(rx)));
84        } else {
85            core::panic!("Panic channel already initialized");
86        }
87    }
88}
89
90#[allow(clippy::await_holding_refcell_ref)]
91pub async fn close_panic_notify() {
92    unsafe {
93        if let Some((tx, _)) = &mut LOCAL_PANIC_CHANNEL {
94            tx.take()
95                .clone()
96                .unwrap()
97                .borrow_mut()
98                .send(Signal::Exit)
99                .await
100                .expect("Failed to send exit signal");
101        } else {
102            core::panic!("Panic channel not initialized");
103        }
104    }
105}
106
107pub async fn unwind_local_spawn_panics() {
108    unsafe {
109        if let Some((_, rx)) = &mut LOCAL_PANIC_CHANNEL {
110            let mut rx = rx.take().unwrap();
111            info!("Waiting for local spawn panics...");
112            if let Some(panic_msg) = rx.next().await {
113                match panic_msg {
114                    Signal::Exit => {}
115                    Signal::Panic(panic_msg) => {
116                        core::panic!("Panic in local spawn: {panic_msg}");
117                    }
118                }
119            }
120        } else {
121            core::panic!("Panic channel not initialized");
122        }
123    }
124}
125
126#[allow(clippy::await_holding_refcell_ref)]
127async fn send_panic(panic: String) {
128    unsafe {
129        if let Some((tx, _)) = &LOCAL_PANIC_CHANNEL {
130            tx.clone()
131                .expect("Panic channel not initialized")
132                .borrow_mut()
133                .send(Signal::Panic(panic))
134                .await
135                .expect("Failed to send panic");
136        } else {
137            core::panic!("Panic channel not initialized");
138        }
139    }
140}
141#[cfg(feature = "embassy_runtime")]
142pub fn spawn_with_panic_notify<S>(
143    async_context: &AsyncContext,
144    spawn_token: embassy_executor::SpawnToken<S>,
145) {
146    async_context
147        .spawner
148        .spawn(spawn_token)
149        .expect("Spawn Error");
150}
151
152#[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
153pub fn spawn_with_panic_notify<F>(_async_context: &AsyncContext, fut: F)
154where
155    F: Future<Output = ()> + 'static,
156{
157    spawn_with_panic_notify_default(fut);
158}
159
160#[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
161pub fn spawn_with_panic_notify_default<F>(fut: F)
162where
163    F: Future<Output = ()> + 'static,
164{
165    spawn_local(async {
166        let result = core::panic::AssertUnwindSafe(fut).catch_unwind().await;
167        if let Err(err) = result {
168            let panic_msg = if let Some(s) = err.downcast_ref::<&str>() {
169                s.to_string()
170            } else if let Some(s) = err.downcast_ref::<String>() {
171                s.clone()
172            } else {
173                "Unknown panic type".to_string()
174            };
175            send_panic(panic_msg).await;
176        }
177    });
178}
179
180cfg_if! {
181    if #[cfg(feature = "tokio_runtime")] {
182        pub async fn timeout<T>(
183            duration: core::time::Duration,
184            fut: impl Future<Output = T>,
185        ) -> Result<T, ()> {
186            tokio::time::timeout(duration, fut)
187                .await
188                .map_err(|_| ())
189        }
190
191        pub fn spawn_local<F>(fut: F)-> tokio::task::JoinHandle<()>
192        where
193            F: Future<Output = ()> + 'static,
194        {
195            tokio::task::spawn_local(fut)
196        }
197        pub fn spawn<F>(fut: F) -> tokio::task::JoinHandle<F::Output>
198        where
199            F: Future<Output = ()> + Send + 'static,
200        {
201            tokio::spawn(fut)
202        }
203        pub fn spawn_blocking<F, R>(f: F) -> tokio::task::JoinHandle<R>
204        where
205            F: FnOnce() -> R + Send + 'static,
206            R: Send + 'static,
207        {
208            tokio::task::spawn_blocking(f)
209        }
210        pub async fn sleep(dur: core::time::Duration) {
211            tokio::time::sleep(dur).await;
212        }
213
214    }
215
216    else if #[cfg(feature = "wasm_runtime")] {
217        use futures::future;
218
219        pub async fn timeout<T>(
220            duration: core::time::Duration,
221            fut: impl Future<Output = T>,
222        ) -> Result<T, ()> {
223            let timeout_fut = sleep(duration);
224            futures::pin_mut!(fut);
225            futures::pin_mut!(timeout_fut);
226
227            match future::select(fut, timeout_fut).await {
228                future::Either::Left((res, _)) => Ok(res),
229                future::Either::Right(_) => Err(()),
230            }
231        }
232        pub async fn sleep(dur: core::time::Duration) {
233            gloo_timers::future::sleep(dur).await;
234        }
235
236        pub fn spawn_local<F>(fut: F)
237        where
238            F: core::future::Future<Output = ()> + 'static,
239        {
240            wasm_bindgen_futures::spawn_local(fut);
241        }
242        pub fn spawn<F>(fut: F)
243        where
244            F: core::future::Future<Output = ()> + 'static,
245        {
246            wasm_bindgen_futures::spawn_local(fut);
247        }
248        pub fn spawn_blocking<F>(_fut: F) -> !
249        where
250            F: core::future::Future + 'static,
251        {
252            core::panic!("`spawn_blocking` is not supported in the wasm runtime.");
253        }
254    }
255
256    else if #[cfg(feature = "embassy_runtime")] {
257        use embassy_time::{Duration, Timer};
258        use embassy_futures::select::select;
259        use embassy_futures::select::Either;
260
261        pub async fn sleep(dur: core::time::Duration) {
262            let emb_dur = Duration::from_millis(dur.as_millis() as u64);
263            Timer::after(emb_dur).await;
264        }
265
266        pub async fn timeout<T>(
267            duration: core::time::Duration,
268            fut: impl Future<Output = T>,
269        ) -> Result<T, ()> {
270            let emb_dur = Duration::from_millis(duration.as_millis() as u64);
271            let timeout = Timer::after(emb_dur);
272
273            match select(fut, timeout).await {
274                Either::First(t) => Ok(t),
275                Either::Second(_) => Err(()),
276            }
277        }
278
279    }
280    else {
281        compile_error!("Unsupported runtime. Please enable either 'tokio_runtime', 'embassy_runtime' or 'wasm_runtime' feature.");
282    }
283}
284
285#[cfg(feature = "embassy_runtime")]
286pub use async_unsync::bounded::{Receiver as _Receiver, Sender as _Sender};
287#[cfg(feature = "embassy_runtime")]
288pub use async_unsync::unbounded::{
289    UnboundedReceiver as _UnboundedReceiver,
290    UnboundedSender as _UnboundedSender,
291};
292use datex_core::runtime::AsyncContext;
293#[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
294use futures::channel::mpsc::{
295    Receiver as _Receiver, Sender as _Sender,
296    UnboundedReceiver as _UnboundedReceiver,
297    UnboundedSender as _UnboundedSender,
298};
299
300#[derive(Debug)]
301pub struct Receiver<T>(_Receiver<T>);
302impl<T> Receiver<T> {
303    pub fn new(receiver: _Receiver<T>) -> Self {
304        Receiver(receiver)
305    }
306
307    pub async fn next(&mut self) -> Option<T> {
308        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
309        {
310            self.0.next().await
311        }
312        #[cfg(feature = "embassy_runtime")]
313        {
314            self.0.recv().await
315        }
316    }
317}
318
319#[derive(Debug)]
320pub struct UnboundedReceiver<T>(_UnboundedReceiver<T>);
321impl<T> UnboundedReceiver<T> {
322    pub fn new(receiver: _UnboundedReceiver<T>) -> Self {
323        UnboundedReceiver(receiver)
324    }
325    pub async fn next(&mut self) -> Option<T> {
326        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
327        {
328            self.0.next().await
329        }
330        #[cfg(feature = "embassy_runtime")]
331        {
332            self.0.recv().await
333        }
334    }
335}
336
337#[derive(Debug)]
338pub struct Sender<T>(_Sender<T>);
339
340impl<T> Clone for Sender<T> {
341    fn clone(&self) -> Self {
342        Sender(self.0.clone())
343    }
344}
345impl<T> Sender<T> {
346    pub fn new(sender: _Sender<T>) -> Self {
347        Sender(sender)
348    }
349
350    pub fn start_send(&mut self, item: T) -> Result<(), ()> {
351        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
352        {
353            self.0.start_send(item).map_err(|_| ())
354        }
355        #[cfg(feature = "embassy_runtime")]
356        {
357            self.0.try_send(item).map_err(|_| ())
358        }
359    }
360
361    pub async fn send(&mut self, item: T) -> Result<(), ()> {
362        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
363        {
364            self.0.send(item).await.map_err(|_| ()).map(|_| ())
365        }
366        #[cfg(feature = "embassy_runtime")]
367        {
368            self.0.send(item).await.map(|_| ()).map_err(|_| ())
369        }
370    }
371
372    pub fn close_channel(&mut self) {
373        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
374        {
375            self.0.close_channel();
376        }
377        #[cfg(feature = "embassy_runtime")]
378        {}
379    }
380}
381
382#[derive(Debug)]
383pub struct UnboundedSender<T>(_UnboundedSender<T>);
384
385// FIXME #603: derive Clone?
386impl<T> Clone for UnboundedSender<T> {
387    fn clone(&self) -> Self {
388        UnboundedSender(self.0.clone())
389    }
390}
391
392impl<T> UnboundedSender<T> {
393    pub fn new(sender: _UnboundedSender<T>) -> Self {
394        UnboundedSender(sender)
395    }
396
397    pub fn start_send(&mut self, item: T) -> Result<(), ()> {
398        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
399        {
400            self.0.start_send(item).map_err(|_| ())
401        }
402        #[cfg(feature = "embassy_runtime")]
403        {
404            self.0.send(item).map_err(|_| ())
405        }
406    }
407
408    pub async fn send(&mut self, item: T) -> Result<(), ()> {
409        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
410        {
411            self.0.send(item).await.map_err(|_| ()).map(|_| ())
412        }
413        #[cfg(feature = "embassy_runtime")]
414        {
415            self.0.send(item).map(|_| ()).map_err(|_| ())
416        }
417    }
418
419    pub fn close_channel(&self) {
420        #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))]
421        {
422            self.0.close_channel();
423        }
424        #[cfg(feature = "embassy_runtime")]
425        {}
426    }
427}
428
429cfg_if! {
430    if #[cfg(any(feature = "tokio_runtime", feature = "wasm_runtime"))] {
431        pub fn create_bounded_channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
432            let (sender, receiver) = futures::channel::mpsc::channel::<T>(capacity);
433            (Sender::new(sender), Receiver::new(receiver))
434        }
435        pub fn create_unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
436            let (sender, receiver) = futures::channel::mpsc::unbounded::<T>();
437            (UnboundedSender::new(sender), UnboundedReceiver::new(receiver))
438        }
439    }
440    else if #[cfg(feature = "embassy_runtime")] {
441        pub fn create_bounded_channel<T>(capacity: usize) -> (Sender<T>, Receiver<T>) {
442            let (sender, receiver) = async_unsync::bounded::channel::<T>(capacity).into_split();
443            (Sender::new(sender), Receiver::new(receiver))
444        }
445         pub fn create_unbounded_channel<T>() -> (UnboundedSender<T>, UnboundedReceiver<T>) {
446            let (sender, receiver) = async_unsync::unbounded::channel::<T>().into_split();
447            (UnboundedSender::new(sender), UnboundedReceiver::new(receiver))
448        }
449    }
450    else {
451        compile_error!("Unsupported runtime. Please enable either 'tokio_runtime', 'embassy_runtime' or 'wasm_runtime' feature.");
452    }
453}