1use std::collections::HashSet;
2use std::path::PathBuf;
3use std::sync::{
4 Arc,
5 Mutex,
6};
7
8use hashbrown::HashMap;
9use indicatif::{
10 MultiProgress,
11 ProgressDrawTarget,
12};
13use serde::Serialize;
14
15use crate::cache::CacheStore;
16use crate::defaults::{
17 default_ignore_errors,
18 default_shell,
19 default_verbose,
20};
21
22use super::{
23 ActiveTasks,
24 CompletedTasks,
25 ContainerRuntime,
26 Shell,
27 TaskRoot,
28};
29
30#[derive(Clone)]
34pub struct TaskContext {
35 pub task_root: Arc<TaskRoot>,
36 pub active_tasks: ActiveTasks,
37 pub completed_tasks: CompletedTasks,
38 pub multi: Arc<MultiProgress>,
39 pub env_vars: HashMap<String, String>,
40 pub task_outputs: Arc<Mutex<HashMap<String, String>>>,
41 pub secret_vault_location: Option<String>,
42 pub secret_keys_location: Option<String>,
43 pub secret_key_name: Option<String>,
44 pub shell: Option<Arc<Shell>>,
45 pub container_runtime: Option<ContainerRuntime>,
46 pub ignore_errors: Option<bool>,
47 pub verbose: Option<bool>,
48 pub force: bool,
49 pub json_events: bool,
50 pub is_nested: bool,
51 pub cache_store: Arc<Mutex<CacheStore>>,
52 pub current_task_name: Option<String>,
53}
54
55impl TaskContext {
56 pub fn empty() -> Self {
57 let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
58 Self {
59 task_root: Arc::new(TaskRoot::default()),
60 active_tasks: Arc::new(Mutex::new(HashSet::new())),
61 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
62 multi: Arc::new(mp),
63 env_vars: HashMap::new(),
64 task_outputs: Arc::new(Mutex::new(HashMap::new())),
65 secret_vault_location: None,
66 secret_keys_location: None,
67 secret_key_name: None,
68 shell: None,
69 container_runtime: None,
70 ignore_errors: None,
71 verbose: None,
72 force: false,
73 json_events: false,
74 is_nested: false,
75 cache_store: Arc::new(Mutex::new(CacheStore::default())),
76 current_task_name: None,
77 }
78 }
79
80 pub fn empty_with_root(task_root: Arc<TaskRoot>) -> Self {
81 let mp = MultiProgress::with_draw_target(ProgressDrawTarget::hidden());
82 Self {
83 task_root: task_root.clone(),
84 active_tasks: Arc::new(Mutex::new(HashSet::new())),
85 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
86 multi: Arc::new(mp),
87 env_vars: HashMap::new(),
88 task_outputs: Arc::new(Mutex::new(HashMap::new())),
89 secret_vault_location: None,
90 secret_keys_location: None,
91 secret_key_name: None,
92 shell: None,
93 container_runtime: None,
94 ignore_errors: None,
95 verbose: None,
96 force: false,
97 json_events: false,
98 is_nested: false,
99 cache_store: Arc::new(Mutex::new(CacheStore::default())),
100 current_task_name: None,
101 }
102 }
103
104 pub fn new(task_root: Arc<TaskRoot>) -> Self {
105 let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
106 Self {
107 task_root: task_root.clone(),
108 active_tasks: Arc::new(Mutex::new(HashSet::new())),
109 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
110 multi: Arc::new(MultiProgress::new()),
111 env_vars: HashMap::new(),
112 task_outputs: Arc::new(Mutex::new(HashMap::new())),
113 secret_vault_location: task_root.vault_location.clone(),
114 secret_keys_location: task_root.keys_location.clone(),
115 secret_key_name: task_root.key_name.clone(),
116 shell: None,
117 container_runtime: task_root.container_runtime.clone(),
118 ignore_errors: None,
119 verbose: None,
120 force: false,
121 json_events: false,
122 is_nested: false,
123 cache_store: Arc::new(Mutex::new(cache_store)),
124 current_task_name: None,
125 }
126 }
127
128 pub fn new_with_options(task_root: Arc<TaskRoot>, force: bool, json_events: bool) -> Self {
129 let cache_store = CacheStore::load_in_dir(&task_root.cache_base_dir()).unwrap_or_default();
130 let multi = if json_events {
131 Arc::new(MultiProgress::with_draw_target(ProgressDrawTarget::hidden()))
132 } else {
133 Arc::new(MultiProgress::new())
134 };
135 Self {
136 task_root: task_root.clone(),
137 active_tasks: Arc::new(Mutex::new(HashSet::new())),
138 completed_tasks: Arc::new(Mutex::new(HashSet::new())),
139 multi,
140 env_vars: HashMap::new(),
141 task_outputs: Arc::new(Mutex::new(HashMap::new())),
142 secret_vault_location: task_root.vault_location.clone(),
143 secret_keys_location: task_root.keys_location.clone(),
144 secret_key_name: task_root.key_name.clone(),
145 shell: None,
146 container_runtime: task_root.container_runtime.clone(),
147 ignore_errors: None,
148 verbose: None,
149 force,
150 json_events,
151 is_nested: false,
152 cache_store: Arc::new(Mutex::new(cache_store)),
153 current_task_name: None,
154 }
155 }
156
157 pub fn from_context(context: &TaskContext) -> Self {
158 Self {
159 task_root: context.task_root.clone(),
160 active_tasks: context.active_tasks.clone(),
161 completed_tasks: context.completed_tasks.clone(),
162 multi: context.multi.clone(),
163 env_vars: context.env_vars.clone(),
164 task_outputs: Arc::new(Mutex::new(HashMap::new())),
165 secret_vault_location: context.secret_vault_location.clone(),
166 secret_keys_location: context.secret_keys_location.clone(),
167 secret_key_name: context.secret_key_name.clone(),
168 shell: context.shell.clone(),
169 container_runtime: context.container_runtime.clone(),
170 ignore_errors: context.ignore_errors,
171 verbose: context.verbose,
172 force: context.force,
173 json_events: context.json_events,
174 is_nested: true,
175 cache_store: context.cache_store.clone(),
176 current_task_name: context.current_task_name.clone(),
177 }
178 }
179
180 pub fn from_context_with_args(context: &TaskContext, ignore_errors: bool, verbose: bool) -> Self {
181 Self {
182 task_root: context.task_root.clone(),
183 active_tasks: context.active_tasks.clone(),
184 completed_tasks: context.completed_tasks.clone(),
185 multi: context.multi.clone(),
186 env_vars: context.env_vars.clone(),
187 task_outputs: Arc::new(Mutex::new(HashMap::new())),
188 secret_vault_location: context.secret_vault_location.clone(),
189 secret_keys_location: context.secret_keys_location.clone(),
190 secret_key_name: context.secret_key_name.clone(),
191 shell: context.shell.clone(),
192 container_runtime: context.container_runtime.clone(),
193 ignore_errors: Some(ignore_errors),
194 verbose: Some(verbose),
195 force: context.force,
196 json_events: context.json_events,
197 is_nested: true,
198 cache_store: context.cache_store.clone(),
199 current_task_name: context.current_task_name.clone(),
200 }
201 }
202
203 pub fn extend_env_vars<I>(&mut self, iter: I)
204 where
205 I: IntoIterator<Item = (String, String)>,
206 {
207 self.env_vars.extend(iter);
208 }
209
210 pub fn set_shell(&mut self, shell: &Shell) {
211 let shell = Arc::new(Shell::from_shell(shell));
212 self.shell = Some(shell);
213 }
214
215 pub fn set_secret_vault_location(&mut self, vault_location: impl Into<String>) {
216 self.secret_vault_location = Some(vault_location.into());
217 }
218
219 pub fn set_secret_keys_location(&mut self, keys_location: impl Into<String>) {
220 self.secret_keys_location = Some(keys_location.into());
221 }
222
223 pub fn set_secret_key_name(&mut self, key_name: impl Into<String>) {
224 self.secret_key_name = Some(key_name.into());
225 }
226
227 pub fn set_container_runtime(&mut self, runtime: &ContainerRuntime) {
228 self.container_runtime = Some(runtime.clone());
229 }
230
231 pub fn set_ignore_errors(&mut self, ignore_errors: bool) {
232 self.ignore_errors = Some(ignore_errors);
233 }
234
235 pub fn set_verbose(&mut self, verbose: bool) {
236 self.verbose = Some(verbose);
237 }
238
239 pub fn insert_task_output(&self, name: impl Into<String>, value: impl Into<String>) -> anyhow::Result<()> {
240 let name = name.into();
241 let mut outputs = self
242 .task_outputs
243 .lock()
244 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
245 if outputs.contains_key(&name) {
246 anyhow::bail!("Task output already exists - {}", name);
247 }
248 outputs.insert(name, value.into());
249 Ok(())
250 }
251
252 pub fn get_task_output(&self, name: &str) -> anyhow::Result<Option<String>> {
253 let outputs = self
254 .task_outputs
255 .lock()
256 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
257 Ok(outputs.get(name).cloned())
258 }
259
260 pub fn has_task_output(&self, name: &str) -> anyhow::Result<bool> {
261 let outputs = self
262 .task_outputs
263 .lock()
264 .map_err(|e| anyhow::anyhow!("Failed to lock task outputs - {}", e))?;
265 Ok(outputs.contains_key(name))
266 }
267
268 pub fn shell(&self) -> Arc<Shell> {
269 self.shell.clone().unwrap_or_else(|| Arc::new(default_shell()))
270 }
271
272 pub fn ignore_errors(&self) -> bool {
273 self.ignore_errors.unwrap_or(default_ignore_errors())
274 }
275
276 pub fn verbose(&self) -> bool {
277 self.verbose.unwrap_or(default_verbose())
278 }
279
280 pub fn is_task_active(&self, task_name: &str) -> anyhow::Result<bool> {
281 let active = self
282 .active_tasks
283 .lock()
284 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
285 Ok(active.contains(task_name))
286 }
287
288 pub fn is_task_completed(&self, task_name: &str) -> anyhow::Result<bool> {
289 let completed = self
290 .completed_tasks
291 .lock()
292 .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
293 Ok(completed.contains(task_name))
294 }
295
296 pub fn mark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
297 let mut active = self
298 .active_tasks
299 .lock()
300 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
301 active.insert(task_name.to_string());
302 Ok(())
303 }
304
305 pub fn unmark_task_active(&self, task_name: &str) -> anyhow::Result<()> {
306 let mut active = self
307 .active_tasks
308 .lock()
309 .map_err(|e| anyhow::anyhow!("Failed to lock active tasks - {}", e))?;
310 active.remove(task_name);
311 Ok(())
312 }
313
314 pub fn mark_task_complete(&self, task_name: &str) -> anyhow::Result<()> {
315 let mut completed = self
316 .completed_tasks
317 .lock()
318 .map_err(|e| anyhow::anyhow!("Failed to lock completed tasks - {}", e))?;
319 completed.insert(task_name.to_string());
320 Ok(())
321 }
322
323 pub fn emit_event<T: Serialize>(&self, value: &T) -> anyhow::Result<()> {
324 if self.json_events {
325 println!("{}", serde_json::to_string(value)?);
326 }
327 Ok(())
328 }
329
330 pub fn set_current_task_name(&mut self, task_name: &str) {
331 self.current_task_name = Some(task_name.to_string());
332 }
333
334 pub fn resolve_from_config(&self, value: &str) -> PathBuf {
335 self.task_root.resolve_from_config(value)
336 }
337}
338
339#[cfg(test)]
340mod test {
341 use super::*;
342
343 #[test]
344 fn test_task_context_1() -> anyhow::Result<()> {
345 {
346 let context = TaskContext::empty();
347 assert_eq!(context.shell().cmd(), "sh".to_string());
348 assert!(!context.ignore_errors());
349 assert!(context.verbose());
350 }
351
352 Ok(())
353 }
354
355 #[test]
356 fn test_task_context_2() -> anyhow::Result<()> {
357 {
358 let mut context = TaskContext::empty();
359 context.set_shell(&Shell::String("bash".to_string()));
360 assert_eq!(context.shell().cmd(), "bash".to_string());
361 }
362
363 Ok(())
364 }
365
366 #[test]
367 fn test_task_context_3() -> anyhow::Result<()> {
368 {
369 let mut context = TaskContext::empty();
370 context.extend_env_vars(vec![("key".to_string(), "value".to_string())]);
371 assert_eq!(context.env_vars.get("key"), Some(&"value".to_string()));
372 }
373
374 Ok(())
375 }
376
377 #[test]
378 fn test_task_context_4() -> anyhow::Result<()> {
379 {
380 let mut context = TaskContext::empty();
381 context.set_ignore_errors(true);
382 assert!(context.ignore_errors());
383 }
384
385 Ok(())
386 }
387
388 #[test]
389 fn test_task_context_5() -> anyhow::Result<()> {
390 {
391 let mut context = TaskContext::empty();
392 context.set_verbose(true);
393 assert!(context.verbose());
394 }
395
396 Ok(())
397 }
398
399 #[test]
400 fn test_task_context_outputs_are_stored() -> anyhow::Result<()> {
401 let context = TaskContext::empty();
402 context.insert_task_output("tag", "v1.0.0")?;
403 assert_eq!(context.get_task_output("tag")?, Some("v1.0.0".to_string()));
404 assert!(context.has_task_output("tag")?);
405 Ok(())
406 }
407}