use async_trait::async_trait;
use std::path::{Component, Path, PathBuf};
use thiserror::Error;
#[derive(Debug, Clone)]
pub struct SaveResult {
pub path: String,
pub bytes_written: u64,
}
#[derive(Debug, Error)]
pub enum FileSaveError {
#[error("Path not allowed: {0}")]
PathNotAllowed(String),
#[error("IO error: {0}")]
Io(#[from] std::io::Error),
#[error("Save error: {0}")]
Other(String),
}
#[async_trait]
pub trait FileSaver: Send + Sync {
async fn save(&self, path: &str, bytes: &[u8]) -> Result<SaveResult, FileSaveError>;
async fn validate_path(&self, path: &str) -> Result<(), FileSaveError> {
let _ = path;
Ok(())
}
}
pub struct LocalFileSaver {
base_dir: Option<PathBuf>,
}
impl LocalFileSaver {
pub fn new(base_dir: Option<PathBuf>) -> Self {
Self { base_dir }
}
fn resolve_path(&self, path: &str) -> Result<PathBuf, FileSaveError> {
if path.is_empty() {
return Err(FileSaveError::PathNotAllowed(
"Path must name a file".into(),
));
}
let input = PathBuf::from(path);
if let Some(base) = &self.base_dir {
let joined = if input.is_absolute() {
input
} else {
base.join(&input)
};
let normalized = normalize_path(&joined);
let normalized_base = normalize_path(base);
if !normalized.starts_with(&normalized_base) {
return Err(FileSaveError::PathNotAllowed(format!(
"Path escapes base directory: {}",
path
)));
}
Ok(normalized)
} else {
if !input.is_absolute() {
return Err(FileSaveError::PathNotAllowed(
"Path must be absolute when no base_dir is set".into(),
));
}
Ok(normalize_path(&input))
}
}
async fn canonicalize_base_dir(&self, base: &Path) -> Result<PathBuf, FileSaveError> {
tokio::fs::create_dir_all(base).await?;
let meta = tokio::fs::symlink_metadata(base).await?;
if meta.file_type().is_symlink() {
return Err(FileSaveError::PathNotAllowed(
"Base directory must not be a symlink".into(),
));
}
if !meta.is_dir() {
return Err(FileSaveError::PathNotAllowed(
"Base directory must be a directory".into(),
));
}
Ok(tokio::fs::canonicalize(base).await?)
}
async fn prepare_parent_dir(&self, resolved: &Path) -> Result<PathBuf, FileSaveError> {
let Some(base) = &self.base_dir else {
return Ok(resolved
.parent()
.ok_or_else(|| FileSaveError::PathNotAllowed("Path must name a file".into()))?
.to_path_buf());
};
let normalized_base = normalize_path(base);
let relative = resolved
.strip_prefix(&normalized_base)
.map_err(|_| FileSaveError::PathNotAllowed("Path escapes base directory".into()))?;
let canonical_base = self.canonicalize_base_dir(base).await?;
let mut current = canonical_base.clone();
for component in relative
.parent()
.unwrap_or_else(|| Path::new(""))
.components()
{
let Component::Normal(name) = component else {
return Err(FileSaveError::PathNotAllowed(format!(
"Unsupported path component in save path: {}",
resolved.display()
)));
};
let candidate = current.join(name);
let meta = match tokio::fs::symlink_metadata(&candidate).await {
Ok(meta) => meta,
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {
if let Err(create_err) = tokio::fs::create_dir(&candidate).await {
if create_err.kind() != std::io::ErrorKind::AlreadyExists {
return Err(create_err.into());
}
}
tokio::fs::symlink_metadata(&candidate).await?
}
Err(err) => return Err(err.into()),
};
if meta.file_type().is_symlink() {
return Err(FileSaveError::PathNotAllowed(format!(
"Path traverses symlink: {}",
candidate.display()
)));
}
if !meta.is_dir() {
return Err(FileSaveError::PathNotAllowed(format!(
"Parent path is not a directory: {}",
candidate.display()
)));
}
let canonical_candidate = tokio::fs::canonicalize(&candidate).await?;
if !canonical_candidate.starts_with(&canonical_base) {
return Err(FileSaveError::PathNotAllowed(format!(
"Path escapes base directory via symlink: {}",
candidate.display()
)));
}
current = canonical_candidate;
}
Ok(current)
}
}
#[async_trait]
impl FileSaver for LocalFileSaver {
async fn save(&self, path: &str, bytes: &[u8]) -> Result<SaveResult, FileSaveError> {
let resolved = self.resolve_path(path)?;
if let Some(base_dir) = &self.base_dir {
if resolved == normalize_path(base_dir) {
return Err(FileSaveError::PathNotAllowed(
"Path must name a file".into(),
));
}
}
let file_name = resolved
.file_name()
.ok_or_else(|| FileSaveError::PathNotAllowed("Path must name a file".into()))?;
let parent_dir = self.prepare_parent_dir(&resolved).await?;
let final_path = parent_dir.join(file_name);
if self.base_dir.is_none() {
if let Some(parent) = final_path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
}
match tokio::fs::symlink_metadata(&final_path).await {
Ok(meta) if meta.file_type().is_symlink() => {
return Err(FileSaveError::PathNotAllowed(format!(
"Refusing to write through symlink: {}",
final_path.display()
)));
}
Ok(_) => {}
Err(err) if err.kind() == std::io::ErrorKind::NotFound => {}
Err(err) => return Err(err.into()),
}
tokio::fs::write(&final_path, bytes).await?;
Ok(SaveResult {
path: final_path.to_string_lossy().to_string(),
bytes_written: bytes.len() as u64,
})
}
async fn validate_path(&self, path: &str) -> Result<(), FileSaveError> {
self.resolve_path(path)?;
Ok(())
}
}
fn normalize_path(path: &Path) -> PathBuf {
let mut components = Vec::new();
for component in path.components() {
match component {
Component::ParentDir => {
if matches!(components.last(), Some(Component::Normal(_))) {
components.pop();
} else {
components.push(component);
}
}
Component::CurDir => {}
c => components.push(c),
}
}
components.iter().collect()
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
#[test]
fn test_normalize_path() {
assert_eq!(
normalize_path(Path::new("/a/b/../c")),
PathBuf::from("/a/c")
);
assert_eq!(
normalize_path(Path::new("/a/b/./c")),
PathBuf::from("/a/b/c")
);
assert_eq!(
normalize_path(Path::new("/a/b/c/../..")),
PathBuf::from("/a")
);
}
#[test]
fn test_local_file_saver_resolve_relative() {
let saver = LocalFileSaver::new(Some(PathBuf::from("/tmp/downloads")));
let resolved = saver.resolve_path("file.txt").unwrap();
assert_eq!(resolved, PathBuf::from("/tmp/downloads/file.txt"));
}
#[test]
fn test_local_file_saver_resolve_subdirectory() {
let saver = LocalFileSaver::new(Some(PathBuf::from("/tmp/downloads")));
let resolved = saver.resolve_path("sub/dir/file.txt").unwrap();
assert_eq!(resolved, PathBuf::from("/tmp/downloads/sub/dir/file.txt"));
}
#[test]
fn test_local_file_saver_reject_traversal() {
let saver = LocalFileSaver::new(Some(PathBuf::from("/tmp/downloads")));
let result = saver.resolve_path("../../etc/passwd");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
FileSaveError::PathNotAllowed(_)
));
}
#[test]
fn test_local_file_saver_reject_traversal_absolute() {
let saver = LocalFileSaver::new(Some(PathBuf::from("/tmp/downloads")));
let result = saver.resolve_path("/etc/passwd");
assert!(result.is_err());
}
#[test]
fn test_local_file_saver_no_base_requires_absolute() {
let saver = LocalFileSaver::new(None);
let result = saver.resolve_path("relative.txt");
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
FileSaveError::PathNotAllowed(_)
));
}
#[test]
fn test_local_file_saver_no_base_accepts_absolute() {
let saver = LocalFileSaver::new(None);
let resolved = saver.resolve_path("/tmp/file.txt").unwrap();
assert_eq!(resolved, PathBuf::from("/tmp/file.txt"));
}
#[tokio::test]
async fn test_local_file_saver_save_and_read() {
let dir = tempfile::tempdir().unwrap();
let saver = LocalFileSaver::new(Some(dir.path().to_path_buf()));
let result = saver.save("test.txt", b"hello world").await.unwrap();
assert_eq!(result.bytes_written, 11);
assert!(result.path.ends_with("test.txt"));
let content = tokio::fs::read_to_string(dir.path().join("test.txt"))
.await
.unwrap();
assert_eq!(content, "hello world");
}
#[tokio::test]
async fn test_local_file_saver_creates_parent_dirs() {
let dir = tempfile::tempdir().unwrap();
let saver = LocalFileSaver::new(Some(dir.path().to_path_buf()));
let result = saver
.save("sub/dir/file.bin", &[0xFF, 0x00, 0xAB])
.await
.unwrap();
assert_eq!(result.bytes_written, 3);
let content = tokio::fs::read(dir.path().join("sub/dir/file.bin"))
.await
.unwrap();
assert_eq!(content, vec![0xFF, 0x00, 0xAB]);
}
#[tokio::test]
async fn test_local_file_saver_validate_path() {
let dir = tempfile::tempdir().unwrap();
let saver = LocalFileSaver::new(Some(dir.path().to_path_buf()));
assert!(saver.validate_path("safe.txt").await.is_ok());
assert!(saver.validate_path("sub/dir/safe.txt").await.is_ok());
assert!(saver.validate_path("../../escape.txt").await.is_err());
}
}