use super::PathResolver;
use crate::error::{ToolError, ToolResult};
use std::path::{Path, PathBuf};
use std::sync::Arc;
#[derive(Debug, Clone)]
pub struct AllowedPathResolver {
allowed_paths: Arc<[PathBuf]>,
}
impl AllowedPathResolver {
pub fn new(allowed_paths: impl IntoIterator<Item = impl AsRef<Path>>) -> ToolResult<Self> {
let canonicalized: Result<Arc<[PathBuf]>, _> = allowed_paths
.into_iter()
.map(|p| {
let path = p.as_ref();
path.canonicalize().map_err(|e| {
ToolError::InvalidPath(format!(
"failed to canonicalize allowed path '{}': {}",
path.display(),
e
))
})
})
.collect();
Ok(Self {
allowed_paths: canonicalized?,
})
}
pub fn from_canonical(allowed_paths: impl IntoIterator<Item = impl AsRef<Path>>) -> Self {
Self {
allowed_paths: allowed_paths
.into_iter()
.map(|p| p.as_ref().to_path_buf())
.collect(),
}
}
pub fn allowed_paths(&self) -> &[PathBuf] {
&self.allowed_paths
}
}
impl PathResolver for AllowedPathResolver {
fn resolve(&self, path: &str) -> ToolResult<PathBuf> {
let input_path = PathBuf::from(path);
for base in self.allowed_paths.iter() {
let candidate = base.join(&input_path);
if let Ok(canonical) = candidate.canonicalize() {
if canonical.starts_with(base) {
return Ok(canonical);
}
continue;
}
if let Some(parent) = candidate.parent() {
if let Ok(canonical_parent) = parent.canonicalize() {
if canonical_parent.starts_with(base) {
let file_name = candidate.file_name().ok_or_else(|| {
ToolError::InvalidPath("path has no file name".into())
})?;
return Ok(canonical_parent.join(file_name));
}
}
}
}
Err(ToolError::InvalidPath(format!(
"path '{}' is not within allowed directories",
path
)))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::TempDir;
fn setup_test_dir() -> TempDir {
let dir = TempDir::new().unwrap();
fs::create_dir_all(dir.path().join("subdir")).unwrap();
fs::write(dir.path().join("file.txt"), "content").unwrap();
fs::write(dir.path().join("subdir/nested.txt"), "nested").unwrap();
dir
}
#[test]
fn resolves_relative_path_in_allowed_dir() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("file.txt");
assert!(result.is_ok());
assert!(result.unwrap().ends_with("file.txt"));
}
#[test]
fn resolves_nested_path() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("subdir/nested.txt");
assert!(result.is_ok());
}
#[test]
fn rejects_path_traversal() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("../../../etc/passwd");
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not within allowed"));
}
#[test]
fn allows_non_existent_path_for_write() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("new_file.txt");
assert!(result.is_ok());
}
#[test]
fn allows_nested_non_existent_path() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("subdir/new_file.txt");
assert!(result.is_ok());
}
#[test]
fn rejects_non_existent_path_outside_allowed() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("subdir/../../../new_file.txt");
assert!(result.is_err());
}
#[test]
fn tries_multiple_allowed_paths() {
let dir1 = setup_test_dir();
let dir2 = setup_test_dir();
fs::write(dir2.path().join("only_in_dir2.txt"), "content").unwrap();
let resolver =
AllowedPathResolver::new(vec![dir1.path().to_path_buf(), dir2.path().to_path_buf()])
.unwrap();
let result = resolver.resolve("only_in_dir2.txt");
assert!(result.is_ok());
}
#[test]
fn returns_canonical_path() {
let dir = setup_test_dir();
let resolver = AllowedPathResolver::new(vec![dir.path().to_path_buf()]).unwrap();
let result = resolver.resolve("subdir/../file.txt");
assert!(result.is_ok());
let resolved = result.unwrap();
assert!(!resolved.to_string_lossy().contains(".."));
}
}