1use crate::data::{Data, DataEntry, DataValue};
2use crate::reporters::Reporter;
3use crate::task::{Task, TaskData};
4use crate::uniq_id::UniqID;
5use anyhow::{Context, Result};
6use std::collections::{BTreeMap, BTreeSet, HashMap};
7use std::future::Future;
8use std::sync::atomic::{AtomicBool, Ordering};
9use std::sync::Arc;
10use std::sync::RwLock;
11use std::thread;
12use std::time::Duration;
13use std::time::SystemTime;
14
15lazy_static::lazy_static! {
16 pub static ref TASK_TREE: Arc<TaskTree> = TaskTree::new();
17}
18
19pub fn add_reporter(reporter: Arc<dyn Reporter>) {
20 TASK_TREE.add_reporter(reporter);
21}
22
23pub trait ErrorFormatter: Send + Sync {
24 fn format_error(&self, err: &anyhow::Error) -> String;
25}
26
27pub struct TaskTree {
28 pub(crate) tree_internal: RwLock<TaskTreeInternal>,
29 force_flush: AtomicBool,
32}
33
34pub(crate) struct TaskTreeInternal {
35 tasks_internal: BTreeMap<UniqID, TaskInternal>,
36 parent_to_children: BTreeMap<UniqID, BTreeSet<UniqID>>,
37 child_to_parents: BTreeMap<UniqID, BTreeSet<UniqID>>,
38 root_tasks: BTreeSet<UniqID>,
39 reporters: Vec<Arc<dyn Reporter>>,
40 tasks_marked_for_deletion: HashMap<UniqID, SystemTime>,
41 report_start: Vec<UniqID>,
42 report_end: Vec<UniqID>,
43 data_transitive: Data,
44 remove_task_after_done_ms: u64,
45 hide_errors_default_msg: Option<Arc<String>>,
46 attach_transitive_data_to_errors_default: bool,
47 error_formatter: Option<Arc<dyn ErrorFormatter>>,
48}
49
50#[derive(Clone)]
51pub struct TaskInternal {
52 pub id: UniqID,
53 pub name: String,
54 pub parent_names: Vec<String>,
55 pub started_at: SystemTime,
56 pub status: TaskStatus,
57 pub data: Data,
58 pub data_transitive: Data,
59 pub tags: BTreeSet<String>,
60 pub progress: Option<(i64, i64)>,
65 pub hide_errors: Option<Arc<String>>,
66 pub attach_transitive_data_to_errors: bool,
67}
68
69#[derive(Clone)]
70pub enum TaskStatus {
71 Running,
72 Finished(TaskResult, SystemTime),
73}
74
75#[derive(Clone)]
76pub enum TaskResult {
77 Success,
78 Failure(String),
79}
80
81impl TaskTree {
82 pub fn new() -> Arc<Self> {
83 let s = Arc::new(Self {
84 tree_internal: RwLock::new(TaskTreeInternal {
85 tasks_internal: BTreeMap::new(),
86 parent_to_children: BTreeMap::new(),
87 child_to_parents: BTreeMap::new(),
88 root_tasks: BTreeSet::new(),
89 reporters: vec![],
90 tasks_marked_for_deletion: HashMap::new(),
91 report_start: vec![],
92 report_end: vec![],
93 data_transitive: Data::empty(),
94 remove_task_after_done_ms: 0,
95 hide_errors_default_msg: None,
96 attach_transitive_data_to_errors_default: true,
97 error_formatter: None,
98 }),
99 force_flush: AtomicBool::new(false),
100 });
101 let clone = s.clone();
102 tokio::spawn(async move {
103 loop {
104 tokio::time::sleep(tokio::time::Duration::from_millis(500)).await;
105 let mut tree = clone.tree_internal.write().unwrap();
106 tree.garbage_collect();
107 }
108 });
109 let clone = s.clone();
110 thread::spawn(move || loop {
111 thread::sleep(std::time::Duration::from_millis(10));
112 clone.report_all();
113 });
114
115 s
116 }
117
118 pub fn set_force_flush(&self, enabled: bool) {
119 self.force_flush.store(enabled, Ordering::SeqCst)
120 }
121
122 pub fn force_flush_enabled(&self) -> bool {
123 self.force_flush.load(Ordering::SeqCst)
124 }
125
126 pub fn create_task(self: &Arc<Self>, name: &str) -> Task {
127 let id = self.create_task_internal(name, None);
128 Task(Arc::new(TaskData {
129 id,
130 task_tree: self.clone(),
131 mark_done_on_drop: true,
132 }))
133 }
134
135 pub fn add_reporter(&self, reporter: Arc<dyn Reporter>) {
136 self.tree_internal.write().unwrap().reporters.push(reporter);
137 }
138
139 fn pre_spawn(self: &Arc<Self>, name: String, parent: Option<UniqID>) -> Task {
140 let task = Task(Arc::new(TaskData {
141 id: self.create_task_internal(&name, parent),
142 task_tree: self.clone(),
143 mark_done_on_drop: false,
144 }));
145 self.maybe_force_flush();
146 task
147 }
148
149 fn post_spawn<T>(self: &Arc<Self>, id: UniqID, result: Result<T>) -> Result<T> {
150 let result = result.with_context(|| {
151 let mut desc = String::from("[Task]");
152 if let Some(task_internal) = self.get_cloned_task(id) {
153 desc.push_str(&format!(" {}", task_internal.name));
154 if task_internal.attach_transitive_data_to_errors {
155 for (k, v) in task_internal.all_data() {
156 desc.push_str(&format!("\n {}: {}", k, v.0));
157 }
158 } else {
159 for (k, v) in &task_internal.data.map {
160 desc.push_str(&format!("\n {}: {}", k, v.0));
161 }
162 };
163 if !desc.is_empty() {
164 desc.push('\n');
165 }
166 }
167 desc
168 });
169 let error_msg = if let Err(err) = &result {
170 let formatter = {
171 let formatter = self.tree_internal.read().unwrap().error_formatter.clone();
172 formatter
173 };
174 if let Some(formatter) = formatter {
175 Some(formatter.format_error(err))
176 } else {
177 Some(format!("{:?}", err))
178 }
179 } else {
180 None
181 };
182 self.mark_done(id, error_msg);
183 self.maybe_force_flush();
184 result
185 }
186
187 pub fn spawn_sync<F, T>(
188 self: &Arc<Self>,
189 name: String,
190 f: F,
191 parent: Option<UniqID>,
192 ) -> Result<T>
193 where
194 F: FnOnce(Task) -> Result<T>,
195 T: Send,
196 {
197 let task = self.pre_spawn(name, parent);
198 let id = task.0.id;
199 let result = f(task);
200 self.post_spawn(id, result)
201 }
202
203 pub(crate) async fn spawn<F, FT, T>(
204 self: &Arc<Self>,
205 name: String,
206 f: F,
207 parent: Option<UniqID>,
208 ) -> Result<T>
209 where
210 F: FnOnce(Task) -> FT,
211 FT: Future<Output = Result<T>> + Send,
212 T: Send,
213 {
214 let task = self.pre_spawn(name, parent);
215 let id = task.0.id;
216 let result = f(task).await;
217 self.post_spawn(id, result)
218 }
219
220 pub fn create_task_internal<S: Into<String>>(
221 self: &Arc<Self>,
222 name: S,
223 parent: Option<UniqID>,
224 ) -> UniqID {
225 let mut tree = self.tree_internal.write().unwrap();
226
227 let mut parent_names = vec![];
228 let mut data_transitive = tree.data_transitive.clone();
229 let (name, tags) = crate::utils::extract_tags(name.into());
230 let id = UniqID::new();
231 if let Some(parent_task) = parent.and_then(|pid| tree.tasks_internal.get(&pid)) {
232 parent_names = parent_task.parent_names.clone();
233 parent_names.push(parent_task.name.clone());
234 data_transitive.merge(&parent_task.data_transitive);
235 let parent_id = parent_task.id;
236
237 tree.parent_to_children
238 .entry(parent_id)
239 .or_insert_with(BTreeSet::new)
240 .insert(id);
241 tree.child_to_parents
242 .entry(id)
243 .or_insert_with(BTreeSet::new)
244 .insert(parent_id);
245 } else {
246 tree.root_tasks.insert(id);
247 }
248
249 let task_internal = TaskInternal {
250 status: TaskStatus::Running,
251 name,
252 parent_names,
253 id,
254 started_at: SystemTime::now(),
255 data: Data::empty(),
256 data_transitive,
257 tags,
258 progress: None,
259 hide_errors: tree.hide_errors_default_msg.clone(),
260 attach_transitive_data_to_errors: tree.attach_transitive_data_to_errors_default,
261 };
262
263 tree.tasks_internal.insert(id, task_internal);
264 tree.report_start.push(id);
265
266 id
267 }
268
269 pub fn mark_done(&self, id: UniqID, error_message: Option<String>) {
270 let mut tree = self.tree_internal.write().unwrap();
271 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
272 task_internal.mark_done(error_message);
273 tree.mark_for_gc(id);
274 tree.report_end.push(id);
275 }
276 }
277
278 pub fn add_data<S: Into<String>, D: Into<DataValue>>(&self, id: UniqID, key: S, value: D) {
279 let mut tree = self.tree_internal.write().unwrap();
280 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
281 task_internal.data.add(key, value);
282 }
283 }
284
285 pub fn get_data<S: Into<String>>(&self, id: UniqID, key: S) -> Option<DataValue> {
286 let mut tree = self.tree_internal.write().unwrap();
287 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
288 let all_data: BTreeMap<_, _> = task_internal.all_data().collect();
289 return all_data.get(&key.into()).map(|de| de.0.clone());
290 }
291 None
292 }
293
294 pub(crate) fn add_data_transitive_for_task<S: Into<String>, D: Into<DataValue>>(
295 &self,
296 id: UniqID,
297 key: S,
298 value: D,
299 ) {
300 let mut tree = self.tree_internal.write().unwrap();
301 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
302 task_internal.data_transitive.add(key, value);
303 }
304 }
305 pub fn hide_errors_default_msg<S: Into<String>>(&self, msg: Option<S>) {
314 let mut tree = self.tree_internal.write().unwrap();
315 let msg = msg.map(|msg| Arc::new(msg.into()));
316 tree.hide_errors_default_msg = msg;
317 }
318
319 pub(crate) fn hide_error_msg_for_task(&self, id: UniqID, msg: Option<Arc<String>>) {
320 let mut tree = self.tree_internal.write().unwrap();
321 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
322 task_internal.hide_errors = msg;
323 }
324 }
325
326 pub fn attach_transitive_data_to_errors_default(&self, val: bool) {
331 let mut tree = self.tree_internal.write().unwrap();
332 tree.attach_transitive_data_to_errors_default = val;
333 }
334
335 pub(crate) fn attach_transitive_data_to_errors_for_task(&self, id: UniqID, val: bool) {
336 let mut tree = self.tree_internal.write().unwrap();
337 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
338 task_internal.attach_transitive_data_to_errors = val;
339 }
340 }
341
342 pub fn set_error_formatter(&self, error_formatter: Option<Arc<dyn ErrorFormatter>>) {
348 let mut tree = self.tree_internal.write().unwrap();
349 tree.error_formatter = error_formatter;
350 }
351
352 pub fn add_data_transitive<S: Into<String>, D: Into<DataValue>>(&self, key: S, value: D) {
355 let mut tree = self.tree_internal.write().unwrap();
356 tree.data_transitive.add(key, value);
357 }
358
359 pub fn task_progress(&self, id: UniqID, done: i64, total: i64) {
360 let mut tree = self.tree_internal.write().unwrap();
361 if let Some(task_internal) = tree.tasks_internal.get_mut(&id) {
362 task_internal.progress = Some((done, total));
363 }
364 }
365
366 fn get_cloned_task(&self, id: UniqID) -> Option<TaskInternal> {
367 let tree = self.tree_internal.read().unwrap();
368 tree.get_task(id).ok().cloned()
369 }
370
371 pub fn maybe_force_flush(&self) {
375 if self.force_flush.load(Ordering::SeqCst) {
376 self.report_all();
377 }
378 }
379
380 pub fn report_all(&self) {
381 let mut tree = self.tree_internal.write().unwrap();
382 let (start_tasks, end_tasks, reporters) = tree.get_tasks_and_reporters();
383 drop(tree);
384 for reporter in reporters {
385 for task in &start_tasks {
386 reporter.task_start(task.clone());
387 }
388 for task in &end_tasks {
389 reporter.task_end(task.clone());
390 }
391 }
392 }
393}
394
395impl TaskTreeInternal {
396 pub fn get_task(&self, id: UniqID) -> Result<&TaskInternal> {
397 self.tasks_internal.get(&id).context("task must be present")
398 }
399
400 pub fn root_tasks(&self) -> &BTreeSet<UniqID> {
401 &self.root_tasks
402 }
403
404 pub fn child_to_parents(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
405 &self.child_to_parents
406 }
407
408 pub fn parent_to_children(&self) -> &BTreeMap<UniqID, BTreeSet<UniqID>> {
409 &self.parent_to_children
410 }
411
412 fn mark_for_gc(&mut self, id: UniqID) {
413 let mut stack = vec![id];
414
415 let mut tasks_to_finished_status = BTreeMap::new();
416
417 while let Some(id) = stack.pop() {
418 if let Some(task_internal) = self.tasks_internal.get(&id) {
419 tasks_to_finished_status
420 .insert(id, matches!(task_internal.status, TaskStatus::Finished(..)));
421 }
422
423 for child_id in self.parent_to_children.get(&id).into_iter().flatten() {
424 stack.push(*child_id);
425 }
426 }
427
428 if tasks_to_finished_status
429 .iter()
430 .all(|(_, finished)| *finished)
431 {
432 for id in tasks_to_finished_status.keys().copied() {
433 self.tasks_marked_for_deletion
434 .entry(id)
435 .or_insert_with(SystemTime::now);
436 }
437
438 let parents = self.child_to_parents.get(&id).cloned().unwrap_or_default();
442 for parent_id in parents {
443 self.mark_for_gc(parent_id);
444 }
445 }
446 }
447
448 fn garbage_collect(&mut self) {
449 let mut will_delete = vec![];
450 for (id, time) in &self.tasks_marked_for_deletion {
451 if let Ok(elapsed) = time.elapsed() {
452 if elapsed > Duration::from_millis(self.remove_task_after_done_ms) {
453 will_delete.push(*id);
454 }
455 }
456 }
457
458 for id in will_delete {
459 self.tasks_internal.remove(&id);
460 self.parent_to_children.remove(&id);
461 self.root_tasks.remove(&id);
462 if let Some(parents) = self.child_to_parents.remove(&id) {
463 for parent in parents {
464 if let Some(children) = self.parent_to_children.get_mut(&parent) {
465 children.remove(&id);
466 }
467 }
468 }
469 self.tasks_marked_for_deletion.remove(&id);
470 }
471 }
472
473 #[allow(clippy::type_complexity)]
474 fn get_tasks_and_reporters(
475 &mut self,
476 ) -> (
477 Vec<Arc<TaskInternal>>,
478 Vec<Arc<TaskInternal>>,
479 Vec<Arc<dyn Reporter>>,
480 ) {
481 let mut start_ids = vec![];
482 std::mem::swap(&mut start_ids, &mut self.report_start);
483 let mut end_ids = vec![];
484 std::mem::swap(&mut end_ids, &mut self.report_end);
485
486 let mut start_tasks = vec![];
487 let mut end_tasks = vec![];
488
489 for id in start_ids {
490 if let Ok(task_internal) = self.get_task(id) {
491 start_tasks.push(Arc::new(task_internal.clone()));
492 }
493 }
494 for id in end_ids {
495 if let Ok(task_internal) = self.get_task(id) {
496 end_tasks.push(Arc::new(task_internal.clone()));
497 }
498 }
499
500 let reporters = self.reporters.clone();
501
502 (start_tasks, end_tasks, reporters)
503 }
504}
505
506impl TaskInternal {
507 pub(crate) fn mark_done(&mut self, error_message: Option<String>) {
508 let task_status = match error_message {
509 None => TaskResult::Success,
510 Some(msg) => TaskResult::Failure(msg),
511 };
512 self.status = TaskStatus::Finished(task_status, SystemTime::now());
513 }
514
515 pub fn full_name(&self) -> String {
516 let mut full_name = String::new();
517 for parent_name in &self.parent_names {
518 full_name.push_str(parent_name);
519 full_name.push(':');
520 }
521 full_name.push_str(&self.name);
522 full_name
523 }
524
525 pub fn all_data(
526 &self,
527 ) -> std::iter::Chain<
528 std::collections::btree_map::Iter<String, DataEntry>,
529 std::collections::btree_map::Iter<String, DataEntry>,
530 > {
531 self.data.map.iter().chain(self.data_transitive.map.iter())
532 }
533}