Skip to main content

mk_lib/schema/
task_context.rs

1use std::collections::HashSet;
2use std::path::PathBuf;
3use std::sync::atomic::{
4  AtomicBool,
5  Ordering,
6};
7use std::sync::{
8  Arc,
9  Mutex,
10};
11
12use hashbrown::HashMap;
13use indicatif::{
14  MultiProgress,
15  ProgressDrawTarget,
16};
17use serde::Serialize;
18
19use crate::cache::CacheStore;
20use crate::defaults::{
21  default_ignore_errors,
22  default_shell,
23  default_verbose,
24};
25use crate::secrets::SecretConfig;
26
27use super::{
28  ActiveTasks,
29  CompletedTasks,
30  ContainerRuntime,
31  Shell,
32  TaskRoot,
33};
34
35/// Used to pass information to tasks
36/// This use arc to allow for sharing of data between tasks
37/// and allow parallel runs of tasks
38#[derive(Clone)]
39pub struct TaskContext {
40  pub task_root: Arc<TaskRoot>,
41  pub active_tasks: ActiveTasks,
42  pub completed_tasks: CompletedTasks,
43  pub multi: Arc<MultiProgress>,
44  pub env_vars: HashMap<String, String>,
45  pub task_outputs: Arc<Mutex<HashMap<String, String>>>,
46  pub secret_config: Option<SecretConfig>,
47  pub shell: Option<Arc<Shell>>,
48  pub container_runtime: Option<ContainerRuntime>,
49  pub ignore_errors: Option<bool>,
50  pub verbose: Option<bool>,
51  pub force: bool,
52  pub json_events: bool,
53  pub is_nested: bool,
54  pub cache_store: Arc<Mutex<CacheStore>>,
55  pub current_task_name: Option<String>,
56  pub cancellation_requested: Arc<AtomicBool>,
57}
58
59impl TaskContext {
60  pub fn empty() -> Self {
61    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
62    Self {
63      task_root: Arc::new(TaskRoot::default()),
64      active_tasks: Arc::new(Mutex::new(HashSet::new())),
65      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
66      multi: Arc::new(mp),
67      env_vars: HashMap::new(),
68      task_outputs: Arc::new(Mutex::new(HashMap::new())),
69      secret_config: None,
70      shell: None,
71      container_runtime: None,
72      ignore_errors: None,
73      verbose: None,
74      force: false,
75      json_events: false,
76      is_nested: false,
77      cache_store: Arc::new(Mutex::new(CacheStore::default())),
78      current_task_name: None,
79      cancellation_requested: Arc::new(AtomicBool::new(false)),
80    }
81  }
82
83  pub fn empty_with_root(task_root: Arc<TaskRoot>) -> Self {
84    let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
85    Self {
86      task_root: task_root.clone(),
87      active_tasks: Arc::new(Mutex::new(HashSet::new())),
88      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
89      multi: Arc::new(mp),
90      env_vars: HashMap::new(),
91      task_outputs: Arc::new(Mutex::new(HashMap::new())),
92      secret_config: None,
93      shell: None,
94      container_runtime: None,
95      ignore_errors: None,
96      verbose: None,
97      force: false,
98      json_events: false,
99      is_nested: false,
100      cache_store: Arc::new(Mutex::new(CacheStore::default())),
101      current_task_name: None,
102      cancellation_requested: Arc::new(AtomicBool::new(false)),
103    }
104  }
105
106  pub fn new(task_root: Arc<TaskRoot>) -> Self {
107    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
108    Self {
109      task_root: task_root.clone(),
110      active_tasks: Arc::new(Mutex::new(HashSet::new())),
111      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
112      multi: Arc::new(MultiProgress::new()),
113      env_vars: HashMap::new(),
114      task_outputs: Arc::new(Mutex::new(HashMap::new())),
115      secret_config: None,
116      shell: None,
117      container_runtime: task_root.container_runtime.clone(),
118      ignore_errors: None,
119      verbose: None,
120      force: false,
121      json_events: false,
122      is_nested: false,
123      cache_store: Arc::new(Mutex::new(cache_store)),
124      current_task_name: None,
125      cancellation_requested: Arc::new(AtomicBool::new(false)),
126    }
127  }
128
129  pub fn new_with_options(task_root: Arc<TaskRoot>, force: bool, json_events: bool) -> Self {
130    let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
131    let multi = if json_events {
132      Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()))
133    } else {
134      Arc::new(MultiProgress::new())
135    };
136    Self {
137      task_root: task_root.clone(),
138      active_tasks: Arc::new(Mutex::new(HashSet::new())),
139      completed_tasks: Arc::new(Mutex::new(HashSet::new())),
140      multi,
141      env_vars: HashMap::new(),
142      task_outputs: Arc::new(Mutex::new(HashMap::new())),
143      secret_config: None,
144      shell: None,
145      container_runtime: task_root.container_runtime.clone(),
146      ignore_errors: None,
147      verbose: None,
148      force,
149      json_events,
150      is_nested: false,
151      cache_store: Arc::new(Mutex::new(cache_store)),
152      current_task_name: None,
153      cancellation_requested: Arc::new(AtomicBool::new(false)),
154    }
155  }
156
157  pub fn from_context(context: &TaskContext) -> Self {
158    Self {
159      task_root: context.task_root.clone(),
160      active_tasks: context.active_tasks.clone(),
161      completed_tasks: context.completed_tasks.clone(),
162      multi: context.multi.clone(),
163      env_vars: context.env_vars.clone(),
164      task_outputs: Arc::new(Mutex::new(HashMap::new())),
165      secret_config: context.secret_config.clone(),
166      shell: context.shell.clone(),
167      container_runtime: context.container_runtime.clone(),
168      ignore_errors: context.ignore_errors,
169      verbose: context.verbose,
170      force: context.force,
171      json_events: context.json_events,
172      is_nested: true,
173      cache_store: context.cache_store.clone(),
174      current_task_name: context.current_task_name.clone(),
175      cancellation_requested: context.cancellation_requested.clone(),
176    }
177  }
178
179  pub fn from_context_with_args(context: &TaskContext, ignore_errors: bool, verbose: bool) -> Self {
180    Self {
181      task_root: context.task_root.clone(),
182      active_tasks: context.active_tasks.clone(),
183      completed_tasks: context.completed_tasks.clone(),
184      multi: context.multi.clone(),
185      env_vars: context.env_vars.clone(),
186      task_outputs: Arc::new(Mutex::new(HashMap::new())),
187      secret_config: context.secret_config.clone(),
188      shell: context.shell.clone(),
189      container_runtime: context.container_runtime.clone(),
190      ignore_errors: Some(ignore_errors),
191      verbose: Some(verbose),
192      force: context.force,
193      json_events: context.json_events,
194      is_nested: true,
195      cache_store: context.cache_store.clone(),
196      current_task_name: context.current_task_name.clone(),
197      cancellation_requested: context.cancellation_requested.clone(),
198    }
199  }
200
201  pub fn extend_env_vars<I>(&mut self, iter: I)
202  where
203    I: IntoIterator<Item = (String, String)>,
204  {
205    self.env_vars.extend(iter);
206  }
207
208  pub fn set_shell(&mut self, shell: &Shell) {
209    let shell = Arc::new(Shell::from_shell(shell));
210    self.shell = Some(shell);
211  }
212
213  pub fn set_secret_config(&mut self, secret_config: SecretConfig) {
214    self.secret_config = Some(secret_config);
215  }
216
217  pub fn set_container_runtime(&mut self, runtime: &ContainerRuntime) {
218    self.container_runtime = Some(runtime.clone());
219  }
220
221  pub fn set_ignore_errors(&mut self, ignore_errors: bool) {
222    self.ignore_errors = Some(ignore_errors);
223  }
224
225  pub fn set_verbose(&mut self, verbose: bool) {
226    self.verbose = Some(verbose);
227  }
228
229  pub fn insert_task_output(&self, name: impl Into<String>, value: impl Into<String>) -> anyhow::Result<()> {
230    let name = name.into();
231    let mut outputs = self
232      .task_outputs
233      .lock()
234      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
235    if outputs.contains_key(&name) {
236      anyhow::bail!("Task output already exists - {}", name);
237    }
238    outputs.insert(name, value.into());
239    Ok(())
240  }
241
242  pub fn get_task_output(&self, name: &str) -> anyhow::Result<Option<String>> {
243    let outputs = self
244      .task_outputs
245      .lock()
246      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
247    Ok(outputs.get(name).cloned())
248  }
249
250  pub fn has_task_output(&self, name: &str) -> anyhow::Result<bool> {
251    let outputs = self
252      .task_outputs
253      .lock()
254      .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
255    Ok(outputs.contains_key(name))
256  }
257
258  pub fn shell(&self) -> Arc<Shell> {
259    self.shell.clone().unwrap_or_else(|| Arc::new(default_shell()))
260  }
261
262  pub fn ignore_errors(&self) -> bool {
263    self.ignore_errors.unwrap_or(default_ignore_errors())
264  }
265
266  pub fn verbose(&self) -> bool {
267    self.verbose.unwrap_or(default_verbose())
268  }
269
270  pub fn is_task_active(&self, task_name: &str) -> anyhow::Result<bool> {
271    let active = self
272      .active_tasks
273      .lock()
274      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
275    Ok(active.contains(task_name))
276  }
277
278  pub fn is_task_completed(&self, task_name: &str) -> anyhow::Result<bool> {
279    let completed = self
280      .completed_tasks
281      .lock()
282      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
283    Ok(completed.contains(task_name))
284  }
285
286  pub fn mark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
287    let mut active = self
288      .active_tasks
289      .lock()
290      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
291    active.insert(task_name.to_string());
292    Ok(())
293  }
294
295  pub fn unmark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
296    let mut active = self
297      .active_tasks
298      .lock()
299      .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
300    active.remove(task_name);
301    Ok(())
302  }
303
304  pub fn mark_task_complete(&self, task_name: &str) -> anyhow::Result<()> {
305    let mut completed = self
306      .completed_tasks
307      .lock()
308      .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
309    completed.insert(task_name.to_string());
310    Ok(())
311  }
312
313  pub fn emit_event<T: Serialize>(&self, value: &T) -> anyhow::Result<()> {
314    if self.json_events {
315      println!("{}", serde_json::to_string(value)?);
316    }
317    Ok(())
318  }
319
320  pub fn set_current_task_name(&mut self, task_name: &str) {
321    self.current_task_name = Some(task_name.to_string());
322  }
323
324  pub fn request_cancellation(&self) {
325    self.cancellation_requested.store(true, Ordering::SeqCst);
326  }
327
328  pub fn cancellation_requested(&self) -> bool {
329    self.cancellation_requested.load(Ordering::SeqCst)
330  }
331
332  pub fn clear_cancellation(&self) {
333    self.cancellation_requested.store(false, Ordering::SeqCst);
334  }
335
336  pub fn resolve_from_config(&self, value: &str) -> PathBuf {
337    self.task_root.resolve_from_config(value)
338  }
339}
340
341#[cfg(test)]
342mod test {
343  use super::*;
344
345  #[test]
346  fn test_task_context_1() -> anyhow::Result<()> {
347    {
348      let context = TaskContext::empty();
349      let expected = if cfg!(windows) { "cmd" } else { "sh" };
350      assert_eq!(context.shell().cmd(), expected.to_string());
351      assert!(!context.ignore_errors());
352      assert!(context.verbose());
353    }
354
355    Ok(())
356  }
357
358  #[test]
359  fn test_task_context_2() -> anyhow::Result<()> {
360    {
361      let mut context = TaskContext::empty();
362      context.set_shell(&Shell::String("bash".to_string()));
363      assert_eq!(context.shell().cmd(), "bash".to_string());
364    }
365
366    Ok(())
367  }
368
369  #[test]
370  fn test_task_context_3() -> anyhow::Result<()> {
371    {
372      let mut context = TaskContext::empty();
373      context.extend_env_vars(vec![("key".to_string(), "value".to_string())]);
374      assert_eq!(context.env_vars.get("key"), Some(&"value".to_string()));
375    }
376
377    Ok(())
378  }
379
380  #[test]
381  fn test_task_context_4() -> anyhow::Result<()> {
382    {
383      let mut context = TaskContext::empty();
384      context.set_ignore_errors(true);
385      assert!(context.ignore_errors());
386    }
387
388    Ok(())
389  }
390
391  #[test]
392  fn test_task_context_5() -> anyhow::Result<()> {
393    {
394      let mut context = TaskContext::empty();
395      context.set_verbose(true);
396      assert!(context.verbose());
397    }
398
399    Ok(())
400  }
401
402  #[test]
403  fn test_task_context_outputs_are_stored() -> anyhow::Result<()> {
404    let context = TaskContext::empty();
405    context.insert_task_output("tag", "v1.0.0")?;
406    assert_eq!(context.get_task_output("tag")?, Some("v1.0.0".to_string()));
407    assert!(context.has_task_output("tag")?);
408    Ok(())
409  }
410}