use crate::error::BuildError;
use once_cell::sync::Lazy;
use regex::Regex;
use std::collections::HashSet;
use std::fs;
use std::path::{Component, Path, PathBuf};
const MAX_PATH_LENGTH: usize = 260;
const MAX_PATH_DEPTH: usize = 32;
static DIRECTORY_TRAVERSAL: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)(\.\./|\.\.\x5c|/\.\./|\x5c\.\.\x5c)").unwrap());
static ENCODED_TRAVERSAL: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)(%2e%2e%2f|%2e%2e%5c|%252e%252e%252f|%252e%252e%255c)").unwrap());
static ABSOLUTE_PATH: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)(^[a-zA-Z]:\x5c|^/|^\x5c\x5c)").unwrap());
static DANGEROUS_CHARS: Lazy<Regex> =
Lazy::new(|| Regex::new(r"[\x00-\x1F\x7F-\x9F]|%00").unwrap());
static WINDOWS_RESERVED: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)^(con|prn|aux|nul|com[1-9]|lpt[1-9])(\.|$)").unwrap());
static SUSPICIOUS_EXTENSIONS: Lazy<Regex> =
Lazy::new(|| Regex::new(r"(?i)\.(exe|bat|cmd|com|scr|pif|vbs|js|jar|dll|sys)$").unwrap());
static WINDOWS_RESERVED_NAMES: Lazy<HashSet<&str>> = Lazy::new(|| {
let mut set = HashSet::new();
set.insert("CON");
set.insert("PRN");
set.insert("AUX");
set.insert("NUL");
set.insert("COM1");
set.insert("COM2");
set.insert("COM3");
set.insert("COM4");
set.insert("COM5");
set.insert("COM6");
set.insert("COM7");
set.insert("COM8");
set.insert("COM9");
set.insert("LPT1");
set.insert("LPT2");
set.insert("LPT3");
set.insert("LPT4");
set.insert("LPT5");
set.insert("LPT6");
set.insert("LPT7");
set.insert("LPT8");
set.insert("LPT9");
set
});
#[derive(Debug, Clone)]
pub struct PathValidationConfig {
pub max_path_length: usize,
pub max_path_depth: usize,
pub allowed_base_dirs: Vec<PathBuf>,
pub allow_relative_outside_base: bool,
pub validate_symlinks: bool,
pub check_existence: bool,
pub allowed_extensions: HashSet<String>,
pub allow_hidden: bool,
}
impl Default for PathValidationConfig {
fn default() -> Self {
let mut allowed_extensions = HashSet::new();
allowed_extensions.insert("xml".to_string());
allowed_extensions.insert("json".to_string());
allowed_extensions.insert("txt".to_string());
allowed_extensions.insert("csv".to_string());
Self {
max_path_length: MAX_PATH_LENGTH,
max_path_depth: MAX_PATH_DEPTH,
allowed_base_dirs: vec![
PathBuf::from("data"),
PathBuf::from("input"),
PathBuf::from("output"),
PathBuf::from("temp"),
PathBuf::from("."),
],
allow_relative_outside_base: false,
validate_symlinks: true,
check_existence: false,
allowed_extensions,
allow_hidden: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ValidatedPath {
pub original: String,
pub normalized: PathBuf,
pub canonical: Option<PathBuf>,
pub exists: bool,
pub warnings: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct PathValidator {
config: PathValidationConfig,
}
impl PathValidator {
pub fn new() -> Self {
Self {
config: PathValidationConfig::default(),
}
}
pub fn with_config(config: PathValidationConfig) -> Self {
Self { config }
}
pub fn validate(&self, path_str: &str) -> Result<ValidatedPath, BuildError> {
let sanitized_input = self.sanitize_input(path_str)?;
if sanitized_input.len() > self.config.max_path_length {
return Err(BuildError::InputSanitization(format!(
"Path too long: {} > {}",
sanitized_input.len(),
self.config.max_path_length
)));
}
self.detect_dangerous_patterns(&sanitized_input)?;
let normalized = self.normalize_path(&sanitized_input)?;
self.validate_components(&normalized)?;
self.validate_against_whitelist(&normalized)?;
let (canonical, exists) = self.safe_canonicalize(&normalized);
if self.config.validate_symlinks {
self.validate_symlinks(&normalized, &canonical)?;
}
if self.config.check_existence && !exists {
return Err(BuildError::InputSanitization(
"File does not exist".to_string(),
));
}
let warnings = self.collect_warnings(&sanitized_input, &normalized);
Ok(ValidatedPath {
original: path_str.to_string(),
normalized,
canonical,
exists,
warnings,
})
}
fn sanitize_input(&self, input: &str) -> Result<String, BuildError> {
if input.contains('\0') {
return Err(BuildError::InputSanitization(
"Null byte detected in path".to_string(),
));
}
let decoded = self.safe_url_decode(input)?;
if decoded
.chars()
.any(|c| c.is_control() && c != '\n' && c != '\r' && c != '\t')
{
return Err(BuildError::InputSanitization(
"Control characters detected in path".to_string(),
));
}
let normalized = self.normalize_unicode(&decoded)?;
Ok(normalized)
}
fn safe_url_decode(&self, input: &str) -> Result<String, BuildError> {
let first_decode = urlencoding::decode(input)
.map_err(|e| BuildError::InputSanitization(format!("URL decode error: {}", e)))?;
let second_decode = urlencoding::decode(&first_decode);
if second_decode.is_ok() && second_decode.as_ref().unwrap() != &first_decode {
return Err(BuildError::InputSanitization(
"Double URL encoding detected (potential attack)".to_string(),
));
}
Ok(first_decode.into_owned())
}
fn normalize_unicode(&self, input: &str) -> Result<String, BuildError> {
use unicode_normalization::UnicodeNormalization;
let nfc = input.nfc().collect::<String>();
let nfd = input.nfd().collect::<String>();
let nfkc = input.nfkc().collect::<String>();
let nfkd = input.nfkd().collect::<String>();
let forms_identical = nfc == nfd && nfd == nfkc && nfkc == nfkd;
if !forms_identical {
let forms = [&nfc, &nfd, &nfkc, &nfkd];
let mut dangerous_forms = Vec::new();
for (i, form) in forms.iter().enumerate() {
if DIRECTORY_TRAVERSAL.is_match(form)
|| ENCODED_TRAVERSAL.is_match(form)
|| ABSOLUTE_PATH.is_match(form)
|| DANGEROUS_CHARS.is_match(form)
{
dangerous_forms.push(match i {
0 => "NFC",
1 => "NFD",
2 => "NFKC",
3 => "NFKD",
_ => unreachable!(),
});
}
}
if !dangerous_forms.is_empty() {
return Err(BuildError::InputSanitization(format!(
"Unicode normalization attack detected in forms: {:?}",
dangerous_forms
)));
}
}
Ok(nfc)
}
fn detect_dangerous_patterns(&self, path: &str) -> Result<(), BuildError> {
if DIRECTORY_TRAVERSAL.is_match(path) {
return Err(BuildError::InputSanitization(
"Directory traversal pattern detected".to_string(),
));
}
if ENCODED_TRAVERSAL.is_match(path) {
return Err(BuildError::InputSanitization(
"Encoded path traversal detected".to_string(),
));
}
if ABSOLUTE_PATH.is_match(path) {
return Err(BuildError::InputSanitization(
"Absolute path not allowed".to_string(),
));
}
if DANGEROUS_CHARS.is_match(path) {
return Err(BuildError::InputSanitization(
"Dangerous characters detected".to_string(),
));
}
if let Some(filename) = Path::new(path).file_name().and_then(|s| s.to_str()) {
if WINDOWS_RESERVED.is_match(filename) {
return Err(BuildError::InputSanitization(
"Windows reserved filename detected".to_string(),
));
}
let filename_upper = filename.to_uppercase();
let base_name = filename_upper.split('.').next().unwrap_or(&filename_upper);
if WINDOWS_RESERVED_NAMES.contains(base_name) {
return Err(BuildError::InputSanitization(
"Windows reserved filename detected".to_string(),
));
}
}
Ok(())
}
fn normalize_path(&self, path: &str) -> Result<PathBuf, BuildError> {
let normalized_str = path.replace('\\', "/");
let components: Vec<&str> = normalized_str
.split('/')
.filter(|c| !c.is_empty() && *c != ".")
.collect();
if components.len() > self.config.max_path_depth {
return Err(BuildError::InputSanitization(format!(
"Path too deep: {} > {}",
components.len(),
self.config.max_path_depth
)));
}
let mut normalized = PathBuf::new();
for component in components {
if component == ".." {
return Err(BuildError::InputSanitization(
"Path traversal (..) detected".to_string(),
));
}
normalized.push(component);
}
Ok(normalized)
}
fn validate_components(&self, path: &Path) -> Result<(), BuildError> {
for component in path.components() {
match component {
Component::Normal(name) => {
let name_str = name.to_string_lossy();
if !self.config.allow_hidden && name_str.starts_with('.') && name_str != "." {
return Err(BuildError::InputSanitization(
"Hidden files/directories not allowed".to_string(),
));
}
if name_str.len() > 255 {
return Err(BuildError::InputSanitization(
"Path component too long".to_string(),
));
}
if name_str.chars().any(|c| r#"<>:"|?*"#.contains(c)) {
return Err(BuildError::InputSanitization(
"Dangerous characters in path component".to_string(),
));
}
}
Component::ParentDir => {
return Err(BuildError::InputSanitization(
"Parent directory traversal detected".to_string(),
));
}
Component::RootDir => {
return Err(BuildError::InputSanitization(
"Root directory access not allowed".to_string(),
));
}
Component::Prefix(_) => {
return Err(BuildError::InputSanitization(
"Windows path prefix not allowed".to_string(),
));
}
Component::CurDir => {
}
}
}
Ok(())
}
fn validate_against_whitelist(&self, path: &Path) -> Result<(), BuildError> {
if self.config.allow_relative_outside_base && path.is_relative() {
return Ok(()); }
for base_dir in &self.config.allowed_base_dirs {
if path.starts_with(base_dir) || path == base_dir {
return Ok(());
}
if base_dir == Path::new(".")
&& (path.parent().is_none() || path.parent() == Some(Path::new("")))
{
return Ok(());
}
if let Ok(canonical_base) = base_dir.canonicalize() {
if let Ok(canonical_path) = path.canonicalize() {
if canonical_path.starts_with(canonical_base) {
return Ok(());
}
}
}
}
Err(BuildError::InputSanitization(
"Path not within allowed directories".to_string(),
))
}
fn safe_canonicalize(&self, path: &Path) -> (Option<PathBuf>, bool) {
let exists = path.exists();
if exists {
match path.canonicalize() {
Ok(canonical) => (Some(canonical), true),
Err(_) => (None, exists),
}
} else {
if let Some(parent) = path.parent() {
if parent.exists() {
match parent.canonicalize() {
Ok(canonical_parent) => {
if let Some(filename) = path.file_name() {
let canonical = canonical_parent.join(filename);
(Some(canonical), false)
} else {
(None, false)
}
}
Err(_) => (None, false),
}
} else {
(None, false)
}
} else {
(None, false)
}
}
}
fn validate_symlinks(
&self,
normalized: &Path,
canonical: &Option<PathBuf>,
) -> Result<(), BuildError> {
if let Some(canonical_path) = canonical {
if normalized != canonical_path {
self.validate_against_whitelist(canonical_path)?;
if let Some(target_str) = canonical_path.to_str() {
if DIRECTORY_TRAVERSAL.is_match(target_str)
|| ENCODED_TRAVERSAL.is_match(target_str)
|| ABSOLUTE_PATH.is_match(target_str)
|| DANGEROUS_CHARS.is_match(target_str)
{
return Err(BuildError::InputSanitization(
"Symlink target contains dangerous patterns".to_string(),
));
}
}
if let Ok(metadata) = fs::symlink_metadata(normalized) {
if metadata.file_type().is_symlink() {
let mut visited = HashSet::new();
let mut current = normalized.to_path_buf();
while current.is_symlink() && visited.len() < 32 {
if visited.contains(¤t) {
return Err(BuildError::InputSanitization(
"Symlink loop detected".to_string(),
));
}
visited.insert(current.clone());
match fs::read_link(¤t) {
Ok(target) => {
current = if target.is_absolute() {
target
} else {
current
.parent()
.unwrap_or_else(|| Path::new("."))
.join(target)
};
}
Err(_) => break,
}
}
if visited.len() >= 32 {
return Err(BuildError::InputSanitization(
"Symlink chain too long (potential loop)".to_string(),
));
}
}
}
}
}
Ok(())
}
fn collect_warnings(&self, input: &str, normalized: &Path) -> Vec<String> {
let mut warnings = Vec::new();
if input.chars().any(|c| !c.is_ascii()) {
warnings.push("Path contains non-ASCII characters".to_string());
}
if let Some(filename) = normalized.file_name().and_then(|s| s.to_str()) {
if filename.len() > 100 {
warnings.push("Very long filename".to_string());
}
}
if normalized.components().count() > 8 {
warnings.push("Deeply nested path".to_string());
}
if let Some(extension) = normalized.extension().and_then(|s| s.to_str()) {
if !self
.config
.allowed_extensions
.contains(&extension.to_lowercase())
{
warnings.push(format!("Unusual file extension: {}", extension));
}
}
if let Some(filename) = normalized.file_name().and_then(|s| s.to_str()) {
if SUSPICIOUS_EXTENSIONS.is_match(filename) {
warnings.push("Suspicious file extension detected".to_string());
}
}
warnings
}
pub fn config(&self) -> &PathValidationConfig {
&self.config
}
pub fn update_config(&mut self, config: PathValidationConfig) {
self.config = config;
}
}
impl Default for PathValidator {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::Path;
#[test]
fn test_basic_path_validation() {
let validator = PathValidator::new();
assert!(validator.validate("data/file.xml").is_ok());
assert!(validator.validate("input/subdir/file.json").is_ok());
assert!(validator.validate("./file.txt").is_ok());
assert!(validator.validate("../etc/passwd").is_err());
assert!(validator.validate("/etc/passwd").is_err());
assert!(validator.validate("C:\\Windows\\System32").is_err());
}
#[test]
fn test_dangerous_patterns() {
let validator = PathValidator::new();
let dangerous_paths = vec![
"../../../etc/passwd",
"..\\..\\..\\windows\\system32\\config\\sam",
"/etc/passwd",
"/proc/self/environ",
"C:\\Windows\\System32",
"\\\\server\\share",
"file%00.txt",
"%2e%2e%2fpasswd",
"%252e%252e%252fpasswd",
];
for path in dangerous_paths {
let result = validator.validate(path);
assert!(result.is_err(), "Should reject dangerous path: {}", path);
}
}
#[test]
fn test_url_encoding_attacks() {
let validator = PathValidator::new();
let encoded_attacks = vec![
"%2e%2e%2f", "%2e%2e%5c", "%252e%252e%252f", "..%2f", "..%00", ];
for attack in encoded_attacks {
assert!(
validator.validate(attack).is_err(),
"Should block encoded attack: {}",
attack
);
}
}
#[test]
fn test_windows_reserved_names() {
let validator = PathValidator::new();
let reserved_names = vec![
"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2", "con.txt", "prn.xml",
"aux.json",
];
for name in reserved_names {
assert!(
validator.validate(name).is_err(),
"Should block reserved name: {}",
name
);
}
}
#[test]
fn test_path_normalization() {
let validator = PathValidator::new();
let result = validator.validate("data//file.xml").unwrap();
assert_eq!(result.normalized, Path::new("data/file.xml"));
let result = validator.validate("data\\subdir\\file.json").unwrap();
assert_eq!(result.normalized, Path::new("data/subdir/file.json"));
let result = validator.validate("./data/./file.txt").unwrap();
assert_eq!(result.normalized, Path::new("data/file.txt"));
}
#[test]
fn test_whitelist_validation() {
let mut config = PathValidationConfig::default();
config.allowed_base_dirs = vec![PathBuf::from("allowed")];
config.allow_relative_outside_base = false;
let validator = PathValidator::with_config(config);
assert!(validator.validate("allowed/file.xml").is_ok());
assert!(validator.validate("disallowed/file.xml").is_err());
}
#[test]
fn test_unicode_normalization() {
let validator = PathValidator::new();
assert!(validator.validate("data/résumé.txt").is_ok());
}
#[test]
fn test_length_limits() {
let mut config = PathValidationConfig::default();
config.max_path_length = 50;
config.max_path_depth = 3;
let validator = PathValidator::with_config(config);
let long_path = "a/".repeat(30);
assert!(validator.validate(&long_path).is_err());
let deep_path = "a/b/c/d/e/f/g.txt";
assert!(validator.validate(deep_path).is_err());
}
#[test]
fn test_file_extensions() {
let mut config = PathValidationConfig::default();
config.allowed_extensions = vec!["xml".to_string(), "json".to_string()]
.into_iter()
.collect();
let validator = PathValidator::with_config(config);
let result = validator.validate("data/file.xml").unwrap();
assert!(result.warnings.is_empty());
let result = validator.validate("data/file.exe").unwrap();
assert!(result.warnings.iter().any(|w| w.contains("extension")));
}
}