Skip to main content

mk_lib/schema/
task_context.rs

1use std::collections::HashSet;
2use std::path::PathBuf;
3use std::sync::{
4  Arc,
5  Mutex,
6};
7
8use hashbrown::HashMap;
9use indicatif::{
10  MultiProgress,
11  ProgressDrawTarget,
12};
13use serde::Serialize;
14
15use crate::cache::CacheStore;
16use crate::defaults::{
17  default_ignore_errors,
18  default_shell,
19  default_verbose,
20};
21
22use super::{
23  ActiveTasks,
24  CompletedTasks,
25  ContainerRuntime,
26  Shell,
27  TaskRoot,
28};
29
30/// Used to pass information to tasks
31/// This use arc to allow for sharing of data between tasks
32/// and allow parallel runs of tasks
33#[derive(Clone)]
34pub struct TaskContext {
35  pub task_root: Arc<TaskRoot>,
36  pub active_tasks: ActiveTasks,
37  pub completed_tasks: CompletedTasks,
38  pub multi: Arc<MultiProgress>,
39  pub env_vars: HashMap<String, String>,
40  pub task_outputs: Arc<Mutex<HashMap<String, String>>>,
41  pub secret_vault_location: Option<String>,
42  pub secret_keys_location: Option<String>,
43  pub secret_key_name: Option<String>,
44  pub shell: Option<Arc<Shell>>,
45  pub container_runtime: Option<ContainerRuntime>,
46  pub ignore_errors: Option<bool>,
47  pub verbose: Option<bool>,
48  pub force: bool,
49  pub json_events: bool,
50  pub is_nested: bool,
51  pub cache_store: Arc<Mutex<CacheStore>>,
52  pub current_task_name: Option<String>,
53}
54
55impl TaskContext {
56  pub fn empty() -> Self {
57    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
58    Self {
59      task_root: Arc::new(TaskRoot::default()),
60      active_tasks: Arc::new(Mutex::new(HashSet::new())),
61      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
62      multi: Arc::new(mp),
63      env_vars: HashMap::new(),
64      task_outputs: Arc::new(Mutex::new(HashMap::new())),
65      secret_vault_location: None,
66      secret_keys_location: None,
67      secret_key_name: None,
68      shell: None,
69      container_runtime: None,
70      ignore_errors: None,
71      verbose: None,
72      force: false,
73      json_events: false,
74      is_nested: false,
75      cache_store: Arc::new(Mutex::new(CacheStore::default())),
76      current_task_name: None,
77    }
78  }
79
80  pub fn empty_with_root(task_root: Arc<TaskRoot>) -> Self {
81    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
82    Self {
83      task_root: task_root.clone(),
84      active_tasks: Arc::new(Mutex::new(HashSet::new())),
85      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
86      multi: Arc::new(mp),
87      env_vars: HashMap::new(),
88      task_outputs: Arc::new(Mutex::new(HashMap::new())),
89      secret_vault_location: None,
90      secret_keys_location: None,
91      secret_key_name: None,
92      shell: None,
93      container_runtime: None,
94      ignore_errors: None,
95      verbose: None,
96      force: false,
97      json_events: false,
98      is_nested: false,
99      cache_store: Arc::new(Mutex::new(CacheStore::default())),
100      current_task_name: None,
101    }
102  }
103
104  pub fn new(task_root: Arc<TaskRoot>) -> Self {
105    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
106    Self {
107      task_root: task_root.clone(),
108      active_tasks: Arc::new(Mutex::new(HashSet::new())),
109      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
110      multi: Arc::new(MultiProgress::new()),
111      env_vars: HashMap::new(),
112      task_outputs: Arc::new(Mutex::new(HashMap::new())),
113      secret_vault_location: task_root.vault_location.clone(),
114      secret_keys_location: task_root.keys_location.clone(),
115      secret_key_name: task_root.key_name.clone(),
116      shell: None,
117      container_runtime: task_root.container_runtime.clone(),
118      ignore_errors: None,
119      verbose: None,
120      force: false,
121      json_events: false,
122      is_nested: false,
123      cache_store: Arc::new(Mutex::new(cache_store)),
124      current_task_name: None,
125    }
126  }
127
128  pub fn new_with_options(task_root: Arc<TaskRoot>, force: bool, json_events: bool) -> Self {
129    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
130    let multi = if json_events {
131      Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()))
132    } else {
133      Arc::new(MultiProgress::new())
134    };
135    Self {
136      task_root: task_root.clone(),
137      active_tasks: Arc::new(Mutex::new(HashSet::new())),
138      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
139      multi,
140      env_vars: HashMap::new(),
141      task_outputs: Arc::new(Mutex::new(HashMap::new())),
142      secret_vault_location: task_root.vault_location.clone(),
143      secret_keys_location: task_root.keys_location.clone(),
144      secret_key_name: task_root.key_name.clone(),
145      shell: None,
146      container_runtime: task_root.container_runtime.clone(),
147      ignore_errors: None,
148      verbose: None,
149      force,
150      json_events,
151      is_nested: false,
152      cache_store: Arc::new(Mutex::new(cache_store)),
153      current_task_name: None,
154    }
155  }
156
157  pub fn from_context(context: &TaskContext) -> Self {
158    Self {
159      task_root: context.task_root.clone(),
160      active_tasks: context.active_tasks.clone(),
161      completed_tasks: context.completed_tasks.clone(),
162      multi: context.multi.clone(),
163      env_vars: context.env_vars.clone(),
164      task_outputs: Arc::new(Mutex::new(HashMap::new())),
165      secret_vault_location: context.secret_vault_location.clone(),
166      secret_keys_location: context.secret_keys_location.clone(),
167      secret_key_name: context.secret_key_name.clone(),
168      shell: context.shell.clone(),
169      container_runtime: context.container_runtime.clone(),
170      ignore_errors: context.ignore_errors,
171      verbose: context.verbose,
172      force: context.force,
173      json_events: context.json_events,
174      is_nested: true,
175      cache_store: context.cache_store.clone(),
176      current_task_name: context.current_task_name.clone(),
177    }
178  }
179
180  pub fn from_context_with_args(context: &TaskContext, ignore_errors: bool, verbose: bool) -> Self {
181    Self {
182      task_root: context.task_root.clone(),
183      active_tasks: context.active_tasks.clone(),
184      completed_tasks: context.completed_tasks.clone(),
185      multi: context.multi.clone(),
186      env_vars: context.env_vars.clone(),
187      task_outputs: Arc::new(Mutex::new(HashMap::new())),
188      secret_vault_location: context.secret_vault_location.clone(),
189      secret_keys_location: context.secret_keys_location.clone(),
190      secret_key_name: context.secret_key_name.clone(),
191      shell: context.shell.clone(),
192      container_runtime: context.container_runtime.clone(),
193      ignore_errors: Some(ignore_errors),
194      verbose: Some(verbose),
195      force: context.force,
196      json_events: context.json_events,
197      is_nested: true,
198      cache_store: context.cache_store.clone(),
199      current_task_name: context.current_task_name.clone(),
200    }
201  }
202
203  pub fn extend_env_vars<I>(&mut self, iter: I)
204  where
205    I: IntoIterator<Item = (String, String)>,
206  {
207    self.env_vars.extend(iter);
208  }
209
210  pub fn set_shell(&mut self, shell: &Shell) {
211    let shell = Arc::new(Shell::from_shell(shell));
212    self.shell = Some(shell);
213  }
214
215  pub fn set_secret_vault_location(&mut self, vault_location: impl Into<String>) {
216    self.secret_vault_location = Some(vault_location.into());
217  }
218
219  pub fn set_secret_keys_location(&mut self, keys_location: impl Into<String>) {
220    self.secret_keys_location = Some(keys_location.into());
221  }
222
223  pub fn set_secret_key_name(&mut self, key_name: impl Into<String>) {
224    self.secret_key_name = Some(key_name.into());
225  }
226
227  pub fn set_container_runtime(&mut self, runtime: &ContainerRuntime) {
228    self.container_runtime = Some(runtime.clone());
229  }
230
231  pub fn set_ignore_errors(&mut self, ignore_errors: bool) {
232    self.ignore_errors = Some(ignore_errors);
233  }
234
235  pub fn set_verbose(&mut self, verbose: bool) {
236    self.verbose = Some(verbose);
237  }
238
239  pub fn insert_task_output(&self, name: impl Into<String>, value: impl Into<String>) -> anyhow::Result<()> {
240    let name = name.into();
241    let mut outputs = self
242      .task_outputs
243      .lock()
244      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
245    if outputs.contains_key(&name) {
246      anyhow::bail!("Task output already exists - {}", name);
247    }
248    outputs.insert(name, value.into());
249    Ok(())
250  }
251
252  pub fn get_task_output(&self, name: &str) -> anyhow::Result<Option<String>> {
253    let outputs = self
254      .task_outputs
255      .lock()
256      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
257    Ok(outputs.get(name).cloned())
258  }
259
260  pub fn has_task_output(&self, name: &str) -> anyhow::Result<bool> {
261    let outputs = self
262      .task_outputs
263      .lock()
264      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
265    Ok(outputs.contains_key(name))
266  }
267
268  pub fn shell(&self) -> Arc<Shell> {
269    self.shell.clone().unwrap_or_else(|| Arc::new(default_shell()))
270  }
271
272  pub fn ignore_errors(&self) -> bool {
273    self.ignore_errors.unwrap_or(default_ignore_errors())
274  }
275
276  pub fn verbose(&self) -> bool {
277    self.verbose.unwrap_or(default_verbose())
278  }
279
280  pub fn is_task_active(&self, task_name: &str) -> anyhow::Result<bool> {
281    let active = self
282      .active_tasks
283      .lock()
284      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
285    Ok(active.contains(task_name))
286  }
287
288  pub fn is_task_completed(&self, task_name: &str) -> anyhow::Result<bool> {
289    let completed = self
290      .completed_tasks
291      .lock()
292      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
293    Ok(completed.contains(task_name))
294  }
295
296  pub fn mark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
297    let mut active = self
298      .active_tasks
299      .lock()
300      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
301    active.insert(task_name.to_string());
302    Ok(())
303  }
304
305  pub fn unmark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
306    let mut active = self
307      .active_tasks
308      .lock()
309      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
310    active.remove(task_name);
311    Ok(())
312  }
313
314  pub fn mark_task_complete(&self, task_name: &str) -> anyhow::Result<()> {
315    let mut completed = self
316      .completed_tasks
317      .lock()
318      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
319    completed.insert(task_name.to_string());
320    Ok(())
321  }
322
323  pub fn emit_event<T: Serialize>(&self, value: &T) -> anyhow::Result<()> {
324    if self.json_events {
325      println!("{}", serde_json::to_string(value)?);
326    }
327    Ok(())
328  }
329
330  pub fn set_current_task_name(&mut self, task_name: &str) {
331    self.current_task_name = Some(task_name.to_string());
332  }
333
334  pub fn resolve_from_config(&self, value: &str) -> PathBuf {
335    self.task_root.resolve_from_config(value)
336  }
337}
338
339#[cfg(test)]
340mod test {
341  use super::*;
342
343  #[test]
344  fn test_task_context_1() -> anyhow::Result<()> {
345    {
346      let context = TaskContext::empty();
347      assert_eq!(context.shell().cmd(), "sh".to_string());
348      assert!(!context.ignore_errors());
349      assert!(context.verbose());
350    }
351
352    Ok(())
353  }
354
355  #[test]
356  fn test_task_context_2() -> anyhow::Result<()> {
357    {
358      let mut context = TaskContext::empty();
359      context.set_shell(&Shell::String("bash".to_string()));
360      assert_eq!(context.shell().cmd(), "bash".to_string());
361    }
362
363    Ok(())
364  }
365
366  #[test]
367  fn test_task_context_3() -> anyhow::Result<()> {
368    {
369      let mut context = TaskContext::empty();
370      context.extend_env_vars(vec![("key".to_string(), "value".to_string())]);
371      assert_eq!(context.env_vars.get("key"), Some(&"value".to_string()));
372    }
373
374    Ok(())
375  }
376
377  #[test]
378  fn test_task_context_4() -> anyhow::Result<()> {
379    {
380      let mut context = TaskContext::empty();
381      context.set_ignore_errors(true);
382      assert!(context.ignore_errors());
383    }
384
385    Ok(())
386  }
387
388  #[test]
389  fn test_task_context_5() -> anyhow::Result<()> {
390    {
391      let mut context = TaskContext::empty();
392      context.set_verbose(true);
393      assert!(context.verbose());
394    }
395
396    Ok(())
397  }
398
399  #[test]
400  fn test_task_context_outputs_are_stored() -> anyhow::Result<()> {
401    let context = TaskContext::empty();
402    context.insert_task_output("tag", "v1.0.0")?;
403    assert_eq!(context.get_task_output("tag")?, Some("v1.0.0".to_string()));
404    assert!(context.has_task_output("tag")?);
405    Ok(())
406  }
407}