1use std::collections::HashSet;
2use std::path::PathBuf;
3use std::sync::{
4 Arc,
5 Mutex,
6};
7
8use hashbrown::HashMap;
9use indicatif::{
10 MultiProgress,
11 ProgressDrawTarget,
12};
13use serde::Serialize;
14
15use crate::cache::CacheStore;
16use crate::defaults::{
17 default_ignore_errors,
18 default_shell,
19 default_verbose,
20};
21
22use super::{
23 ActiveTasks,
24 CompletedTasks,
25 ContainerRuntime,
26 Shell,
27 TaskRoot,
28};
29
30#[derive(Clone)]
34pub struct TaskContext {
35 pub task_root: Arc<TaskRoot>,
36 pub active_tasks: ActiveTasks,
37 pub completed_tasks: CompletedTasks,
38 pub multi: Arc<MultiProgress>,
39 pub env_vars: HashMap<String, String>,
40 pub task_outputs: Arc<Mutex<HashMap<String, String>>>,
41 pub secret_vault_location: Option<String>,
42 pub secret_keys_location: Option<String>,
43 pub secret_key_name: Option<String>,
44 pub secret_gpg_key_id: Option<String>,
45 pub shell: Option<Arc<Shell>>,
46 pub container_runtime: Option<ContainerRuntime>,
47 pub ignore_errors: Option<bool>,
48 pub verbose: Option<bool>,
49 pub force: bool,
50 pub json_events: bool,
51 pub is_nested: bool,
52 pub cache_store: Arc<Mutex<CacheStore>>,
53 pub current_task_name: Option<String>,
54}
55
56impl TaskContext {
57 pub fn empty() -> Self {
58 let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
59 Self {
60 task_root: Arc::new(TaskRoot::default()),
61 active_tasks: Arc::new(Mutex::new(HashSet::new())),
62 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
63 multi: Arc::new(mp),
64 env_vars: HashMap::new(),
65 task_outputs: Arc::new(Mutex::new(HashMap::new())),
66 secret_vault_location: None,
67 secret_keys_location: None,
68 secret_key_name: None,
69 secret_gpg_key_id: None,
70 shell: None,
71 container_runtime: None,
72 ignore_errors: None,
73 verbose: None,
74 force: false,
75 json_events: false,
76 is_nested: false,
77 cache_store: Arc::new(Mutex::new(CacheStore::default())),
78 current_task_name: None,
79 }
80 }
81
82 pub fn empty_with_root(task_root: Arc<TaskRoot>) -> Self {
83 let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
84 Self {
85 task_root: task_root.clone(),
86 active_tasks: Arc::new(Mutex::new(HashSet::new())),
87 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
88 multi: Arc::new(mp),
89 env_vars: HashMap::new(),
90 task_outputs: Arc::new(Mutex::new(HashMap::new())),
91 secret_vault_location: None,
92 secret_keys_location: None,
93 secret_key_name: None,
94 secret_gpg_key_id: None,
95 shell: None,
96 container_runtime: None,
97 ignore_errors: None,
98 verbose: None,
99 force: false,
100 json_events: false,
101 is_nested: false,
102 cache_store: Arc::new(Mutex::new(CacheStore::default())),
103 current_task_name: None,
104 }
105 }
106
107 pub fn new(task_root: Arc<TaskRoot>) -> Self {
108 let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
109 Self {
110 task_root: task_root.clone(),
111 active_tasks: Arc::new(Mutex::new(HashSet::new())),
112 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
113 multi: Arc::new(MultiProgress::new()),
114 env_vars: HashMap::new(),
115 task_outputs: Arc::new(Mutex::new(HashMap::new())),
116 secret_vault_location: task_root.vault_location.clone(),
117 secret_keys_location: task_root.keys_location.clone(),
118 secret_key_name: task_root.key_name.clone(),
119 secret_gpg_key_id: task_root.gpg_key_id.clone(),
120 shell: None,
121 container_runtime: task_root.container_runtime.clone(),
122 ignore_errors: None,
123 verbose: None,
124 force: false,
125 json_events: false,
126 is_nested: false,
127 cache_store: Arc::new(Mutex::new(cache_store)),
128 current_task_name: None,
129 }
130 }
131
132 pub fn new_with_options(task_root: Arc<TaskRoot>, force: bool, json_events: bool) -> Self {
133 let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
134 let multi = if json_events {
135 Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()))
136 } else {
137 Arc::new(MultiProgress::new())
138 };
139 Self {
140 task_root: task_root.clone(),
141 active_tasks: Arc::new(Mutex::new(HashSet::new())),
142 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
143 multi,
144 env_vars: HashMap::new(),
145 task_outputs: Arc::new(Mutex::new(HashMap::new())),
146 secret_vault_location: task_root.vault_location.clone(),
147 secret_keys_location: task_root.keys_location.clone(),
148 secret_key_name: task_root.key_name.clone(),
149 secret_gpg_key_id: task_root.gpg_key_id.clone(),
150 shell: None,
151 container_runtime: task_root.container_runtime.clone(),
152 ignore_errors: None,
153 verbose: None,
154 force,
155 json_events,
156 is_nested: false,
157 cache_store: Arc::new(Mutex::new(cache_store)),
158 current_task_name: None,
159 }
160 }
161
162 pub fn from_context(context: &TaskContext) -> Self {
163 Self {
164 task_root: context.task_root.clone(),
165 active_tasks: context.active_tasks.clone(),
166 completed_tasks: context.completed_tasks.clone(),
167 multi: context.multi.clone(),
168 env_vars: context.env_vars.clone(),
169 task_outputs: Arc::new(Mutex::new(HashMap::new())),
170 secret_vault_location: context.secret_vault_location.clone(),
171 secret_keys_location: context.secret_keys_location.clone(),
172 secret_key_name: context.secret_key_name.clone(),
173 secret_gpg_key_id: context.secret_gpg_key_id.clone(),
174 shell: context.shell.clone(),
175 container_runtime: context.container_runtime.clone(),
176 ignore_errors: context.ignore_errors,
177 verbose: context.verbose,
178 force: context.force,
179 json_events: context.json_events,
180 is_nested: true,
181 cache_store: context.cache_store.clone(),
182 current_task_name: context.current_task_name.clone(),
183 }
184 }
185
186 pub fn from_context_with_args(context: &TaskContext, ignore_errors: bool, verbose: bool) -> Self {
187 Self {
188 task_root: context.task_root.clone(),
189 active_tasks: context.active_tasks.clone(),
190 completed_tasks: context.completed_tasks.clone(),
191 multi: context.multi.clone(),
192 env_vars: context.env_vars.clone(),
193 task_outputs: Arc::new(Mutex::new(HashMap::new())),
194 secret_vault_location: context.secret_vault_location.clone(),
195 secret_keys_location: context.secret_keys_location.clone(),
196 secret_key_name: context.secret_key_name.clone(),
197 secret_gpg_key_id: context.secret_gpg_key_id.clone(),
198 shell: context.shell.clone(),
199 container_runtime: context.container_runtime.clone(),
200 ignore_errors: Some(ignore_errors),
201 verbose: Some(verbose),
202 force: context.force,
203 json_events: context.json_events,
204 is_nested: true,
205 cache_store: context.cache_store.clone(),
206 current_task_name: context.current_task_name.clone(),
207 }
208 }
209
210 pub fn extend_env_vars<I>(&mut self, iter: I)
211 where
212 I: IntoIterator<Item = (String, String)>,
213 {
214 self.env_vars.extend(iter);
215 }
216
217 pub fn set_shell(&mut self, shell: &Shell) {
218 let shell = Arc::new(Shell::from_shell(shell));
219 self.shell = Some(shell);
220 }
221
222 pub fn set_secret_vault_location(&mut self, vault_location: impl Into<String>) {
223 self.secret_vault_location = Some(vault_location.into());
224 }
225
226 pub fn set_secret_keys_location(&mut self, keys_location: impl Into<String>) {
227 self.secret_keys_location = Some(keys_location.into());
228 }
229
230 pub fn set_secret_key_name(&mut self, key_name: impl Into<String>) {
231 self.secret_key_name = Some(key_name.into());
232 }
233
234 pub fn set_secret_gpg_key_id(&mut self, gpg_key_id: impl Into<String>) {
235 self.secret_gpg_key_id = Some(gpg_key_id.into());
236 }
237
238 pub fn set_container_runtime(&mut self, runtime: &ContainerRuntime) {
239 self.container_runtime = Some(runtime.clone());
240 }
241
242 pub fn set_ignore_errors(&mut self, ignore_errors: bool) {
243 self.ignore_errors = Some(ignore_errors);
244 }
245
246 pub fn set_verbose(&mut self, verbose: bool) {
247 self.verbose = Some(verbose);
248 }
249
250 pub fn insert_task_output(&self, name: impl Into<String>, value: impl Into<String>) -> anyhow::Result<()> {
251 let name = name.into();
252 let mut outputs = self
253 .task_outputs
254 .lock()
255 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
256 if outputs.contains_key(&name) {
257 anyhow::bail!("Task output already exists - {}", name);
258 }
259 outputs.insert(name, value.into());
260 Ok(())
261 }
262
263 pub fn get_task_output(&self, name: &str) -> anyhow::Result<Option<String>> {
264 let outputs = self
265 .task_outputs
266 .lock()
267 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
268 Ok(outputs.get(name).cloned())
269 }
270
271 pub fn has_task_output(&self, name: &str) -> anyhow::Result<bool> {
272 let outputs = self
273 .task_outputs
274 .lock()
275 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
276 Ok(outputs.contains_key(name))
277 }
278
279 pub fn shell(&self) -> Arc<Shell> {
280 self.shell.clone().unwrap_or_else(|| Arc::new(default_shell()))
281 }
282
283 pub fn ignore_errors(&self) -> bool {
284 self.ignore_errors.unwrap_or(default_ignore_errors())
285 }
286
287 pub fn verbose(&self) -> bool {
288 self.verbose.unwrap_or(default_verbose())
289 }
290
291 pub fn is_task_active(&self, task_name: &str) -> anyhow::Result<bool> {
292 let active = self
293 .active_tasks
294 .lock()
295 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
296 Ok(active.contains(task_name))
297 }
298
299 pub fn is_task_completed(&self, task_name: &str) -> anyhow::Result<bool> {
300 let completed = self
301 .completed_tasks
302 .lock()
303 .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
304 Ok(completed.contains(task_name))
305 }
306
307 pub fn mark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
308 let mut active = self
309 .active_tasks
310 .lock()
311 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
312 active.insert(task_name.to_string());
313 Ok(())
314 }
315
316 pub fn unmark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
317 let mut active = self
318 .active_tasks
319 .lock()
320 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
321 active.remove(task_name);
322 Ok(())
323 }
324
325 pub fn mark_task_complete(&self, task_name: &str) -> anyhow::Result<()> {
326 let mut completed = self
327 .completed_tasks
328 .lock()
329 .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
330 completed.insert(task_name.to_string());
331 Ok(())
332 }
333
334 pub fn emit_event<T: Serialize>(&self, value: &T) -> anyhow::Result<()> {
335 if self.json_events {
336 println!("{}", serde_json::to_string(value)?);
337 }
338 Ok(())
339 }
340
341 pub fn set_current_task_name(&mut self, task_name: &str) {
342 self.current_task_name = Some(task_name.to_string());
343 }
344
345 pub fn resolve_from_config(&self, value: &str) -> PathBuf {
346 self.task_root.resolve_from_config(value)
347 }
348}
349
350#[cfg(test)]
351mod test {
352 use super::*;
353
354 #[test]
355 fn test_task_context_1() -> anyhow::Result<()> {
356 {
357 let context = TaskContext::empty();
358 assert_eq!(context.shell().cmd(), "sh".to_string());
359 assert!(!context.ignore_errors());
360 assert!(context.verbose());
361 }
362
363 Ok(())
364 }
365
366 #[test]
367 fn test_task_context_2() -> anyhow::Result<()> {
368 {
369 let mut context = TaskContext::empty();
370 context.set_shell(&Shell::String("bash".to_string()));
371 assert_eq!(context.shell().cmd(), "bash".to_string());
372 }
373
374 Ok(())
375 }
376
377 #[test]
378 fn test_task_context_3() -> anyhow::Result<()> {
379 {
380 let mut context = TaskContext::empty();
381 context.extend_env_vars(vec![("key".to_string(), "value".to_string())]);
382 assert_eq!(context.env_vars.get("key"), Some(&"value".to_string()));
383 }
384
385 Ok(())
386 }
387
388 #[test]
389 fn test_task_context_4() -> anyhow::Result<()> {
390 {
391 let mut context = TaskContext::empty();
392 context.set_ignore_errors(true);
393 assert!(context.ignore_errors());
394 }
395
396 Ok(())
397 }
398
399 #[test]
400 fn test_task_context_5() -> anyhow::Result<()> {
401 {
402 let mut context = TaskContext::empty();
403 context.set_verbose(true);
404 assert!(context.verbose());
405 }
406
407 Ok(())
408 }
409
410 #[test]
411 fn test_task_context_outputs_are_stored() -> anyhow::Result<()> {
412 let context = TaskContext::empty();
413 context.insert_task_output("tag", "v1.0.0")?;
414 assert_eq!(context.get_task_output("tag")?, Some("v1.0.0".to_string()));
415 assert!(context.has_task_output("tag")?);
416 Ok(())
417 }
418}