use crate::error::{Error, Result};
use std::fs::{self, File, OpenOptions};
use std::path::{Path, PathBuf};
pub fn safe_file_path(path: &Path, allow_symlinks: bool) -> Result<PathBuf> {
if path.exists() {
if path.is_symlink() {
if !allow_symlinks {
return Err(Error::Validation(format!(
"Security error: Path {} is a symlink, which is not allowed",
path.display()
)));
}
let target = fs::read_link(path)?;
if !is_safe_symlink_target(&target) {
return Err(Error::Validation(format!(
"Security error: Symlink target {} is not in an allowed location",
target.display()
)));
}
return Ok(target);
}
#[cfg(unix)]
{
use std::os::unix::fs::MetadataExt;
let metadata = fs::metadata(path)?;
if metadata.nlink() > 1 {
return Err(Error::Validation(format!(
"Security error: Path {} has multiple hard links ({})",
path.display(),
metadata.nlink()
)));
}
}
}
Ok(path.to_path_buf())
}
fn is_safe_symlink_target(target: &Path) -> bool {
if let Ok(canonical) = target.canonicalize() {
canonical.starts_with("/tmp") || canonical.starts_with("/var/app/data")
} else {
false
}
}
pub fn safe_open_file(path: &Path, allow_symlinks: bool) -> Result<File> {
let safe_path = safe_file_path(path, allow_symlinks)?;
File::open(&safe_path).map_err(Error::from)
}
pub fn safe_create_file(path: &Path, allow_symlinks: bool) -> Result<File> {
let safe_path = safe_file_path(path, allow_symlinks)?;
File::create(&safe_path).map_err(Error::from)
}
pub fn safe_open_options(path: &Path, allow_symlinks: bool) -> Result<OpenOptions> {
let _safe_path = safe_file_path(path, allow_symlinks)?;
Ok(OpenOptions::new())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::error::Result;
use std::fs::{self, File};
use std::io::{Read, Write};
use std::path::PathBuf;
use tempfile::tempdir;
#[test]
fn test_safe_file_path_normal() -> Result<()> {
let dir = tempdir()?;
let normal_path = dir.path().join("test_file.txt");
let result = safe_file_path(&normal_path, false)?;
assert_eq!(result, normal_path);
Ok(())
}
#[test]
fn test_safe_file_path_nonexistent() -> Result<()> {
let dir = tempdir()?;
let nonexistent_path = dir.path().join("nonexistent_file.txt");
let result = safe_file_path(&nonexistent_path, false)?;
assert_eq!(result, nonexistent_path);
Ok(())
}
#[test]
#[cfg(unix)] fn test_safe_file_path_symlink() -> Result<()> {
let dir = tempdir()?;
let target_path = dir.path().join("target_file.txt");
let symlink_path = dir.path().join("symlink_file.txt");
let mut file = File::create(&target_path)?;
file.write_all(b"target file content")?;
std::os::unix::fs::symlink(&target_path, &symlink_path)?;
let result = safe_file_path(&symlink_path, false);
assert!(result.is_err(), "Should reject symlinks when not allowed");
let result = safe_file_path(&symlink_path, true)?;
assert_eq!(result, target_path);
Ok(())
}
#[test]
#[cfg(unix)] fn test_safe_file_path_unsafe_symlink() {
let dir = tempdir().unwrap();
let unsafe_target = PathBuf::from("/etc/passwd");
let unsafe_symlink = dir.path().join("unsafe_symlink.txt");
std::os::unix::fs::symlink(&unsafe_target, &unsafe_symlink).unwrap();
let result = safe_file_path(&unsafe_symlink, true);
assert!(
result.is_err(),
"Should reject symlinks to unsafe locations"
);
}
#[test]
#[cfg(unix)] fn test_safe_file_path_hardlink() -> Result<()> {
let dir = tempdir()?;
let target_path = dir.path().join("target_file.txt");
let hardlink_path = dir.path().join("hardlink_file.txt");
let mut file = File::create(&target_path)?;
file.write_all(b"target file content")?;
std::fs::hard_link(&target_path, &hardlink_path)?;
let result = safe_file_path(&hardlink_path, false);
assert!(
result.is_err(),
"Should reject files with multiple hard links"
);
Ok(())
}
#[test]
fn test_safe_open_file() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("test_open.txt");
{
let mut file = File::create(&file_path)?;
file.write_all(b"test content")?;
}
let mut file = safe_open_file(&file_path, false)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
assert_eq!(content, "test content");
Ok(())
}
#[test]
fn test_safe_create_file() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("test_create.txt");
{
let mut file = safe_create_file(&file_path, false)?;
file.write_all(b"created content")?;
}
let mut content = String::new();
let mut file = File::open(&file_path)?;
file.read_to_string(&mut content)?;
assert_eq!(content, "created content");
Ok(())
}
#[test]
fn test_safe_open_options() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("test_options.txt");
{
let mut options = safe_open_options(&file_path, false)?;
let mut file = options.write(true).create(true).open(&file_path)?;
file.write_all(b"options content")?;
}
let mut content = String::new();
let mut file = File::open(&file_path)?;
file.read_to_string(&mut content)?;
assert_eq!(content, "options content");
Ok(())
}
#[test]
fn test_safe_open_file_nonexistent() {
let nonexistent_path = PathBuf::from("/tmp/this_file_should_not_exist.txt");
if nonexistent_path.exists() {
fs::remove_file(&nonexistent_path).unwrap();
}
let result = safe_open_file(&nonexistent_path, false);
assert!(result.is_err());
if let Err(e) = result {
match e {
crate::error::Error::Io(_) => {} _ => panic!("Unexpected error type: {e:?}"),
}
}
}
#[test]
fn test_is_safe_symlink_target() {
let check_path = |path: &str| -> bool {
let path = Path::new(path);
if let Ok(canonical) = path.canonicalize() {
canonical.starts_with("/tmp") || canonical.starts_with("/var/app/data")
} else {
false
}
};
let tmp_dir = tempdir().unwrap();
assert!(
check_path(tmp_dir.path().to_str().unwrap()),
"Temporary directory should be considered safe"
);
assert!(
!check_path("/etc/passwd"),
"/etc/passwd should not be considered safe"
);
assert!(
!check_path("/home/user/file.txt"),
"/home/user/file.txt should not be considered safe"
);
}
#[test]
fn test_safe_open_file_comprehensive() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("comprehensive_test.txt");
let result = safe_open_file(&file_path, false);
assert!(result.is_err(), "Opening non-existent file should fail");
{
let mut file = File::create(&file_path)?;
file.write_all(b"comprehensive test")?;
}
let mut file = safe_open_file(&file_path, false)?;
let mut content = String::new();
file.read_to_string(&mut content)?;
assert_eq!(content, "comprehensive test");
let invalid_path = PathBuf::from("\0invalid");
let result = safe_open_file(&invalid_path, false);
assert!(
result.is_err(),
"Opening file with invalid path should fail"
);
Ok(())
}
#[test]
fn test_safe_create_file_existing() -> Result<()> {
let dir = tempdir()?;
let file_path = dir.path().join("existing.txt");
{
let mut file = File::create(&file_path)?;
file.write_all(b"initial content")?;
}
{
let mut file = safe_create_file(&file_path, false)?;
file.write_all(b"overwritten content")?;
}
let mut content = String::new();
let mut file = File::open(&file_path)?;
file.read_to_string(&mut content)?;
assert_eq!(content, "overwritten content");
Ok(())
}
}