Skip to main content

deepseek_rust_cli/tools/
base.rs

1use std::{collections::HashMap, path::Path};
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde_json::Value;
6
7use crate::agent::types::UndoAction;
8
9#[async_trait]
10pub trait Tool: Send + Sync {
11    fn name(&self) -> &str;
12    async fn execute(
13        &self,
14        args: &HashMap<String, Value>,
15        undo_stack: &mut Vec<UndoAction>,
16        cwd: Option<&Path>,
17    ) -> Result<String>;
18}
19
20pub struct ToolRegistry {
21    tools: HashMap<String, Box<dyn Tool>>,
22}
23
24impl Default for ToolRegistry {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30impl ToolRegistry {
31    pub fn new() -> Self {
32        Self {
33            tools: HashMap::new(),
34        }
35    }
36
37    pub fn register(&mut self, tool: Box<dyn Tool>) {
38        self.tools.insert(tool.name().to_string(), tool);
39    }
40
41    pub async fn execute(
42        &self,
43        name: &str,
44        args: &HashMap<String, Value>,
45        undo_stack: &mut Vec<UndoAction>,
46        cwd: Option<&Path>,
47    ) -> Result<String> {
48        if let Some(tool) = self.tools.get(name) {
49            tool.execute(args, undo_stack, cwd).await
50        } else {
51            Err(anyhow::anyhow!("Tool '{}' not found", name))
52        }
53    }
54}
55
56pub static ALLOW_PATH_TRAVERSAL: std::sync::atomic::AtomicBool =
57    std::sync::atomic::AtomicBool::new(false);
58
59pub static STARTUP_DIR: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
60
61pub fn init_startup_dir() {
62    if let Ok(dir) = std::env::current_dir() {
63        if let Ok(canonical) = std::fs::canonicalize(&dir) {
64            let _ = STARTUP_DIR.set(canonical);
65        } else {
66            let _ = STARTUP_DIR.set(dir);
67        }
68    }
69}
70
71pub fn strip_unc_prefix(path: &std::path::Path) -> std::path::PathBuf {
72    #[cfg(windows)]
73    {
74        let path_str = path.to_string_lossy();
75        if let Some(stripped) = path_str.strip_prefix(r"\\?\") {
76            std::path::PathBuf::from(stripped)
77        } else {
78            path.to_path_buf()
79        }
80    }
81    #[cfg(not(windows))]
82    {
83        path.to_path_buf()
84    }
85}
86
87pub struct PathTraversalGuard {
88    active: bool,
89}
90
91impl PathTraversalGuard {
92    pub fn new(active: bool) -> Self {
93        if active {
94            ALLOW_PATH_TRAVERSAL.store(true, std::sync::atomic::Ordering::SeqCst);
95        }
96        Self { active }
97    }
98}
99
100impl Drop for PathTraversalGuard {
101    fn drop(&mut self) {
102        if self.active {
103            ALLOW_PATH_TRAVERSAL.store(false, std::sync::atomic::Ordering::SeqCst);
104        }
105    }
106}
107
108pub fn validate_path(path: &str) -> Result<std::path::PathBuf> {
109    let p = std::path::PathBuf::from(path);
110    let abs = if p.is_absolute() {
111        p
112    } else {
113        let mut a = std::env::current_dir()?;
114        a.push(p);
115        a
116    };
117
118    let normalized = crate::agent::security::normalize_path(&abs);
119
120    let canonical = match std::fs::canonicalize(&normalized) {
121        Ok(c) => c,
122        Err(_) => {
123            let mut ancestor = normalized.as_path();
124            let mut components = Vec::new();
125            let mut resolved = normalized.clone();
126            while let Some(parent) = ancestor.parent() {
127                if let Some(file_name) = ancestor.file_name() {
128                    components.push(file_name);
129                }
130                if parent.exists() {
131                    if let Ok(can_parent) = std::fs::canonicalize(parent) {
132                        let mut result = can_parent;
133                        for comp in components.iter().rev() {
134                            result.push(comp);
135                        }
136                        resolved = result;
137                        break;
138                    }
139                    break;
140                }
141                ancestor = parent;
142            }
143            resolved
144        }
145    };
146
147    let root = STARTUP_DIR.get().cloned().unwrap_or_else(|| {
148        std::fs::canonicalize(
149            std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
150        )
151        .unwrap_or_else(|_| std::path::PathBuf::from("."))
152    });
153
154    if !canonical.starts_with(&root)
155        && !path.is_empty()
156        && !ALLOW_PATH_TRAVERSAL.load(std::sync::atomic::Ordering::SeqCst)
157    {
158        anyhow::bail!("Path traversal detected: access to '{}' is denied", path);
159    }
160
161    Ok(strip_unc_prefix(&canonical))
162}
163
164#[cfg(test)]
165mod tests {
166    use serde_json::json;
167
168    use super::*;
169
170    struct MockTool;
171    #[async_trait]
172    impl Tool for MockTool {
173        fn name(&self) -> &str {
174            "mock_tool"
175        }
176        async fn execute(
177            &self,
178            args: &HashMap<String, Value>,
179            _undo: &mut Vec<UndoAction>,
180            _cwd: Option<&Path>,
181        ) -> Result<String> {
182            let val = args
183                .get("val")
184                .and_then(|v| v.as_str())
185                .unwrap_or("default");
186            Ok(format!("mock: {}", val))
187        }
188    }
189
190    #[tokio::test]
191    async fn test_tool_registry() {
192        let mut registry = ToolRegistry::new();
193        registry.register(Box::new(MockTool));
194
195        let mut args = HashMap::new();
196        args.insert("val".to_string(), json!("hello"));
197
198        let mut undo = Vec::new();
199        let res = registry
200            .execute("mock_tool", &args, &mut undo, None)
201            .await
202            .unwrap();
203        assert_eq!(res, "mock: hello");
204
205        let res_err = registry.execute("unknown", &args, &mut undo, None).await;
206        assert!(res_err.is_err());
207    }
208
209    #[test]
210    fn test_validate_path() {
211        let p = validate_path("test.txt").unwrap();
212        assert!(p.is_absolute());
213        assert!(p.ends_with("test.txt"));
214    }
215}