1#![cfg_attr(target_family = "wasm", allow(dead_code))]
2
3mod inner;
4
5pub mod jit;
7pub mod waiter;
8
9use std::future::Future;
10use std::pin::{Pin, pin};
11use std::sync::Arc;
12use std::time::SystemTime;
13
14use anyhow::bail;
15use fedimint_core::time::now;
16use fedimint_logging::{LOG_TASK, LOG_TEST};
17use futures::future::{self, Either};
18use inner::TaskGroupInner;
19use scopeguard::defer;
20use thiserror::Error;
21use tokio::sync::{oneshot, watch};
22use tracing::{debug, error, info, trace};
23
24use crate::runtime;
25pub use crate::runtime::*;
28#[derive(Clone, Default, Debug)]
39pub struct TaskGroup {
40    inner: Arc<TaskGroupInner>,
41}
42
43impl TaskGroup {
44    pub fn new() -> Self {
45        Self::default()
46    }
47
48    pub fn make_handle(&self) -> TaskHandle {
49        TaskHandle {
50            inner: self.inner.clone(),
51        }
52    }
53
54    pub fn make_subgroup(&self) -> Self {
67        let new_tg = Self::new();
68        self.inner.add_subgroup(new_tg.clone());
69        new_tg
70    }
71
72    pub fn is_shutting_down(&self) -> bool {
74        self.inner.is_shutting_down()
75    }
76
77    pub fn shutdown(&self) {
80        self.inner.shutdown();
81    }
82
83    pub async fn shutdown_join_all(
85        self,
86        join_timeout: impl Into<Option<Duration>>,
87    ) -> Result<(), anyhow::Error> {
88        self.shutdown();
89        self.join_all(join_timeout.into()).await
90    }
91
92    #[cfg(not(target_family = "wasm"))]
95    pub fn install_kill_handler(&self) {
96        async fn wait_for_shutdown_signal() {
98            use tokio::signal;
99
100            let ctrl_c = async {
101                signal::ctrl_c()
102                    .await
103                    .expect("failed to install Ctrl+C handler");
104            };
105
106            #[cfg(unix)]
107            let terminate = async {
108                signal::unix::signal(signal::unix::SignalKind::terminate())
109                    .expect("failed to install signal handler")
110                    .recv()
111                    .await;
112            };
113
114            #[cfg(not(unix))]
115            let terminate = std::future::pending::<()>();
116
117            tokio::select! {
118                () = ctrl_c => {},
119                () = terminate => {},
120            }
121        }
122
123        runtime::spawn("kill handlers", {
124            let task_group = self.clone();
125            async move {
126                wait_for_shutdown_signal().await;
127                info!(
128                    target: LOG_TASK,
129                    "signal received, starting graceful shutdown"
130                );
131                task_group.shutdown();
132            }
133        });
134    }
135
136    pub fn spawn<Fut, R>(
137        &self,
138        name: impl Into<String>,
139        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
140    ) -> oneshot::Receiver<R>
141    where
142        Fut: Future<Output = R> + MaybeSend + 'static,
143        R: MaybeSend + 'static,
144    {
145        self.spawn_inner(name, f, false)
146    }
147
148    pub fn spawn_silent<Fut, R>(
152        &self,
153        name: impl Into<String>,
154        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
155    ) -> oneshot::Receiver<R>
156    where
157        Fut: Future<Output = R> + MaybeSend + 'static,
158        R: MaybeSend + 'static,
159    {
160        self.spawn_inner(name, f, true)
161    }
162
163    fn spawn_inner<Fut, R>(
164        &self,
165        name: impl Into<String>,
166        f: impl FnOnce(TaskHandle) -> Fut + MaybeSend + 'static,
167        quiet: bool,
168    ) -> oneshot::Receiver<R>
169    where
170        Fut: Future<Output = R> + MaybeSend + 'static,
171        R: MaybeSend + 'static,
172    {
173        let name = name.into();
174        let mut guard = TaskPanicGuard {
175            name: name.clone(),
176            inner: self.inner.clone(),
177            completed: false,
178        };
179        let handle = self.make_handle();
180
181        let (tx, rx) = oneshot::channel();
182        self.inner
183            .active_tasks_join_handles
184            .lock()
185            .expect("Locking failed")
186            .insert_with_key(move |task_key| {
187                (
188                    name.clone(),
189                    crate::runtime::spawn(&name, {
190                        let name = name.clone();
191                        async move {
192                            defer! {
193                                if handle
198                                    .inner
199                                    .active_tasks_join_handles
200                                    .lock()
201                                    .expect("Locking failed")
202                                    .remove(task_key)
203                                    .is_none() {
204                                        trace!(target: LOG_TASK, %name, "Task already canceled");
205                                    }
206                            }
207                            if quiet {
209                                trace!(target: LOG_TASK, %name, "Starting task");
210                            } else {
211                                debug!(target: LOG_TASK, %name, "Starting task");
212                            }
213                            let r = f(handle.clone()).await;
214                            guard.completed = true;
215
216                            if quiet {
217                                trace!(target: LOG_TASK, %name, "Finished task");
218                            } else {
219                                debug!(target: LOG_TASK, %name, "Finished task");
220                            }
221                            let _ = tx.send(r);
223
224                            drop(guard);
227                        }
228                    }),
229                )
230            });
231
232        rx
233    }
234
235    pub fn spawn_cancellable<R>(
238        &self,
239        name: impl Into<String>,
240        future: impl Future<Output = R> + MaybeSend + 'static,
241    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
242    where
243        R: MaybeSend + 'static,
244    {
245        self.spawn(name, |handle| async move {
246            let value = handle.cancel_on_shutdown(future).await;
247            if value.is_err() {
248                debug!(target: LOG_TASK, "task cancelled on shutdown");
250            }
251            value
252        })
253    }
254
255    pub fn spawn_cancellable_silent<R>(
256        &self,
257        name: impl Into<String>,
258        future: impl Future<Output = R> + MaybeSend + 'static,
259    ) -> oneshot::Receiver<Result<R, ShuttingDownError>>
260    where
261        R: MaybeSend + 'static,
262    {
263        self.spawn_silent(name, |handle| async move {
264            let value = handle.cancel_on_shutdown(future).await;
265            if value.is_err() {
266                debug!(target: LOG_TASK, "task cancelled on shutdown");
268            }
269            value
270        })
271    }
272
273    pub async fn join_all(self, timeout: Option<Duration>) -> Result<(), anyhow::Error> {
274        let deadline = timeout.map(|timeout| now() + timeout);
275        let mut errors = vec![];
276
277        self.join_all_inner(deadline, &mut errors).await;
278
279        if errors.is_empty() {
280            Ok(())
281        } else {
282            let num_errors = errors.len();
283            bail!("{num_errors} tasks did not finish cleanly: {errors:?}")
284        }
285    }
286
287    #[cfg_attr(not(target_family = "wasm"), ::async_recursion::async_recursion)]
288    #[cfg_attr(target_family = "wasm", ::async_recursion::async_recursion(?Send))]
289    pub async fn join_all_inner(self, deadline: Option<SystemTime>, errors: &mut Vec<JoinError>) {
290        self.inner.join_all(deadline, errors).await;
291    }
292}
293
294struct TaskPanicGuard {
295    name: String,
296    inner: Arc<TaskGroupInner>,
297    completed: bool,
299}
300
301impl Drop for TaskPanicGuard {
302    fn drop(&mut self) {
303        trace!(
304            target: LOG_TASK,
305            name = %self.name,
306            "Task drop"
307        );
308        if !self.completed {
309            info!(
310                target: LOG_TASK,
311                name = %self.name,
312                "Task shut down uncleanly"
313            );
314            self.inner.shutdown();
315        }
316    }
317}
318
319#[derive(Clone, Debug)]
320pub struct TaskHandle {
321    inner: Arc<TaskGroupInner>,
322}
323
324#[derive(thiserror::Error, Debug, Clone)]
325#[error("Task group is shutting down")]
326#[non_exhaustive]
327pub struct ShuttingDownError {}
328
329impl TaskHandle {
330    pub fn is_shutting_down(&self) -> bool {
334        self.inner.is_shutting_down()
335    }
336
337    pub fn make_shutdown_rx(&self) -> TaskShutdownToken {
342        self.inner.make_shutdown_rx()
343    }
344
345    pub async fn cancel_on_shutdown<F: Future>(
347        &self,
348        fut: F,
349    ) -> Result<F::Output, ShuttingDownError> {
350        let rx = self.make_shutdown_rx();
351        match future::select(pin!(rx), pin!(fut)).await {
352            Either::Left(((), _)) => Err(ShuttingDownError {}),
353            Either::Right((value, _)) => Ok(value),
354        }
355    }
356}
357
358pub struct TaskShutdownToken(Pin<Box<dyn Future<Output = ()> + Send>>);
359
360impl TaskShutdownToken {
361    fn new(mut rx: watch::Receiver<bool>) -> Self {
362        Self(Box::pin(async move {
363            let _ = rx.wait_for(|v| *v).await;
364        }))
365    }
366}
367
368impl Future for TaskShutdownToken {
369    type Output = ();
370
371    fn poll(
372        mut self: Pin<&mut Self>,
373        cx: &mut std::task::Context<'_>,
374    ) -> std::task::Poll<Self::Output> {
375        self.0.as_mut().poll(cx)
376    }
377}
378
379#[macro_export]
396macro_rules! async_trait_maybe_send {
397    ($($tt:tt)*) => {
398        #[cfg_attr(not(target_family = "wasm"), ::async_trait::async_trait)]
399        #[cfg_attr(target_family = "wasm", ::async_trait::async_trait(?Send))]
400        $($tt)*
401    };
402}
403
404#[cfg(not(target_family = "wasm"))]
415#[macro_export]
416macro_rules! maybe_add_send {
417    ($($tt:tt)*) => {
418        $($tt)* + Send
419    };
420}
421
422#[cfg(target_family = "wasm")]
430#[macro_export]
431macro_rules! maybe_add_send {
432    ($($tt:tt)*) => {
433        $($tt)*
434    };
435}
436
437#[cfg(not(target_family = "wasm"))]
439#[macro_export]
440macro_rules! maybe_add_send_sync {
441    ($($tt:tt)*) => {
442        $($tt)* + Send + Sync
443    };
444}
445
446#[cfg(target_family = "wasm")]
448#[macro_export]
449macro_rules! maybe_add_send_sync {
450    ($($tt:tt)*) => {
451        $($tt)*
452    };
453}
454
455#[cfg(target_family = "wasm")]
460pub trait MaybeSend {}
461
462#[cfg(not(target_family = "wasm"))]
467pub trait MaybeSend: Send {}
468
469#[cfg(not(target_family = "wasm"))]
470impl<T: Send> MaybeSend for T {}
471
472#[cfg(target_family = "wasm")]
473impl<T> MaybeSend for T {}
474
475#[cfg(target_family = "wasm")]
477pub trait MaybeSync {}
478
479#[cfg(not(target_family = "wasm"))]
481pub trait MaybeSync: Sync {}
482
483#[cfg(not(target_family = "wasm"))]
484impl<T: Sync> MaybeSync for T {}
485
486#[cfg(target_family = "wasm")]
487impl<T> MaybeSync for T {}
488
489pub async fn sleep_in_test(comment: impl AsRef<str>, duration: Duration) {
492    info!(
493        target: LOG_TEST,
494        "Sleeping for {}.{:03} seconds because: {}",
495        duration.as_secs(),
496        duration.subsec_millis(),
497        comment.as_ref()
498    );
499    sleep(duration).await;
500}
501
502#[derive(Error, Debug)]
504#[error("Operation cancelled")]
505pub struct Cancelled;
506
507pub type Cancellable<T> = std::result::Result<T, Cancelled>;
510
511#[cfg(test)]
512mod tests;