ll/
task_tree.rs

1use crate::data::{Data, DataEntry, DataValue};
2use crate::reporters::Reporter;
3use crate::task::{Task, TaskData};
4use crate::uniq_id::UniqID;
5use anyhow::{Context, Result};
6use std::collections::{BTreeMap, BTreeSet, HashMap};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::sync::RwLock;
11use std::thread;
12use std::time::Duration;
13use std::time::SystemTime;
14
15lazy_static::lazy_static! {
16    pub static ref TASK_TREE: Arc<TaskTree>  = TaskTree::new();
17}
18
19pub fn add_reporter(reporter: Arc<dyn Reporter>) {
20    TASK_TREE.add_reporter(reporter);
21}
22
23pub trait ErrorFormatter: Send + Sync {
24    fn format_error(&self, err: &anyhow::Error) -> String;
25}
26
27pub struct TaskTree {
28    pub(crate) tree_internal: RwLock<TaskTreeInternal>,
29    /// If true, it will block the current thread until all task events are
30    /// reported (e.g. written to STDOUT)
31    force_flush: AtomicBool,
32}
33
34pub(crate) struct TaskTreeInternal {
35    tasks_internal: BTreeMap<UniqID, TaskInternal>,
36    parent_to_children: BTreeMap<UniqID, BTreeSet<UniqID>>,
37    child_to_parents: BTreeMap<UniqID, BTreeSet<UniqID>>,
38    root_tasks: BTreeSet<UniqID>,
39    reporters: Vec<Arc<dyn Reporter>>,
40    tasks_marked_for_deletion: HashMap<UniqID, SystemTime>,
41    report_start: Vec<UniqID>,
42    report_end: Vec<UniqID>,
43    data_transitive: Data,
44    remove_task_after_done_ms: u64,
45    hide_errors_default_msg: Option<Arc<String>>,
46    attach_transitive_data_to_errors_default: bool,
47    error_formatter: Option<Arc<dyn ErrorFormatter>>,
48}
49
50#[derive(Clone)]
51pub struct TaskInternal {
52    pub id: UniqID,
53    pub name: String,
54    pub parent_names: Vec<String>,
55    pub started_at: SystemTime,
56    pub status: TaskStatus,
57    pub data: Data,
58    pub data_transitive: Data,
59    pub tags: BTreeSet<String>,
60    /// optional tuple containing values indicating task progress, where
61    /// first value is how many items finished and the second value is how many
62    /// items there are total. E.g. if it's a task processing 10 pieces of work,
63    /// (1, 10) would mean that 1 out of ten pieces is done.
64    pub progress: Option<(i64, i64)>,
65    pub hide_errors: Option<Arc<String>>,
66    pub attach_transitive_data_to_errors: bool,
67}
68
69#[derive(Clone)]
70pub enum TaskStatus {
71    Running,
72    Finished(TaskResult, SystemTime),
73}
74
75#[derive(Clone)]
76pub enum TaskResult {
77    Success,
78    Failure(String),
79}
80
81impl TaskTree {
82    pub fn new() -> Arc<Self> {
83        let s = Arc::new(Self {
84            tree_internal: RwLock::new(TaskTreeInternal {
85                tasks_internal: BTreeMap::new(),
86                parent_to_children: BTreeMap::new(),
87                child_to_parents: BTreeMap::new(),
88                root_tasks: BTreeSet::new(),
89                reporters: vec![],
90                tasks_marked_for_deletion: HashMap::new(),
91                report_start: vec![],
92                report_end: vec![],
93                data_transitive: Data::empty(),
94                remove_task_after_done_ms: 0,
95                hide_errors_default_msg: None,
96                attach_transitive_data_to_errors_default: true,
97                error_formatter: None,
98            }),
99            force_flush: AtomicBool::new(false),
100        });
101        let clone = s.clone();
102        tokio::spawn(async move {
103            loop {
104                tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
105                let mut tree = clone.tree_internal.write().unwrap();
106                tree.garbage_collect();
107            }
108        });
109        let clone = s.clone();
110        thread::spawn(move || loop {
111            thread::sleep(std::time::Duration::from_millis(10));
112            clone.report_all();
113        });
114
115        s
116    }
117
118    pub fn set_force_flush(&self, enabled: bool) {
119        self.force_flush.store(enabled, Ordering::SeqCst)
120    }
121
122    pub fn force_flush_enabled(&self) -> bool {
123        self.force_flush.load(Ordering::SeqCst)
124    }
125
126    pub fn create_task(self: &Arc<Self>, name: &str) -> Task {
127        let id = self.create_task_internal(name, None);
128        Task(Arc::new(TaskData {
129            id,
130            task_tree: self.clone(),
131            mark_done_on_drop: true,
132        }))
133    }
134
135    pub fn add_reporter(&self, reporter: Arc<dyn Reporter>) {
136        self.tree_internal.write().unwrap().reporters.push(reporter);
137    }
138
139    fn pre_spawn(self: &Arc<Self>, name: String, parent: Option<UniqID>) -> Task {
140        let task = Task(Arc::new(TaskData {
141            id: self.create_task_internal(&name, parent),
142            task_tree: self.clone(),
143            mark_done_on_drop: false,
144        }));
145        self.maybe_force_flush();
146        task
147    }
148
149    fn post_spawn<T>(self: &Arc<Self>, id: UniqID, result: Result<T>) -> Result<T> {
150        let result = result.with_context(|| {
151            let mut desc = String::from("[Task]");
152            if let Some(task_internal) = self.get_cloned_task(id) {
153                desc.push_str(&format!(" {}", task_internal.name));
154                if task_internal.attach_transitive_data_to_errors {
155                    for (k, v) in task_internal.all_data() {
156                        desc.push_str(&format!("\n  {}: {}", k, v.0));
157                    }
158                } else {
159                    for (k, v) in &task_internal.data.map {
160                        desc.push_str(&format!("\n  {}: {}", k, v.0));
161                    }
162                };
163                if !desc.is_empty() {
164                    desc.push('\n');
165                }
166            }
167            desc
168        });
169        let error_msg = if let Err(err) = &result {
170            let formatter = {
171                let formatter = self.tree_internal.read().unwrap().error_formatter.clone();
172                formatter
173            };
174            if let Some(formatter) = formatter {
175                Some(formatter.format_error(err))
176            } else {
177                Some(format!("{:?}", err))
178            }
179        } else {
180            None
181        };
182        self.mark_done(id, error_msg);
183        self.maybe_force_flush();
184        result
185    }
186
187    pub fn spawn_sync<F, T>(
188        self: &Arc<Self>,
189        name: String,
190        f: F,
191        parent: Option<UniqID>,
192    ) -> Result<T>
193    where
194        F: FnOnce(Task) -> Result<T>,
195        T: Send,
196    {
197        let task = self.pre_spawn(name, parent);
198        let id = task.0.id;
199        let result = f(task);
200        self.post_spawn(id, result)
201    }
202
203    pub(crate) async fn spawn<F, FT, T>(
204        self: &Arc<Self>,
205        name: String,
206        f: F,
207        parent: Option<UniqID>,
208    ) -> Result<T>
209    where
210        F: FnOnce(Task) -> FT,
211        FT: Future<Output = Result<T>> + Send,
212        T: Send,
213    {
214        let task = self.pre_spawn(name, parent);
215        let id = task.0.id;
216        let result = f(task).await;
217        self.post_spawn(id, result)
218    }
219
220    pub fn create_task_internal<S: Into<String>>(
221        self: &Arc<Self>,
222        name: S,
223        parent: Option<UniqID>,
224    ) -> UniqID {
225        let mut tree = self.tree_internal.write().unwrap();
226
227        let mut parent_names = vec![];
228        let mut data_transitive = tree.data_transitive.clone();
229        let (name, tags) = crate::utils::extract_tags(name.into());
230        let id = UniqID::new();
231        if let Some(parent_task) = parent.and_then(|pid| tree.tasks_internal.get(&pid)) {
232            parent_names = parent_task.parent_names.clone();
233            parent_names.push(parent_task.name.clone());
234            data_transitive.merge(&parent_task.data_transitive);
235            let parent_id = parent_task.id;
236
237            tree.parent_to_children
238                .entry(parent_id)
239                .or_insert_with(BTreeSet::new)
240                .insert(id);
241            tree.child_to_parents
242                .entry(id)
243                .or_insert_with(BTreeSet::new)
244                .insert(parent_id);
245        } else {
246            tree.root_tasks.insert(id);
247        }
248
249        let task_internal = TaskInternal {
250            status: TaskStatus::Running,
251            name,
252            parent_names,
253            id,
254            started_at: SystemTime::now(),
255            data: Data::empty(),
256            data_transitive,
257            tags,
258            progress: None,
259            hide_errors: tree.hide_errors_default_msg.clone(),
260            attach_transitive_data_to_errors: tree.attach_transitive_data_to_errors_default,
261        };
262
263        tree.tasks_internal.insert(id, task_internal);
264        tree.report_start.push(id);
265
266        id
267    }
268
269    pub fn mark_done(&self, id: UniqID, error_message: Option<String>) {
270        let mut tree = self.tree_internal.write().unwrap();
271        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
272            task_internal.mark_done(error_message);
273            tree.mark_for_gc(id);
274            tree.report_end.push(id);
275        }
276    }
277
278    pub fn add_data<S: Into<String>, D: Into<DataValue>>(&self, id: UniqID, key: S, value: D) {
279        let mut tree = self.tree_internal.write().unwrap();
280        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
281            task_internal.data.add(key, value);
282        }
283    }
284
285    pub fn get_data<S: Into<String>>(&self, id: UniqID, key: S) -> Option<DataValue> {
286        let mut tree = self.tree_internal.write().unwrap();
287        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
288            let all_data: BTreeMap<_, _> = task_internal.all_data().collect();
289            return all_data.get(&key.into()).map(|de| de.0.clone());
290        }
291        None
292    }
293
294    pub(crate) fn add_data_transitive_for_task<S: Into<String>, D: Into<DataValue>>(
295        &self,
296        id: UniqID,
297        key: S,
298        value: D,
299    ) {
300        let mut tree = self.tree_internal.write().unwrap();
301        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
302            task_internal.data_transitive.add(key, value);
303        }
304    }
305    /// Reporters can use this flag to choose to not report errors.
306    /// This is useful for cases where there's a large task chain and every
307    /// single task reports a partial errors (that gets built up with each task)
308    /// It would make sense to report it only once at the top level (thrift
309    /// request, cli call, etc) and only mark other tasks.
310    /// If set to Some, the message inside is what would be reported by default
311    /// instead of reporting errors to avoid confusion (e.g. "error was hidden,
312    /// see ...")
313    pub fn hide_errors_default_msg<S: Into<String>>(&self, msg: Option<S>) {
314        let mut tree = self.tree_internal.write().unwrap();
315        let msg = msg.map(|msg| Arc::new(msg.into()));
316        tree.hide_errors_default_msg = msg;
317    }
318
319    pub(crate) fn hide_error_msg_for_task(&self, id: UniqID, msg: Option<Arc<String>>) {
320        let mut tree = self.tree_internal.write().unwrap();
321        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
322            task_internal.hide_errors = msg;
323        }
324    }
325
326    /// When errors occur, we attach task data to it in the description.
327    /// If set to false, only task direct data will be attached and not
328    /// transitive data. This is useful sometimes to remove the noise of
329    /// transitive data appearing in every error in the chain (e.g. hostname)
330    pub fn attach_transitive_data_to_errors_default(&self, val: bool) {
331        let mut tree = self.tree_internal.write().unwrap();
332        tree.attach_transitive_data_to_errors_default = val;
333    }
334
335    pub(crate) fn attach_transitive_data_to_errors_for_task(&self, id: UniqID, val: bool) {
336        let mut tree = self.tree_internal.write().unwrap();
337        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
338            task_internal.attach_transitive_data_to_errors = val;
339        }
340    }
341
342    /// Add a custom error formatter to change how error messages look in
343    /// reporters.
344    /// Unfortunately it is not configurable per reporter, because errors
345    /// normally don't implement `Clone` and it will be almost impossible to add
346    /// reference counters to all errors in all chains
347    pub fn set_error_formatter(&self, error_formatter: Option<Arc<dyn ErrorFormatter>>) {
348        let mut tree = self.tree_internal.write().unwrap();
349        tree.error_formatter = error_formatter;
350    }
351
352    /// Add transitive data to the task tree. This transitive data will be
353    /// added to every task created in this task tree
354    pub fn add_data_transitive<S: Into<String>, D: Into<DataValue>>(&self, key: S, value: D) {
355        let mut tree = self.tree_internal.write().unwrap();
356        tree.data_transitive.add(key, value);
357    }
358
359    pub fn task_progress(&self, id: UniqID, done: i64, total: i64) {
360        let mut tree = self.tree_internal.write().unwrap();
361        if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
362            task_internal.progress = Some((done, total));
363        }
364    }
365
366    fn get_cloned_task(&self, id: UniqID) -> Option<TaskInternal> {
367        let tree = self.tree_internal.read().unwrap();
368        tree.get_task(id).ok().cloned()
369    }
370
371    /// If force_flush set to true, this function will block the thread until everything
372    /// is reported. Useful for cases when the process exits before all async events
373    /// are reported and stuff is missing from stdout.
374    pub fn maybe_force_flush(&self) {
375        if self.force_flush.load(Ordering::SeqCst) {
376            self.report_all();
377        }
378    }
379
380    pub fn report_all(&self) {
381        let mut tree = self.tree_internal.write().unwrap();
382        let (start_tasks, end_tasks, reporters) = tree.get_tasks_and_reporters();
383        drop(tree);
384        for reporter in reporters {
385            for task in &start_tasks {
386                reporter.task_start(task.clone());
387            }
388            for task in &end_tasks {
389                reporter.task_end(task.clone());
390            }
391        }
392    }
393}
394
395impl TaskTreeInternal {
396    pub fn get_task(&self, id: UniqID) -> Result<&TaskInternal> {
397        self.tasks_internal.get(&id).context("task must be present")
398    }
399
400    pub fn root_tasks(&self) -> &BTreeSet<UniqID> {
401        &self.root_tasks
402    }
403
404    pub fn child_to_parents(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
405        &self.child_to_parents
406    }
407
408    pub fn parent_to_children(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
409        &self.parent_to_children
410    }
411
412    fn mark_for_gc(&mut self, id: UniqID) {
413        let mut stack = vec![id];
414
415        let mut tasks_to_finished_status = BTreeMap::new();
416
417        while let Some(id) = stack.pop() {
418            if let Some(task_internal) = self.tasks_internal.get(&id) {
419                tasks_to_finished_status
420                    .insert(id, matches!(task_internal.status, TaskStatus::Finished(..)));
421            }
422
423            for child_id in self.parent_to_children.get(&id).into_iter().flatten() {
424                stack.push(*child_id);
425            }
426        }
427
428        if tasks_to_finished_status
429            .iter()
430            .all(|(_, finished)| *finished)
431        {
432            for id in tasks_to_finished_status.keys().copied() {
433                self.tasks_marked_for_deletion
434                    .entry(id)
435                    .or_insert_with(SystemTime::now);
436            }
437
438            // This sub branch might have been holding other parent branches that
439            // weren't able to be garbage collected because of this subtree. we'll go
440            // level up and perform the same logic.
441            let parents = self.child_to_parents.get(&id).cloned().unwrap_or_default();
442            for parent_id in parents {
443                self.mark_for_gc(parent_id);
444            }
445        }
446    }
447
448    fn garbage_collect(&mut self) {
449        let mut will_delete = vec![];
450        for (id, time) in &self.tasks_marked_for_deletion {
451            if let Ok(elapsed) = time.elapsed() {
452                if elapsed > Duration::from_millis(self.remove_task_after_done_ms) {
453                    will_delete.push(*id);
454                }
455            }
456        }
457
458        for id in will_delete {
459            self.tasks_internal.remove(&id);
460            self.parent_to_children.remove(&id);
461            self.root_tasks.remove(&id);
462            if let Some(parents) = self.child_to_parents.remove(&id) {
463                for parent in parents {
464                    if let Some(children) = self.parent_to_children.get_mut(&parent) {
465                        children.remove(&id);
466                    }
467                }
468            }
469            self.tasks_marked_for_deletion.remove(&id);
470        }
471    }
472
473    #[allow(clippy::type_complexity)]
474    fn get_tasks_and_reporters(
475        &mut self,
476    ) -> (
477        Vec<Arc<TaskInternal>>,
478        Vec<Arc<TaskInternal>>,
479        Vec<Arc<dyn Reporter>>,
480    ) {
481        let mut start_ids = vec![];
482        std::mem::swap(&mut start_ids, &mut self.report_start);
483        let mut end_ids = vec![];
484        std::mem::swap(&mut end_ids, &mut self.report_end);
485
486        let mut start_tasks = vec![];
487        let mut end_tasks = vec![];
488
489        for id in start_ids {
490            if let Ok(task_internal) = self.get_task(id) {
491                start_tasks.push(Arc::new(task_internal.clone()));
492            }
493        }
494        for id in end_ids {
495            if let Ok(task_internal) = self.get_task(id) {
496                end_tasks.push(Arc::new(task_internal.clone()));
497            }
498        }
499
500        let reporters = self.reporters.clone();
501
502        (start_tasks, end_tasks, reporters)
503    }
504}
505
506impl TaskInternal {
507    pub(crate) fn mark_done(&mut self, error_message: Option<String>) {
508        let task_status = match error_message {
509            None => TaskResult::Success,
510            Some(msg) => TaskResult::Failure(msg),
511        };
512        self.status = TaskStatus::Finished(task_status, SystemTime::now());
513    }
514
515    pub fn full_name(&self) -> String {
516        let mut full_name = String::new();
517        for parent_name in &self.parent_names {
518            full_name.push_str(parent_name);
519            full_name.push(':');
520        }
521        full_name.push_str(&self.name);
522        full_name
523    }
524
525    pub fn all_data(
526        &self,
527    ) -> std::iter::Chain<
528        std::collections::btree_map::Iter<String, DataEntry>,
529        std::collections::btree_map::Iter<String, DataEntry>,
530    > {
531        self.data.map.iter().chain(self.data_transitive.map.iter())
532    }
533}