use crate::errors::{ExecutionError, ValidationError};
use crate::types::{Command, ExecutionRequest};
use std::path::{Path, PathBuf};
use tokio::process::Command as TokioCommand;
pub fn build_command(request: &ExecutionRequest) -> Result<TokioCommand, ExecutionError> {
let mut cmd = match &request.command {
Command::Shell { command, shell } => {
validate_shell_command(command)?;
build_shell_command(command, shell)
}
Command::Exec { program, args } => {
validate_exec_command(program)?;
build_exec_command(program, args)
}
Command::Script { path, interpreter } => {
validate_script_path(path)?;
build_script_command(path, interpreter)?
}
Command::AwsCli {
service,
operation,
args,
profile,
region,
} => {
validate_aws_cli(service, operation)?;
build_aws_cli_command(
service,
operation,
args,
profile.as_deref(),
region.as_deref(),
)
}
};
for (key, value) in &request.env {
cmd.env(key, value);
}
if let Some(working_dir) = &request.working_dir {
validate_working_directory(working_dir)?;
cmd.current_dir(working_dir);
}
cmd.stdout(std::process::Stdio::piped());
cmd.stderr(std::process::Stdio::piped());
cmd.kill_on_drop(true);
Ok(cmd)
}
fn build_shell_command(command: &str, shell: &str) -> TokioCommand {
let mut cmd = TokioCommand::new(shell);
if cfg!(target_os = "windows") {
if shell == "powershell" {
cmd.args(["-Command", command]);
} else {
cmd.args(["/C", command]);
}
} else {
cmd.args(["-c", command]);
}
cmd
}
fn build_exec_command(program: &str, args: &[String]) -> TokioCommand {
let mut cmd = TokioCommand::new(program);
cmd.args(args);
cmd
}
fn build_script_command(
path: &PathBuf,
interpreter: &Option<String>,
) -> Result<TokioCommand, ExecutionError> {
if let Some(interp) = interpreter {
let mut cmd = TokioCommand::new(interp);
cmd.arg(path);
Ok(cmd)
} else {
#[cfg(unix)]
{
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::os::unix::fs::PermissionsExt;
let metadata = std::fs::metadata(path).map_err(ExecutionError::Io)?;
let permissions = metadata.permissions();
if permissions.mode() & 0o111 != 0 {
Ok(TokioCommand::new(path))
} else {
let file = File::open(path).map_err(ExecutionError::Io)?;
let mut reader = BufReader::new(file);
let mut first_line = String::new();
reader
.read_line(&mut first_line)
.map_err(ExecutionError::Io)?;
if first_line.starts_with("#!") {
let interp_path = first_line.trim_start_matches("#!").trim();
let mut cmd = TokioCommand::new(interp_path);
cmd.arg(path);
Ok(cmd)
} else {
Err(ValidationError::ScriptNotExecutable(path.clone()).into())
}
}
}
#[cfg(not(unix))]
{
Ok(TokioCommand::new(path))
}
}
}
fn build_aws_cli_command(
service: &str,
operation: &str,
args: &[String],
profile: Option<&str>,
region: Option<&str>,
) -> TokioCommand {
let mut cmd = TokioCommand::new("aws");
cmd.arg(service);
cmd.arg(operation);
if let Some(prof) = profile {
cmd.arg("--profile");
cmd.arg(prof);
}
if let Some(reg) = region {
cmd.arg("--region");
cmd.arg(reg);
}
cmd.args(args);
cmd.arg("--output");
cmd.arg("json");
cmd
}
fn validate_shell_command(command: &str) -> Result<(), ValidationError> {
if command.trim().is_empty() {
return Err(ValidationError::EmptyCommand);
}
Ok(())
}
fn validate_exec_command(program: &str) -> Result<(), ValidationError> {
if program.trim().is_empty() {
return Err(ValidationError::EmptyCommand);
}
Ok(())
}
fn validate_script_path(path: &PathBuf) -> Result<(), ValidationError> {
if !path.exists() {
return Err(ValidationError::ScriptNotFound(path.clone()));
}
if !path.is_file() {
return Err(ValidationError::InvalidCommand(format!(
"Script path is not a file: {path:?}"
)));
}
Ok(())
}
fn validate_aws_cli(service: &str, operation: &str) -> Result<(), ValidationError> {
if service.trim().is_empty() {
return Err(ValidationError::MissingField("service".to_string()));
}
if operation.trim().is_empty() {
return Err(ValidationError::MissingField("operation".to_string()));
}
Ok(())
}
fn validate_working_directory(path: &Path) -> Result<(), ValidationError> {
if !path.exists() {
return Err(ValidationError::WorkingDirNotFound(path.to_path_buf()));
}
if !path.is_dir() {
return Err(ValidationError::InvalidWorkingDir(path.to_path_buf()));
}
Ok(())
}
pub fn command_to_string(cmd: &Command) -> String {
match cmd {
Command::Shell { command, shell } => {
format!("{shell} -c '{command}'")
}
Command::Exec { program, args } => {
format!("{} {}", program, args.join(" "))
}
Command::Script { path, interpreter } => {
if let Some(interp) = interpreter {
format!("{interp} {path:?}")
} else {
format!("{path:?}")
}
}
Command::AwsCli {
service,
operation,
args,
profile,
region,
} => {
let mut parts = vec!["aws".to_string(), service.clone(), operation.clone()];
if let Some(prof) = profile {
parts.push(format!("--profile {prof}"));
}
if let Some(reg) = region {
parts.push(format!("--region {reg}"));
}
parts.extend(args.clone());
parts.join(" ")
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::collections::HashMap;
use uuid::Uuid;
fn create_test_request(command: Command) -> ExecutionRequest {
ExecutionRequest {
id: Uuid::new_v4(),
command,
env: HashMap::new(),
working_dir: None,
timeout_ms: None,
output_log_path: None,
metadata: Default::default(),
}
}
#[test]
fn test_build_shell_command() {
let request = create_test_request(Command::Shell {
command: "echo hello".to_string(),
shell: "bash".to_string(),
});
let cmd = build_command(&request).unwrap();
let program = cmd.as_std().get_program();
#[cfg(unix)]
assert_eq!(program, "bash");
#[cfg(windows)]
assert!(program.to_str().unwrap().contains("bash"));
}
#[test]
fn test_build_exec_command() {
let request = create_test_request(Command::Exec {
program: "ls".to_string(),
args: vec!["-la".to_string()],
});
let cmd = build_command(&request).unwrap();
let program = cmd.as_std().get_program();
assert_eq!(program, "ls");
}
#[test]
fn test_build_aws_cli_command() {
let request = create_test_request(Command::AwsCli {
service: "ec2".to_string(),
operation: "describe-instances".to_string(),
args: vec!["--max-items".to_string(), "10".to_string()],
profile: Some("prod".to_string()),
region: Some("us-west-2".to_string()),
});
let cmd = build_command(&request).unwrap();
let program = cmd.as_std().get_program();
assert_eq!(program, "aws");
}
#[test]
fn test_command_with_env_vars() {
let mut env = HashMap::new();
env.insert("TEST_VAR".to_string(), "test_value".to_string());
let mut request = create_test_request(Command::Shell {
command: "echo $TEST_VAR".to_string(),
shell: "bash".to_string(),
});
request.env = env;
let cmd = build_command(&request).unwrap();
assert!(cmd.as_std().get_envs().any(|(k, _)| k == "TEST_VAR"));
}
#[test]
fn test_validate_empty_shell_command() {
let err = validate_shell_command(" ").unwrap_err();
assert!(matches!(err, ValidationError::EmptyCommand));
}
#[test]
fn test_validate_empty_program() {
let err = validate_exec_command("").unwrap_err();
assert!(matches!(err, ValidationError::EmptyCommand));
}
#[test]
fn test_validate_missing_script() {
let path = PathBuf::from("/nonexistent/script.sh");
let err = validate_script_path(&path).unwrap_err();
assert!(matches!(err, ValidationError::ScriptNotFound(_)));
}
#[test]
fn test_validate_aws_cli_missing_service() {
let err = validate_aws_cli("", "describe-instances").unwrap_err();
assert!(matches!(err, ValidationError::MissingField(_)));
}
#[test]
fn test_command_to_string_shell() {
let cmd = Command::Shell {
command: "ls -la".to_string(),
shell: "bash".to_string(),
};
let s = command_to_string(&cmd);
assert!(s.contains("bash"));
assert!(s.contains("ls -la"));
}
#[test]
fn test_command_to_string_exec() {
let cmd = Command::Exec {
program: "python3".to_string(),
args: vec!["script.py".to_string(), "--verbose".to_string()],
};
let s = command_to_string(&cmd);
assert!(s.contains("python3"));
assert!(s.contains("script.py"));
assert!(s.contains("--verbose"));
}
#[test]
fn test_command_to_string_aws_cli() {
let cmd = Command::AwsCli {
service: "s3".to_string(),
operation: "ls".to_string(),
args: vec![],
profile: Some("dev".to_string()),
region: Some("us-east-1".to_string()),
};
let s = command_to_string(&cmd);
assert!(s.contains("aws"));
assert!(s.contains("s3"));
assert!(s.contains("ls"));
assert!(s.contains("--profile dev"));
assert!(s.contains("--region us-east-1"));
}
}