1use crate::{RunToken, scope_guard::scope_guard};
3use futures_util::{
4    Future, FutureExt,
5    future::{self},
6    pin_mut,
7};
8use log::{debug, error, info};
9use std::{
10    borrow::Cow,
11    sync::{
12        Arc,
13        atomic::{AtomicUsize, Ordering},
14    },
15};
16use std::{collections::HashMap, sync::atomic::AtomicBool};
17use std::{fmt::Display, sync::Mutex};
18use std::{pin::Pin, task::Poll};
19use tokio::{
20    sync::Notify,
21    task::{JoinError, JoinHandle},
22};
23
24#[cfg(feature = "ordered-locks")]
25use ordered_locks::{CleanLockToken, L0, LockToken};
26
27static TASKS: Mutex<Option<HashMap<usize, Arc<dyn TaskBase>>>> = Mutex::new(None);
29static SHUTDOWN_NOTIFY: Notify = Notify::const_new();
31static TASK_ID_COUNT: AtomicUsize = AtomicUsize::new(0);
33static SHUTTING_DOWN: AtomicBool = AtomicBool::new(false);
35
36#[derive(Debug)]
38pub struct CancelledError {}
39impl Display for CancelledError {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        write!(f, "CancelledError")
42    }
43}
44impl std::error::Error for CancelledError {}
45
46pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
48
49pub async fn cancelable<T, F: Future<Output = T>>(
51    run_token: &RunToken,
52    fut: F,
53) -> Result<T, CancelledError> {
54    let c = run_token.cancelled();
55    pin_mut!(fut, c);
56    let f = future::select(c, fut).await;
57    match f {
58        future::Either::Right((v, _)) => Ok(v),
59        future::Either::Left(_) => Err(CancelledError {}),
60    }
61}
62
63#[cfg(feature = "ordered-locks")]
65pub async fn cancelable_checked<T, F: Future<Output = T>>(
66    run_token: &RunToken,
67    lock_token: LockToken<'_, L0>,
68    fut: F,
69) -> Result<T, CancelledError> {
70    let c = run_token.cancelled_checked(lock_token);
71    pin_mut!(fut, c);
72    let f = future::select(c, fut).await;
73    match f {
74        future::Either::Right((v, _)) => Ok(v),
75        future::Either::Left(_) => Err(CancelledError {}),
76    }
77}
78
79#[doc(hidden)]
80#[derive(Debug)]
81pub enum FinishState<'a> {
82    Success,
83    Drop,
84    Abort,
85    JoinError(JoinError),
86    Failure(&'a (dyn std::fmt::Debug + Sync + Send)),
87}
88
89pub struct TaskBuilder {
91    id: usize,
93    name: Cow<'static, str>,
95    run_token: RunToken,
97    critical: bool,
99    main: bool,
101    abort: bool,
103    no_shutdown: bool,
105    shutdown_order: i32,
107}
108
109impl TaskBuilder {
110    pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
112        Self {
113            id: TASK_ID_COUNT.fetch_add(1, Ordering::SeqCst),
114            name: name.into(),
115            run_token: Default::default(),
116            critical: false,
117            main: false,
118            abort: false,
119            no_shutdown: false,
120            shutdown_order: 0,
121        }
122    }
123
124    pub fn id(&self) -> usize {
126        self.id
127    }
128
129    pub fn set_run_token(self, run_token: RunToken) -> Self {
132        Self { run_token, ..self }
133    }
134
135    pub fn critical(self) -> Self {
137        Self {
138            critical: true,
139            ..self
140        }
141    }
142
143    pub fn main(self) -> Self {
145        Self { main: true, ..self }
146    }
147
148    pub fn abort(self) -> Self {
150        Self {
151            abort: true,
152            ..self
153        }
154    }
155
156    pub fn no_shutdown(self) -> Self {
158        Self {
159            no_shutdown: true,
160            ..self
161        }
162    }
163
164    pub fn shutdown_order(self, shutdown_order: i32) -> Self {
166        Self {
167            shutdown_order,
168            ..self
169        }
170    }
171
172    pub fn create<
174        T: 'static + Send + Sync,
175        E: std::fmt::Debug + Sync + Send + 'static,
176        Fu: Future<Output = Result<T, E>> + Send + 'static,
177        F: FnOnce(RunToken) -> Fu,
178    >(
179        self,
180        fun: F,
181    ) -> Arc<Task<T, E>> {
182        let fut = fun(self.run_token.clone());
183        let id = self.id;
184        let mut tasks = TASKS.lock().unwrap();
186        debug!("Started task {} ({})", self.name, id);
187        let join_handle = tokio::spawn(async move {
188            let g = scope_guard(|| {
189                if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
190                    t._internal_handle_finished(FinishState::Drop);
191                }
192            });
193            let r = fut.await;
194            let s = match &r {
195                Ok(_) => FinishState::Success,
196                Err(e) => FinishState::Failure(e),
197            };
198            g.release();
199            if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
200                t._internal_handle_finished(s);
201            }
202            r
203        });
204        let task = Arc::new(Task {
205            id: self.id,
206            name: self.name,
207            critical: self.critical,
208            main: self.main,
209            abort: self.abort,
210            no_shutdown: self.no_shutdown,
211            shutdown_order: self.shutdown_order,
212            run_token: self.run_token,
213            start_time: std::time::SystemTime::now()
214                .duration_since(std::time::UNIX_EPOCH)
215                .unwrap()
216                .as_secs_f64(),
217            join_handle: Mutex::new(Some(join_handle)),
218        });
219        tasks.get_or_insert_default().insert(self.id, task.clone());
220        task
221    }
222
223    #[cfg(feature = "ordered-locks")]
225    pub fn create_with_lock_token<
226        T: 'static + Send + Sync,
227        E: std::fmt::Debug + Sync + Send + 'static,
228        Fu: Future<Output = Result<T, E>> + Send + 'static,
229        F: FnOnce(RunToken, CleanLockToken) -> Fu,
230    >(
231        self,
232        fun: F,
233    ) -> Arc<Task<T, E>> {
234        self.create(|run_token| fun(run_token, unsafe { CleanLockToken::new() }))
236    }
237}
238
239pub trait TaskBase: Send + Sync {
241    #[doc(hidden)]
242    fn _internal_handle_finished(&self, state: FinishState);
243    fn shutdown_order(&self) -> i32;
245    fn name(&self) -> &str;
247    fn id(&self) -> usize;
249    fn main(&self) -> bool;
251    fn abort(&self) -> bool;
253    fn critical(&self) -> bool;
255    fn start_time(&self) -> f64;
257    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()>;
259    fn run_token(&self) -> &RunToken;
261    fn no_shutdown(&self) -> bool;
263}
264
265pub struct Task<T: Send + Sync, E: Sync + Sync> {
267    id: usize,
269    name: Cow<'static, str>,
271    critical: bool,
273    main: bool,
275    abort: bool,
277    no_shutdown: bool,
279    shutdown_order: i32,
281    run_token: RunToken,
283    start_time: f64,
285    join_handle: Mutex<Option<JoinHandle<Result<T, E>>>>,
287}
288
289impl<T: Send + Sync + 'static, E: Send + Sync + 'static> TaskBase for Task<T, E> {
290    fn shutdown_order(&self) -> i32 {
291        self.shutdown_order
292    }
293
294    fn name(&self) -> &str {
295        self.name.as_ref()
296    }
297
298    fn id(&self) -> usize {
299        self.id
300    }
301
302    fn _internal_handle_finished(&self, state: FinishState) {
303        match state {
304            FinishState::Success => {
305                if !self.main
306                    || !shutdown(format!(
307                        "Main task {} ({}) finished unexpected",
308                        self.name, self.id
309                    ))
310                {
311                    debug!("Finished task {} ({})", self.name, self.id);
312                }
313            }
314            FinishState::Drop => {
315                if self.main || self.critical {
316                    if shutdown(format!("Critical task {} ({}) dropped", self.name, self.id)) {
317                    } else if !self.abort {
318                        error!("Critical task {} ({}) dropped", self.name, self.id);
320                    } else {
321                        debug!("Critical task {} ({}) dropped", self.name, self.id)
322                    }
323                } else if !self.abort {
324                    error!("Task {} ({}) dropped", self.name, self.id);
326                } else {
327                    debug!("Task {} ({}) dropped", self.name, self.id)
328                }
329            }
330            FinishState::JoinError(e) => {
331                if (!self.main && !self.critical)
332                    || !shutdown(format!(
333                        "Join error in critical task {} ({}): {:?}",
334                        self.name, self.id, e
335                    ))
336                {
337                    error!("Join error in task {} ({}): {:?}", self.name, self.id, e);
338                }
339            }
340            FinishState::Failure(e) => {
341                if (!self.main && !self.critical)
342                    || !shutdown(format!(
343                        "Failure in critical task {} ({}) @ {:?}: {:?}",
344                        self.name,
345                        self.id,
346                        self.run_token().location(),
347                        e
348                    ))
349                {
350                    let location = self.run_token().location();
351                    error!(
352                        "Failure in task {} ({}) @ {:?}: {:?}",
353                        self.name, self.id, location, e
354                    );
355                }
356            }
357            FinishState::Abort => {
358                if !self.main
359                    || !shutdown(format!(
360                        "Main task {} ({}) aborted unexpected",
361                        self.name, self.id
362                    ))
363                {
364                    debug!("Aborted task {} ({})", self.name, self.id);
365                }
366            }
367        }
368    }
369
370    fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()> {
371        Box::pin(self.cancel())
372    }
373
374    fn main(&self) -> bool {
375        self.main
376    }
377
378    fn abort(&self) -> bool {
379        self.abort
380    }
381
382    fn critical(&self) -> bool {
383        self.critical
384    }
385
386    fn start_time(&self) -> f64 {
387        self.start_time
388    }
389
390    fn run_token(&self) -> &RunToken {
391        &self.run_token
392    }
393
394    fn no_shutdown(&self) -> bool {
395        self.no_shutdown
396    }
397}
398
399#[derive(Debug)]
401pub enum WaitError<E: Send + Sync> {
402    HandleUnset(String),
404    JoinError(tokio::task::JoinError),
406    TaskFailure(E),
408}
409
410impl<E: std::fmt::Display + Send + Sync> std::fmt::Display for WaitError<E> {
411    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
412        match self {
413            WaitError::HandleUnset(v) => write!(f, "Handle unset: {v}"),
414            WaitError::JoinError(v) => write!(f, "Join Error: {v}"),
415            WaitError::TaskFailure(v) => write!(f, "Task Failure: {v}"),
416        }
417    }
418}
419
420impl<E: std::error::Error + Send + Sync> std::error::Error for WaitError<E> {}
421
422struct TaskJoinHandleBorrow<'a, T: Send + Sync, E: Send + Sync> {
424    task: &'a Arc<Task<T, E>>,
426    jh: Option<JoinHandle<Result<T, E>>>,
428}
429
430impl<'a, T: Send + Sync, E: Send + Sync> TaskJoinHandleBorrow<'a, T, E> {
431    fn new(task: &'a Arc<Task<T, E>>) -> Self {
433        let jh = task.join_handle.lock().unwrap().take();
434        Self { task, jh }
435    }
436}
437
438impl<'a, T: Send + Sync, E: Send + Sync> Drop for TaskJoinHandleBorrow<'a, T, E> {
439    fn drop(&mut self) {
440        *self.task.join_handle.lock().unwrap() = self.jh.take();
441    }
442}
443
444impl<T: Send + Sync, E: Send + Sync> Task<T, E> {
445    pub async fn cancel(self: Arc<Self>) {
449        let mut b = TaskJoinHandleBorrow::new(&self);
450        self.run_token.cancel();
451        if let Some(jh) = &mut b.jh {
452            if self.abort {
453                jh.abort();
454                let _ = jh.await;
455                if let Some(t) = TASKS
456                    .lock()
457                    .unwrap()
458                    .get_or_insert_default()
459                    .remove(&self.id)
460                {
461                    t._internal_handle_finished(FinishState::Abort);
462                }
463            } else if let Err(e) = jh.await {
464                info!("Unable to join task {e:?}");
465                if let Some(t) = TASKS
466                    .lock()
467                    .unwrap()
468                    .get_or_insert_default()
469                    .remove(&self.id)
470                {
471                    t._internal_handle_finished(FinishState::JoinError(e));
472                }
473            }
474        }
475        if !SHUTTING_DOWN.load(Ordering::SeqCst) {
476            info!("  canceled {} ({})", self.name, self.id);
477        }
478        std::mem::forget(b);
479    }
480
481    pub async fn wait(self: Arc<Self>) -> Result<T, WaitError<E>> {
483        let mut b = TaskJoinHandleBorrow::new(&self);
484        let r = match &mut b.jh {
485            None => Err(WaitError::HandleUnset(self.name.to_string())),
486            Some(jh) => match jh.await {
487                Ok(Ok(v)) => Ok(v),
488                Ok(Err(e)) => Err(WaitError::TaskFailure(e)),
489                Err(e) => Err(WaitError::JoinError(e)),
490            },
491        };
492        std::mem::forget(b);
493        r
494    }
495}
496
497struct WaitTasks<'a, Sleep, Fut>(Sleep, &'a mut Vec<(String, usize, Fut, RunToken)>);
499impl<'a, Sleep: Unpin, Fut: Unpin> Unpin for WaitTasks<'a, Sleep, Fut> {}
500impl<'a, Sleep: Future + Unpin, Fut: Future + Unpin> Future for WaitTasks<'a, Sleep, Fut> {
501    type Output = bool;
502
503    fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<bool> {
504        if self.0.poll_unpin(cx).is_ready() {
505            return Poll::Ready(false);
506        }
507
508        self.1
509            .retain_mut(|(_, _, f, _)| !matches!(f.poll_unpin(cx), Poll::Ready(_)));
510
511        if self.1.is_empty() {
512            Poll::Ready(true)
513        } else {
514            Poll::Pending
515        }
516    }
517}
518
519pub fn shutdown(message: String) -> bool {
521    if SHUTTING_DOWN
522        .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
523        .is_err()
524    {
525        return false;
527    }
528    info!("Shutting down: {message}");
529    tokio::spawn(async move {
530        let mut shutdown_tasks: Vec<Arc<dyn TaskBase>> = Vec::new();
531        loop {
532            for (_, task) in TASKS.lock().unwrap().get_or_insert_default().iter() {
533                if task.no_shutdown() {
534                    continue;
535                }
536                if let Some(t) = shutdown_tasks.first() {
537                    if t.shutdown_order() < task.shutdown_order() {
538                        continue;
539                    }
540                    if t.shutdown_order() > task.shutdown_order() {
541                        shutdown_tasks.clear();
542                    }
543                }
544                shutdown_tasks.push(task.clone());
545            }
546            if shutdown_tasks.is_empty() {
547                break;
548            }
549            info!(
550                "shutting down {} tasks with order {}",
551                shutdown_tasks.len(),
552                shutdown_tasks[0].shutdown_order()
553            );
554            let mut stop_futures: Vec<(String, usize, _, RunToken)> = shutdown_tasks
555                .iter()
556                .map(|t| {
557                    (
558                        t.name().to_string(),
559                        t.id(),
560                        t.clone().cancel(),
561                        t.run_token().clone(),
562                    )
563                })
564                .collect();
565            while !WaitTasks(
566                Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(30))),
567                &mut stop_futures,
568            )
569            .await
570            {
571                info!("still waiting for {} tasks", stop_futures.len(),);
572                for (name, id, _, rt) in &stop_futures {
573                    if let Some((file, line)) = rt.location() {
574                        info!("  {name} ({id}) at {file}:{line}");
575                    } else {
576                        info!("  {name} ({id})");
577                    }
578                }
579            }
580            shutdown_tasks.clear();
581        }
582        info!("shutdown done");
583        SHUTDOWN_NOTIFY.notify_waiters();
584    });
585    true
586}
587
588pub async fn run_tasks() {
590    SHUTDOWN_NOTIFY.notified().await
591}
592
593pub fn list_tasks() -> Vec<Arc<dyn TaskBase>> {
595    TASKS
596        .lock()
597        .unwrap()
598        .get_or_insert_default()
599        .values()
600        .cloned()
601        .collect()
602}
603
604pub fn try_list_tasks_for(duration: std::time::Duration) -> Option<Vec<Arc<dyn TaskBase>>> {
607    let tries = 50;
608    for _ in 0..tries {
609        if let Ok(mut tasks) = TASKS.try_lock() {
610            return Some(tasks.get_or_insert_default().values().cloned().collect());
611        }
612        std::thread::sleep(duration / tries);
613    }
614    if let Ok(mut tasks) = TASKS.try_lock() {
615        return Some(tasks.get_or_insert_default().values().cloned().collect());
616    }
617    None
618}