1use std::collections::HashMap;
7use std::fs;
8use std::path::{Path, PathBuf};
9use std::process::Command;
10
11#[derive(Debug, Clone)]
13pub struct ToolResult {
14 pub tool_name: String,
15 pub success: bool,
16 pub output: String,
17 pub error: Option<String>,
18}
19
20impl ToolResult {
21 pub fn success(tool_name: &str, output: String) -> Self {
22 Self {
23 tool_name: tool_name.to_string(),
24 success: true,
25 output,
26 error: None,
27 }
28 }
29
30 pub fn failure(tool_name: &str, error: String) -> Self {
31 Self {
32 tool_name: tool_name.to_string(),
33 success: false,
34 output: String::new(),
35 error: Some(error),
36 }
37 }
38}
39
40#[derive(Debug, Clone)]
42pub struct ToolCall {
43 pub name: String,
44 pub arguments: HashMap<String, String>,
45}
46
47pub struct AgentTools {
49 working_dir: PathBuf,
51 require_approval: bool,
53}
54
55impl AgentTools {
56 pub fn new(working_dir: PathBuf, require_approval: bool) -> Self {
58 Self {
59 working_dir,
60 require_approval,
61 }
62 }
63
64 pub async fn execute(&self, call: &ToolCall) -> ToolResult {
66 match call.name.as_str() {
67 "read_file" => self.read_file(call),
68 "search_code" => self.search_code(call),
69 "apply_patch" => self.apply_patch(call),
70 "run_command" => self.run_command(call).await,
71 "list_files" => self.list_files(call),
72 "write_file" => self.write_file(call),
73 _ => ToolResult::failure(&call.name, format!("Unknown tool: {}", call.name)),
74 }
75 }
76
77 fn read_file(&self, call: &ToolCall) -> ToolResult {
79 let path = match call.arguments.get("path") {
80 Some(p) => self.resolve_path(p),
81 None => return ToolResult::failure("read_file", "Missing 'path' argument".to_string()),
82 };
83
84 match fs::read_to_string(&path) {
85 Ok(content) => ToolResult::success("read_file", content),
86 Err(e) => ToolResult::failure("read_file", format!("Failed to read {:?}: {}", path, e)),
87 }
88 }
89
90 fn search_code(&self, call: &ToolCall) -> ToolResult {
92 let query = match call.arguments.get("query") {
93 Some(q) => q,
94 None => {
95 return ToolResult::failure("search_code", "Missing 'query' argument".to_string())
96 }
97 };
98
99 let path = call
100 .arguments
101 .get("path")
102 .map(|p| self.resolve_path(p))
103 .unwrap_or_else(|| self.working_dir.clone());
104
105 let output = Command::new("rg")
107 .args(["--json", "-n", query])
108 .current_dir(&path)
109 .output()
110 .or_else(|_| {
111 Command::new("grep")
112 .args(["-rn", query, "."])
113 .current_dir(&path)
114 .output()
115 });
116
117 match output {
118 Ok(out) => {
119 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
120 ToolResult::success("search_code", stdout)
121 }
122 Err(e) => ToolResult::failure("search_code", format!("Search failed: {}", e)),
123 }
124 }
125
126 fn apply_patch(&self, call: &ToolCall) -> ToolResult {
128 let path = match call.arguments.get("path") {
129 Some(p) => self.resolve_path(p),
130 None => {
131 return ToolResult::failure("apply_patch", "Missing 'path' argument".to_string())
132 }
133 };
134
135 let content = match call.arguments.get("content") {
136 Some(c) => c,
137 None => {
138 return ToolResult::failure("apply_patch", "Missing 'content' argument".to_string())
139 }
140 };
141
142 if let Some(parent) = path.parent() {
144 if let Err(e) = fs::create_dir_all(parent) {
145 return ToolResult::failure(
146 "apply_patch",
147 format!("Failed to create directories: {}", e),
148 );
149 }
150 }
151
152 match fs::write(&path, content) {
153 Ok(_) => ToolResult::success("apply_patch", format!("Successfully wrote {:?}", path)),
154 Err(e) => {
155 ToolResult::failure("apply_patch", format!("Failed to write {:?}: {}", path, e))
156 }
157 }
158 }
159
160 async fn run_command(&self, call: &ToolCall) -> ToolResult {
162 let cmd = match call.arguments.get("command") {
163 Some(c) => c,
164 None => {
165 return ToolResult::failure("run_command", "Missing 'command' argument".to_string())
166 }
167 };
168
169 if self.require_approval {
172 log::info!("Command requires approval: {}", cmd);
173 }
175
176 let output = Command::new("sh")
177 .args(["-c", cmd])
178 .current_dir(&self.working_dir)
179 .output();
180
181 match output {
182 Ok(out) => {
183 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
184 let stderr = String::from_utf8_lossy(&out.stderr).to_string();
185
186 if out.status.success() {
187 ToolResult::success("run_command", stdout)
188 } else {
189 ToolResult::failure(
190 "run_command",
191 format!("Exit code: {:?}\n{}", out.status.code(), stderr),
192 )
193 }
194 }
195 Err(e) => ToolResult::failure("run_command", format!("Failed to execute: {}", e)),
196 }
197 }
198
199 fn list_files(&self, call: &ToolCall) -> ToolResult {
201 let path = call
202 .arguments
203 .get("path")
204 .map(|p| self.resolve_path(p))
205 .unwrap_or_else(|| self.working_dir.clone());
206
207 match fs::read_dir(&path) {
208 Ok(entries) => {
209 let files: Vec<String> = entries
210 .filter_map(|e| e.ok())
211 .map(|e| {
212 let name = e.file_name().to_string_lossy().to_string();
213 if e.file_type().map(|t| t.is_dir()).unwrap_or(false) {
214 format!("{}/", name)
215 } else {
216 name
217 }
218 })
219 .collect();
220 ToolResult::success("list_files", files.join("\n"))
221 }
222 Err(e) => {
223 ToolResult::failure("list_files", format!("Failed to list {:?}: {}", path, e))
224 }
225 }
226 }
227
228 fn write_file(&self, call: &ToolCall) -> ToolResult {
230 self.apply_patch(call)
232 }
233
234 fn resolve_path(&self, path: &str) -> PathBuf {
236 let p = Path::new(path);
237 if p.is_absolute() {
238 p.to_path_buf()
239 } else {
240 self.working_dir.join(p)
241 }
242 }
243}
244
245pub fn get_tool_definitions() -> Vec<ToolDefinition> {
247 vec![
248 ToolDefinition {
249 name: "read_file".to_string(),
250 description: "Read the contents of a file".to_string(),
251 parameters: vec![ToolParameter {
252 name: "path".to_string(),
253 description: "Path to the file to read".to_string(),
254 required: true,
255 }],
256 },
257 ToolDefinition {
258 name: "search_code".to_string(),
259 description: "Search for code patterns in the workspace using grep/ripgrep".to_string(),
260 parameters: vec![
261 ToolParameter {
262 name: "query".to_string(),
263 description: "Search pattern (regex supported)".to_string(),
264 required: true,
265 },
266 ToolParameter {
267 name: "path".to_string(),
268 description: "Directory to search in (default: working directory)".to_string(),
269 required: false,
270 },
271 ],
272 },
273 ToolDefinition {
274 name: "apply_patch".to_string(),
275 description: "Write or replace file contents".to_string(),
276 parameters: vec![
277 ToolParameter {
278 name: "path".to_string(),
279 description: "Path to the file to write".to_string(),
280 required: true,
281 },
282 ToolParameter {
283 name: "content".to_string(),
284 description: "New file contents".to_string(),
285 required: true,
286 },
287 ],
288 },
289 ToolDefinition {
290 name: "run_command".to_string(),
291 description: "Execute a shell command in the working directory".to_string(),
292 parameters: vec![ToolParameter {
293 name: "command".to_string(),
294 description: "Shell command to execute".to_string(),
295 required: true,
296 }],
297 },
298 ToolDefinition {
299 name: "list_files".to_string(),
300 description: "List files in a directory".to_string(),
301 parameters: vec![ToolParameter {
302 name: "path".to_string(),
303 description: "Directory path (default: working directory)".to_string(),
304 required: false,
305 }],
306 },
307 ]
308}
309
310#[derive(Debug, Clone)]
312pub struct ToolDefinition {
313 pub name: String,
314 pub description: String,
315 pub parameters: Vec<ToolParameter>,
316}
317
318#[derive(Debug, Clone)]
320pub struct ToolParameter {
321 pub name: String,
322 pub description: String,
323 pub required: bool,
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329 use std::env::temp_dir;
330
331 #[tokio::test]
332 async fn test_read_file() {
333 let dir = temp_dir();
334 let test_file = dir.join("test_read.txt");
335 fs::write(&test_file, "Hello, World!").unwrap();
336
337 let tools = AgentTools::new(dir.clone(), false);
338 let call = ToolCall {
339 name: "read_file".to_string(),
340 arguments: [("path".to_string(), test_file.to_string_lossy().to_string())]
341 .into_iter()
342 .collect(),
343 };
344
345 let result = tools.execute(&call).await;
346 assert!(result.success);
347 assert_eq!(result.output, "Hello, World!");
348 }
349
350 #[tokio::test]
351 async fn test_list_files() {
352 let dir = temp_dir();
353 let tools = AgentTools::new(dir.clone(), false);
354 let call = ToolCall {
355 name: "list_files".to_string(),
356 arguments: HashMap::new(),
357 };
358
359 let result = tools.execute(&call).await;
360 assert!(result.success);
361 }
362}