use thiserror::Error;
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Tool not found: {0}")]
NotFound(String),
#[error("Invalid input: {0}")]
InvalidInput(String),
#[error("Execution error: {0}")]
Execution(String),
#[error("Permission denied: {0}")]
PermissionDenied(String),
#[error("Tool requires approval: {0}")]
ApprovalRequired(String),
#[error("File not found: {0}")]
FileNotFound(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("JSON error: {0}")]
Json(#[from] serde_json::Error),
#[error("Tool execution timed out after {0}s")]
Timeout(u64),
#[error("Internal error: {0}")]
Internal(String),
}
pub type Result<T> = std::result::Result<T, ToolError>;
pub fn expand_tilde(path: &str) -> std::path::PathBuf {
use std::path::PathBuf;
if let Some(rest) = path.strip_prefix("~/")
&& let Some(home) = dirs::home_dir()
{
return home.join(rest);
}
if path == "~"
&& let Some(home) = dirs::home_dir()
{
return home;
}
PathBuf::from(path)
}
pub fn resolve_tool_path(
requested_path: &str,
working_directory: &std::path::Path,
) -> std::path::PathBuf {
let expanded = expand_tilde(requested_path);
if expanded.is_absolute() {
expanded
} else {
working_directory.join(expanded)
}
}
pub fn validate_path_safety(
requested_path: &str,
working_directory: &std::path::Path,
) -> Result<std::path::PathBuf> {
let path = resolve_tool_path(requested_path, working_directory);
if !path.exists() {
let parent = path
.parent()
.ok_or_else(|| ToolError::InvalidInput("Invalid path: no parent directory".into()))?;
if !parent.exists() {
return Err(ToolError::InvalidInput(format!(
"Parent directory does not exist: {}",
parent.display()
)));
}
}
Ok(path)
}
pub fn validate_file_path(
requested_path: &str,
working_directory: &std::path::Path,
) -> std::result::Result<std::path::PathBuf, String> {
let path = match validate_path_safety(requested_path, working_directory) {
Ok(p) => p,
Err(ToolError::InvalidInput(msg)) => {
return Err(format!("Invalid path: {}", msg));
}
Err(e) => {
return Err(format!("Path validation failed: {}", e));
}
};
if !path.exists() {
return Err(format!("File not found: {}", path.display()));
}
if !path.is_file() {
return Err(format!("Path is not a file: {}", path.display()));
}
Ok(path)
}
pub fn validate_directory_path(
requested_path: &str,
working_directory: &std::path::Path,
) -> std::result::Result<std::path::PathBuf, String> {
let path = match validate_path_safety(requested_path, working_directory) {
Ok(p) => p,
Err(ToolError::InvalidInput(msg)) => {
return Err(format!("Invalid path: {}", msg));
}
Err(e) => {
return Err(format!("Path validation failed: {}", e));
}
};
if !path.exists() {
return Err(format!("Directory not found: {}", path.display()));
}
if !path.is_dir() {
return Err(format!("Path is not a directory: {}", path.display()));
}
Ok(path)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tool_error_display() {
let err = ToolError::NotFound("test_tool".to_string());
assert_eq!(err.to_string(), "Tool not found: test_tool");
let err = ToolError::PermissionDenied("dangerous_operation".to_string());
assert_eq!(err.to_string(), "Permission denied: dangerous_operation");
}
#[test]
fn test_expand_tilde_prefix() {
let home = dirs::home_dir().expect("home dir required for this test");
assert_eq!(expand_tilde("~/foo/bar"), home.join("foo/bar"));
assert_eq!(expand_tilde("~"), home);
}
#[test]
fn test_expand_tilde_passthrough() {
assert_eq!(
expand_tilde("/tmp/~backup").to_string_lossy(),
"/tmp/~backup"
);
assert_eq!(expand_tilde("foo/bar").to_string_lossy(), "foo/bar");
assert_eq!(expand_tilde("/abs/path").to_string_lossy(), "/abs/path");
}
#[test]
fn test_resolve_tool_path_tilde_becomes_absolute() {
let cwd = std::path::Path::new("/Users/adolfo/srv/rs/opencrabs");
let resolved = resolve_tool_path("~/.opencrabs/logs", cwd);
let home = dirs::home_dir().expect("home dir required");
assert_eq!(resolved, home.join(".opencrabs/logs"));
assert!(resolved.is_absolute());
}
#[test]
fn test_resolve_tool_path_relative_joins_cwd() {
let cwd = std::path::Path::new("/tmp/project");
assert_eq!(
resolve_tool_path("src/main.rs", cwd),
std::path::PathBuf::from("/tmp/project/src/main.rs"),
);
}
#[test]
fn test_resolve_tool_path_absolute_passthrough() {
let cwd = std::path::Path::new("/tmp/project");
assert_eq!(
resolve_tool_path("/etc/hosts", cwd),
std::path::PathBuf::from("/etc/hosts"),
);
}
}