pipeworks_tasks/
task_tree.rs1use std::{
2 future::Future,
3 sync::{
4 atomic::{AtomicU64, Ordering},
5 Arc,
6 },
7};
8
9use tokio_util::{
10 sync::{CancellationToken, DropGuard},
11 task::TaskTracker,
12};
13
14#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone)]
16pub struct TaskId(pub u64);
17
18static GLOBAL_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
19
20impl TaskId {
21 pub fn next() -> TaskId {
22 TaskId(GLOBAL_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
23 }
24
25 pub fn current() -> TaskId {
26 TASK_ID.with(Clone::clone)
27 }
28
29 pub fn try_current() -> Option<TaskId> {
30 TASK_ID.try_with(Clone::clone).ok()
31 }
32
33 pub fn try_parent() -> Option<TaskId> {
34 PARENT_TASK_ID.try_with(Clone::clone).ok().flatten()
35 }
36}
37
38tokio::task_local! {
39 static TASK_TREE: Arc<TaskTree>;
40 static TASK_ID: TaskId;
41 static PARENT_TASK_ID: Option<TaskId>;
42}
43
44pub struct TaskTree {
45 pub cancel: CancellationToken,
46 pub tasks: TaskTracker,
47 _drop_guard: DropGuard,
48}
49
50impl TaskTree {
51 pub fn spawn_current_tree<F>(task_id: TaskId, future: F) -> (Arc<Self>, TaskId)
52 where
53 F: Future<Output = ()> + Send + Sync + 'static,
54 {
55 let task_tree = Arc::new(
56 TASK_TREE
57 .try_with(|t| TaskTree {
58 cancel: t.cancel.clone(),
59 tasks: t.tasks.clone(),
60 _drop_guard: t.cancel.clone().drop_guard(),
61 })
62 .unwrap_or_default(),
63 );
64
65 let future = TASK_TREE.scope(task_tree.clone(), future);
66 let future = TASK_ID.scope(task_id, future);
67 let future = PARENT_TASK_ID.scope(TaskId::try_current(), future);
68
69 task_tree.spawn(future);
70
71 (task_tree, task_id)
72 }
73
74 pub fn spawn_isolated_subtree<F>(task_id: TaskId, future: F) -> (Arc<Self>, TaskId)
75 where
76 F: Future<Output = ()> + Send + Sync + 'static,
77 {
78 let task_tree = Arc::new(
79 TASK_TREE
80 .try_with(|t| {
81 let cancel = t.cancel.child_token();
82 TaskTree {
83 cancel: cancel.clone(),
84 tasks: TaskTracker::new(),
85 _drop_guard: cancel.drop_guard(),
86 }
87 })
88 .unwrap_or_default(),
89 );
90
91 let future = TASK_TREE.scope(task_tree.clone(), future);
92 let future = TASK_ID.scope(task_id, future);
93 let future = PARENT_TASK_ID.scope(TaskId::try_current(), future);
94
95 task_tree.spawn(future);
96
97 (task_tree, task_id)
98 }
99
100 fn spawn<F>(&self, future: F)
101 where
102 F: Future<Output = ()> + Send + Sync + 'static,
103 {
104 let future = self.cancel.clone().run_until_cancelled_owned(future);
106
107 let drop_guard = self.cancel.clone().drop_guard();
109 let future = async move {
110 let _drop_guard = drop_guard;
111 future.await;
112 };
113
114 self.tasks.spawn(future);
115
116 self.tasks.close();
118 }
119}
120
121impl Default for TaskTree {
122 fn default() -> Self {
123 let cancel = CancellationToken::new();
124 Self {
125 cancel: cancel.clone(),
126 tasks: TaskTracker::new(),
127 _drop_guard: cancel.drop_guard(),
128 }
129 }
130}