use crate::utils::error::{Error, Result};
use crate::utils::safe_path::SafePath;
use std::convert::TryFrom;
use std::fmt;
use std::process::Command;
const MAX_COMMAND_LENGTH: usize = 4096;
const SHELL_METACHARACTERS: &[char] = &[';', '|', '&', '>', '<', '$', '`', '\n', '\r'];
const ALLOWED_COMMANDS: &[&str] = &[
"cargo",
"git",
"npm",
"rustc",
"rustfmt",
"clippy-driver",
"timeout",
"make",
"cmake",
"sh", "bash", "ggen", ];
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandName {
inner: String,
}
impl CommandName {
pub fn new<S: AsRef<str>>(name: S) -> Result<Self> {
let name = name.as_ref();
if name.is_empty() {
return Err(Error::invalid_input("Command name cannot be empty"));
}
if name.contains(char::is_whitespace) {
return Err(Error::invalid_input(
"Command name cannot contain whitespace",
));
}
if !ALLOWED_COMMANDS.contains(&name) {
return Err(Error::invalid_input(format!(
"Command '{}' is not in whitelist. Allowed commands: {:?}",
name, ALLOWED_COMMANDS
)));
}
Ok(Self {
inner: name.to_string(),
})
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.inner
}
#[must_use]
pub fn into_string(self) -> String {
self.inner
}
}
impl TryFrom<&str> for CommandName {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
Self::new(value)
}
}
impl TryFrom<String> for CommandName {
type Error = Error;
fn try_from(value: String) -> Result<Self> {
Self::new(value)
}
}
impl fmt::Display for CommandName {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct CommandArg {
inner: String,
}
impl CommandArg {
pub fn new<S: AsRef<str>>(arg: S) -> Result<Self> {
let arg = arg.as_ref();
for ch in SHELL_METACHARACTERS {
if arg.contains(*ch) {
return Err(Error::invalid_input(format!(
"Argument contains shell metacharacter '{}': {}",
ch, arg
)));
}
}
Ok(Self {
inner: arg.to_string(),
})
}
#[must_use]
pub fn from_path(path: &SafePath) -> Self {
Self {
inner: path.as_path().display().to_string(),
}
}
#[must_use]
pub fn as_str(&self) -> &str {
&self.inner
}
#[must_use]
pub fn into_string(self) -> String {
self.inner
}
}
impl TryFrom<&str> for CommandArg {
type Error = Error;
fn try_from(value: &str) -> Result<Self> {
Self::new(value)
}
}
impl TryFrom<String> for CommandArg {
type Error = Error;
fn try_from(value: String) -> Result<Self> {
Self::new(value)
}
}
impl fmt::Display for CommandArg {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.inner)
}
}
#[derive(Debug)]
pub struct Unvalidated;
#[derive(Debug, Clone)]
pub struct Validated;
#[derive(Debug, Clone)]
pub struct SafeCommand<State = Unvalidated> {
command: CommandName,
args: Vec<CommandArg>,
_state: std::marker::PhantomData<State>,
}
impl SafeCommand<Unvalidated> {
pub fn new<S: AsRef<str>>(command: S) -> Result<Self> {
let command = CommandName::new(command)?;
Ok(Self {
command,
args: Vec::new(),
_state: std::marker::PhantomData,
})
}
pub fn arg<S: AsRef<str>>(mut self, arg: S) -> Result<Self> {
let arg = CommandArg::new(arg)?;
self.args.push(arg);
Ok(self)
}
#[must_use]
pub fn arg_path(mut self, path: &SafePath) -> Self {
let arg = CommandArg::from_path(path);
self.args.push(arg);
self
}
pub fn args<I, S>(mut self, args: I) -> Result<Self>
where
I: IntoIterator<Item = S>,
S: AsRef<str>,
{
for arg in args {
let validated_arg = CommandArg::new(arg)?;
self.args.push(validated_arg);
}
Ok(self)
}
pub fn validate(self) -> Result<SafeCommand<Validated>> {
let total_length = self.total_length();
if total_length > MAX_COMMAND_LENGTH {
return Err(Error::invalid_input(format!(
"Command length {} exceeds maximum allowed length of {}",
total_length, MAX_COMMAND_LENGTH
)));
}
Ok(SafeCommand {
command: self.command,
args: self.args,
_state: std::marker::PhantomData,
})
}
fn total_length(&self) -> usize {
let mut length = self.command.as_str().len();
for arg in &self.args {
length += 1; length += arg.as_str().len();
}
length
}
}
impl SafeCommand<Validated> {
#[must_use]
pub fn into_command(self) -> Command {
let mut cmd = Command::new(self.command.as_str());
for arg in self.args {
cmd.arg(arg.as_str());
}
cmd
}
#[must_use]
pub fn command(&self) -> &CommandName {
&self.command
}
#[must_use]
pub fn args(&self) -> &[CommandArg] {
&self.args
}
#[must_use]
pub fn to_string_debug(&self) -> String {
let mut result = self.command.as_str().to_string();
for arg in &self.args {
result.push(' ');
result.push_str(arg.as_str());
}
result
}
}
impl fmt::Display for SafeCommand<Validated> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_string_debug())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_command_name_valid() {
let name = "cargo";
let result = CommandName::new(name);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "cargo");
}
#[test]
fn test_command_name_all_whitelisted() {
for &cmd in ALLOWED_COMMANDS {
let result = CommandName::new(cmd);
assert!(result.is_ok(), "Should allow whitelisted command: {}", cmd);
}
}
#[test]
fn test_command_name_not_whitelisted() {
let dangerous_commands = vec!["rm", "mv", "dd", "mkfs", "kill"];
for cmd in dangerous_commands {
let result = CommandName::new(cmd);
assert!(
result.is_err(),
"Should block non-whitelisted command: {}",
cmd
);
assert!(
result.unwrap_err().to_string().contains("not in whitelist"),
"Error should mention whitelist"
);
}
}
#[test]
fn test_command_name_empty() {
let name = "";
let result = CommandName::new(name);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("cannot be empty"));
}
#[test]
fn test_command_name_with_whitespace() {
let name = "cargo build";
let result = CommandName::new(name);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("whitespace"));
}
#[test]
fn test_command_name_try_from_str() {
let name = "git";
let result = CommandName::try_from(name);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "git");
}
#[test]
fn test_command_name_try_from_string() {
let name = String::from("npm");
let result = CommandName::try_from(name);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "npm");
}
#[test]
fn test_command_name_display() {
let name = CommandName::new("cargo").unwrap();
let display = format!("{}", name);
assert_eq!(display, "cargo");
}
#[test]
fn test_command_arg_valid() {
let arg = "build";
let result = CommandArg::new(arg);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "build");
}
#[test]
fn test_command_arg_with_dashes() {
let args = vec!["--release", "-v", "--all-features"];
for arg in args {
let result = CommandArg::new(arg);
assert!(result.is_ok(), "Should allow arg with dashes: {}", arg);
}
}
#[test]
fn test_command_arg_empty() {
let arg = "";
let result = CommandArg::new(arg);
assert!(result.is_ok());
}
#[test]
fn test_command_arg_shell_metacharacters() {
let attacks = vec![
("build; rm -rf /", ';'),
("build | tee output", '|'),
("build && rm -rf /", '&'),
("build > /dev/null", '>'),
("build < input", '<'),
("$(whoami)", '$'),
("`whoami`", '`'),
("build\nrm -rf /", '\n'),
("build\rrm -rf /", '\r'),
];
for (attack, metachar) in attacks {
let result = CommandArg::new(attack);
assert!(
result.is_err(),
"Should block shell metacharacter: {}",
metachar
);
assert!(
result.unwrap_err().to_string().contains("metacharacter"),
"Error should mention metacharacter"
);
}
}
#[test]
fn test_command_arg_from_path() {
let path = SafePath::new("src/generated").unwrap();
let arg = CommandArg::from_path(&path);
assert_eq!(arg.as_str(), "src/generated");
}
#[test]
fn test_command_arg_try_from() {
let arg_str = "--release";
let result = CommandArg::try_from(arg_str);
assert!(result.is_ok());
assert_eq!(result.unwrap().as_str(), "--release");
}
#[test]
fn test_safe_command_new() {
let cmd = "cargo";
let result = SafeCommand::new(cmd);
assert!(result.is_ok());
}
#[test]
fn test_safe_command_new_invalid() {
let cmd = "rm";
let result = SafeCommand::new(cmd);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not in whitelist"));
}
#[test]
fn test_safe_command_single_arg() {
let result = SafeCommand::new("cargo").unwrap().arg("build");
assert!(result.is_ok());
}
#[test]
fn test_safe_command_multiple_args() {
let result = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg("--release")
.unwrap()
.arg("--all-features");
assert!(result.is_ok());
}
#[test]
fn test_safe_command_args_bulk() {
let result =
SafeCommand::new("cargo")
.unwrap()
.args(["build", "--release", "--all-features"]);
assert!(result.is_ok());
}
#[test]
fn test_safe_command_arg_path() {
let path = SafePath::new("src/generated").unwrap();
let cmd = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg_path(&path);
let validated = cmd.validate();
assert!(validated.is_ok());
}
#[test]
fn test_safe_command_validate_success() {
let cmd = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg("--release")
.unwrap();
let result = cmd.validate();
assert!(result.is_ok());
}
#[test]
fn test_safe_command_into_command() {
let safe_cmd = SafeCommand::new("cargo")
.unwrap()
.arg("--version")
.unwrap()
.validate()
.unwrap();
let process_cmd = safe_cmd.into_command();
let _ = process_cmd;
}
#[test]
fn test_safe_command_to_string_debug() {
let cmd = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg("--release")
.unwrap()
.validate()
.unwrap();
let debug_string = cmd.to_string_debug();
assert_eq!(debug_string, "cargo build --release");
}
#[test]
fn test_safe_command_display() {
let cmd = SafeCommand::new("git")
.unwrap()
.arg("status")
.unwrap()
.validate()
.unwrap();
let display = format!("{}", cmd);
assert_eq!(display, "git status");
}
#[test]
fn test_safe_command_injection_in_arg() {
let result = SafeCommand::new("cargo").unwrap().arg("build; rm -rf /");
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("metacharacter"));
}
#[test]
fn test_safe_command_injection_in_multiple_args() {
let result = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg("--release && rm -rf /");
assert!(result.is_err());
}
#[test]
fn test_safe_command_max_length() {
let long_arg = "a".repeat(MAX_COMMAND_LENGTH);
let result = SafeCommand::new("cargo")
.unwrap()
.arg(&long_arg)
.unwrap()
.validate();
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("exceeds maximum"));
}
#[test]
fn test_safe_command_max_length_boundary() {
let arg_length = MAX_COMMAND_LENGTH - 6;
let exact_arg = "a".repeat(arg_length);
let result = SafeCommand::new("cargo")
.unwrap()
.arg(&exact_arg)
.unwrap()
.validate();
assert!(result.is_ok());
}
#[test]
fn test_safe_command_command_and_args_accessors() {
let cmd = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.arg("--release")
.unwrap()
.validate()
.unwrap();
let command = cmd.command();
let args = cmd.args();
assert_eq!(command.as_str(), "cargo");
assert_eq!(args.len(), 2);
assert_eq!(args[0].as_str(), "build");
assert_eq!(args[1].as_str(), "--release");
}
#[test]
fn test_safe_command_with_safe_path() {
let path = SafePath::new("src/generated/output.rs").unwrap();
let cmd = SafeCommand::new("rustfmt")
.unwrap()
.arg_path(&path)
.validate()
.unwrap();
let debug_string = cmd.to_string_debug();
assert!(debug_string.contains("src/generated/output.rs"));
}
#[test]
fn test_safe_command_multiple_paths() {
let path1 = SafePath::new("src/main.rs").unwrap();
let path2 = SafePath::new("src/lib.rs").unwrap();
let cmd = SafeCommand::new("rustfmt")
.unwrap()
.arg_path(&path1)
.arg_path(&path2)
.validate()
.unwrap();
let debug_string = cmd.to_string_debug();
assert!(debug_string.contains("src/main.rs"));
assert!(debug_string.contains("src/lib.rs"));
}
#[test]
fn test_safe_command_no_args() {
let result = SafeCommand::new("git").unwrap().validate();
assert!(result.is_ok());
}
#[test]
fn test_safe_command_many_small_args() {
let mut cmd = SafeCommand::new("cargo").unwrap();
for i in 0..100 {
cmd = cmd.arg(format!("arg{}", i)).unwrap();
}
let result = cmd.validate();
assert!(result.is_ok());
}
#[test]
fn test_command_name_clone() {
let name = CommandName::new("cargo").unwrap();
let cloned = name.clone();
assert_eq!(name, cloned);
}
#[test]
fn test_command_arg_clone() {
let arg = CommandArg::new("build").unwrap();
let cloned = arg.clone();
assert_eq!(arg, cloned);
}
#[test]
fn test_safe_command_clone() {
let cmd = SafeCommand::new("cargo")
.unwrap()
.arg("build")
.unwrap()
.validate()
.unwrap();
let cloned = cmd.clone();
assert_eq!(cmd.to_string_debug(), cloned.to_string_debug());
}
}