use anyhow::{Result, anyhow};
use std::path::{Path, PathBuf};
fn expand_path(path: &str) -> Result<String> {
let trimmed = path.trim();
if trimmed == "~" {
return home_dir().map(|h| h.to_string_lossy().to_string());
}
if let Some(rest) = trimmed.strip_prefix("~/") {
let home = home_dir()?;
return Ok(format!("{}/{}", home.display(), rest));
}
Ok(trimmed.to_string())
}
fn canonicalize_existing(path: &Path) -> Result<PathBuf> {
path.canonicalize()
.map_err(|e| anyhow!("Cannot canonicalize path '{}': {}", path.display(), e))
}
fn contains_traversal(path: &str) -> bool {
let path_lower = path.to_lowercase();
path_lower.contains("..")
|| path_lower.contains("./")
|| path.contains('\0')
|| path.contains('\n')
|| path.contains('\r')
}
fn home_dir() -> Result<PathBuf> {
std::env::var("HOME")
.map(PathBuf::from)
.map_err(|_| anyhow!("Cannot determine home directory from $HOME"))
}
fn is_under_allowed_base(path: &Path) -> Result<bool> {
let home = home_dir()?;
if path.starts_with(&home) {
return Ok(true);
}
#[cfg(target_os = "macos")]
if path.starts_with("/Users") {
let components: Vec<_> = path.components().collect();
if components.len() >= 3 {
return Ok(true);
}
}
if path.starts_with("/tmp")
|| path.starts_with("/var/folders")
|| path.starts_with("/private/tmp")
|| path.starts_with("/private/var/folders")
{
return Ok(true);
}
Ok(false)
}
pub fn sanitize_existing_path(path: &str) -> Result<PathBuf> {
if contains_traversal(path) {
return Err(anyhow!(
"Path contains invalid traversal sequence: {}",
path
));
}
let expanded = expand_path(path)?;
if contains_traversal(&expanded) {
return Err(anyhow!(
"Expanded path contains invalid sequence: {}",
expanded
));
}
let path_buf = PathBuf::from(&expanded);
let canonical = canonicalize_existing(&path_buf)?;
if !is_under_allowed_base(&canonical)? {
return Err(anyhow!(
"Path '{}' is not under an allowed directory",
canonical.display()
));
}
Ok(canonical)
}
pub fn sanitize_new_path(path: &str) -> Result<PathBuf> {
if contains_traversal(path) {
return Err(anyhow!(
"Path contains invalid traversal sequence: {}",
path
));
}
let expanded = expand_path(path)?;
if contains_traversal(&expanded) {
return Err(anyhow!(
"Expanded path contains invalid sequence: {}",
expanded
));
}
let path_buf = PathBuf::from(&expanded);
if let Some(parent) = path_buf.parent() {
if parent.exists() {
let canonical_parent = canonicalize_existing(parent)?;
if !is_under_allowed_base(&canonical_parent)? {
return Err(anyhow!(
"Parent directory '{}' is not under an allowed directory",
canonical_parent.display()
));
}
} else if let Some(grandparent) = parent.parent()
&& grandparent.exists()
{
let canonical_gp = canonicalize_existing(grandparent)?;
if !is_under_allowed_base(&canonical_gp)? {
return Err(anyhow!(
"Path '{}' would be created outside allowed directories",
path_buf.display()
));
}
}
}
Ok(path_buf)
}
pub fn validate_read_path(path: &Path) -> Result<PathBuf> {
if !path.exists() {
return Err(anyhow!("Path does not exist: {}", path.display()));
}
let canonical = canonicalize_existing(path)?;
if !is_under_allowed_base(&canonical)? {
return Err(anyhow!(
"Cannot read from path outside allowed directories: {}",
canonical.display()
));
}
Ok(canonical)
}
pub fn validate_write_path(path: &Path) -> Result<PathBuf> {
let path_str = path.to_string_lossy();
if contains_traversal(&path_str) {
return Err(anyhow!("Path contains invalid traversal sequence"));
}
if path.exists() {
let canonical = canonicalize_existing(path)?;
if !is_under_allowed_base(&canonical)? {
return Err(anyhow!(
"Cannot write to path outside allowed directories: {}",
canonical.display()
));
}
Ok(canonical)
} else {
sanitize_new_path(&path_str)
}
}
pub fn safe_read_to_string(path: &str) -> Result<(PathBuf, String)> {
let validated = sanitize_existing_path(path)?;
let contents = std::fs::read_to_string(&validated)
.map_err(|e| anyhow!("Failed to read '{}': {}", validated.display(), e))?;
Ok((validated, contents))
}
pub async fn safe_read_to_string_async(path: &Path) -> Result<(PathBuf, String)> {
let validated = validate_read_path(path)?;
let contents = tokio::fs::read_to_string(&validated)
.await
.map_err(|e| anyhow!("Failed to read '{}': {}", validated.display(), e))?;
Ok((validated, contents))
}
pub async fn safe_open_file_async(path: &Path) -> Result<(PathBuf, tokio::fs::File)> {
let validated = validate_read_path(path)?;
let file = tokio::fs::File::open(&validated)
.await
.map_err(|e| anyhow!("Failed to open '{}': {}", validated.display(), e))?;
Ok((validated, file))
}
pub async fn safe_read_dir(path: &Path) -> Result<(PathBuf, tokio::fs::ReadDir)> {
let validated = validate_read_path(path)?;
let entries = tokio::fs::read_dir(&validated)
.await
.map_err(|e| anyhow!("Failed to read directory '{}': {}", validated.display(), e))?;
Ok((validated, entries))
}
pub fn safe_copy(src: &Path, dst: &Path) -> Result<PathBuf> {
let safe_src = validate_read_path(src)?;
let safe_dst = validate_write_path(dst)?;
std::fs::copy(&safe_src, &safe_dst).map_err(|e| {
anyhow!(
"Failed to copy '{}' → '{}': {}",
safe_src.display(),
safe_dst.display(),
e
)
})?;
Ok(safe_dst)
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_traversal_detection() {
assert!(contains_traversal("../etc/passwd"));
assert!(contains_traversal("foo/../bar"));
assert!(contains_traversal("./hidden"));
assert!(contains_traversal("path\0with\0nulls"));
assert!(!contains_traversal("/normal/path"));
assert!(!contains_traversal("~/Documents"));
}
#[test]
fn test_sanitize_existing_path() {
let tmp = tempdir().unwrap();
let test_file = tmp.path().join("test.txt");
fs::write(&test_file, "test").unwrap();
let result = sanitize_existing_path(test_file.to_str().unwrap());
assert!(
result.is_ok(),
"Failed for path: {:?}, error: {:?}",
test_file,
result
);
let traversal = format!("{}/../../../etc/passwd", tmp.path().display());
let result = sanitize_existing_path(&traversal);
assert!(result.is_err());
}
#[test]
fn test_validate_read_path() {
let tmp = tempdir().unwrap();
let test_file = tmp.path().join("readable.txt");
fs::write(&test_file, "content").unwrap();
let result = validate_read_path(&test_file);
assert!(result.is_ok());
let missing = tmp.path().join("missing.txt");
let result = validate_read_path(&missing);
assert!(result.is_err());
}
#[test]
fn test_validate_write_path() {
let tmp = tempdir().unwrap();
let new_file = tmp.path().join("new.txt");
let result = validate_write_path(&new_file);
assert!(result.is_ok());
let existing = tmp.path().join("existing.txt");
fs::write(&existing, "data").unwrap();
let result = validate_write_path(&existing);
assert!(result.is_ok());
}
}