deepseek_rust_cli/tools/
base.rs1use std::{collections::HashMap, path::Path};
2
3use anyhow::Result;
4use async_trait::async_trait;
5use serde_json::Value;
6
7use crate::agent::types::UndoAction;
8
9#[async_trait]
10pub trait Tool: Send + Sync {
11 fn name(&self) -> &str;
12 async fn execute(
13 &self,
14 args: &HashMap<String, Value>,
15 undo_stack: &mut Vec<UndoAction>,
16 cwd: Option<&Path>,
17 ) -> Result<String>;
18}
19
20pub struct ToolRegistry {
21 tools: HashMap<String, Box<dyn Tool>>,
22}
23
24impl Default for ToolRegistry {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30impl ToolRegistry {
31 pub fn new() -> Self {
32 Self {
33 tools: HashMap::new(),
34 }
35 }
36
37 pub fn register(&mut self, tool: Box<dyn Tool>) {
38 self.tools.insert(tool.name().to_string(), tool);
39 }
40
41 pub async fn execute(
42 &self,
43 name: &str,
44 args: &HashMap<String, Value>,
45 undo_stack: &mut Vec<UndoAction>,
46 cwd: Option<&Path>,
47 ) -> Result<String> {
48 if let Some(tool) = self.tools.get(name) {
49 tool.execute(args, undo_stack, cwd).await
50 } else {
51 Err(anyhow::anyhow!("Tool '{}' not found", name))
52 }
53 }
54}
55
56pub static ALLOW_PATH_TRAVERSAL: std::sync::atomic::AtomicBool =
57 std::sync::atomic::AtomicBool::new(false);
58
59pub static STARTUP_DIR: std::sync::OnceLock<std::path::PathBuf> = std::sync::OnceLock::new();
60
61pub fn init_startup_dir() {
62 if let Ok(dir) = std::env::current_dir() {
63 if let Ok(canonical) = std::fs::canonicalize(&dir) {
64 let _ = STARTUP_DIR.set(canonical);
65 } else {
66 let _ = STARTUP_DIR.set(dir);
67 }
68 }
69}
70
71pub fn strip_unc_prefix(path: &std::path::Path) -> std::path::PathBuf {
72 #[cfg(windows)]
73 {
74 let path_str = path.to_string_lossy();
75 if let Some(stripped) = path_str.strip_prefix(r"\\?\") {
76 std::path::PathBuf::from(stripped)
77 } else {
78 path.to_path_buf()
79 }
80 }
81 #[cfg(not(windows))]
82 {
83 path.to_path_buf()
84 }
85}
86
87pub struct PathTraversalGuard {
88 active: bool,
89}
90
91impl PathTraversalGuard {
92 pub fn new(active: bool) -> Self {
93 if active {
94 ALLOW_PATH_TRAVERSAL.store(true, std::sync::atomic::Ordering::SeqCst);
95 }
96 Self { active }
97 }
98}
99
100impl Drop for PathTraversalGuard {
101 fn drop(&mut self) {
102 if self.active {
103 ALLOW_PATH_TRAVERSAL.store(false, std::sync::atomic::Ordering::SeqCst);
104 }
105 }
106}
107
108pub fn validate_path(path: &str) -> Result<std::path::PathBuf> {
109 let p = std::path::PathBuf::from(path);
110 let abs = if p.is_absolute() {
111 p
112 } else {
113 let mut a = std::env::current_dir()?;
114 a.push(p);
115 a
116 };
117
118 let normalized = crate::agent::security::normalize_path(&abs);
119
120 let canonical = match std::fs::canonicalize(&normalized) {
121 Ok(c) => c,
122 Err(_) => {
123 let mut ancestor = normalized.as_path();
124 let mut components = Vec::new();
125 let mut resolved = normalized.clone();
126 while let Some(parent) = ancestor.parent() {
127 if let Some(file_name) = ancestor.file_name() {
128 components.push(file_name);
129 }
130 if parent.exists() {
131 if let Ok(can_parent) = std::fs::canonicalize(parent) {
132 let mut result = can_parent;
133 for comp in components.iter().rev() {
134 result.push(comp);
135 }
136 resolved = result;
137 break;
138 }
139 break;
140 }
141 ancestor = parent;
142 }
143 resolved
144 }
145 };
146
147 let root = STARTUP_DIR.get().cloned().unwrap_or_else(|| {
148 std::fs::canonicalize(
149 std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
150 )
151 .unwrap_or_else(|_| std::path::PathBuf::from("."))
152 });
153
154 if !canonical.starts_with(&root)
155 && !path.is_empty()
156 && !ALLOW_PATH_TRAVERSAL.load(std::sync::atomic::Ordering::SeqCst)
157 {
158 anyhow::bail!("Path traversal detected: access to '{}' is denied", path);
159 }
160
161 Ok(strip_unc_prefix(&canonical))
162}
163
164#[cfg(test)]
165mod tests {
166 use serde_json::json;
167
168 use super::*;
169
170 struct MockTool;
171 #[async_trait]
172 impl Tool for MockTool {
173 fn name(&self) -> &str {
174 "mock_tool"
175 }
176 async fn execute(
177 &self,
178 args: &HashMap<String, Value>,
179 _undo: &mut Vec<UndoAction>,
180 _cwd: Option<&Path>,
181 ) -> Result<String> {
182 let val = args
183 .get("val")
184 .and_then(|v| v.as_str())
185 .unwrap_or("default");
186 Ok(format!("mock: {}", val))
187 }
188 }
189
190 #[tokio::test]
191 async fn test_tool_registry() {
192 let mut registry = ToolRegistry::new();
193 registry.register(Box::new(MockTool));
194
195 let mut args = HashMap::new();
196 args.insert("val".to_string(), json!("hello"));
197
198 let mut undo = Vec::new();
199 let res = registry
200 .execute("mock_tool", &args, &mut undo, None)
201 .await
202 .unwrap();
203 assert_eq!(res, "mock: hello");
204
205 let res_err = registry.execute("unknown", &args, &mut undo, None).await;
206 assert!(res_err.is_err());
207 }
208
209 #[test]
210 fn test_validate_path() {
211 let p = validate_path("test.txt").unwrap();
212 assert!(p.is_absolute());
213 assert!(p.ends_with("test.txt"));
214 }
215}