use crate::utils::error::{Error, Result};
use std::path::{Path, PathBuf};
#[derive(Debug, thiserror::Error)]
pub enum ValidationError {
#[error("Invalid path: {0}")]
InvalidPath(String),
#[error("Path traversal detected: {0}")]
PathTraversal(String),
#[error("Invalid environment variable: {0}")]
InvalidEnvVar(String),
#[error("Input too long: {0} (max: {1})")]
TooLong(usize, usize),
#[error("Invalid characters: {0}")]
InvalidCharacters(String),
#[error("Empty input")]
EmptyInput,
}
impl From<ValidationError> for Error {
fn from(err: ValidationError) -> Self {
Error::new(&err.to_string())
}
}
pub struct PathValidator;
impl PathValidator {
const DANGEROUS_COMPONENTS: &'static [&'static str] = &[
"..", "~", "$", "`", "|", ";", "&", "<", ">", "(", ")", "{", "}", "\n", "\r",
];
const MAX_PATH_LENGTH: usize = 4096;
pub fn validate(path: &Path) -> Result<PathBuf> {
let path_str = path.to_string_lossy();
if path_str.len() > Self::MAX_PATH_LENGTH {
return Err(ValidationError::TooLong(path_str.len(), Self::MAX_PATH_LENGTH).into());
}
for component in path.components() {
let component_str = component.as_os_str().to_string_lossy();
for dangerous in Self::DANGEROUS_COMPONENTS {
if component_str.contains(dangerous) {
return Err(ValidationError::PathTraversal(format!(
"Path contains dangerous component: {}",
component_str
))
.into());
}
}
}
Ok(path.to_path_buf())
}
pub fn validate_within(path: &Path, base: &Path) -> Result<PathBuf> {
let validated = Self::validate(path)?;
let abs_path = if validated.is_relative() {
base.join(&validated)
} else {
validated.clone()
};
if !abs_path.starts_with(base) {
return Err(ValidationError::PathTraversal(format!(
"Path escapes base directory: {} not in {}",
abs_path.display(),
base.display()
))
.into());
}
Ok(validated)
}
pub fn validate_extension(path: &Path, allowed: &[&str]) -> Result<()> {
let ext = path
.extension()
.and_then(|e| e.to_str())
.ok_or_else(|| ValidationError::InvalidPath("No file extension".to_string()))?;
if !allowed.contains(&ext) {
return Err(ValidationError::InvalidPath(format!(
"Extension '{}' not in allowed list",
ext
))
.into());
}
Ok(())
}
}
pub struct EnvVarValidator;
impl EnvVarValidator {
const DANGEROUS_CHARS: &'static [char] = &[
';', '|', '&', '$', '`', '\n', '\r', '<', '>', '(', ')', '{', '}', '\\',
];
const MAX_ENV_LENGTH: usize = 32768;
pub fn validate_name(name: &str) -> Result<String> {
if name.is_empty() {
return Err(ValidationError::EmptyInput.into());
}
if name.len() > Self::MAX_ENV_LENGTH {
return Err(ValidationError::TooLong(name.len(), Self::MAX_ENV_LENGTH).into());
}
if name.chars().any(|c| Self::DANGEROUS_CHARS.contains(&c)) {
return Err(ValidationError::InvalidCharacters(format!(
"Environment variable name contains dangerous characters: {}",
name
))
.into());
}
if !name.chars().all(|c| c.is_alphanumeric() || c == '_') {
return Err(ValidationError::InvalidCharacters(format!(
"Environment variable name must be alphanumeric: {}",
name
))
.into());
}
Ok(name.to_string())
}
pub fn validate_value(value: &str) -> Result<String> {
if value.len() > Self::MAX_ENV_LENGTH {
return Err(ValidationError::TooLong(value.len(), Self::MAX_ENV_LENGTH).into());
}
if value.chars().any(|c| Self::DANGEROUS_CHARS.contains(&c)) {
return Err(ValidationError::InvalidCharacters(
"Environment variable value contains dangerous characters".to_string(),
)
.into());
}
Ok(value.to_string())
}
}
pub struct InputValidator;
impl InputValidator {
pub fn validate_string(
input: &str, max_length: usize, allowed_chars: fn(char) -> bool,
) -> Result<String> {
if input.is_empty() {
return Err(ValidationError::EmptyInput.into());
}
if input.len() > max_length {
return Err(ValidationError::TooLong(input.len(), max_length).into());
}
if !input.chars().all(allowed_chars) {
return Err(ValidationError::InvalidCharacters(
"Input contains invalid characters".to_string(),
)
.into());
}
Ok(input.to_string())
}
pub fn validate_identifier(input: &str) -> Result<String> {
Self::validate_string(input, 256, |c| c.is_alphanumeric() || c == '_' || c == '-')
}
pub fn validate_template_name(input: &str) -> Result<String> {
Self::validate_string(input, 256, |c| {
c.is_alphanumeric() || c == '_' || c == '-' || c == '.' || c == '/'
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_path_traversal_detection() {
assert!(PathValidator::validate(Path::new("../../../etc/passwd")).is_err());
assert!(PathValidator::validate(Path::new("../../.ssh/id_rsa")).is_err());
assert!(PathValidator::validate(Path::new("..\\..\\windows\\system32")).is_err());
assert!(PathValidator::validate(Path::new("src/main.rs")).is_ok());
assert!(PathValidator::validate(Path::new("templates/rust.tmpl")).is_ok());
}
#[test]
fn test_path_length_validation() {
let long_path = "a/".repeat(3000);
assert!(PathValidator::validate(Path::new(&long_path)).is_err());
assert!(PathValidator::validate(Path::new("src/lib.rs")).is_ok());
}
#[test]
fn test_env_var_name_validation() {
assert!(EnvVarValidator::validate_name("PATH").is_ok());
assert!(EnvVarValidator::validate_name("MY_VAR_123").is_ok());
assert!(EnvVarValidator::validate_name("").is_err());
assert!(EnvVarValidator::validate_name("VAR; rm -rf /").is_err());
assert!(EnvVarValidator::validate_name("VAR|cat").is_err());
assert!(EnvVarValidator::validate_name("$(whoami)").is_err());
}
#[test]
fn test_env_var_value_validation() {
assert!(EnvVarValidator::validate_value("value").is_ok());
assert!(EnvVarValidator::validate_value("/usr/bin").is_ok());
assert!(EnvVarValidator::validate_value("value; rm -rf /").is_err());
assert!(EnvVarValidator::validate_value("$(whoami)").is_err());
assert!(EnvVarValidator::validate_value("`whoami`").is_err());
}
#[test]
fn test_identifier_validation() {
assert!(InputValidator::validate_identifier("my_var").is_ok());
assert!(InputValidator::validate_identifier("my-var-123").is_ok());
assert!(InputValidator::validate_identifier("").is_err());
assert!(InputValidator::validate_identifier("my var").is_err());
assert!(InputValidator::validate_identifier("my;var").is_err());
}
#[test]
fn test_template_name_validation() {
assert!(InputValidator::validate_template_name("rust-cli").is_ok());
assert!(InputValidator::validate_template_name("templates/rust.tmpl").is_ok());
assert!(InputValidator::validate_template_name("").is_err());
assert!(InputValidator::validate_template_name("template; rm -rf /").is_err());
}
}