1use crate::{RunToken, scope_guard::scope_guard};
2use futures_util::{
3 Future, FutureExt,
4 future::{self},
5 pin_mut,
6};
7use log::{debug, error, info};
8use std::{
9 borrow::Cow,
10 sync::{
11 Arc,
12 atomic::{AtomicUsize, Ordering},
13 },
14};
15use std::{collections::HashMap, sync::atomic::AtomicBool};
16use std::{fmt::Display, sync::Mutex};
17use std::{pin::Pin, task::Poll};
18use tokio::{
19 sync::Notify,
20 task::{JoinError, JoinHandle},
21};
22
23#[cfg(feature = "ordered-locks")]
24use ordered_locks::{CleanLockToken, L0, LockToken};
25
26static TASKS: Mutex<Option<HashMap<usize, Arc<dyn TaskBase>>>> = Mutex::new(None);
27static SHUTDOWN_NOTIFY: Notify = Notify::const_new();
28static TASK_ID_COUNT: AtomicUsize = AtomicUsize::new(0);
29static SHUTTING_DOWN: AtomicBool = AtomicBool::new(false);
30
31#[derive(Debug)]
33pub struct CancelledError {}
34impl Display for CancelledError {
35 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
36 write!(f, "CancelledError")
37 }
38}
39impl std::error::Error for CancelledError {}
40
41pub type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
42
43pub async fn cancelable<T, F: Future<Output = T>>(
45 run_token: &RunToken,
46 fut: F,
47) -> Result<T, CancelledError> {
48 let c = run_token.cancelled();
49 pin_mut!(fut, c);
50 let f = future::select(c, fut).await;
51 match f {
52 future::Either::Right((v, _)) => Ok(v),
53 future::Either::Left(_) => Err(CancelledError {}),
54 }
55}
56
57#[cfg(feature = "ordered-locks")]
59pub async fn cancelable_checked<T, F: Future<Output = T>>(
60 run_token: &RunToken,
61 lock_token: LockToken<'_, L0>,
62 fut: F,
63) -> Result<T, CancelledError> {
64 let c = run_token.cancelled_checked(lock_token);
65 pin_mut!(fut, c);
66 let f = future::select(c, fut).await;
67 match f {
68 future::Either::Right((v, _)) => Ok(v),
69 future::Either::Left(_) => Err(CancelledError {}),
70 }
71}
72
73#[doc(hidden)]
74#[derive(Debug)]
75pub enum FinishState<'a> {
76 Success,
77 Drop,
78 Abort,
79 JoinError(JoinError),
80 Failure(&'a (dyn std::fmt::Debug + Sync + Send)),
81}
82
83pub struct TaskBuilder {
85 id: usize,
86 name: Cow<'static, str>,
87 run_token: RunToken,
88 critical: bool,
89 main: bool,
90 abort: bool,
91 no_shutdown: bool,
92 shutdown_order: i32,
93}
94
95impl TaskBuilder {
96 pub fn new(name: impl Into<Cow<'static, str>>) -> Self {
98 Self {
99 id: TASK_ID_COUNT.fetch_add(1, Ordering::SeqCst),
100 name: name.into(),
101 run_token: Default::default(),
102 critical: false,
103 main: false,
104 abort: false,
105 no_shutdown: false,
106 shutdown_order: 0,
107 }
108 }
109
110 pub fn id(&self) -> usize {
112 self.id
113 }
114
115 pub fn set_run_token(self, run_token: RunToken) -> Self {
118 Self { run_token, ..self }
119 }
120
121 pub fn critical(self) -> Self {
123 Self {
124 critical: true,
125 ..self
126 }
127 }
128
129 pub fn main(self) -> Self {
131 Self { main: true, ..self }
132 }
133
134 pub fn abort(self) -> Self {
136 Self {
137 abort: true,
138 ..self
139 }
140 }
141
142 pub fn no_shutdown(self) -> Self {
144 Self {
145 no_shutdown: true,
146 ..self
147 }
148 }
149
150 pub fn shutdown_order(self, shutdown_order: i32) -> Self {
152 Self {
153 shutdown_order,
154 ..self
155 }
156 }
157
158 pub fn create<
160 T: 'static + Send + Sync,
161 E: std::fmt::Debug + Sync + Send + 'static,
162 Fu: Future<Output = Result<T, E>> + Send + 'static,
163 F: FnOnce(RunToken) -> Fu,
164 >(
165 self,
166 fun: F,
167 ) -> Arc<Task<T, E>> {
168 let fut = fun(self.run_token.clone());
169 let id = self.id;
170 let mut tasks = TASKS.lock().unwrap();
172 debug!("Started task {} ({})", self.name, id);
173 let join_handle = tokio::spawn(async move {
174 let g = scope_guard(|| {
175 if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
176 t._internal_handle_finished(FinishState::Drop);
177 }
178 });
179 let r = fut.await;
180 let s = match &r {
181 Ok(_) => FinishState::Success,
182 Err(e) => FinishState::Failure(e),
183 };
184 g.release();
185 if let Some(t) = TASKS.lock().unwrap().get_or_insert_default().remove(&id) {
186 t._internal_handle_finished(s);
187 }
188 r
189 });
190 let task = Arc::new(Task {
191 id: self.id,
192 name: self.name,
193 critical: self.critical,
194 main: self.main,
195 abort: self.abort,
196 no_shutdown: self.no_shutdown,
197 shutdown_order: self.shutdown_order,
198 run_token: self.run_token,
199 start_time: std::time::SystemTime::now()
200 .duration_since(std::time::UNIX_EPOCH)
201 .unwrap()
202 .as_secs_f64(),
203 join_handle: Mutex::new(Some(join_handle)),
204 });
205 tasks.get_or_insert_default().insert(self.id, task.clone());
206 task
207 }
208
209 #[cfg(feature = "ordered-locks")]
211 pub fn create_with_lock_token<
212 T: 'static + Send + Sync,
213 E: std::fmt::Debug + Sync + Send + 'static,
214 Fu: Future<Output = Result<T, E>> + Send + 'static,
215 F: FnOnce(RunToken, CleanLockToken) -> Fu,
216 >(
217 self,
218 fun: F,
219 ) -> Arc<Task<T, E>> {
220 self.create(|run_token| fun(run_token, unsafe { CleanLockToken::new() }))
221 }
222}
223
224pub trait TaskBase: Send + Sync {
226 #[doc(hidden)]
227 fn _internal_handle_finished(&self, state: FinishState);
228 fn shutdown_order(&self) -> i32;
230 fn name(&self) -> &str;
232 fn id(&self) -> usize;
234 fn main(&self) -> bool;
236 fn abort(&self) -> bool;
238 fn critical(&self) -> bool;
240 fn start_time(&self) -> f64;
242 fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()>;
244 fn run_token(&self) -> &RunToken;
246 fn no_shutdown(&self) -> bool;
248}
249
250pub struct Task<T: Send + Sync, E: Sync + Sync> {
252 id: usize,
253 name: Cow<'static, str>,
254 critical: bool,
255 main: bool,
256 abort: bool,
257 no_shutdown: bool,
258 shutdown_order: i32,
259 run_token: RunToken,
260 start_time: f64,
261 join_handle: Mutex<Option<JoinHandle<Result<T, E>>>>,
262}
263
264impl<T: Send + Sync + 'static, E: Send + Sync + 'static> TaskBase for Task<T, E> {
265 fn shutdown_order(&self) -> i32 {
266 self.shutdown_order
267 }
268
269 fn name(&self) -> &str {
270 self.name.as_ref()
271 }
272
273 fn id(&self) -> usize {
274 self.id
275 }
276
277 fn _internal_handle_finished(&self, state: FinishState) {
278 match state {
279 FinishState::Success => {
280 if !self.main
281 || !shutdown(format!(
282 "Main task {} ({}) finished unexpected",
283 self.name, self.id
284 ))
285 {
286 debug!("Finished task {} ({})", self.name, self.id);
287 }
288 }
289 FinishState::Drop => {
290 if self.main || self.critical {
291 if shutdown(format!("Critical task {} ({}) dropped", self.name, self.id)) {
292 } else if !self.abort {
293 error!("Critical task {} ({}) dropped", self.name, self.id);
295 } else {
296 debug!("Critical task {} ({}) dropped", self.name, self.id)
297 }
298 } else if !self.abort {
299 error!("Task {} ({}) dropped", self.name, self.id);
301 } else {
302 debug!("Task {} ({}) dropped", self.name, self.id)
303 }
304 }
305 FinishState::JoinError(e) => {
306 if (!self.main && !self.critical)
307 || !shutdown(format!(
308 "Join error in critical task {} ({}): {:?}",
309 self.name, self.id, e
310 ))
311 {
312 error!("Join error in task {} ({}): {:?}", self.name, self.id, e);
313 }
314 }
315 FinishState::Failure(e) => {
316 if (!self.main && !self.critical)
317 || !shutdown(format!(
318 "Failure in critical task {} ({}) @ {:?}: {:?}",
319 self.name,
320 self.id,
321 self.run_token().location(),
322 e
323 ))
324 {
325 let location = self.run_token().location();
326 error!(
327 "Failure in task {} ({}) @ {:?}: {:?}",
328 self.name, self.id, location, e
329 );
330 }
331 }
332 FinishState::Abort => {
333 if !self.main
334 || !shutdown(format!(
335 "Main task {} ({}) aborted unexpected",
336 self.name, self.id
337 ))
338 {
339 debug!("Aborted task {} ({})", self.name, self.id);
340 }
341 }
342 }
343 }
344
345 fn cancel(self: Arc<Self>) -> BoxFuture<'static, ()> {
346 Box::pin(self.cancel())
347 }
348
349 fn main(&self) -> bool {
350 self.main
351 }
352
353 fn abort(&self) -> bool {
354 self.abort
355 }
356
357 fn critical(&self) -> bool {
358 self.critical
359 }
360
361 fn start_time(&self) -> f64 {
362 self.start_time
363 }
364
365 fn run_token(&self) -> &RunToken {
366 &self.run_token
367 }
368
369 fn no_shutdown(&self) -> bool {
370 self.no_shutdown
371 }
372}
373
374#[derive(Debug)]
376pub enum WaitError<E: Send + Sync> {
377 HandleUnset(String),
379 JoinError(tokio::task::JoinError),
381 TaskFailure(E),
383}
384
385impl<E: std::fmt::Display + Send + Sync> std::fmt::Display for WaitError<E> {
386 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
387 match self {
388 WaitError::HandleUnset(v) => write!(f, "Handle unset: {v}"),
389 WaitError::JoinError(v) => write!(f, "Join Error: {v}"),
390 WaitError::TaskFailure(v) => write!(f, "Task Failure: {v}"),
391 }
392 }
393}
394
395impl<E: std::error::Error + Send + Sync> std::error::Error for WaitError<E> {}
396
397struct TaskJoinHandleBorrow<'a, T: Send + Sync, E: Send + Sync> {
398 task: &'a Arc<Task<T, E>>,
399 jh: Option<JoinHandle<Result<T, E>>>,
400}
401
402impl<'a, T: Send + Sync, E: Send + Sync> TaskJoinHandleBorrow<'a, T, E> {
403 fn new(task: &'a Arc<Task<T, E>>) -> Self {
404 let jh = task.join_handle.lock().unwrap().take();
405 Self { task, jh }
406 }
407}
408
409impl<'a, T: Send + Sync, E: Send + Sync> Drop for TaskJoinHandleBorrow<'a, T, E> {
410 fn drop(&mut self) {
411 *self.task.join_handle.lock().unwrap() = self.jh.take();
412 }
413}
414
415impl<T: Send + Sync, E: Send + Sync> Task<T, E> {
416 pub async fn cancel(self: Arc<Self>) {
420 let mut b = TaskJoinHandleBorrow::new(&self);
421 self.run_token.cancel();
422 if let Some(jh) = &mut b.jh {
423 if self.abort {
424 jh.abort();
425 let _ = jh.await;
426 if let Some(t) = TASKS
427 .lock()
428 .unwrap()
429 .get_or_insert_default()
430 .remove(&self.id)
431 {
432 t._internal_handle_finished(FinishState::Abort);
433 }
434 } else if let Err(e) = jh.await {
435 info!("Unable to join task {e:?}");
436 if let Some(t) = TASKS
437 .lock()
438 .unwrap()
439 .get_or_insert_default()
440 .remove(&self.id)
441 {
442 t._internal_handle_finished(FinishState::JoinError(e));
443 }
444 }
445 }
446 if !SHUTTING_DOWN.load(Ordering::SeqCst) {
447 info!(" canceled {} ({})", self.name, self.id);
448 }
449 std::mem::forget(b);
450 }
451
452 pub async fn wait(self: Arc<Self>) -> Result<T, WaitError<E>> {
454 let mut b = TaskJoinHandleBorrow::new(&self);
455 let r = match &mut b.jh {
456 None => Err(WaitError::HandleUnset(self.name.to_string())),
457 Some(jh) => match jh.await {
458 Ok(Ok(v)) => Ok(v),
459 Ok(Err(e)) => Err(WaitError::TaskFailure(e)),
460 Err(e) => Err(WaitError::JoinError(e)),
461 },
462 };
463 std::mem::forget(b);
464 r
465 }
466}
467struct WaitTasks<'a, Sleep, Fut>(Sleep, &'a mut Vec<(String, usize, Fut, RunToken)>);
468impl<'a, Sleep: Unpin, Fut: Unpin> Unpin for WaitTasks<'a, Sleep, Fut> {}
469impl<'a, Sleep: Future + Unpin, Fut: Future + Unpin> Future for WaitTasks<'a, Sleep, Fut> {
470 type Output = bool;
471
472 fn poll(mut self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<bool> {
473 if self.0.poll_unpin(cx).is_ready() {
474 return Poll::Ready(false);
475 }
476
477 self.1
478 .retain_mut(|(_, _, f, _)| !matches!(f.poll_unpin(cx), Poll::Ready(_)));
479
480 if self.1.is_empty() {
481 Poll::Ready(true)
482 } else {
483 Poll::Pending
484 }
485 }
486}
487
488pub fn shutdown(message: String) -> bool {
490 if SHUTTING_DOWN
491 .compare_exchange(false, true, Ordering::SeqCst, Ordering::SeqCst)
492 .is_err()
493 {
494 return false;
496 }
497 info!("Shutting down: {message}");
498 tokio::spawn(async move {
499 let mut shutdown_tasks: Vec<Arc<dyn TaskBase>> = Vec::new();
500 loop {
501 for (_, task) in TASKS.lock().unwrap().get_or_insert_default().iter() {
502 if task.no_shutdown() {
503 continue;
504 }
505 if let Some(t) = shutdown_tasks.first() {
506 if t.shutdown_order() < task.shutdown_order() {
507 continue;
508 }
509 if t.shutdown_order() > task.shutdown_order() {
510 shutdown_tasks.clear();
511 }
512 }
513 shutdown_tasks.push(task.clone());
514 }
515 if shutdown_tasks.is_empty() {
516 break;
517 }
518 info!(
519 "shutting down {} tasks with order {}",
520 shutdown_tasks.len(),
521 shutdown_tasks[0].shutdown_order()
522 );
523 let mut stop_futures: Vec<(String, usize, _, RunToken)> = shutdown_tasks
524 .iter()
525 .map(|t| {
526 (
527 t.name().to_string(),
528 t.id(),
529 t.clone().cancel(),
530 t.run_token().clone(),
531 )
532 })
533 .collect();
534 while !WaitTasks(
535 Box::pin(tokio::time::sleep(tokio::time::Duration::from_secs(30))),
536 &mut stop_futures,
537 )
538 .await
539 {
540 info!("still waiting for {} tasks", stop_futures.len(),);
541 for (name, id, _, rt) in &stop_futures {
542 if let Some((file, line)) = rt.location() {
543 info!(" {name} ({id}) at {file}:{line}");
544 } else {
545 info!(" {name} ({id})");
546 }
547 }
548 }
549 shutdown_tasks.clear();
550 }
551 info!("shutdown done");
552 SHUTDOWN_NOTIFY.notify_waiters();
553 });
554 true
555}
556
557pub async fn run_tasks() {
559 SHUTDOWN_NOTIFY.notified().await
560}
561
562pub fn list_tasks() -> Vec<Arc<dyn TaskBase>> {
564 TASKS
565 .lock()
566 .unwrap()
567 .get_or_insert_default()
568 .values()
569 .cloned()
570 .collect()
571}
572
573pub fn try_list_tasks_for(duration: std::time::Duration) -> Option<Vec<Arc<dyn TaskBase>>> {
576 let tries = 50;
577 for _ in 0..tries {
578 if let Ok(mut tasks) = TASKS.try_lock() {
579 return Some(tasks.get_or_insert_default().values().cloned().collect());
580 }
581 std::thread::sleep(duration / tries);
582 }
583 if let Ok(mut tasks) = TASKS.try_lock() {
584 return Some(tasks.get_or_insert_default().values().cloned().collect());
585 }
586 None
587}