Skip to main content

ll/
task_tree.rs

1use crate::data::{Data, DataEntry, DataValue};
2use crate::reporters::{EventQueue, Reporter, TaskEvent};
3use crate::task::{Task, TaskData};
4use crate::uniq_id::UniqID;
5use anyhow::{Context, Result};
6use std::collections::{BTreeMap, BTreeSet};
7use std::future::Future;
8use std::sync::Arc;
9use std::sync::{Mutex, RwLock};
10use web_time::SystemTime;
11
12lazy_static::lazy_static! {
13    pub static ref TASK_TREE: Arc<TaskTree>  = TaskTree::new();
14}
15
16pub fn add_reporter(reporter: Arc<dyn Reporter>) {
17    TASK_TREE.add_reporter(reporter);
18}
19
20pub trait ErrorFormatter: Send + Sync {
21    fn format_error(&self, err: &anyhow::Error) -> String;
22}
23
24pub struct TaskTree {
25    pub tree_internal: RwLock<TaskTreeInternal>,
26}
27
28pub struct TaskTreeInternal {
29    pub tasks_internal: BTreeMap<UniqID, TaskInternal>,
30    parent_to_children: BTreeMap<UniqID, BTreeSet<UniqID>>,
31    child_to_parents: BTreeMap<UniqID, BTreeSet<UniqID>>,
32    root_tasks: BTreeSet<UniqID>,
33    event_queues: Vec<EventQueue>,
34    data_transitive: Data,
35    hide_errors_default_msg: Option<Arc<String>>,
36    attach_transitive_data_to_errors_default: bool,
37    error_formatter: Option<Arc<dyn ErrorFormatter>>,
38}
39
40#[derive(Clone)]
41pub struct TaskInternal {
42    pub id: UniqID,
43    pub name: String,
44    pub parent_names: Vec<String>,
45    pub started_at: SystemTime,
46    pub status: TaskStatus,
47    pub data: Data,
48    pub data_transitive: Data,
49    pub tags: BTreeSet<String>,
50    /// optional tuple containing values indicating task progress, where
51    /// first value is how many items finished and the second value is how many
52    /// items there are total. E.g. if it's a task processing 10 pieces of work,
53    /// (1, 10) would mean that 1 out of ten pieces is done.
54    pub progress: Option<(i64, i64)>,
55    pub hide_errors: Option<Arc<String>>,
56    pub attach_transitive_data_to_errors: bool,
57}
58
59#[derive(Clone)]
60pub enum TaskStatus {
61    Running,
62    Finished(TaskResult, SystemTime),
63}
64
65#[derive(Clone)]
66pub enum TaskResult {
67    Success,
68    Failure(String),
69}
70
71// ── TaskTree ─────────────────────────────────────────────────────
72
73impl TaskTree {
74    pub fn new() -> Arc<Self> {
75        Arc::new(Self {
76            tree_internal: RwLock::new(TaskTreeInternal {
77                tasks_internal: BTreeMap::new(),
78                parent_to_children: BTreeMap::new(),
79                child_to_parents: BTreeMap::new(),
80                root_tasks: BTreeSet::new(),
81                event_queues: vec![],
82                data_transitive: Data::empty(),
83                hide_errors_default_msg: None,
84                attach_transitive_data_to_errors_default: true,
85                error_formatter: None,
86            }),
87        })
88    }
89
90    pub fn create_task(self: &Arc<Self>, name: &str) -> Task {
91        let id = self.create_task_internal(name, None);
92        Task(Arc::new(TaskData {
93            id,
94            task_tree: self.clone(),
95            mark_done_on_drop: true,
96        }))
97    }
98
99    pub fn add_reporter(&self, reporter: Arc<dyn Reporter>) {
100        let queue: EventQueue = Arc::new(Mutex::new(Vec::new()));
101        self.tree_internal
102            .write()
103            .unwrap()
104            .event_queues
105            .push(queue.clone());
106        reporter.start(queue);
107    }
108
109    fn pre_spawn(self: &Arc<Self>, name: String, parent: Option<UniqID>) -> Task {
110        Task(Arc::new(TaskData {
111            id: self.create_task_internal(&name, parent),
112            task_tree: self.clone(),
113            mark_done_on_drop: false,
114        }))
115    }
116
117    fn post_spawn<T>(self: &Arc<Self>, id: UniqID, result: Result<T>) -> Result<T> {
118        let result = result.with_context(|| {
119            let mut desc = String::from("[Task]");
120            if let Some(task_internal) = self.get_cloned_task(id) {
121                desc.push_str(&format!(" {}", task_internal.name));
122                if task_internal.attach_transitive_data_to_errors {
123                    for (k, v) in task_internal.all_data() {
124                        desc.push_str(&format!("\n  {k}: {}", v.0));
125                    }
126                } else {
127                    for (k, v) in &task_internal.data.map {
128                        desc.push_str(&format!("\n  {k}: {}", v.0));
129                    }
130                };
131                if !desc.is_empty() {
132                    desc.push('\n');
133                }
134            }
135            desc
136        });
137        let error_msg = if let Err(err) = &result {
138            let formatter = {
139                let formatter = self.tree_internal.read().unwrap().error_formatter.clone();
140                formatter
141            };
142            if let Some(formatter) = formatter {
143                Some(formatter.format_error(err))
144            } else {
145                Some(format!("{err:?}"))
146            }
147        } else {
148            None
149        };
150        self.mark_done(id, error_msg);
151        result
152    }
153
154    pub fn spawn_sync<F, T>(
155        self: &Arc<Self>,
156        name: String,
157        f: F,
158        parent: Option<UniqID>,
159    ) -> Result<T>
160    where
161        F: FnOnce(Task) -> Result<T>,
162        T: Send,
163    {
164        let task = self.pre_spawn(name, parent);
165        let id = task.0.id;
166        let result = f(task);
167        self.post_spawn(id, result)
168    }
169
170    pub(crate) async fn spawn<F, FT, T>(
171        self: &Arc<Self>,
172        name: String,
173        f: F,
174        parent: Option<UniqID>,
175    ) -> Result<T>
176    where
177        F: FnOnce(Task) -> FT,
178        FT: Future<Output = Result<T>> + Send,
179        T: Send,
180    {
181        let task = self.pre_spawn(name, parent);
182        let id = task.0.id;
183        let result = f(task).await;
184        self.post_spawn(id, result)
185    }
186
187    pub fn create_task_internal<S: Into<String>>(
188        self: &Arc<Self>,
189        name: S,
190        parent: Option<UniqID>,
191    ) -> UniqID {
192        let mut tree = self.tree_internal.write().unwrap();
193
194        let mut parent_names = vec![];
195        let mut data_transitive = tree.data_transitive.clone();
196        let (name, tags) = crate::utils::extract_tags(name.into());
197        let id = UniqID::new();
198        if let Some(parent_task) = parent.and_then(|pid| tree.tasks_internal.get(&pid)) {
199            parent_names = parent_task.parent_names.clone();
200            parent_names.push(parent_task.name.clone());
201            data_transitive.merge(&parent_task.data_transitive);
202            let parent_id = parent_task.id;
203
204            tree.parent_to_children
205                .entry(parent_id)
206                .or_default()
207                .insert(id);
208            tree.child_to_parents
209                .entry(id)
210                .or_default()
211                .insert(parent_id);
212        } else {
213            tree.root_tasks.insert(id);
214        }
215
216        let task_internal = TaskInternal {
217            status: TaskStatus::Running,
218            name,
219            parent_names,
220            id,
221            started_at: SystemTime::now(),
222            data: Data::empty(),
223            data_transitive,
224            tags,
225            progress: None,
226            hide_errors: tree.hide_errors_default_msg.clone(),
227            attach_transitive_data_to_errors: tree.attach_transitive_data_to_errors_default,
228        };
229
230        tree.tasks_internal.insert(id, task_internal.clone());
231
232        // Push start event to all reporter queues.
233        let task_arc = Arc::new(task_internal);
234        for queue in &tree.event_queues {
235            queue
236                .lock()
237                .unwrap()
238                .push(TaskEvent::Start(task_arc.clone()));
239        }
240
241        id
242    }
243
244    pub fn mark_done(&self, id: UniqID, error_message: Option<String>) {
245        let mut tree = self.tree_internal.write().unwrap();
246        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
247            task_internal.mark_done(error_message);
248
249            // Push end event to all reporter queues.
250            let task_arc = Arc::new(task_internal.clone());
251            for queue in &tree.event_queues {
252                queue.lock().unwrap().push(TaskEvent::End(task_arc.clone()));
253            }
254
255            // Clean up this task and any finished ancestors.
256            tree.try_remove(id);
257        }
258    }
259
260    pub fn add_data<S: Into<String>, D: Into<DataValue>>(&self, id: UniqID, key: S, value: D) {
261        let mut tree = self.tree_internal.write().unwrap();
262        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
263            task_internal.data.add(key, value);
264        }
265    }
266
267    pub fn get_data<S: Into<String>>(&self, id: UniqID, key: S) -> Option<DataValue> {
268        let tree = self.tree_internal.read().unwrap();
269        if let Some(task_internal) = tree.tasks_internal.get(&id) {
270            let all_data: BTreeMap<_, _> = task_internal.all_data().collect();
271            return all_data.get(&key.into()).map(|de| de.0.clone());
272        }
273        None
274    }
275
276    pub(crate) fn add_data_transitive_for_task<S: Into<String>, D: Into<DataValue>>(
277        &self,
278        id: UniqID,
279        key: S,
280        value: D,
281    ) {
282        let mut tree = self.tree_internal.write().unwrap();
283        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
284            task_internal.data_transitive.add(key, value);
285        }
286    }
287
288    pub fn hide_errors_default_msg<S: Into<String>>(&self, msg: Option<S>) {
289        let mut tree = self.tree_internal.write().unwrap();
290        let msg = msg.map(|msg| Arc::new(msg.into()));
291        tree.hide_errors_default_msg = msg;
292    }
293
294    pub(crate) fn hide_error_msg_for_task(&self, id: UniqID, msg: Option<Arc<String>>) {
295        let mut tree = self.tree_internal.write().unwrap();
296        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
297            task_internal.hide_errors = msg;
298        }
299    }
300
301    pub fn attach_transitive_data_to_errors_default(&self, val: bool) {
302        let mut tree = self.tree_internal.write().unwrap();
303        tree.attach_transitive_data_to_errors_default = val;
304    }
305
306    pub(crate) fn attach_transitive_data_to_errors_for_task(&self, id: UniqID, val: bool) {
307        let mut tree = self.tree_internal.write().unwrap();
308        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
309            task_internal.attach_transitive_data_to_errors = val;
310        }
311    }
312
313    pub fn set_error_formatter(&self, error_formatter: Option<Arc<dyn ErrorFormatter>>) {
314        let mut tree = self.tree_internal.write().unwrap();
315        tree.error_formatter = error_formatter;
316    }
317
318    pub fn add_data_transitive<S: Into<String>, D: Into<DataValue>>(&self, key: S, value: D) {
319        let mut tree = self.tree_internal.write().unwrap();
320        tree.data_transitive.add(key, value);
321    }
322
323    pub fn task_progress(&self, id: UniqID, done: i64, total: i64) {
324        let mut tree = self.tree_internal.write().unwrap();
325        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
326            task_internal.progress = Some((done, total));
327
328            // Push progress event to all reporter queues.
329            let task_arc = Arc::new(task_internal.clone());
330            for queue in &tree.event_queues {
331                queue
332                    .lock()
333                    .unwrap()
334                    .push(TaskEvent::Progress(task_arc.clone()));
335            }
336        }
337    }
338
339    fn get_cloned_task(&self, id: UniqID) -> Option<TaskInternal> {
340        let tree = self.tree_internal.read().unwrap();
341        tree.get_task(id).ok().cloned()
342    }
343}
344
345// ── TaskTreeInternal ─────────────────────────────────────────────
346
347#[allow(dead_code)]
348impl TaskTreeInternal {
349    pub fn get_task(&self, id: UniqID) -> Result<&TaskInternal> {
350        self.tasks_internal.get(&id).context("task must be present")
351    }
352
353    pub fn root_tasks(&self) -> &BTreeSet<UniqID> {
354        &self.root_tasks
355    }
356
357    pub fn child_to_parents(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
358        &self.child_to_parents
359    }
360
361    pub fn parent_to_children(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
362        &self.parent_to_children
363    }
364
365    /// Remove a finished task if all its children are also gone.
366    /// Then cascade up to the parent — it may now be removable too.
367    fn try_remove(&mut self, id: UniqID) {
368        if let Some(children) = self.parent_to_children.get(&id) {
369            if !children.is_empty() {
370                return;
371            }
372        }
373
374        let is_finished = self
375            .tasks_internal
376            .get(&id)
377            .is_some_and(|t| matches!(t.status, TaskStatus::Finished(..)));
378        if !is_finished {
379            return;
380        }
381
382        self.tasks_internal.remove(&id);
383        self.parent_to_children.remove(&id);
384        self.root_tasks.remove(&id);
385
386        if let Some(parents) = self.child_to_parents.remove(&id) {
387            for parent_id in parents {
388                if let Some(children) = self.parent_to_children.get_mut(&parent_id) {
389                    children.remove(&id);
390                }
391                self.try_remove(parent_id);
392            }
393        }
394    }
395}
396
397// ── TaskInternal ─────────────────────────────────────────────────
398
399impl TaskInternal {
400    pub(crate) fn mark_done(&mut self, error_message: Option<String>) {
401        let task_status = match error_message {
402            None => TaskResult::Success,
403            Some(msg) => TaskResult::Failure(msg),
404        };
405        self.status = TaskStatus::Finished(task_status, SystemTime::now());
406    }
407
408    pub fn full_name(&self) -> String {
409        let mut full_name = String::new();
410        for parent_name in &self.parent_names {
411            full_name.push_str(parent_name);
412            full_name.push(':');
413        }
414        full_name.push_str(&self.name);
415        full_name
416    }
417
418    pub fn all_data(
419        &self,
420    ) -> std::iter::Chain<
421        std::collections::btree_map::Iter<'_, String, DataEntry>,
422        std::collections::btree_map::Iter<'_, String, DataEntry>,
423    > {
424        self.data.map.iter().chain(self.data_transitive.map.iter())
425    }
426}