1use diffy::{apply, Patch};
7use std::collections::HashMap;
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::process::{Command, Stdio};
11use tokio::io::{AsyncBufReadExt, BufReader};
12use tokio::process::Command as AsyncCommand;
13
14#[derive(Debug, Clone)]
16pub struct ToolResult {
17 pub tool_name: String,
18 pub success: bool,
19 pub output: String,
20 pub error: Option<String>,
21}
22
23impl ToolResult {
24 pub fn success(tool_name: &str, output: String) -> Self {
25 Self {
26 tool_name: tool_name.to_string(),
27 success: true,
28 output,
29 error: None,
30 }
31 }
32
33 pub fn failure(tool_name: &str, error: String) -> Self {
34 Self {
35 tool_name: tool_name.to_string(),
36 success: false,
37 output: String::new(),
38 error: Some(error),
39 }
40 }
41}
42
43#[derive(Debug, Clone)]
45pub struct ToolCall {
46 pub name: String,
47 pub arguments: HashMap<String, String>,
48}
49
50pub struct AgentTools {
52 working_dir: PathBuf,
54 require_approval: bool,
56 event_sender: Option<perspt_core::events::channel::EventSender>,
58}
59
60impl AgentTools {
61 pub fn new(working_dir: PathBuf, require_approval: bool) -> Self {
63 Self {
64 working_dir,
65 require_approval,
66 event_sender: None,
67 }
68 }
69
70 pub fn set_event_sender(&mut self, sender: perspt_core::events::channel::EventSender) {
72 self.event_sender = Some(sender);
73 }
74
75 pub async fn execute(&self, call: &ToolCall) -> ToolResult {
77 match call.name.as_str() {
78 "read_file" => self.read_file(call),
79 "search_code" => self.search_code(call),
80 "apply_patch" => self.apply_patch(call),
81 "run_command" => self.run_command(call).await,
82 "list_files" => self.list_files(call),
83 "write_file" => self.write_file(call),
84 "apply_diff" => self.apply_diff(call),
85 "sed_replace" => self.sed_replace(call),
87 "awk_filter" => self.awk_filter(call),
88 "diff_files" => self.diff_files(call),
89 _ => ToolResult::failure(&call.name, format!("Unknown tool: {}", call.name)),
90 }
91 }
92
93 fn read_file(&self, call: &ToolCall) -> ToolResult {
95 let path = match call.arguments.get("path") {
96 Some(p) => self.resolve_path(p),
97 None => return ToolResult::failure("read_file", "Missing 'path' argument".to_string()),
98 };
99
100 match fs::read_to_string(&path) {
101 Ok(content) => ToolResult::success("read_file", content),
102 Err(e) => ToolResult::failure("read_file", format!("Failed to read {:?}: {}", path, e)),
103 }
104 }
105
106 fn search_code(&self, call: &ToolCall) -> ToolResult {
108 let query = match call.arguments.get("query") {
109 Some(q) => q,
110 None => {
111 return ToolResult::failure("search_code", "Missing 'query' argument".to_string())
112 }
113 };
114
115 let path = call
116 .arguments
117 .get("path")
118 .map(|p| self.resolve_path(p))
119 .unwrap_or_else(|| self.working_dir.clone());
120
121 let output = Command::new("rg")
123 .args(["--json", "-n", query])
124 .current_dir(&path)
125 .output()
126 .or_else(|_| {
127 Command::new("grep")
128 .args(["-rn", query, "."])
129 .current_dir(&path)
130 .output()
131 });
132
133 match output {
134 Ok(out) => {
135 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
136 ToolResult::success("search_code", stdout)
137 }
138 Err(e) => ToolResult::failure("search_code", format!("Search failed: {}", e)),
139 }
140 }
141
142 fn apply_patch(&self, call: &ToolCall) -> ToolResult {
144 let path = match call.arguments.get("path") {
145 Some(p) => self.resolve_path(p),
146 None => {
147 return ToolResult::failure("apply_patch", "Missing 'path' argument".to_string())
148 }
149 };
150
151 let content = match call.arguments.get("content") {
152 Some(c) => c,
153 None => {
154 return ToolResult::failure("apply_patch", "Missing 'content' argument".to_string())
155 }
156 };
157
158 if let Some(parent) = path.parent() {
160 if let Err(e) = fs::create_dir_all(parent) {
161 return ToolResult::failure(
162 "apply_patch",
163 format!("Failed to create directories: {}", e),
164 );
165 }
166 }
167
168 match fs::write(&path, content) {
169 Ok(_) => ToolResult::success("apply_patch", format!("Successfully wrote {:?}", path)),
170 Err(e) => {
171 ToolResult::failure("apply_patch", format!("Failed to write {:?}: {}", path, e))
172 }
173 }
174 }
175
176 fn apply_diff(&self, call: &ToolCall) -> ToolResult {
178 let path = match call.arguments.get("path") {
179 Some(p) => self.resolve_path(p),
180 None => {
181 return ToolResult::failure("apply_diff", "Missing 'path' argument".to_string())
182 }
183 };
184
185 let diff_content = match call.arguments.get("diff") {
186 Some(c) => c,
187 None => {
188 return ToolResult::failure("apply_diff", "Missing 'diff' argument".to_string())
189 }
190 };
191
192 let original = match fs::read_to_string(&path) {
194 Ok(c) => c,
195 Err(e) => {
196 return ToolResult::failure(
199 "apply_diff",
200 format!("Failed to read base file {:?}: {}", path, e),
201 );
202 }
203 };
204
205 let patch = match Patch::from_str(diff_content) {
207 Ok(p) => p,
208 Err(e) => {
209 return ToolResult::failure("apply_diff", format!("Failed to parse diff: {}", e));
210 }
211 };
212
213 match apply(&original, &patch) {
215 Ok(patched) => match fs::write(&path, patched) {
216 Ok(_) => {
217 ToolResult::success("apply_diff", format!("Successfully patched {:?}", path))
218 }
219 Err(e) => ToolResult::failure(
220 "apply_diff",
221 format!("Failed to write patched file: {}", e),
222 ),
223 },
224 Err(e) => ToolResult::failure("apply_diff", format!("Failed to apply patch: {}", e)),
225 }
226 }
227
228 async fn run_command(&self, call: &ToolCall) -> ToolResult {
230 let cmd_str = match call.arguments.get("command") {
231 Some(c) => c,
232 None => {
233 return ToolResult::failure("run_command", "Missing 'command' argument".to_string())
234 }
235 };
236
237 if self.require_approval {
238 log::info!("Command requires approval: {}", cmd_str);
239 }
240
241 let mut child = match AsyncCommand::new("sh")
242 .args(["-c", cmd_str])
243 .current_dir(&self.working_dir)
244 .stdout(Stdio::piped())
245 .stderr(Stdio::piped())
246 .spawn()
247 {
248 Ok(child) => child,
249 Err(e) => return ToolResult::failure("run_command", format!("Failed to spawn: {}", e)),
250 };
251
252 let stdout = child.stdout.take().expect("Failed to open stdout");
253 let stderr = child.stderr.take().expect("Failed to open stderr");
254 let sender = self.event_sender.clone();
255
256 let stdout_handle = tokio::spawn(async move {
257 let mut reader = BufReader::new(stdout).lines();
258 let mut output = String::new();
259 while let Ok(Some(line)) = reader.next_line().await {
260 if let Some(ref s) = sender {
261 let _ = s.send(perspt_core::AgentEvent::Log(line.clone()));
262 }
263 output.push_str(&line);
264 output.push('\n');
265 }
266 output
267 });
268
269 let sender_err = self.event_sender.clone();
270 let stderr_handle = tokio::spawn(async move {
271 let mut reader = BufReader::new(stderr).lines();
272 let mut output = String::new();
273 while let Ok(Some(line)) = reader.next_line().await {
274 if let Some(ref s) = sender_err {
275 let _ = s.send(perspt_core::AgentEvent::Log(format!("ERR: {}", line)));
276 }
277 output.push_str(&line);
278 output.push('\n');
279 }
280 output
281 });
282
283 let status = match child.wait().await {
284 Ok(s) => s,
285 Err(e) => return ToolResult::failure("run_command", format!("Failed to wait: {}", e)),
286 };
287
288 let stdout_str = stdout_handle.await.unwrap_or_default();
289 let stderr_str = stderr_handle.await.unwrap_or_default();
290
291 if status.success() {
292 ToolResult::success("run_command", stdout_str)
293 } else {
294 ToolResult::failure(
295 "run_command",
296 format!("Exit code: {:?}\n{}", status.code(), stderr_str),
297 )
298 }
299 }
300
301 fn list_files(&self, call: &ToolCall) -> ToolResult {
303 let path = call
304 .arguments
305 .get("path")
306 .map(|p| self.resolve_path(p))
307 .unwrap_or_else(|| self.working_dir.clone());
308
309 match fs::read_dir(&path) {
310 Ok(entries) => {
311 let files: Vec<String> = entries
312 .filter_map(|e| e.ok())
313 .map(|e| {
314 let name = e.file_name().to_string_lossy().to_string();
315 if e.file_type().map(|t| t.is_dir()).unwrap_or(false) {
316 format!("{}/", name)
317 } else {
318 name
319 }
320 })
321 .collect();
322 ToolResult::success("list_files", files.join("\n"))
323 }
324 Err(e) => {
325 ToolResult::failure("list_files", format!("Failed to list {:?}: {}", path, e))
326 }
327 }
328 }
329
330 fn write_file(&self, call: &ToolCall) -> ToolResult {
332 self.apply_patch(call)
334 }
335
336 fn resolve_path(&self, path: &str) -> PathBuf {
338 let p = Path::new(path);
339 if p.is_absolute() {
340 p.to_path_buf()
341 } else {
342 self.working_dir.join(p)
343 }
344 }
345
346 fn sed_replace(&self, call: &ToolCall) -> ToolResult {
352 let path = match call.arguments.get("path") {
353 Some(p) => self.resolve_path(p),
354 None => {
355 return ToolResult::failure("sed_replace", "Missing 'path' argument".to_string())
356 }
357 };
358
359 let pattern = match call.arguments.get("pattern") {
360 Some(p) => p,
361 None => {
362 return ToolResult::failure("sed_replace", "Missing 'pattern' argument".to_string())
363 }
364 };
365
366 let replacement = match call.arguments.get("replacement") {
367 Some(r) => r,
368 None => {
369 return ToolResult::failure(
370 "sed_replace",
371 "Missing 'replacement' argument".to_string(),
372 )
373 }
374 };
375
376 match fs::read_to_string(&path) {
378 Ok(content) => {
379 let new_content = content.replace(pattern, replacement);
380 match fs::write(&path, &new_content) {
381 Ok(_) => ToolResult::success(
382 "sed_replace",
383 format!(
384 "Replaced '{}' with '{}' in {:?}",
385 pattern, replacement, path
386 ),
387 ),
388 Err(e) => ToolResult::failure("sed_replace", format!("Failed to write: {}", e)),
389 }
390 }
391 Err(e) => {
392 ToolResult::failure("sed_replace", format!("Failed to read {:?}: {}", path, e))
393 }
394 }
395 }
396
397 fn awk_filter(&self, call: &ToolCall) -> ToolResult {
399 let path = match call.arguments.get("path") {
400 Some(p) => self.resolve_path(p),
401 None => {
402 return ToolResult::failure("awk_filter", "Missing 'path' argument".to_string())
403 }
404 };
405
406 let filter = match call.arguments.get("filter") {
407 Some(f) => f,
408 None => {
409 return ToolResult::failure("awk_filter", "Missing 'filter' argument".to_string())
410 }
411 };
412
413 let output = Command::new("awk").arg(filter).arg(&path).output();
415
416 match output {
417 Ok(out) => {
418 if out.status.success() {
419 ToolResult::success(
420 "awk_filter",
421 String::from_utf8_lossy(&out.stdout).to_string(),
422 )
423 } else {
424 ToolResult::failure(
425 "awk_filter",
426 String::from_utf8_lossy(&out.stderr).to_string(),
427 )
428 }
429 }
430 Err(e) => ToolResult::failure("awk_filter", format!("Failed to run awk: {}", e)),
431 }
432 }
433
434 fn diff_files(&self, call: &ToolCall) -> ToolResult {
436 let file1 = match call.arguments.get("file1") {
437 Some(p) => self.resolve_path(p),
438 None => {
439 return ToolResult::failure("diff_files", "Missing 'file1' argument".to_string())
440 }
441 };
442
443 let file2 = match call.arguments.get("file2") {
444 Some(p) => self.resolve_path(p),
445 None => {
446 return ToolResult::failure("diff_files", "Missing 'file2' argument".to_string())
447 }
448 };
449
450 let output = Command::new("diff")
452 .args([
453 "--unified",
454 &file1.to_string_lossy(),
455 &file2.to_string_lossy(),
456 ])
457 .output();
458
459 match output {
460 Ok(out) => {
461 let stdout = String::from_utf8_lossy(&out.stdout).to_string();
463 if stdout.is_empty() {
464 ToolResult::success("diff_files", "Files are identical".to_string())
465 } else {
466 ToolResult::success("diff_files", stdout)
467 }
468 }
469 Err(e) => ToolResult::failure("diff_files", format!("Failed to run diff: {}", e)),
470 }
471 }
472}
473
474pub fn get_tool_definitions() -> Vec<ToolDefinition> {
476 vec![
477 ToolDefinition {
478 name: "read_file".to_string(),
479 description: "Read the contents of a file".to_string(),
480 parameters: vec![ToolParameter {
481 name: "path".to_string(),
482 description: "Path to the file to read".to_string(),
483 required: true,
484 }],
485 },
486 ToolDefinition {
487 name: "search_code".to_string(),
488 description: "Search for code patterns in the workspace using grep/ripgrep".to_string(),
489 parameters: vec![
490 ToolParameter {
491 name: "query".to_string(),
492 description: "Search pattern (regex supported)".to_string(),
493 required: true,
494 },
495 ToolParameter {
496 name: "path".to_string(),
497 description: "Directory to search in (default: working directory)".to_string(),
498 required: false,
499 },
500 ],
501 },
502 ToolDefinition {
503 name: "apply_patch".to_string(),
504 description: "Write or replace file contents".to_string(),
505 parameters: vec![
506 ToolParameter {
507 name: "path".to_string(),
508 description: "Path to the file to write".to_string(),
509 required: true,
510 },
511 ToolParameter {
512 name: "content".to_string(),
513 description: "New file contents".to_string(),
514 required: true,
515 },
516 ],
517 },
518 ToolDefinition {
519 name: "apply_diff".to_string(),
520 description: "Apply a Unified Diff patch to a file".to_string(),
521 parameters: vec![
522 ToolParameter {
523 name: "path".to_string(),
524 description: "Path to the file to patch".to_string(),
525 required: true,
526 },
527 ToolParameter {
528 name: "diff".to_string(),
529 description: "Unified Diff content".to_string(),
530 required: true,
531 },
532 ],
533 },
534 ToolDefinition {
535 name: "run_command".to_string(),
536 description: "Execute a shell command in the working directory".to_string(),
537 parameters: vec![ToolParameter {
538 name: "command".to_string(),
539 description: "Shell command to execute".to_string(),
540 required: true,
541 }],
542 },
543 ToolDefinition {
544 name: "list_files".to_string(),
545 description: "List files in a directory".to_string(),
546 parameters: vec![ToolParameter {
547 name: "path".to_string(),
548 description: "Directory path (default: working directory)".to_string(),
549 required: false,
550 }],
551 },
552 ToolDefinition {
554 name: "sed_replace".to_string(),
555 description: "Replace text in a file using sed-like pattern matching".to_string(),
556 parameters: vec![
557 ToolParameter {
558 name: "path".to_string(),
559 description: "Path to the file".to_string(),
560 required: true,
561 },
562 ToolParameter {
563 name: "pattern".to_string(),
564 description: "Search pattern".to_string(),
565 required: true,
566 },
567 ToolParameter {
568 name: "replacement".to_string(),
569 description: "Replacement text".to_string(),
570 required: true,
571 },
572 ],
573 },
574 ToolDefinition {
575 name: "awk_filter".to_string(),
576 description: "Filter file content using awk-like field selection".to_string(),
577 parameters: vec![
578 ToolParameter {
579 name: "path".to_string(),
580 description: "Path to the file".to_string(),
581 required: true,
582 },
583 ToolParameter {
584 name: "filter".to_string(),
585 description: "Awk filter expression (e.g., '$1 == \"error\"')".to_string(),
586 required: true,
587 },
588 ],
589 },
590 ToolDefinition {
591 name: "diff_files".to_string(),
592 description: "Show differences between two files".to_string(),
593 parameters: vec![
594 ToolParameter {
595 name: "file1".to_string(),
596 description: "First file path".to_string(),
597 required: true,
598 },
599 ToolParameter {
600 name: "file2".to_string(),
601 description: "Second file path".to_string(),
602 required: true,
603 },
604 ],
605 },
606 ]
607}
608
609#[derive(Debug, Clone)]
611pub struct ToolDefinition {
612 pub name: String,
613 pub description: String,
614 pub parameters: Vec<ToolParameter>,
615}
616
617#[derive(Debug, Clone)]
619pub struct ToolParameter {
620 pub name: String,
621 pub description: String,
622 pub required: bool,
623}
624
625#[cfg(test)]
626mod tests {
627 use super::*;
628 use std::env::temp_dir;
629
630 #[tokio::test]
631 async fn test_read_file() {
632 let dir = temp_dir();
633 let test_file = dir.join("test_read.txt");
634 fs::write(&test_file, "Hello, World!").unwrap();
635
636 let tools = AgentTools::new(dir.clone(), false);
637 let call = ToolCall {
638 name: "read_file".to_string(),
639 arguments: [("path".to_string(), test_file.to_string_lossy().to_string())]
640 .into_iter()
641 .collect(),
642 };
643
644 let result = tools.execute(&call).await;
645 assert!(result.success);
646 assert_eq!(result.output, "Hello, World!");
647 }
648
649 #[tokio::test]
650 async fn test_list_files() {
651 let dir = temp_dir();
652 let tools = AgentTools::new(dir.clone(), false);
653 let call = ToolCall {
654 name: "list_files".to_string(),
655 arguments: HashMap::new(),
656 };
657
658 let result = tools.execute(&call).await;
659 assert!(result.success);
660 }
661
662 #[tokio::test]
663 async fn test_apply_diff_tool() {
664 use std::collections::HashMap;
665 use std::io::Write;
666 let temp_dir = temp_dir();
667 let file_path = temp_dir.join("test_diff.txt");
668 let mut file = std::fs::File::create(&file_path).unwrap();
669 file.write_all(b"Hello world\nThis is a test\n").unwrap();
671
672 let tools = AgentTools::new(temp_dir.clone(), true);
673
674 let diff = "--- test_diff.txt\n+++ test_diff.txt\n@@ -1,2 +1,2 @@\n-Hello world\n+Hello diffy\n This is a test\n";
676
677 let mut args = HashMap::new();
678 args.insert("path".to_string(), "test_diff.txt".to_string());
679 args.insert("diff".to_string(), diff.to_string());
680
681 let call = ToolCall {
682 name: "apply_diff".to_string(),
683 arguments: args,
684 };
685
686 let result = tools.apply_diff(&call);
687 assert!(
688 result.success,
689 "Diff application failed: {:?}",
690 result.error
691 );
692
693 let content = fs::read_to_string(&file_path).unwrap();
694 assert_eq!(content, "Hello diffy\nThis is a test\n");
695 }
696}