use crate::server::error_codes::{ValidationError, ValidationErrorCode};
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct PathValidationConfig {
pub base_dir: PathBuf,
pub resolve_symlinks: bool,
pub allow_relative: bool,
pub allow_hidden: bool,
pub max_depth: Option<usize>,
pub blocked_patterns: Vec<String>,
}
impl PathValidationConfig {
pub fn new(base_dir: impl Into<PathBuf>) -> Self {
Self {
base_dir: base_dir.into(),
resolve_symlinks: true,
allow_relative: false,
allow_hidden: false,
max_depth: None,
blocked_patterns: Vec::new(),
}
}
pub fn allow_relative(mut self, allow: bool) -> Self {
self.allow_relative = allow;
self
}
pub fn allow_hidden(mut self, allow: bool) -> Self {
self.allow_hidden = allow;
self
}
pub fn max_depth(mut self, depth: usize) -> Self {
self.max_depth = Some(depth);
self
}
pub fn block_patterns(mut self, patterns: Vec<String>) -> Self {
self.blocked_patterns = patterns;
self
}
}
pub fn validate_path(path: &str, config: &PathValidationConfig) -> crate::Result<PathBuf> {
if path.is_empty() {
return Err(
ValidationError::new(ValidationErrorCode::MissingField, "path")
.expected("Non-empty path")
.to_error(),
);
}
if path.contains('\0') {
return Err(
ValidationError::new(ValidationErrorCode::SecurityViolation, "path")
.message("Path contains null bytes")
.to_error(),
);
}
let path = normalize_path_separators(path);
if path.contains("..") && !config.allow_relative {
return Err(
ValidationError::new(ValidationErrorCode::SecurityViolation, "path")
.message("Path traversal detected (.. not allowed)")
.to_error(),
);
}
let mut path_buf = PathBuf::from(&path);
if !path_buf.is_absolute() {
if !config.allow_relative {
return Err(
ValidationError::new(ValidationErrorCode::SecurityViolation, "path")
.message("Relative paths not allowed")
.expected("Absolute path")
.to_error(),
);
}
path_buf = config.base_dir.join(&path_buf);
}
let canonical_path = if config.resolve_symlinks {
match path_buf.canonicalize() {
Ok(p) => p,
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
if let Some(parent) = path_buf.parent() {
if let Ok(canonical_parent) = parent.canonicalize() {
if let Some(file_name) = path_buf.file_name() {
canonical_parent.join(file_name)
} else {
return Err(ValidationError::new(
ValidationErrorCode::InvalidFormat,
"path",
)
.message(format!("Invalid path format: {}", e))
.to_error());
}
} else {
return Err(ValidationError::new(
ValidationErrorCode::InvalidFormat,
"path",
)
.message(format!("Cannot resolve parent directory: {}", e))
.to_error());
}
} else {
return Err(
ValidationError::new(ValidationErrorCode::InvalidFormat, "path")
.message("Path has no parent directory")
.to_error(),
);
}
},
Err(e) => {
return Err(
ValidationError::new(ValidationErrorCode::InvalidFormat, "path")
.message(format!("Cannot canonicalize path: {}", e))
.to_error(),
);
},
}
} else {
normalize_path(&path_buf)?
};
let canonical_base = config
.base_dir
.canonicalize()
.unwrap_or_else(|_| config.base_dir.clone());
if !canonical_path.starts_with(&canonical_base) {
return Err(
ValidationError::new(ValidationErrorCode::SecurityViolation, "path")
.message(format!(
"Path escapes base directory. Path must be under: {}",
canonical_base.display()
))
.to_error(),
);
}
if !config.allow_hidden {
for component in canonical_path.components() {
if let std::path::Component::Normal(name) = component {
if let Some(name_str) = name.to_str() {
if name_str.starts_with('.') && name_str != "." && name_str != ".." {
return Err(
ValidationError::new(ValidationErrorCode::NotAllowed, "path")
.message("Hidden files/directories not allowed")
.to_error(),
);
}
}
}
}
}
if let Some(max_depth) = config.max_depth {
let depth = canonical_path
.strip_prefix(&canonical_base)
.unwrap_or(&canonical_path)
.components()
.count();
if depth > max_depth {
return Err(
ValidationError::new(ValidationErrorCode::OutOfRange, "path")
.message(format!(
"Path depth {} exceeds maximum {}",
depth, max_depth
))
.to_error(),
);
}
}
if !config.blocked_patterns.is_empty() {
let path_str = canonical_path.to_string_lossy();
for pattern in &config.blocked_patterns {
if glob_match(pattern, &path_str) {
return Err(
ValidationError::new(ValidationErrorCode::NotAllowed, "path")
.message(format!("Path matches blocked pattern: {}", pattern))
.to_error(),
);
}
}
}
Ok(canonical_path)
}
fn normalize_path_separators(path: &str) -> String {
#[cfg(windows)]
{
path.replace('/', "\\")
}
#[cfg(not(windows))]
{
path.replace('\\', "/")
}
}
fn normalize_path(path: &Path) -> crate::Result<PathBuf> {
let mut normalized = PathBuf::new();
let mut depth = 0i32;
for component in path.components() {
match component {
std::path::Component::Prefix(p) => {
normalized.push(p.as_os_str());
},
std::path::Component::RootDir => {
normalized = PathBuf::from("/");
},
std::path::Component::CurDir => {
},
std::path::Component::ParentDir => {
depth -= 1;
if depth < 0 {
return Err(ValidationError::new(
ValidationErrorCode::SecurityViolation,
"path",
)
.message("Path escapes root with too many '..' components")
.to_error());
}
normalized.pop();
},
std::path::Component::Normal(name) => {
depth += 1;
normalized.push(name);
},
}
}
Ok(normalized)
}
fn glob_match(pattern: &str, text: &str) -> bool {
let pattern = pattern.replace('.', "\\.");
let pattern = pattern.replace('*', ".*");
let pattern = pattern.replace('?', ".");
let pattern = format!("^{}$", pattern);
regex::Regex::new(&pattern).is_ok_and(|re| re.is_match(text))
}
pub fn secure_path_validator(
base_dir: impl Into<PathBuf>,
) -> impl Fn(&str) -> crate::Result<PathBuf> {
let config = PathValidationConfig::new(base_dir)
.allow_relative(false)
.allow_hidden(false)
.block_patterns(vec![
"*.exe".to_string(),
"*.dll".to_string(),
"*.so".to_string(),
"*.dylib".to_string(),
]);
move |path: &str| validate_path(path, &config)
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[test]
fn test_path_validation() {
let temp_dir = env::temp_dir();
let config = PathValidationConfig::new(&temp_dir);
let result = validate_path("../etc/passwd", &config);
assert!(result.is_err());
let result = validate_path("/tmp/file\0.txt", &config);
assert!(result.is_err());
let config = PathValidationConfig::new(&temp_dir).allow_relative(true);
let result = validate_path("subdir/file.txt", &config);
if result.is_err() {
eprintln!("Path validation error: {:?}", result);
}
}
#[test]
fn test_hidden_files() {
let temp_dir = env::temp_dir();
let config = PathValidationConfig::new(&temp_dir).allow_hidden(false);
let result = validate_path(".hidden", &config);
assert!(result.is_err());
let config = PathValidationConfig::new(&temp_dir).allow_hidden(true);
let _result = validate_path(".hidden", &config);
}
#[test]
fn test_blocked_patterns() {
let temp_dir = env::temp_dir();
let config = PathValidationConfig::new(&temp_dir)
.block_patterns(vec!["*.exe".to_string(), "*.dll".to_string()]);
let exe_path = temp_dir.join("test.exe");
let txt_path = temp_dir.join("test.txt");
let result = validate_path(&exe_path.to_string_lossy(), &config);
assert!(result.is_err());
let _result = validate_path(&txt_path.to_string_lossy(), &config);
}
#[test]
fn test_cross_platform_separators() {
let path1 = normalize_path_separators("C:\\Users\\test\\file.txt");
let path2 = normalize_path_separators("/home/user/file.txt");
#[cfg(windows)]
{
assert!(path1.contains('\\'));
assert!(path2.contains('\\'));
}
#[cfg(not(windows))]
{
assert!(path1.contains('/'));
assert!(path2.contains('/'));
}
}
}