use std::{collections::HashMap, path::Path};
use anyhow::Result;
use async_trait::async_trait;
use serde_json::Value;
use crate::agent::types::UndoAction;
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
async fn execute(
&self,
args: &HashMap<String, Value>,
undo_stack: &mut Vec<UndoAction>,
cwd: Option<&Path>,
) -> Result<String>;
}
pub struct ToolRegistry {
tools: HashMap<String, Box<dyn Tool>>,
}
impl Default for ToolRegistry {
fn default() -> Self {
Self::new()
}
}
impl ToolRegistry {
pub fn new() -> Self {
Self {
tools: HashMap::new(),
}
}
pub fn register(&mut self, tool: Box<dyn Tool>) {
self.tools.insert(tool.name().to_string(), tool);
}
pub async fn execute(
&self,
name: &str,
args: &HashMap<String, Value>,
undo_stack: &mut Vec<UndoAction>,
cwd: Option<&Path>,
) -> Result<String> {
if let Some(tool) = self.tools.get(name) {
tool.execute(args, undo_stack, cwd).await
} else {
Err(anyhow::anyhow!("Tool '{}' not found", name))
}
}
}
pub static ALLOW_PATH_TRAVERSAL: std::sync::atomic::AtomicBool =
std::sync::atomic::AtomicBool::new(false);
pub static STARTUP_DIR: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
pub fn init_startup_dir() {
if let Ok(dir) = std::env::current_dir() {
if let Ok(canonical) = std::fs::canonicalize(&dir) {
let _ = STARTUP_DIR.set(canonical);
} else {
let _ = STARTUP_DIR.set(dir);
}
}
}
pub fn strip_unc_prefix(path: &std::path::Path) -> std::path::PathBuf {
#[cfg(windows)]
{
let path_str = path.to_string_lossy();
if let Some(stripped) = path_str.strip_prefix(r"\\?\") {
std::path::PathBuf::from(stripped)
} else {
path.to_path_buf()
}
}
#[cfg(not(windows))]
{
path.to_path_buf()
}
}
pub struct PathTraversalGuard {
active: bool,
}
impl PathTraversalGuard {
pub fn new(active: bool) -> Self {
if active {
ALLOW_PATH_TRAVERSAL.store(true, std::sync::atomic::Ordering::SeqCst);
}
Self { active }
}
}
impl Drop for PathTraversalGuard {
fn drop(&mut self) {
if self.active {
ALLOW_PATH_TRAVERSAL.store(false, std::sync::atomic::Ordering::SeqCst);
}
}
}
pub fn validate_path(path: &str) -> Result<std::path::PathBuf> {
let p = std::path::PathBuf::from(path);
let abs = if p.is_absolute() {
p
} else {
let mut a = std::env::current_dir()?;
a.push(p);
a
};
let normalized = crate::agent::security::normalize_path(&abs);
let canonical = match std::fs::canonicalize(&normalized) {
Ok(c) => c,
Err(_) => {
let mut ancestor = normalized.as_path();
let mut components = Vec::new();
let mut resolved = normalized.clone();
while let Some(parent) = ancestor.parent() {
if let Some(file_name) = ancestor.file_name() {
components.push(file_name);
}
if parent.exists() {
if let Ok(can_parent) = std::fs::canonicalize(parent) {
let mut result = can_parent;
for comp in components.iter().rev() {
result.push(comp);
}
resolved = result;
break;
}
break;
}
ancestor = parent;
}
resolved
}
};
let root = STARTUP_DIR.get().cloned().unwrap_or_else(|| {
std::fs::canonicalize(
std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
)
.unwrap_or_else(|_| std::path::PathBuf::from("."))
});
if !canonical.starts_with(&root)
&& !path.is_empty()
&& !ALLOW_PATH_TRAVERSAL.load(std::sync::atomic::Ordering::SeqCst)
{
anyhow::bail!("Path traversal detected: access to '{}' is denied", path);
}
Ok(strip_unc_prefix(&canonical))
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
struct MockTool;
#[async_trait]
impl Tool for MockTool {
fn name(&self) -> &str {
"mock_tool"
}
async fn execute(
&self,
args: &HashMap<String, Value>,
_undo: &mut Vec<UndoAction>,
_cwd: Option<&Path>,
) -> Result<String> {
let val = args
.get("val")
.and_then(|v| v.as_str())
.unwrap_or("default");
Ok(format!("mock: {}", val))
}
}
#[tokio::test]
async fn test_tool_registry() {
let mut registry = ToolRegistry::new();
registry.register(Box::new(MockTool));
let mut args = HashMap::new();
args.insert("val".to_string(), json!("hello"));
let mut undo = Vec::new();
let res = registry
.execute("mock_tool", &args, &mut undo, None)
.await
.unwrap();
assert_eq!(res, "mock: hello");
let res_err = registry.execute("unknown", &args, &mut undo, None).await;
assert!(res_err.is_err());
}
#[test]
fn test_validate_path() {
let p = validate_path("test.txt").unwrap();
assert!(p.is_absolute());
assert!(p.ends_with("test.txt"));
}
}