Skip to main content

mk_lib/schema/
task_context.rs

1use std::collections::HashSet;
2use std::path::PathBuf;
3use std::sync::{
4  Arc,
5  Mutex,
6};
7
8use hashbrown::HashMap;
9use indicatif::{
10  MultiProgress,
11  ProgressDrawTarget,
12};
13use serde::Serialize;
14
15use crate::cache::CacheStore;
16use crate::defaults::{
17  default_ignore_errors,
18  default_shell,
19  default_verbose,
20};
21
22use super::{
23  ActiveTasks,
24  CompletedTasks,
25  ContainerRuntime,
26  Shell,
27  TaskRoot,
28};
29
30/// Used to pass information to tasks
31/// This use arc to allow for sharing of data between tasks
32/// and allow parallel runs of tasks
33#[derive(Clone)]
34pub struct TaskContext {
35  pub task_root: Arc<TaskRoot>,
36  pub active_tasks: ActiveTasks,
37  pub completed_tasks: CompletedTasks,
38  pub multi: Arc<MultiProgress>,
39  pub env_vars: HashMap<String, String>,
40  pub task_outputs: Arc<Mutex<HashMap<String, String>>>,
41  pub secret_vault_location: Option<String>,
42  pub secret_keys_location: Option<String>,
43  pub secret_key_name: Option<String>,
44  pub secret_gpg_key_id: Option<String>,
45  pub shell: Option<Arc<Shell>>,
46  pub container_runtime: Option<ContainerRuntime>,
47  pub ignore_errors: Option<bool>,
48  pub verbose: Option<bool>,
49  pub force: bool,
50  pub json_events: bool,
51  pub is_nested: bool,
52  pub cache_store: Arc<Mutex<CacheStore>>,
53  pub current_task_name: Option<String>,
54}
55
56impl TaskContext {
57  pub fn empty() -> Self {
58    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
59    Self {
60      task_root: Arc::new(TaskRoot::default()),
61      active_tasks: Arc::new(Mutex::new(HashSet::new())),
62      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
63      multi: Arc::new(mp),
64      env_vars: HashMap::new(),
65      task_outputs: Arc::new(Mutex::new(HashMap::new())),
66      secret_vault_location: None,
67      secret_keys_location: None,
68      secret_key_name: None,
69      secret_gpg_key_id: None,
70      shell: None,
71      container_runtime: None,
72      ignore_errors: None,
73      verbose: None,
74      force: false,
75      json_events: false,
76      is_nested: false,
77      cache_store: Arc::new(Mutex::new(CacheStore::default())),
78      current_task_name: None,
79    }
80  }
81
82  pub fn empty_with_root(task_root: Arc<TaskRoot>) -> Self {
83    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
84    Self {
85      task_root: task_root.clone(),
86      active_tasks: Arc::new(Mutex::new(HashSet::new())),
87      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
88      multi: Arc::new(mp),
89      env_vars: HashMap::new(),
90      task_outputs: Arc::new(Mutex::new(HashMap::new())),
91      secret_vault_location: None,
92      secret_keys_location: None,
93      secret_key_name: None,
94      secret_gpg_key_id: None,
95      shell: None,
96      container_runtime: None,
97      ignore_errors: None,
98      verbose: None,
99      force: false,
100      json_events: false,
101      is_nested: false,
102      cache_store: Arc::new(Mutex::new(CacheStore::default())),
103      current_task_name: None,
104    }
105  }
106
107  pub fn new(task_root: Arc<TaskRoot>) -> Self {
108    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
109    Self {
110      task_root: task_root.clone(),
111      active_tasks: Arc::new(Mutex::new(HashSet::new())),
112      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
113      multi: Arc::new(MultiProgress::new()),
114      env_vars: HashMap::new(),
115      task_outputs: Arc::new(Mutex::new(HashMap::new())),
116      secret_vault_location: task_root.vault_location.clone(),
117      secret_keys_location: task_root.keys_location.clone(),
118      secret_key_name: task_root.key_name.clone(),
119      secret_gpg_key_id: task_root.gpg_key_id.clone(),
120      shell: None,
121      container_runtime: task_root.container_runtime.clone(),
122      ignore_errors: None,
123      verbose: None,
124      force: false,
125      json_events: false,
126      is_nested: false,
127      cache_store: Arc::new(Mutex::new(cache_store)),
128      current_task_name: None,
129    }
130  }
131
132  pub fn new_with_options(task_root: Arc<TaskRoot>, force: bool, json_events: bool) -> Self {
133    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
134    let multi = if json_events {
135      Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()))
136    } else {
137      Arc::new(MultiProgress::new())
138    };
139    Self {
140      task_root: task_root.clone(),
141      active_tasks: Arc::new(Mutex::new(HashSet::new())),
142      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
143      multi,
144      env_vars: HashMap::new(),
145      task_outputs: Arc::new(Mutex::new(HashMap::new())),
146      secret_vault_location: task_root.vault_location.clone(),
147      secret_keys_location: task_root.keys_location.clone(),
148      secret_key_name: task_root.key_name.clone(),
149      secret_gpg_key_id: task_root.gpg_key_id.clone(),
150      shell: None,
151      container_runtime: task_root.container_runtime.clone(),
152      ignore_errors: None,
153      verbose: None,
154      force,
155      json_events,
156      is_nested: false,
157      cache_store: Arc::new(Mutex::new(cache_store)),
158      current_task_name: None,
159    }
160  }
161
162  pub fn from_context(context: &TaskContext) -> Self {
163    Self {
164      task_root: context.task_root.clone(),
165      active_tasks: context.active_tasks.clone(),
166      completed_tasks: context.completed_tasks.clone(),
167      multi: context.multi.clone(),
168      env_vars: context.env_vars.clone(),
169      task_outputs: Arc::new(Mutex::new(HashMap::new())),
170      secret_vault_location: context.secret_vault_location.clone(),
171      secret_keys_location: context.secret_keys_location.clone(),
172      secret_key_name: context.secret_key_name.clone(),
173      secret_gpg_key_id: context.secret_gpg_key_id.clone(),
174      shell: context.shell.clone(),
175      container_runtime: context.container_runtime.clone(),
176      ignore_errors: context.ignore_errors,
177      verbose: context.verbose,
178      force: context.force,
179      json_events: context.json_events,
180      is_nested: true,
181      cache_store: context.cache_store.clone(),
182      current_task_name: context.current_task_name.clone(),
183    }
184  }
185
186  pub fn from_context_with_args(context: &TaskContext, ignore_errors: bool, verbose: bool) -> Self {
187    Self {
188      task_root: context.task_root.clone(),
189      active_tasks: context.active_tasks.clone(),
190      completed_tasks: context.completed_tasks.clone(),
191      multi: context.multi.clone(),
192      env_vars: context.env_vars.clone(),
193      task_outputs: Arc::new(Mutex::new(HashMap::new())),
194      secret_vault_location: context.secret_vault_location.clone(),
195      secret_keys_location: context.secret_keys_location.clone(),
196      secret_key_name: context.secret_key_name.clone(),
197      secret_gpg_key_id: context.secret_gpg_key_id.clone(),
198      shell: context.shell.clone(),
199      container_runtime: context.container_runtime.clone(),
200      ignore_errors: Some(ignore_errors),
201      verbose: Some(verbose),
202      force: context.force,
203      json_events: context.json_events,
204      is_nested: true,
205      cache_store: context.cache_store.clone(),
206      current_task_name: context.current_task_name.clone(),
207    }
208  }
209
210  pub fn extend_env_vars<I>(&mut self, iter: I)
211  where
212    I: IntoIterator<Item = (String, String)>,
213  {
214    self.env_vars.extend(iter);
215  }
216
217  pub fn set_shell(&mut self, shell: &Shell) {
218    let shell = Arc::new(Shell::from_shell(shell));
219    self.shell = Some(shell);
220  }
221
222  pub fn set_secret_vault_location(&mut self, vault_location: impl Into<String>) {
223    self.secret_vault_location = Some(vault_location.into());
224  }
225
226  pub fn set_secret_keys_location(&mut self, keys_location: impl Into<String>) {
227    self.secret_keys_location = Some(keys_location.into());
228  }
229
230  pub fn set_secret_key_name(&mut self, key_name: impl Into<String>) {
231    self.secret_key_name = Some(key_name.into());
232  }
233
234  pub fn set_secret_gpg_key_id(&mut self, gpg_key_id: impl Into<String>) {
235    self.secret_gpg_key_id = Some(gpg_key_id.into());
236  }
237
238  pub fn set_container_runtime(&mut self, runtime: &ContainerRuntime) {
239    self.container_runtime = Some(runtime.clone());
240  }
241
242  pub fn set_ignore_errors(&mut self, ignore_errors: bool) {
243    self.ignore_errors = Some(ignore_errors);
244  }
245
246  pub fn set_verbose(&mut self, verbose: bool) {
247    self.verbose = Some(verbose);
248  }
249
250  pub fn insert_task_output(&self, name: impl Into<String>, value: impl Into<String>) -> anyhow::Result<()> {
251    let name = name.into();
252    let mut outputs = self
253      .task_outputs
254      .lock()
255      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
256    if outputs.contains_key(&name) {
257      anyhow::bail!("Task output already exists - {}", name);
258    }
259    outputs.insert(name, value.into());
260    Ok(())
261  }
262
263  pub fn get_task_output(&self, name: &str) -> anyhow::Result<Option<String>> {
264    let outputs = self
265      .task_outputs
266      .lock()
267      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
268    Ok(outputs.get(name).cloned())
269  }
270
271  pub fn has_task_output(&self, name: &str) -> anyhow::Result<bool> {
272    let outputs = self
273      .task_outputs
274      .lock()
275      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
276    Ok(outputs.contains_key(name))
277  }
278
279  pub fn shell(&self) -> Arc<Shell> {
280    self.shell.clone().unwrap_or_else(|| Arc::new(default_shell()))
281  }
282
283  pub fn ignore_errors(&self) -> bool {
284    self.ignore_errors.unwrap_or(default_ignore_errors())
285  }
286
287  pub fn verbose(&self) -> bool {
288    self.verbose.unwrap_or(default_verbose())
289  }
290
291  pub fn is_task_active(&self, task_name: &str) -> anyhow::Result<bool> {
292    let active = self
293      .active_tasks
294      .lock()
295      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
296    Ok(active.contains(task_name))
297  }
298
299  pub fn is_task_completed(&self, task_name: &str) -> anyhow::Result<bool> {
300    let completed = self
301      .completed_tasks
302      .lock()
303      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
304    Ok(completed.contains(task_name))
305  }
306
307  pub fn mark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
308    let mut active = self
309      .active_tasks
310      .lock()
311      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
312    active.insert(task_name.to_string());
313    Ok(())
314  }
315
316  pub fn unmark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
317    let mut active = self
318      .active_tasks
319      .lock()
320      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
321    active.remove(task_name);
322    Ok(())
323  }
324
325  pub fn mark_task_complete(&self, task_name: &str) -> anyhow::Result<()> {
326    let mut completed = self
327      .completed_tasks
328      .lock()
329      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
330    completed.insert(task_name.to_string());
331    Ok(())
332  }
333
334  pub fn emit_event<T: Serialize>(&self, value: &T) -> anyhow::Result<()> {
335    if self.json_events {
336      println!("{}", serde_json::to_string(value)?);
337    }
338    Ok(())
339  }
340
341  pub fn set_current_task_name(&mut self, task_name: &str) {
342    self.current_task_name = Some(task_name.to_string());
343  }
344
345  pub fn resolve_from_config(&self, value: &str) -> PathBuf {
346    self.task_root.resolve_from_config(value)
347  }
348}
349
350#[cfg(test)]
351mod test {
352  use super::*;
353
354  #[test]
355  fn test_task_context_1() -> anyhow::Result<()> {
356    {
357      let context = TaskContext::empty();
358      assert_eq!(context.shell().cmd(), "sh".to_string());
359      assert!(!context.ignore_errors());
360      assert!(context.verbose());
361    }
362
363    Ok(())
364  }
365
366  #[test]
367  fn test_task_context_2() -> anyhow::Result<()> {
368    {
369      let mut context = TaskContext::empty();
370      context.set_shell(&Shell::String("bash".to_string()));
371      assert_eq!(context.shell().cmd(), "bash".to_string());
372    }
373
374    Ok(())
375  }
376
377  #[test]
378  fn test_task_context_3() -> anyhow::Result<()> {
379    {
380      let mut context = TaskContext::empty();
381      context.extend_env_vars(vec![("key".to_string(), "value".to_string())]);
382      assert_eq!(context.env_vars.get("key"), Some(&"value".to_string()));
383    }
384
385    Ok(())
386  }
387
388  #[test]
389  fn test_task_context_4() -> anyhow::Result<()> {
390    {
391      let mut context = TaskContext::empty();
392      context.set_ignore_errors(true);
393      assert!(context.ignore_errors());
394    }
395
396    Ok(())
397  }
398
399  #[test]
400  fn test_task_context_5() -> anyhow::Result<()> {
401    {
402      let mut context = TaskContext::empty();
403      context.set_verbose(true);
404      assert!(context.verbose());
405    }
406
407    Ok(())
408  }
409
410  #[test]
411  fn test_task_context_outputs_are_stored() -> anyhow::Result<()> {
412    let context = TaskContext::empty();
413    context.insert_task_output("tag", "v1.0.0")?;
414    assert_eq!(context.get_task_output("tag")?, Some("v1.0.0".to_string()));
415    assert!(context.has_task_output("tag")?);
416    Ok(())
417  }
418}