use crate::error::{CliError, CliResult};
use std::path::{Component, Path, PathBuf};
const MAX_FILENAME_LENGTH: usize = 255;
const RESERVED_FILENAMES: &[&str] = &[
".", "..", "con", "prn", "aux", "nul", "com1", "com2", "com3", "com4", "com5", "com6", "com7",
"com8", "com9", "lpt1", "lpt2", "lpt3", "lpt4", "lpt5", "lpt6", "lpt7", "lpt8", "lpt9",
];
pub fn validate_output_path(base_dir: &Path, requested_path: &str) -> CliResult<PathBuf> {
if requested_path.contains("..") {
return Err(CliError::SecurityViolation {
reason: format!("Path traversal detected: '{}'", requested_path),
details: "Paths containing '..' are not allowed for security reasons".to_string(),
});
}
let requested = PathBuf::from(requested_path);
if requested.is_absolute() {
return Err(CliError::SecurityViolation {
reason: format!("Absolute path not allowed: '{}'", requested_path),
details: "All output files must use relative paths within the output directory"
.to_string(),
});
}
for component in requested.components() {
if matches!(component, Component::ParentDir) {
return Err(CliError::SecurityViolation {
reason: format!("Path traversal detected: '{}'", requested_path),
details: "Paths containing '..' components are not allowed for security reasons"
.to_string(),
});
}
}
let full_path = base_dir.join(&requested);
let base_canonical = base_dir.canonicalize().map_err(CliError::Io)?;
if full_path.exists() {
let canonical = full_path.canonicalize().map_err(CliError::Io)?;
if !canonical.starts_with(&base_canonical) {
return Err(CliError::SecurityViolation {
reason: format!("Path escapes output directory: '{}'", canonical.display()),
details: format!(
"Resolved path '{}' is outside base directory '{}'",
canonical.display(),
base_canonical.display()
),
});
}
return Ok(canonical);
}
let relative_to_base =
full_path
.strip_prefix(base_dir)
.map_err(|_| CliError::SecurityViolation {
reason: "Internal error: path not relative to base".to_string(),
details: "Path validation failed unexpectedly".to_string(),
})?;
Ok(base_canonical.join(relative_to_base))
}
pub fn sanitize_filename(name: &str) -> CliResult<String> {
if name.is_empty() {
return Err(CliError::SecurityViolation {
reason: "Empty filename".to_string(),
details: "Filename cannot be empty".to_string(),
});
}
let sanitized: String = name
.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_' || *c == '.')
.collect();
if sanitized.is_empty() {
return Err(CliError::SecurityViolation {
reason: format!("Invalid filename: '{}'", name),
details: "Filename must contain at least one alphanumeric character".to_string(),
});
}
if sanitized.contains("..") {
return Err(CliError::SecurityViolation {
reason: format!("Invalid filename pattern: '{}'", sanitized),
details: "Filenames containing '..' patterns are not allowed".to_string(),
});
}
if sanitized.len() > MAX_FILENAME_LENGTH {
return Err(CliError::SecurityViolation {
reason: format!("Filename too long: {} characters", sanitized.len()),
details: format!(
"Filename must be at most {} characters",
MAX_FILENAME_LENGTH
),
});
}
let lower = sanitized.to_lowercase();
if RESERVED_FILENAMES.contains(&lower.as_str()) {
return Err(CliError::SecurityViolation {
reason: format!("Reserved filename: '{}'", sanitized),
details: "This filename is reserved by the operating system".to_string(),
});
}
if sanitized.starts_with('.') && sanitized.len() <= 2 {
return Err(CliError::SecurityViolation {
reason: format!("Invalid filename: '{}'", sanitized),
details: "Filenames starting with '.' are not allowed".to_string(),
});
}
Ok(sanitized)
}
pub fn safe_output_path(base_dir: &Path, name: &str, extension: &str) -> CliResult<PathBuf> {
let sanitized = sanitize_filename(name)?;
let filename = if extension.is_empty() {
sanitized
} else {
format!("{}.{}", sanitized, extension)
};
validate_output_path(base_dir, &filename)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
#[test]
fn test_sanitize_valid_filenames() {
assert_eq!(sanitize_filename("my_tool").unwrap(), "my_tool");
assert_eq!(sanitize_filename("tool-123").unwrap(), "tool-123");
assert_eq!(sanitize_filename("tool.v1").unwrap(), "tool.v1");
assert_eq!(sanitize_filename("Tool_Name_123").unwrap(), "Tool_Name_123");
}
#[test]
fn test_sanitize_removes_unsafe_chars() {
assert_eq!(sanitize_filename("my/tool").unwrap(), "mytool");
assert_eq!(sanitize_filename("my\\tool").unwrap(), "mytool");
assert_eq!(sanitize_filename("tool:name").unwrap(), "toolname");
assert_eq!(sanitize_filename("tool*name").unwrap(), "toolname");
}
#[test]
fn test_sanitize_rejects_reserved_names() {
assert!(sanitize_filename(".").is_err());
assert!(sanitize_filename("..").is_err());
assert!(sanitize_filename("con").is_err());
assert!(sanitize_filename("CON").is_err());
assert!(sanitize_filename("prn").is_err());
assert!(sanitize_filename("aux").is_err());
assert!(sanitize_filename("nul").is_err());
assert!(sanitize_filename("com1").is_err());
assert!(sanitize_filename("lpt1").is_err());
}
#[test]
fn test_sanitize_rejects_empty() {
assert!(sanitize_filename("").is_err());
assert!(sanitize_filename("///").is_err()); assert!(sanitize_filename("***").is_err()); }
#[test]
fn test_validate_accepts_relative_paths() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
let result = validate_output_path(base, "tool.json");
assert!(result.is_ok());
fs::create_dir_all(base.join("subdir")).unwrap();
let result = validate_output_path(base, "subdir/tool.json");
assert!(result.is_ok());
}
#[test]
fn test_validate_rejects_absolute_paths() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
assert!(validate_output_path(base, "/etc/passwd").is_err());
assert!(validate_output_path(base, "/tmp/evil").is_err());
#[cfg(windows)]
{
assert!(validate_output_path(base, "C:\\Windows\\System32").is_err());
}
}
#[test]
fn test_validate_rejects_parent_directory() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
assert!(validate_output_path(base, "..").is_err());
assert!(validate_output_path(base, "../etc/passwd").is_err());
assert!(validate_output_path(base, "../../.ssh/authorized_keys").is_err());
assert!(validate_output_path(base, "subdir/../../../etc/passwd").is_err());
}
#[test]
fn test_validate_handles_existing_files() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
let test_file = base.join("test.json");
fs::write(&test_file, "{}").unwrap();
let result = validate_output_path(base, "test.json");
assert!(result.is_ok());
}
#[test]
fn test_validate_handles_nonexistent_files() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
let result = validate_output_path(base, "new_file.json");
assert!(result.is_ok());
let result = validate_output_path(base, "newdir/file.json");
assert!(result.is_ok());
}
#[test]
fn test_safe_output_path_integration() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
let result = safe_output_path(base, "my_tool", "json");
assert!(result.is_ok());
assert!(result.unwrap().ends_with("my_tool.json"));
let result = safe_output_path(base, "../../../etc/passwd", "json");
assert!(result.is_err(), "Should reject path traversal attempts");
}
#[test]
fn test_comprehensive_attack_scenarios() {
let temp_dir = TempDir::new().unwrap();
let base = temp_dir.path();
let base_canonical = base.canonicalize().unwrap();
let malicious_inputs = vec![
"../../../etc/passwd",
"../../.ssh/authorized_keys",
"../../../.bash_history",
"/etc/shadow",
"../../../../../../../../etc/passwd",
"..\\..\\..\\windows\\system32",
"subdir/../../etc/passwd",
];
for input in malicious_inputs {
let result = validate_output_path(base, input);
assert!(
result.is_err(),
"Should reject malicious path directly: {}",
input
);
match sanitize_filename(input) {
Ok(sanitized) => {
let result = validate_output_path(base, &sanitized);
if let Ok(path) = result {
assert!(
path.starts_with(&base_canonical),
"Sanitized path must be within base dir: {} -> {} (base: {})",
input,
path.display(),
base_canonical.display()
);
}
}
Err(_) => {
}
}
}
}
}