1use super::{Task, TaskDefinition, TaskGraph, TaskGroup, Tasks};
9use crate::environment::Environment;
10use crate::{Error, Result};
11use async_recursion::async_recursion;
12use std::collections::HashMap;
13use std::process::Stdio;
14use std::sync::Arc;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::task::JoinSet;
18
19#[derive(Debug, Clone)]
21pub struct TaskResult {
22 pub name: String,
24 pub exit_code: Option<i32>,
26 pub stdout: String,
28 pub stderr: String,
30 pub success: bool,
32}
33
34#[derive(Debug, Clone)]
36pub struct ExecutorConfig {
37 pub capture_output: bool,
39 pub max_parallel: usize,
41 pub environment: Environment,
43}
44
45impl Default for ExecutorConfig {
46 fn default() -> Self {
47 Self {
48 capture_output: false,
49 max_parallel: 0,
50 environment: Environment::new(),
51 }
52 }
53}
54
55pub struct TaskExecutor {
57 config: ExecutorConfig,
58}
59
60impl TaskExecutor {
61 pub fn new(config: ExecutorConfig) -> Self {
63 Self { config }
64 }
65
66 pub async fn execute_task(&self, name: &str, task: &Task) -> Result<TaskResult> {
68 tracing::info!("Executing task: {}", name);
69
70 let mut cmd = if let Some(shell) = &task.shell {
72 if shell.command.is_some() && shell.flag.is_some() {
74 let shell_command = shell.command.as_ref().unwrap();
76 let shell_flag = shell.flag.as_ref().unwrap();
77 let mut cmd = Command::new(shell_command);
78 cmd.arg(shell_flag);
79
80 if task.args.is_empty() {
81 cmd.arg(&task.command);
83 } else {
84 let full_command = if task.command.is_empty() {
86 task.args.join(" ")
87 } else {
88 format!("{} {}", task.command, task.args.join(" "))
89 };
90 cmd.arg(full_command);
91 }
92 cmd
93 } else {
94 let mut cmd = Command::new(&task.command);
96 for arg in &task.args {
97 cmd.arg(arg);
98 }
99 cmd
100 }
101 } else {
102 let mut cmd = Command::new(&task.command);
104 for arg in &task.args {
105 cmd.arg(arg);
106 }
107 cmd
108 };
109
110 let env_vars = self.config.environment.merge_with_system();
112 for (key, value) in env_vars {
113 cmd.env(key, value);
114 }
115
116 if self.config.capture_output {
118 cmd.stdout(Stdio::piped());
119 cmd.stderr(Stdio::piped());
120 } else {
121 cmd.stdout(Stdio::inherit());
122 cmd.stderr(Stdio::inherit());
123 }
124
125 let mut child = cmd
127 .spawn()
128 .map_err(|e| Error::configuration(format!("Failed to spawn task '{}': {}", name, e)))?;
129
130 let (stdout, stderr) = if self.config.capture_output {
131 let stdout_handle = child.stdout.take();
133 let stderr_handle = child.stderr.take();
134
135 let stdout_task = async {
136 if let Some(stdout) = stdout_handle {
137 let reader = BufReader::new(stdout);
138 let mut lines = reader.lines();
139 let mut stdout_lines = Vec::new();
140 while let Ok(Some(line)) = lines.next_line().await {
141 stdout_lines.push(line);
142 }
143 stdout_lines.join("\n")
144 } else {
145 String::new()
146 }
147 };
148
149 let stderr_task = async {
150 if let Some(stderr) = stderr_handle {
151 let reader = BufReader::new(stderr);
152 let mut lines = reader.lines();
153 let mut stderr_lines = Vec::new();
154 while let Ok(Some(line)) = lines.next_line().await {
155 stderr_lines.push(line);
156 }
157 stderr_lines.join("\n")
158 } else {
159 String::new()
160 }
161 };
162
163 tokio::join!(stdout_task, stderr_task)
165 } else {
166 (String::new(), String::new())
167 };
168
169 let status = child.wait().await.map_err(|e| {
171 Error::configuration(format!("Failed to wait for task '{}': {}", name, e))
172 })?;
173
174 let exit_code = status.code();
175 let success = status.success();
176
177 if !success {
178 tracing::warn!("Task '{}' failed with exit code: {:?}", name, exit_code);
179 } else {
180 tracing::info!("Task '{}' completed successfully", name);
181 }
182
183 Ok(TaskResult {
184 name: name.to_string(),
185 exit_code,
186 stdout,
187 stderr,
188 success,
189 })
190 }
191
192 #[async_recursion]
194 pub async fn execute_definition(
195 &self,
196 name: &str,
197 definition: &TaskDefinition,
198 all_tasks: &Tasks,
199 ) -> Result<Vec<TaskResult>> {
200 match definition {
201 TaskDefinition::Single(task) => {
202 let result = self.execute_task(name, task).await?;
203 Ok(vec![result])
204 }
205 TaskDefinition::Group(group) => self.execute_group(name, group, all_tasks).await,
206 }
207 }
208
209 async fn execute_group(
211 &self,
212 prefix: &str,
213 group: &TaskGroup,
214 all_tasks: &Tasks,
215 ) -> Result<Vec<TaskResult>> {
216 match group {
217 TaskGroup::Sequential(tasks) => self.execute_sequential(prefix, tasks, all_tasks).await,
218 TaskGroup::Parallel(tasks) => self.execute_parallel(prefix, tasks, all_tasks).await,
219 }
220 }
221
222 async fn execute_sequential(
224 &self,
225 prefix: &str,
226 tasks: &[TaskDefinition],
227 all_tasks: &Tasks,
228 ) -> Result<Vec<TaskResult>> {
229 let mut results = Vec::new();
230
231 for (i, task_def) in tasks.iter().enumerate() {
232 let task_name = format!("{}[{}]", prefix, i);
233 let task_results = self
234 .execute_definition(&task_name, task_def, all_tasks)
235 .await?;
236
237 for result in &task_results {
239 if !result.success {
240 return Err(Error::configuration(format!(
241 "Task '{}' failed in sequential group",
242 result.name
243 )));
244 }
245 }
246
247 results.extend(task_results);
248 }
249
250 Ok(results)
251 }
252
253 async fn execute_parallel(
255 &self,
256 prefix: &str,
257 tasks: &HashMap<String, TaskDefinition>,
258 all_tasks: &Tasks,
259 ) -> Result<Vec<TaskResult>> {
260 let mut join_set = JoinSet::new();
261 let all_tasks = Arc::new(all_tasks.clone());
262
263 for (name, task_def) in tasks {
264 let task_name = format!("{}.{}", prefix, name);
265 let task_def = task_def.clone();
266 let all_tasks = Arc::clone(&all_tasks);
267 let executor = self.clone_with_config();
268
269 join_set.spawn(async move {
270 executor
271 .execute_definition(&task_name, &task_def, &all_tasks)
272 .await
273 });
274
275 if self.config.max_parallel > 0 && join_set.len() >= self.config.max_parallel {
277 if let Some(result) = join_set.join_next().await {
279 match result {
280 Ok(Ok(_)) => {} Ok(Err(e)) => return Err(e),
282 Err(e) => {
283 return Err(Error::configuration(format!(
284 "Task execution panicked: {}",
285 e
286 )));
287 }
288 }
289 }
290 }
291 }
292
293 let mut all_results = Vec::new();
295 while let Some(result) = join_set.join_next().await {
296 match result {
297 Ok(Ok(results)) => all_results.extend(results),
298 Ok(Err(e)) => return Err(e),
299 Err(e) => {
300 return Err(Error::configuration(format!(
301 "Task execution panicked: {}",
302 e
303 )));
304 }
305 }
306 }
307
308 Ok(all_results)
309 }
310
311 pub async fn execute_graph(&self, graph: &TaskGraph) -> Result<Vec<TaskResult>> {
313 let parallel_groups = graph.get_parallel_groups()?;
314 let mut all_results = Vec::new();
315
316 let mut join_set = JoinSet::new();
318 let mut group_iter = parallel_groups.into_iter();
319 let mut current_group = group_iter.next();
320
321 while current_group.is_some() || !join_set.is_empty() {
322 if let Some(group) = current_group.as_mut() {
324 while let Some(node) = group.pop() {
325 let task = node.task.clone();
326 let name = node.name.clone();
327 let executor = self.clone_with_config();
328
329 join_set.spawn(async move { executor.execute_task(&name, &task).await });
330
331 if self.config.max_parallel > 0 && join_set.len() >= self.config.max_parallel {
333 break;
334 }
335 }
336
337 if group.is_empty() {
339 current_group = group_iter.next();
340 }
341 }
342
343 if let Some(result) = join_set.join_next().await {
345 match result {
346 Ok(Ok(task_result)) => {
347 if !task_result.success {
348 return Err(Error::configuration(format!(
349 "Task '{}' failed",
350 task_result.name
351 )));
352 }
353 all_results.push(task_result);
354 }
355 Ok(Err(e)) => return Err(e),
356 Err(e) => {
357 return Err(Error::configuration(format!(
358 "Task execution panicked: {}",
359 e
360 )));
361 }
362 }
363 }
364 }
365
366 Ok(all_results)
367 }
368
369 fn clone_with_config(&self) -> Self {
371 Self {
372 config: self.config.clone(),
373 }
374 }
375}
376
377pub async fn execute_command(
379 command: &str,
380 args: &[String],
381 environment: &Environment,
382) -> Result<i32> {
383 tracing::info!("Executing command: {} {:?}", command, args);
384
385 let mut cmd = Command::new(command);
386 cmd.args(args);
387
388 let env_vars = environment.merge_with_system();
390 for (key, value) in env_vars {
391 cmd.env(key, value);
392 }
393
394 cmd.stdout(Stdio::inherit());
396 cmd.stderr(Stdio::inherit());
397 cmd.stdin(Stdio::inherit());
398
399 let status = cmd.status().await.map_err(|e| {
401 Error::configuration(format!("Failed to execute command '{}': {}", command, e))
402 })?;
403
404 Ok(status.code().unwrap_or(1))
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[tokio::test]
412 async fn test_executor_config_default() {
413 let config = ExecutorConfig::default();
414 assert!(!config.capture_output);
415 assert_eq!(config.max_parallel, 0);
416 assert!(config.environment.is_empty());
417 }
418
419 #[tokio::test]
420 async fn test_task_result() {
421 let result = TaskResult {
422 name: "test".to_string(),
423 exit_code: Some(0),
424 stdout: "output".to_string(),
425 stderr: String::new(),
426 success: true,
427 };
428
429 assert_eq!(result.name, "test");
430 assert_eq!(result.exit_code, Some(0));
431 assert!(result.success);
432 assert_eq!(result.stdout, "output");
433 }
434
435 #[tokio::test]
436 async fn test_execute_simple_task() {
437 let config = ExecutorConfig {
438 capture_output: true,
439 ..Default::default()
440 };
441
442 let executor = TaskExecutor::new(config);
443
444 let task = Task {
445 command: "echo".to_string(),
446 args: vec!["hello".to_string()],
447 shell: None,
448 env: HashMap::new(),
449 depends_on: vec![],
450 inputs: vec![],
451 outputs: vec![],
452 description: Some("Hello task".to_string()),
453 };
454
455 let result = executor.execute_task("test", &task).await.unwrap();
456
457 assert!(result.success);
458 assert_eq!(result.exit_code, Some(0));
459 assert!(result.stdout.contains("hello"));
460 }
461
462 #[tokio::test]
463 async fn test_execute_with_environment() {
464 let mut config = ExecutorConfig {
465 capture_output: true,
466 ..Default::default()
467 };
468 config
469 .environment
470 .set("TEST_VAR".to_string(), "test_value".to_string());
471
472 let executor = TaskExecutor::new(config);
473
474 let task = Task {
475 command: "printenv".to_string(),
476 args: vec!["TEST_VAR".to_string()],
477 shell: None,
478 env: HashMap::new(),
479 depends_on: vec![],
480 inputs: vec![],
481 outputs: vec![],
482 description: Some("Print env task".to_string()),
483 };
484
485 let result = executor.execute_task("test", &task).await.unwrap();
486
487 assert!(result.success);
488 assert!(result.stdout.contains("test_value"));
489 }
490
491 #[tokio::test]
492 async fn test_execute_failing_task() {
493 let config = ExecutorConfig {
494 capture_output: true,
495 ..Default::default()
496 };
497
498 let executor = TaskExecutor::new(config);
499
500 let task = Task {
501 command: "false".to_string(),
502 args: vec![],
503 shell: None,
504 env: HashMap::new(),
505 depends_on: vec![],
506 inputs: vec![],
507 outputs: vec![],
508 description: Some("Failing task".to_string()),
509 };
510
511 let result = executor.execute_task("test", &task).await.unwrap();
512
513 assert!(!result.success);
514 assert_eq!(result.exit_code, Some(1));
515 }
516
517 #[tokio::test]
518 async fn test_execute_sequential_group() {
519 let config = ExecutorConfig {
520 capture_output: true,
521 ..Default::default()
522 };
523
524 let executor = TaskExecutor::new(config);
525
526 let task1 = Task {
527 command: "echo".to_string(),
528 args: vec!["first".to_string()],
529 shell: None,
530 env: HashMap::new(),
531 depends_on: vec![],
532 inputs: vec![],
533 outputs: vec![],
534 description: Some("First task".to_string()),
535 };
536
537 let task2 = Task {
538 command: "echo".to_string(),
539 args: vec!["second".to_string()],
540 shell: None,
541 env: HashMap::new(),
542 depends_on: vec![],
543 inputs: vec![],
544 outputs: vec![],
545 description: Some("Second task".to_string()),
546 };
547
548 let group = TaskGroup::Sequential(vec![
549 TaskDefinition::Single(task1),
550 TaskDefinition::Single(task2),
551 ]);
552
553 let all_tasks = Tasks::new();
554 let results = executor
555 .execute_group("seq", &group, &all_tasks)
556 .await
557 .unwrap();
558
559 assert_eq!(results.len(), 2);
560 assert!(results[0].stdout.contains("first"));
561 assert!(results[1].stdout.contains("second"));
562 }
563
564 #[tokio::test]
565 async fn test_command_injection_prevention() {
566 let config = ExecutorConfig {
567 capture_output: true,
568 ..Default::default()
569 };
570
571 let executor = TaskExecutor::new(config);
572
573 let malicious_task = Task {
575 command: "echo".to_string(),
576 args: vec!["hello".to_string(), "; rm -rf /".to_string()],
577 shell: None,
578 env: HashMap::new(),
579 depends_on: vec![],
580 inputs: vec![],
581 outputs: vec![],
582 description: Some("Malicious task test".to_string()),
583 };
584
585 let result = executor
586 .execute_task("malicious", &malicious_task)
587 .await
588 .unwrap();
589
590 assert!(result.success);
592 assert!(result.stdout.contains("hello ; rm -rf /"));
593 }
594
595 #[tokio::test]
596 async fn test_special_characters_in_args() {
597 let config = ExecutorConfig {
598 capture_output: true,
599 ..Default::default()
600 };
601
602 let executor = TaskExecutor::new(config);
603
604 let special_chars = vec![
606 "$USER", "$(whoami)", "`whoami`", "&& echo hacked", "|| echo failed", "> /tmp/hack", "| cat", ];
614
615 for special_arg in special_chars {
616 let task = Task {
617 command: "echo".to_string(),
618 args: vec!["safe".to_string(), special_arg.to_string()],
619 shell: None,
620 env: HashMap::new(),
621 depends_on: vec![],
622 inputs: vec![],
623 outputs: vec![],
624 description: Some("Special character test".to_string()),
625 };
626
627 let result = executor.execute_task("special", &task).await.unwrap();
628
629 assert!(result.success);
631 assert!(result.stdout.contains("safe"));
632 assert!(result.stdout.contains(special_arg));
633 }
634 }
635
636 #[tokio::test]
637 async fn test_environment_variable_safety() {
638 let mut config = ExecutorConfig {
639 capture_output: true,
640 ..Default::default()
641 };
642
643 config
645 .environment
646 .set("DANGEROUS_VAR".to_string(), "; rm -rf /".to_string());
647
648 let executor = TaskExecutor::new(config);
649
650 let task = Task {
651 command: "printenv".to_string(),
652 args: vec!["DANGEROUS_VAR".to_string()],
653 shell: None,
654 env: HashMap::new(),
655 depends_on: vec![],
656 inputs: vec![],
657 outputs: vec![],
658 description: Some("Environment variable safety test".to_string()),
659 };
660
661 let result = executor.execute_task("env_test", &task).await.unwrap();
662
663 assert!(result.success);
665 assert!(result.stdout.contains("; rm -rf /"));
666 }
667}