use crate::utils::error::{Error, Result};
use std::path::Path;
use std::process::{Command, Output};
#[derive(Debug, thiserror::Error)]
pub enum CommandError {
#[error("Invalid command: {0}")]
InvalidCommand(String),
#[error("Command not allowed: {0}")]
NotAllowed(String),
#[error("Command execution failed: {0}")]
ExecutionFailed(String),
#[error("Invalid argument: {0}")]
InvalidArgument(String),
#[error("Command output invalid UTF-8")]
InvalidUtf8,
}
impl From<CommandError> for Error {
fn from(err: CommandError) -> Self {
Error::new(&err.to_string())
}
}
#[derive(Clone)]
pub struct SafeCommand {
program: String,
args: Vec<String>,
current_dir: Option<String>,
}
impl SafeCommand {
const ALLOWED_COMMANDS: &'static [&'static str] =
&["git", "cargo", "npm", "node", "rustc", "rustup", "ls", "echo"];
const DANGEROUS_CHARS: &'static [char] = &[
';', '|', '&', '$', '`', '\n', '\r', '<', '>', '(', ')', '{', '}',
];
pub fn new(program: &str) -> Result<Self> {
if program.is_empty() {
return Err(CommandError::InvalidCommand("Empty command".to_string()).into());
}
if program.chars().any(|c| Self::DANGEROUS_CHARS.contains(&c)) {
return Err(CommandError::InvalidCommand(format!(
"Command contains dangerous characters: {}",
program
))
.into());
}
if !Self::ALLOWED_COMMANDS.contains(&program) {
return Err(CommandError::NotAllowed(format!(
"Command '{}' is not in allowed list",
program
))
.into());
}
Ok(Self {
program: program.to_string(),
args: Vec::new(),
current_dir: None,
})
}
pub fn arg(mut self, arg: &str) -> Result<Self> {
if arg.chars().any(|c| Self::DANGEROUS_CHARS.contains(&c)) {
return Err(CommandError::InvalidArgument(format!(
"Argument contains dangerous characters: {}",
arg
))
.into());
}
self.args.push(arg.to_string());
Ok(self)
}
pub fn args<I, S>(mut self, args: I) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for arg in args {
self = self.arg(arg.as_ref())?;
}
Ok(self)
}
pub fn current_dir(mut self, dir: &Path) -> Result<Self> {
if !dir.exists() {
return Err(Error::new(&format!(
"Directory does not exist: {}",
dir.display()
)));
}
if !dir.is_dir() {
return Err(Error::new(&format!(
"Path is not a directory: {}",
dir.display()
)));
}
self.current_dir = Some(dir.to_string_lossy().to_string());
Ok(self)
}
pub fn execute(self) -> Result<Output> {
let mut cmd = Command::new(&self.program);
for arg in &self.args {
cmd.arg(arg);
}
if let Some(dir) = &self.current_dir {
cmd.current_dir(dir);
}
let output = cmd
.output()
.map_err(|e| CommandError::ExecutionFailed(format!("{}: {}", self.program, e)))?;
Ok(output)
}
pub fn execute_stdout(self) -> Result<String> {
let program = self.program.clone();
let output = self.execute()?;
if !output.status.success() {
let stderr = String::from_utf8_lossy(&output.stderr);
return Err(CommandError::ExecutionFailed(format!("{}: {}", program, stderr)).into());
}
String::from_utf8(output.stdout).map_err(|_| CommandError::InvalidUtf8.into())
}
}
pub struct CommandExecutor;
impl CommandExecutor {
pub fn git(args: &[&str]) -> Result<Output> {
let mut cmd = SafeCommand::new("git")?;
for arg in args {
cmd = cmd.arg(arg)?;
}
cmd.execute()
}
pub fn cargo(args: &[&str]) -> Result<Output> {
let mut cmd = SafeCommand::new("cargo")?;
for arg in args {
cmd = cmd.arg(arg)?;
}
cmd.execute()
}
pub fn npm(args: &[&str]) -> Result<Output> {
let mut cmd = SafeCommand::new("npm")?;
for arg in args {
cmd = cmd.arg(arg)?;
}
cmd.execute()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_safe_command_new_validates_whitelist() {
assert!(SafeCommand::new("git").is_ok());
assert!(SafeCommand::new("cargo").is_ok());
assert!(SafeCommand::new("rm").is_err());
assert!(SafeCommand::new("sh").is_err());
}
#[test]
fn test_safe_command_rejects_dangerous_chars() {
assert!(SafeCommand::new("git; rm -rf /").is_err());
assert!(SafeCommand::new("git | cat").is_err());
assert!(SafeCommand::new("git && ls").is_err());
let cmd1 = SafeCommand::new("git").unwrap();
assert!(cmd1.arg("init; rm -rf /").is_err());
let cmd2 = SafeCommand::new("git").unwrap();
assert!(cmd2.arg("init | cat").is_err());
}
#[test]
fn test_safe_command_arg_validation() {
let cmd = SafeCommand::new("git").unwrap();
assert!(cmd.clone().arg("init").is_ok());
assert!(cmd.clone().arg("status").is_ok());
assert!(cmd.clone().arg("init; ls").is_err());
assert!(cmd.clone().arg("$(whoami)").is_err());
assert!(cmd.clone().arg("`whoami`").is_err());
}
#[test]
fn test_command_injection_prevention() {
let result = SafeCommand::new("git")
.unwrap()
.arg("init")
.unwrap()
.arg("; rm -rf /");
assert!(result.is_err());
let result = SafeCommand::new("git; whoami");
assert!(result.is_err());
}
#[test]
fn test_executor_git() {
let result = CommandExecutor::git(&["--version"]);
let _ = result;
}
}