pipeworks_tasks/
task_tree.rs

1use std::{
2    future::Future,
3    sync::{
4        atomic::{AtomicU64, Ordering},
5        Arc,
6    },
7};
8
9use tokio_util::{
10    sync::{CancellationToken, DropGuard},
11    task::TaskTracker,
12};
13
14// Task IDs
15#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Copy, Clone)]
16pub struct TaskId(pub u64);
17
18static GLOBAL_ID_COUNTER: AtomicU64 = AtomicU64::new(0);
19
20impl TaskId {
21    pub fn next() -> TaskId {
22        TaskId(GLOBAL_ID_COUNTER.fetch_add(1, Ordering::Relaxed))
23    }
24
25    pub fn current() -> TaskId {
26        TASK_ID.with(Clone::clone)
27    }
28
29    pub fn try_current() -> Option<TaskId> {
30        TASK_ID.try_with(Clone::clone).ok()
31    }
32
33    pub fn try_parent() -> Option<TaskId> {
34        PARENT_TASK_ID.try_with(Clone::clone).ok().flatten()
35    }
36}
37
38tokio::task_local! {
39    static TASK_TREE: Arc<TaskTree>;
40    static TASK_ID: TaskId;
41    static PARENT_TASK_ID: Option<TaskId>;
42}
43
44pub struct TaskTree {
45    pub cancel: CancellationToken,
46    pub tasks: TaskTracker,
47    _drop_guard: DropGuard,
48}
49
50impl TaskTree {
51    pub fn spawn_current_tree<F>(task_id: TaskId, future: F) -> (Arc<Self>, TaskId)
52    where
53        F: Future<Output = ()> + Send + Sync + 'static,
54    {
55        let task_tree = Arc::new(
56            TASK_TREE
57                .try_with(|t| TaskTree {
58                    cancel: t.cancel.clone(),
59                    tasks: t.tasks.clone(),
60                    _drop_guard: t.cancel.clone().drop_guard(),
61                })
62                .unwrap_or_default(),
63        );
64
65        let future = TASK_TREE.scope(task_tree.clone(), future);
66        let future = TASK_ID.scope(task_id, future);
67        let future = PARENT_TASK_ID.scope(TaskId::try_current(), future);
68
69        task_tree.spawn(future);
70
71        (task_tree, task_id)
72    }
73
74    pub fn spawn_isolated_subtree<F>(task_id: TaskId, future: F) -> (Arc<Self>, TaskId)
75    where
76        F: Future<Output = ()> + Send + Sync + 'static,
77    {
78        let task_tree = Arc::new(
79            TASK_TREE
80                .try_with(|t| {
81                    let cancel = t.cancel.child_token();
82                    TaskTree {
83                        cancel: cancel.clone(),
84                        tasks: TaskTracker::new(),
85                        _drop_guard: cancel.drop_guard(),
86                    }
87                })
88                .unwrap_or_default(),
89        );
90
91        let future = TASK_TREE.scope(task_tree.clone(), future);
92        let future = TASK_ID.scope(task_id, future);
93        let future = PARENT_TASK_ID.scope(TaskId::try_current(), future);
94
95        task_tree.spawn(future);
96
97        (task_tree, task_id)
98    }
99
100    fn spawn<F>(&self, future: F)
101    where
102        F: Future<Output = ()> + Send + Sync + 'static,
103    {
104        // Make the task cancellable
105        let future = self.cancel.clone().run_until_cancelled_owned(future);
106
107        // When any task fails, it should fail the whole sub-tree.
108        let drop_guard = self.cancel.clone().drop_guard();
109        let future = async move {
110            let _drop_guard = drop_guard;
111            future.await;
112        };
113
114        self.tasks.spawn(future);
115
116        // Close the Task Tracker after the first task is spawned.
117        self.tasks.close();
118    }
119}
120
121impl Default for TaskTree {
122    fn default() -> Self {
123        let cancel = CancellationToken::new();
124        Self {
125            cancel: cancel.clone(),
126            tasks: TaskTracker::new(),
127            _drop_guard: cancel.drop_guard(),
128        }
129    }
130}