use anyhow::{anyhow, Result};
use std::path::PathBuf;
pub fn validate_file_path(path: &str, allowed_base: Option<&str>) -> Result<PathBuf> {
let path_buf = PathBuf::from(path);
if let Some(base) = allowed_base {
let canon_path = std::fs::canonicalize(&path_buf)
.map_err(|e| anyhow!("Failed to canonicalize path '{}': {}", path, e))?;
let canon_base = std::fs::canonicalize(base)
.map_err(|e| anyhow!("Failed to canonicalize base '{}': {}", base, e))?;
if !canon_path.starts_with(canon_base) {
return Err(anyhow!(
"Path '{}' is outside allowed directory '{}'",
path,
base
));
}
Ok(canon_path)
} else {
std::fs::canonicalize(&path_buf).map_err(|e| anyhow!("Invalid path '{}': {}", path, e))
}
}
pub fn validate_command(command: &str, allowed_commands: &[&str]) -> Result<()> {
const DANGEROUS_CHARS: &[char] = &[
'|', '&', ';', '\n', '`', '$', '(', ')', '<', '>', '"', '\'', '\\',
];
if command.chars().any(|c| DANGEROUS_CHARS.contains(&c)) {
return Err(anyhow!(
"Command contains dangerous characters: '{}'",
command
));
}
let path_buf = PathBuf::from(command);
let cmd_name = path_buf
.file_name()
.and_then(|s| s.to_str())
.unwrap_or(command);
if !allowed_commands.is_empty() && !allowed_commands.contains(&cmd_name) {
return Err(anyhow!(
"Command '{}' is not in the allowed list. Permitted commands: {:?}",
cmd_name,
allowed_commands
));
}
Ok(())
}
pub fn validate_command_args(args: &[String]) -> Result<()> {
for arg in args {
if arg.contains("&&") || arg.contains("||") || arg.contains(";") || arg.contains("|") {
return Err(anyhow!(
"Argument contains dangerous shell operators: '{}'",
arg
));
}
if arg.contains("$(") || arg.contains("`") {
return Err(anyhow!("Argument contains command substitution: '{}'", arg));
}
}
Ok(())
}
pub fn validate_url_security(url: &str, require_tls: bool) -> Result<()> {
let url_lower = url.to_lowercase();
if require_tls {
if !(url_lower.starts_with("https://")
|| url_lower.starts_with("wss://")
|| url_lower.starts_with("grpcs://"))
{
return Err(anyhow!(
"URL must use TLS/SSL (https://, wss://, grpcs://): '{}'",
url
));
}
}
if url_lower.contains("localhost") || url_lower.contains("127.0.0.1") {
eprintln!("Warning: URL uses localhost/127.0.0.1: '{}'", url);
}
Ok(())
}
pub fn validate_size_limit(data: &[u8], max_size: usize) -> Result<()> {
if data.len() > max_size {
return Err(anyhow!(
"Data size {} bytes exceeds maximum allowed size {} bytes",
data.len(),
max_size
));
}
Ok(())
}
pub fn validate_timeout(timeout_ms: u64, max_timeout_ms: u64) -> Result<()> {
if timeout_ms == 0 {
return Err(anyhow!("Timeout cannot be zero"));
}
if timeout_ms > max_timeout_ms {
return Err(anyhow!(
"Timeout {}ms exceeds maximum allowed {}ms",
timeout_ms,
max_timeout_ms
));
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_validate_command_rejects_dangerous_chars() {
let allowed = vec!["python3", "node"];
assert!(validate_command("python3", &allowed).is_ok());
assert!(validate_command("ls; rm -rf /", &[]).is_err());
assert!(validate_command("cat /etc/passwd | grep root", &[]).is_err());
assert!(validate_command("echo `whoami`", &[]).is_err());
assert!(validate_command("cmd && evil", &[]).is_err());
}
#[test]
fn test_validate_command_allowlist() {
let allowed = vec!["python3", "node", "npm"];
assert!(validate_command("python3", &allowed).is_ok());
assert!(validate_command("node", &allowed).is_ok());
assert!(validate_command("bash", &allowed).is_err());
assert!(validate_command("/usr/bin/python3", &allowed).is_ok()); }
#[test]
fn test_validate_command_args() {
assert!(validate_command_args(&["--help".to_string()]).is_ok());
assert!(validate_command_args(&["-v".to_string(), "file.txt".to_string()]).is_ok());
assert!(validate_command_args(&["arg && evil".to_string()]).is_err());
assert!(validate_command_args(&["$(whoami)".to_string()]).is_err());
assert!(validate_command_args(&["`id`".to_string()]).is_err());
assert!(validate_command_args(&["arg | grep".to_string()]).is_err());
}
#[test]
fn test_validate_url_security() {
assert!(validate_url_security("https://api.example.com", true).is_ok());
assert!(validate_url_security("wss://ws.example.com", true).is_ok());
assert!(validate_url_security("http://api.example.com", true).is_err());
assert!(validate_url_security("http://api.example.com", false).is_ok());
}
#[test]
fn test_validate_size_limit() {
let data = vec![0u8; 1000];
assert!(validate_size_limit(&data, 2000).is_ok());
assert!(validate_size_limit(&data, 500).is_err());
}
#[test]
fn test_validate_timeout() {
assert!(validate_timeout(1000, 60000).is_ok());
assert!(validate_timeout(0, 60000).is_err());
assert!(validate_timeout(100000, 60000).is_err());
}
#[test]
fn test_validate_file_path() {
let temp_dir = TempDir::new().unwrap();
let temp_path = temp_dir.path();
let test_file = temp_path.join("test.txt");
fs::write(&test_file, b"test").unwrap();
let result = validate_file_path(
test_file.to_str().unwrap(),
Some(temp_path.to_str().unwrap()),
);
assert!(result.is_ok());
let outside_path = "/tmp/outside.txt";
let result = validate_file_path(outside_path, Some(temp_path.to_str().unwrap()));
assert!(result.is_err());
}
}