cuenv_core/tasks/
executor.rs

1//! Task executor for running tasks with environment support
2//!
3//! This module handles the actual execution of tasks, including:
4//! - Environment variable propagation
5//! - Parallel and sequential execution
6//! - Output capture and streaming
7
8use super::{Task, TaskDefinition, TaskGraph, TaskGroup, Tasks};
9use crate::environment::Environment;
10use crate::{Error, Result};
11use async_recursion::async_recursion;
12use std::collections::HashMap;
13use std::process::Stdio;
14use std::sync::Arc;
15use tokio::io::{AsyncBufReadExt, BufReader};
16use tokio::process::Command;
17use tokio::task::JoinSet;
18
19/// Task execution result
20#[derive(Debug, Clone)]
21pub struct TaskResult {
22    /// Task name
23    pub name: String,
24    /// Exit code
25    pub exit_code: Option<i32>,
26    /// Standard output
27    pub stdout: String,
28    /// Standard error
29    pub stderr: String,
30    /// Whether the task succeeded
31    pub success: bool,
32}
33
34/// Task executor configuration
35#[derive(Debug, Clone)]
36pub struct ExecutorConfig {
37    /// Whether to capture output (vs streaming to stdout/stderr)
38    pub capture_output: bool,
39    /// Maximum parallel tasks (0 = unlimited)
40    pub max_parallel: usize,
41    /// Environment variables to propagate
42    pub environment: Environment,
43}
44
45impl Default for ExecutorConfig {
46    fn default() -> Self {
47        Self {
48            capture_output: false,
49            max_parallel: 0,
50            environment: Environment::new(),
51        }
52    }
53}
54
55/// Task executor
56pub struct TaskExecutor {
57    config: ExecutorConfig,
58}
59
60impl TaskExecutor {
61    /// Create a new task executor
62    pub fn new(config: ExecutorConfig) -> Self {
63        Self { config }
64    }
65
66    /// Execute a single task
67    pub async fn execute_task(&self, name: &str, task: &Task) -> Result<TaskResult> {
68        tracing::info!("Executing task: {}", name);
69
70        // Build the command based on shell and args configuration
71        let mut cmd = if let Some(shell) = &task.shell {
72            // Check if shell is properly configured
73            if shell.command.is_some() && shell.flag.is_some() {
74                // Execute via specified shell
75                let shell_command = shell.command.as_ref().unwrap();
76                let shell_flag = shell.flag.as_ref().unwrap();
77                let mut cmd = Command::new(shell_command);
78                cmd.arg(shell_flag);
79
80                if task.args.is_empty() {
81                    // Just execute the command string as-is
82                    cmd.arg(&task.command);
83                } else {
84                    // Concatenate command and args with proper shell quoting
85                    let full_command = if task.command.is_empty() {
86                        task.args.join(" ")
87                    } else {
88                        format!("{} {}", task.command, task.args.join(" "))
89                    };
90                    cmd.arg(full_command);
91                }
92                cmd
93            } else {
94                // Shell field present but not properly configured, fall back to direct execution
95                let mut cmd = Command::new(&task.command);
96                for arg in &task.args {
97                    cmd.arg(arg);
98                }
99                cmd
100            }
101        } else {
102            // Direct execution (secure by default)
103            let mut cmd = Command::new(&task.command);
104            for arg in &task.args {
105                cmd.arg(arg);
106            }
107            cmd
108        };
109
110        // Set environment variables
111        let env_vars = self.config.environment.merge_with_system();
112        for (key, value) in env_vars {
113            cmd.env(key, value);
114        }
115
116        // Configure output handling
117        if self.config.capture_output {
118            cmd.stdout(Stdio::piped());
119            cmd.stderr(Stdio::piped());
120        } else {
121            cmd.stdout(Stdio::inherit());
122            cmd.stderr(Stdio::inherit());
123        }
124
125        // Execute the command
126        let mut child = cmd
127            .spawn()
128            .map_err(|e| Error::configuration(format!("Failed to spawn task '{}': {}", name, e)))?;
129
130        let (stdout, stderr) = if self.config.capture_output {
131            // Capture output concurrently to prevent deadlocks
132            let stdout_handle = child.stdout.take();
133            let stderr_handle = child.stderr.take();
134
135            let stdout_task = async {
136                if let Some(stdout) = stdout_handle {
137                    let reader = BufReader::new(stdout);
138                    let mut lines = reader.lines();
139                    let mut stdout_lines = Vec::new();
140                    while let Ok(Some(line)) = lines.next_line().await {
141                        stdout_lines.push(line);
142                    }
143                    stdout_lines.join("\n")
144                } else {
145                    String::new()
146                }
147            };
148
149            let stderr_task = async {
150                if let Some(stderr) = stderr_handle {
151                    let reader = BufReader::new(stderr);
152                    let mut lines = reader.lines();
153                    let mut stderr_lines = Vec::new();
154                    while let Ok(Some(line)) = lines.next_line().await {
155                        stderr_lines.push(line);
156                    }
157                    stderr_lines.join("\n")
158                } else {
159                    String::new()
160                }
161            };
162
163            // Read stdout and stderr concurrently
164            tokio::join!(stdout_task, stderr_task)
165        } else {
166            (String::new(), String::new())
167        };
168
169        // Wait for completion
170        let status = child.wait().await.map_err(|e| {
171            Error::configuration(format!("Failed to wait for task '{}': {}", name, e))
172        })?;
173
174        let exit_code = status.code();
175        let success = status.success();
176
177        if !success {
178            tracing::warn!("Task '{}' failed with exit code: {:?}", name, exit_code);
179        } else {
180            tracing::info!("Task '{}' completed successfully", name);
181        }
182
183        Ok(TaskResult {
184            name: name.to_string(),
185            exit_code,
186            stdout,
187            stderr,
188            success,
189        })
190    }
191
192    /// Execute a task definition (single task or group)
193    #[async_recursion]
194    pub async fn execute_definition(
195        &self,
196        name: &str,
197        definition: &TaskDefinition,
198        all_tasks: &Tasks,
199    ) -> Result<Vec<TaskResult>> {
200        match definition {
201            TaskDefinition::Single(task) => {
202                let result = self.execute_task(name, task).await?;
203                Ok(vec![result])
204            }
205            TaskDefinition::Group(group) => self.execute_group(name, group, all_tasks).await,
206        }
207    }
208
209    /// Execute a task group
210    async fn execute_group(
211        &self,
212        prefix: &str,
213        group: &TaskGroup,
214        all_tasks: &Tasks,
215    ) -> Result<Vec<TaskResult>> {
216        match group {
217            TaskGroup::Sequential(tasks) => self.execute_sequential(prefix, tasks, all_tasks).await,
218            TaskGroup::Parallel(tasks) => self.execute_parallel(prefix, tasks, all_tasks).await,
219        }
220    }
221
222    /// Execute tasks sequentially
223    async fn execute_sequential(
224        &self,
225        prefix: &str,
226        tasks: &[TaskDefinition],
227        all_tasks: &Tasks,
228    ) -> Result<Vec<TaskResult>> {
229        let mut results = Vec::new();
230
231        for (i, task_def) in tasks.iter().enumerate() {
232            let task_name = format!("{}[{}]", prefix, i);
233            let task_results = self
234                .execute_definition(&task_name, task_def, all_tasks)
235                .await?;
236
237            // Check if any task failed
238            for result in &task_results {
239                if !result.success {
240                    return Err(Error::configuration(format!(
241                        "Task '{}' failed in sequential group",
242                        result.name
243                    )));
244                }
245            }
246
247            results.extend(task_results);
248        }
249
250        Ok(results)
251    }
252
253    /// Execute tasks in parallel
254    async fn execute_parallel(
255        &self,
256        prefix: &str,
257        tasks: &HashMap<String, TaskDefinition>,
258        all_tasks: &Tasks,
259    ) -> Result<Vec<TaskResult>> {
260        let mut join_set = JoinSet::new();
261        let all_tasks = Arc::new(all_tasks.clone());
262
263        for (name, task_def) in tasks {
264            let task_name = format!("{}.{}", prefix, name);
265            let task_def = task_def.clone();
266            let all_tasks = Arc::clone(&all_tasks);
267            let executor = self.clone_with_config();
268
269            join_set.spawn(async move {
270                executor
271                    .execute_definition(&task_name, &task_def, &all_tasks)
272                    .await
273            });
274
275            // Apply parallelism limit if configured
276            if self.config.max_parallel > 0 && join_set.len() >= self.config.max_parallel {
277                // Wait for one to complete before starting more
278                if let Some(result) = join_set.join_next().await {
279                    match result {
280                        Ok(Ok(_)) => {} // Task completed successfully, continue
281                        Ok(Err(e)) => return Err(e),
282                        Err(e) => {
283                            return Err(Error::configuration(format!(
284                                "Task execution panicked: {}",
285                                e
286                            )));
287                        }
288                    }
289                }
290            }
291        }
292
293        // Wait for all remaining tasks
294        let mut all_results = Vec::new();
295        while let Some(result) = join_set.join_next().await {
296            match result {
297                Ok(Ok(results)) => all_results.extend(results),
298                Ok(Err(e)) => return Err(e),
299                Err(e) => {
300                    return Err(Error::configuration(format!(
301                        "Task execution panicked: {}",
302                        e
303                    )));
304                }
305            }
306        }
307
308        Ok(all_results)
309    }
310
311    /// Execute tasks using a task graph (respects dependencies)
312    pub async fn execute_graph(&self, graph: &TaskGraph) -> Result<Vec<TaskResult>> {
313        let parallel_groups = graph.get_parallel_groups()?;
314        let mut all_results = Vec::new();
315
316        // Use a single JoinSet for all groups to enforce global parallelism limit
317        let mut join_set = JoinSet::new();
318        let mut group_iter = parallel_groups.into_iter();
319        let mut current_group = group_iter.next();
320
321        while current_group.is_some() || !join_set.is_empty() {
322            // Start tasks from current group up to parallelism limit
323            if let Some(group) = current_group.as_mut() {
324                while let Some(node) = group.pop() {
325                    let task = node.task.clone();
326                    let name = node.name.clone();
327                    let executor = self.clone_with_config();
328
329                    join_set.spawn(async move { executor.execute_task(&name, &task).await });
330
331                    // Apply parallelism limit if configured
332                    if self.config.max_parallel > 0 && join_set.len() >= self.config.max_parallel {
333                        break;
334                    }
335                }
336
337                // Move to next group if current group is empty
338                if group.is_empty() {
339                    current_group = group_iter.next();
340                }
341            }
342
343            // Wait for at least one task to complete
344            if let Some(result) = join_set.join_next().await {
345                match result {
346                    Ok(Ok(task_result)) => {
347                        if !task_result.success {
348                            return Err(Error::configuration(format!(
349                                "Task '{}' failed",
350                                task_result.name
351                            )));
352                        }
353                        all_results.push(task_result);
354                    }
355                    Ok(Err(e)) => return Err(e),
356                    Err(e) => {
357                        return Err(Error::configuration(format!(
358                            "Task execution panicked: {}",
359                            e
360                        )));
361                    }
362                }
363            }
364        }
365
366        Ok(all_results)
367    }
368
369    /// Clone executor with same config (for parallel execution)
370    fn clone_with_config(&self) -> Self {
371        Self {
372            config: self.config.clone(),
373        }
374    }
375}
376
377/// Execute an arbitrary command with the cuenv environment
378pub async fn execute_command(
379    command: &str,
380    args: &[String],
381    environment: &Environment,
382) -> Result<i32> {
383    tracing::info!("Executing command: {} {:?}", command, args);
384
385    let mut cmd = Command::new(command);
386    cmd.args(args);
387
388    // Set environment variables
389    let env_vars = environment.merge_with_system();
390    for (key, value) in env_vars {
391        cmd.env(key, value);
392    }
393
394    // Inherit stdio for interactive commands
395    cmd.stdout(Stdio::inherit());
396    cmd.stderr(Stdio::inherit());
397    cmd.stdin(Stdio::inherit());
398
399    // Execute and wait
400    let status = cmd.status().await.map_err(|e| {
401        Error::configuration(format!("Failed to execute command '{}': {}", command, e))
402    })?;
403
404    Ok(status.code().unwrap_or(1))
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410
411    #[tokio::test]
412    async fn test_executor_config_default() {
413        let config = ExecutorConfig::default();
414        assert!(!config.capture_output);
415        assert_eq!(config.max_parallel, 0);
416        assert!(config.environment.is_empty());
417    }
418
419    #[tokio::test]
420    async fn test_task_result() {
421        let result = TaskResult {
422            name: "test".to_string(),
423            exit_code: Some(0),
424            stdout: "output".to_string(),
425            stderr: String::new(),
426            success: true,
427        };
428
429        assert_eq!(result.name, "test");
430        assert_eq!(result.exit_code, Some(0));
431        assert!(result.success);
432        assert_eq!(result.stdout, "output");
433    }
434
435    #[tokio::test]
436    async fn test_execute_simple_task() {
437        let config = ExecutorConfig {
438            capture_output: true,
439            ..Default::default()
440        };
441
442        let executor = TaskExecutor::new(config);
443
444        let task = Task {
445            command: "echo".to_string(),
446            args: vec!["hello".to_string()],
447            shell: None,
448            env: HashMap::new(),
449            depends_on: vec![],
450            inputs: vec![],
451            outputs: vec![],
452            description: Some("Hello task".to_string()),
453        };
454
455        let result = executor.execute_task("test", &task).await.unwrap();
456
457        assert!(result.success);
458        assert_eq!(result.exit_code, Some(0));
459        assert!(result.stdout.contains("hello"));
460    }
461
462    #[tokio::test]
463    async fn test_execute_with_environment() {
464        let mut config = ExecutorConfig {
465            capture_output: true,
466            ..Default::default()
467        };
468        config
469            .environment
470            .set("TEST_VAR".to_string(), "test_value".to_string());
471
472        let executor = TaskExecutor::new(config);
473
474        let task = Task {
475            command: "printenv".to_string(),
476            args: vec!["TEST_VAR".to_string()],
477            shell: None,
478            env: HashMap::new(),
479            depends_on: vec![],
480            inputs: vec![],
481            outputs: vec![],
482            description: Some("Print env task".to_string()),
483        };
484
485        let result = executor.execute_task("test", &task).await.unwrap();
486
487        assert!(result.success);
488        assert!(result.stdout.contains("test_value"));
489    }
490
491    #[tokio::test]
492    async fn test_execute_failing_task() {
493        let config = ExecutorConfig {
494            capture_output: true,
495            ..Default::default()
496        };
497
498        let executor = TaskExecutor::new(config);
499
500        let task = Task {
501            command: "false".to_string(),
502            args: vec![],
503            shell: None,
504            env: HashMap::new(),
505            depends_on: vec![],
506            inputs: vec![],
507            outputs: vec![],
508            description: Some("Failing task".to_string()),
509        };
510
511        let result = executor.execute_task("test", &task).await.unwrap();
512
513        assert!(!result.success);
514        assert_eq!(result.exit_code, Some(1));
515    }
516
517    #[tokio::test]
518    async fn test_execute_sequential_group() {
519        let config = ExecutorConfig {
520            capture_output: true,
521            ..Default::default()
522        };
523
524        let executor = TaskExecutor::new(config);
525
526        let task1 = Task {
527            command: "echo".to_string(),
528            args: vec!["first".to_string()],
529            shell: None,
530            env: HashMap::new(),
531            depends_on: vec![],
532            inputs: vec![],
533            outputs: vec![],
534            description: Some("First task".to_string()),
535        };
536
537        let task2 = Task {
538            command: "echo".to_string(),
539            args: vec!["second".to_string()],
540            shell: None,
541            env: HashMap::new(),
542            depends_on: vec![],
543            inputs: vec![],
544            outputs: vec![],
545            description: Some("Second task".to_string()),
546        };
547
548        let group = TaskGroup::Sequential(vec![
549            TaskDefinition::Single(task1),
550            TaskDefinition::Single(task2),
551        ]);
552
553        let all_tasks = Tasks::new();
554        let results = executor
555            .execute_group("seq", &group, &all_tasks)
556            .await
557            .unwrap();
558
559        assert_eq!(results.len(), 2);
560        assert!(results[0].stdout.contains("first"));
561        assert!(results[1].stdout.contains("second"));
562    }
563
564    #[tokio::test]
565    async fn test_command_injection_prevention() {
566        let config = ExecutorConfig {
567            capture_output: true,
568            ..Default::default()
569        };
570
571        let executor = TaskExecutor::new(config);
572
573        // Test that malicious shell metacharacters in arguments don't get executed
574        let malicious_task = Task {
575            command: "echo".to_string(),
576            args: vec!["hello".to_string(), "; rm -rf /".to_string()],
577            shell: None,
578            env: HashMap::new(),
579            depends_on: vec![],
580            inputs: vec![],
581            outputs: vec![],
582            description: Some("Malicious task test".to_string()),
583        };
584
585        let result = executor
586            .execute_task("malicious", &malicious_task)
587            .await
588            .unwrap();
589
590        // The malicious command should be treated as literal argument to echo
591        assert!(result.success);
592        assert!(result.stdout.contains("hello ; rm -rf /"));
593    }
594
595    #[tokio::test]
596    async fn test_special_characters_in_args() {
597        let config = ExecutorConfig {
598            capture_output: true,
599            ..Default::default()
600        };
601
602        let executor = TaskExecutor::new(config);
603
604        // Test various special characters that could be used for injection
605        let special_chars = vec![
606            "$USER",          // Variable expansion
607            "$(whoami)",      // Command substitution
608            "`whoami`",       // Backtick command substitution
609            "&& echo hacked", // Command chaining
610            "|| echo failed", // Error chaining
611            "> /tmp/hack",    // Redirection
612            "| cat",          // Piping
613        ];
614
615        for special_arg in special_chars {
616            let task = Task {
617                command: "echo".to_string(),
618                args: vec!["safe".to_string(), special_arg.to_string()],
619                shell: None,
620                env: HashMap::new(),
621                depends_on: vec![],
622                inputs: vec![],
623                outputs: vec![],
624                description: Some("Special character test".to_string()),
625            };
626
627            let result = executor.execute_task("special", &task).await.unwrap();
628
629            // Special characters should be treated literally, not interpreted
630            assert!(result.success);
631            assert!(result.stdout.contains("safe"));
632            assert!(result.stdout.contains(special_arg));
633        }
634    }
635
636    #[tokio::test]
637    async fn test_environment_variable_safety() {
638        let mut config = ExecutorConfig {
639            capture_output: true,
640            ..Default::default()
641        };
642
643        // Set environment variable with potentially dangerous value
644        config
645            .environment
646            .set("DANGEROUS_VAR".to_string(), "; rm -rf /".to_string());
647
648        let executor = TaskExecutor::new(config);
649
650        let task = Task {
651            command: "printenv".to_string(),
652            args: vec!["DANGEROUS_VAR".to_string()],
653            shell: None,
654            env: HashMap::new(),
655            depends_on: vec![],
656            inputs: vec![],
657            outputs: vec![],
658            description: Some("Environment variable safety test".to_string()),
659        };
660
661        let result = executor.execute_task("env_test", &task).await.unwrap();
662
663        // Environment variable should be passed safely
664        assert!(result.success);
665        assert!(result.stdout.contains("; rm -rf /"));
666    }
667}