use std::path::{Component, Path, PathBuf};
#[derive(Debug)]
pub struct PathSecurityError {
pub message: String,
}
impl std::fmt::Display for PathSecurityError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "path security: {}", self.message)
}
}
impl std::error::Error for PathSecurityError {}
impl From<PathSecurityError> for String {
fn from(e: PathSecurityError) -> Self {
e.to_string()
}
}
pub fn safe_join(project_root: &str, user_path: &str) -> Result<PathBuf, PathSecurityError> {
let root = Path::new(project_root)
.canonicalize()
.map_err(|e| PathSecurityError {
message: format!("project root {:?} not accessible: {}", project_root, e),
})?;
let trimmed = user_path.trim_start_matches("./");
let candidate = Path::new(trimmed);
if candidate.is_absolute() {
return Err(PathSecurityError {
message: format!("absolute paths are not allowed: {}", user_path),
});
}
for comp in candidate.components() {
match comp {
Component::ParentDir => {
return Err(PathSecurityError {
message: format!("'..' traversal is not allowed: {}", user_path),
});
}
Component::Prefix(_) | Component::RootDir => {
return Err(PathSecurityError {
message: format!("absolute/drive paths are not allowed: {}", user_path),
});
}
_ => {}
}
}
let joined = root.join(candidate);
if let Ok(canonical) = joined.canonicalize() {
if !canonical.starts_with(&root) {
return Err(PathSecurityError {
message: format!("resolved path escapes project root: {}", user_path),
});
}
return Ok(canonical);
}
Ok(joined)
}
pub fn validate_relative(user_path: &str) -> Result<&str, PathSecurityError> {
let trimmed = user_path.trim_start_matches("./");
let candidate = Path::new(trimmed);
if candidate.is_absolute() {
return Err(PathSecurityError {
message: format!("absolute paths are not allowed: {}", user_path),
});
}
for comp in candidate.components() {
if matches!(
comp,
Component::ParentDir | Component::RootDir | Component::Prefix(_)
) {
return Err(PathSecurityError {
message: format!("unsafe path component in: {}", user_path),
});
}
}
Ok(trimmed)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn rejects_absolute_unix() {
let tmp = tempdir().unwrap();
let root = tmp.path().to_str().unwrap();
let err = safe_join(root, "/etc/passwd").unwrap_err();
assert!(err.message.contains("absolute"));
}
#[test]
fn rejects_parent_traversal() {
let tmp = tempdir().unwrap();
let root = tmp.path().to_str().unwrap();
let err = safe_join(root, "../../../etc/passwd").unwrap_err();
assert!(err.message.contains("..") || err.message.contains("traversal"));
}
#[test]
fn rejects_sneaky_parent_midpath() {
let tmp = tempdir().unwrap();
let root = tmp.path().to_str().unwrap();
let err = safe_join(root, "src/../../etc/passwd").unwrap_err();
assert!(err.message.contains("traversal"));
}
#[test]
fn accepts_normal_relative() {
let tmp = tempdir().unwrap();
let sub = tmp.path().join("src");
fs::create_dir(&sub).unwrap();
fs::write(sub.join("lib.rs"), "fn main(){}").unwrap();
let root = tmp.path().to_str().unwrap();
let out = safe_join(root, "src/lib.rs").unwrap();
assert!(out.ends_with("src/lib.rs"));
}
#[test]
fn accepts_dot_slash_prefix() {
let tmp = tempdir().unwrap();
let sub = tmp.path().join("src");
fs::create_dir(&sub).unwrap();
fs::write(sub.join("lib.rs"), "").unwrap();
let root = tmp.path().to_str().unwrap();
let out = safe_join(root, "./src/lib.rs").unwrap();
assert!(out.ends_with("src/lib.rs"));
}
#[test]
fn allows_nonexistent_child() {
let tmp = tempdir().unwrap();
let root = tmp.path().to_str().unwrap();
let out = safe_join(root, "new/file/path.rs").unwrap();
assert!(out.starts_with(tmp.path().canonicalize().unwrap()));
}
#[test]
fn rejects_symlink_escape() {
#[cfg(unix)]
{
use std::os::unix::fs::symlink;
let tmp = tempdir().unwrap();
let outside = tempdir().unwrap();
fs::write(outside.path().join("secret"), "pw").unwrap();
symlink(outside.path().join("secret"), tmp.path().join("link")).unwrap();
let root = tmp.path().to_str().unwrap();
let err = safe_join(root, "link").unwrap_err();
assert!(err.message.contains("escapes project root"));
}
}
#[test]
fn validate_relative_ok() {
assert_eq!(validate_relative("src/foo.rs").unwrap(), "src/foo.rs");
assert_eq!(validate_relative("./src/foo.rs").unwrap(), "src/foo.rs");
}
#[test]
fn validate_relative_rejects_traversal() {
assert!(validate_relative("../etc/passwd").is_err());
assert!(validate_relative("/etc/passwd").is_err());
}
}