1use std::path::{Path, PathBuf};
17
18use async_trait::async_trait;
19
20use crate::agent::capability::Capability;
21use crate::agent::driver::ToolDefinition;
22
23use super::{Tool, ToolResult};
24
25const MAX_READ_BYTES: usize = 128 * 1024;
27
28const MAX_READ_LINES: usize = 2000;
30
31fn validate_path(raw: &str, allowed: &[String]) -> Result<PathBuf, String> {
36 if raw.is_empty() {
37 return Err("path is empty".into());
38 }
39 let canonical = PathBuf::from(raw)
41 .canonicalize()
42 .map_err(|e| format!("cannot resolve path '{}': {}", raw, e))?;
43 check_prefix(&canonical, &canonical, allowed)
44}
45
46fn validate_write_path(raw: &str, allowed: &[String]) -> Result<PathBuf, String> {
49 if raw.is_empty() {
50 return Err("path is empty".into());
51 }
52
53 let path = PathBuf::from(raw);
54
55 if path.exists() {
57 return validate_path(raw, allowed);
58 }
59
60 let parent = path.parent().ok_or_else(|| format!("cannot determine parent of '{}'", raw))?;
62
63 let parent_canon = parent
64 .canonicalize()
65 .map_err(|e| format!("parent directory '{}' not found: {}", parent.display(), e))?;
66
67 let target = parent_canon.join(path.file_name().unwrap_or_default());
68 check_prefix(&target, &parent_canon, allowed)
69}
70
71fn check_prefix(target: &Path, canonical: &Path, allowed: &[String]) -> Result<PathBuf, String> {
73 if allowed.iter().any(|p| p == "*") {
74 return Ok(target.to_path_buf());
75 }
76 for prefix in allowed {
77 if let Ok(prefix_canon) = PathBuf::from(prefix).canonicalize() {
78 if canonical.starts_with(&prefix_canon) {
79 return Ok(target.to_path_buf());
80 }
81 }
82 }
83 Err(format!("path '{}' outside allowed prefixes: {:?}", target.display(), allowed))
84}
85
86pub struct FileReadTool {
93 allowed_paths: Vec<String>,
94}
95
96impl FileReadTool {
97 pub fn new(allowed_paths: Vec<String>) -> Self {
98 Self { allowed_paths }
99 }
100}
101
102#[async_trait]
103impl Tool for FileReadTool {
104 fn name(&self) -> &'static str {
105 "file_read"
106 }
107
108 fn definition(&self) -> ToolDefinition {
109 ToolDefinition {
110 name: "file_read".into(),
111 description: "Read a file's contents. Returns numbered lines.".into(),
112 input_schema: serde_json::json!({
113 "type": "object",
114 "required": ["path"],
115 "properties": {
116 "path": {
117 "type": "string",
118 "description": "Absolute path to the file"
119 },
120 "offset": {
121 "type": "integer",
122 "description": "Line number to start from (1-based, default 1)"
123 },
124 "limit": {
125 "type": "integer",
126 "description": "Maximum lines to read (default 2000)"
127 }
128 }
129 }),
130 }
131 }
132
133 async fn execute(&self, input: serde_json::Value) -> ToolResult {
134 let path_str = match input.get("path").and_then(|v| v.as_str()) {
135 Some(p) => p,
136 None => return ToolResult::error("missing required field 'path'"),
137 };
138
139 let offset = input.get("offset").and_then(|v| v.as_u64()).unwrap_or(1).max(1) as usize;
140 let limit = input
141 .get("limit")
142 .and_then(|v| v.as_u64())
143 .unwrap_or(MAX_READ_LINES as u64)
144 .min(MAX_READ_LINES as u64) as usize;
145
146 let path = match validate_path(path_str, &self.allowed_paths) {
147 Ok(p) => p,
148 Err(e) => return ToolResult::error(e),
149 };
150
151 match std::fs::metadata(&path) {
153 Ok(meta) if meta.len() > MAX_READ_BYTES as u64 => {
154 return ToolResult::error(format!(
155 "file too large ({} bytes, max {}). Use offset/limit to read a portion.",
156 meta.len(),
157 MAX_READ_BYTES
158 ));
159 }
160 Err(e) => return ToolResult::error(format!("cannot stat '{}': {}", path.display(), e)),
161 _ => {}
162 }
163
164 match std::fs::read_to_string(&path) {
165 Ok(content) => {
166 let lines: Vec<&str> = content.lines().collect();
167 let start = (offset - 1).min(lines.len());
168 let end = (start + limit).min(lines.len());
169 let selected = &lines[start..end];
170
171 let mut result = String::with_capacity(selected.len() * 80);
172 for (i, line) in selected.iter().enumerate() {
173 let line_num = start + i + 1;
174 result.push_str(&format!("{line_num}\t{line}\n"));
175 }
176
177 if end < lines.len() {
178 result.push_str(&format!(
179 "\n[{} more lines, use offset={} to continue]",
180 lines.len() - end,
181 end + 1
182 ));
183 }
184
185 ToolResult::success(result)
186 }
187 Err(e) => ToolResult::error(format!("cannot read '{}': {}", path.display(), e)),
188 }
189 }
190
191 fn required_capability(&self) -> Capability {
192 Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
193 }
194}
195
196pub struct FileWriteTool {
202 allowed_paths: Vec<String>,
203}
204
205impl FileWriteTool {
206 pub fn new(allowed_paths: Vec<String>) -> Self {
207 Self { allowed_paths }
208 }
209}
210
211#[async_trait]
212impl Tool for FileWriteTool {
213 fn name(&self) -> &'static str {
214 "file_write"
215 }
216
217 fn definition(&self) -> ToolDefinition {
218 ToolDefinition {
219 name: "file_write".into(),
220 description: "Create or overwrite a file with the given content.".into(),
221 input_schema: serde_json::json!({
222 "type": "object",
223 "required": ["path", "content"],
224 "properties": {
225 "path": {
226 "type": "string",
227 "description": "Absolute path to the file"
228 },
229 "content": {
230 "type": "string",
231 "description": "File content to write"
232 }
233 }
234 }),
235 }
236 }
237
238 async fn execute(&self, input: serde_json::Value) -> ToolResult {
239 let path_str = match input.get("path").and_then(|v| v.as_str()) {
240 Some(p) => p,
241 None => return ToolResult::error("missing required field 'path'"),
242 };
243
244 let content = match input.get("content").and_then(|v| v.as_str()) {
245 Some(c) => c,
246 None => return ToolResult::error("missing required field 'content'"),
247 };
248
249 let path = match validate_write_path(path_str, &self.allowed_paths) {
250 Ok(p) => p,
251 Err(e) => return ToolResult::error(e),
252 };
253
254 if let Some(parent) = path.parent() {
256 if !parent.exists() {
257 if let Err(e) = std::fs::create_dir_all(parent) {
258 return ToolResult::error(format!(
259 "cannot create directory '{}': {}",
260 parent.display(),
261 e
262 ));
263 }
264 }
265 }
266
267 match std::fs::write(&path, content) {
268 Ok(()) => {
269 ToolResult::success(format!("Wrote {} bytes to {}", content.len(), path.display()))
270 }
271 Err(e) => ToolResult::error(format!("cannot write '{}': {}", path.display(), e)),
272 }
273 }
274
275 fn required_capability(&self) -> Capability {
276 Capability::FileWrite { allowed_paths: self.allowed_paths.clone() }
277 }
278}
279
280pub struct FileEditTool {
289 allowed_paths: Vec<String>,
290}
291
292impl FileEditTool {
293 pub fn new(allowed_paths: Vec<String>) -> Self {
294 Self { allowed_paths }
295 }
296}
297
298#[async_trait]
299impl Tool for FileEditTool {
300 fn name(&self) -> &'static str {
301 "file_edit"
302 }
303
304 fn definition(&self) -> ToolDefinition {
305 ToolDefinition {
306 name: "file_edit".into(),
307 description: "Replace a unique string in a file. old_string must appear exactly once."
308 .into(),
309 input_schema: serde_json::json!({
310 "type": "object",
311 "required": ["path", "old_string", "new_string"],
312 "properties": {
313 "path": {
314 "type": "string",
315 "description": "Absolute path to the file"
316 },
317 "old_string": {
318 "type": "string",
319 "description": "Exact string to find (must be unique in the file)"
320 },
321 "new_string": {
322 "type": "string",
323 "description": "Replacement string"
324 }
325 }
326 }),
327 }
328 }
329
330 async fn execute(&self, input: serde_json::Value) -> ToolResult {
331 let path_str = match input.get("path").and_then(|v| v.as_str()) {
332 Some(p) => p,
333 None => return ToolResult::error("missing required field 'path'"),
334 };
335
336 let old_string = match input.get("old_string").and_then(|v| v.as_str()) {
337 Some(s) => s,
338 None => return ToolResult::error("missing required field 'old_string'"),
339 };
340
341 let new_string = match input.get("new_string").and_then(|v| v.as_str()) {
342 Some(s) => s,
343 None => return ToolResult::error("missing required field 'new_string'"),
344 };
345
346 if old_string == new_string {
347 return ToolResult::error("old_string and new_string are identical");
348 }
349
350 let path = match validate_path(path_str, &self.allowed_paths) {
351 Ok(p) => p,
352 Err(e) => return ToolResult::error(e),
353 };
354
355 let content = match std::fs::read_to_string(&path) {
356 Ok(c) => c,
357 Err(e) => return ToolResult::error(format!("cannot read '{}': {}", path.display(), e)),
358 };
359
360 let count = content.matches(old_string).count();
361 match count {
362 0 => ToolResult::error(format!(
363 "old_string not found in {}. Provide more context to match.",
364 path.display()
365 )),
366 1 => {
367 let new_content = content.replacen(old_string, new_string, 1);
368 match std::fs::write(&path, &new_content) {
369 Ok(()) => ToolResult::success(format!(
370 "Edited {}. Replaced 1 occurrence ({} bytes → {} bytes).",
371 path.display(),
372 old_string.len(),
373 new_string.len()
374 )),
375 Err(e) => {
376 ToolResult::error(format!("cannot write '{}': {}", path.display(), e))
377 }
378 }
379 }
380 n => ToolResult::error(format!(
381 "old_string found {} times in {}. Provide more context to make it unique.",
382 n,
383 path.display()
384 )),
385 }
386 }
387
388 fn required_capability(&self) -> Capability {
389 Capability::FileWrite { allowed_paths: self.allowed_paths.clone() }
390 }
391}
392
393#[cfg(test)]
394mod tests {
395 use super::*;
396 use std::io::Write;
397 use tempfile::TempDir;
398
399 fn temp_file(dir: &Path, name: &str, content: &str) -> PathBuf {
400 let path = dir.join(name);
401 let mut f = std::fs::File::create(&path).unwrap();
402 f.write_all(content.as_bytes()).unwrap();
403 path
404 }
405
406 #[tokio::test]
409 async fn test_file_read_basic() {
410 let dir = TempDir::new().unwrap();
411 let path = temp_file(dir.path(), "test.txt", "line1\nline2\nline3\n");
412 let tool = FileReadTool::new(vec!["*".into()]);
413
414 let result = tool.execute(serde_json::json!({"path": path.to_str().unwrap()})).await;
415 assert!(!result.is_error, "error: {}", result.content);
416 assert!(result.content.contains("1\tline1"));
417 assert!(result.content.contains("2\tline2"));
418 assert!(result.content.contains("3\tline3"));
419 }
420
421 #[tokio::test]
422 async fn test_file_read_with_offset_and_limit() {
423 let dir = TempDir::new().unwrap();
424 let content: String = (1..=100).map(|i| format!("line{i}\n")).collect();
425 let path = temp_file(dir.path(), "big.txt", &content);
426 let tool = FileReadTool::new(vec!["*".into()]);
427
428 let result = tool
429 .execute(serde_json::json!({"path": path.to_str().unwrap(), "offset": 50, "limit": 5}))
430 .await;
431 assert!(!result.is_error);
432 assert!(result.content.contains("50\tline50"));
433 assert!(result.content.contains("54\tline54"));
434 assert!(!result.content.contains("55\tline55"));
435 }
436
437 #[tokio::test]
438 async fn test_file_read_nonexistent() {
439 let tool = FileReadTool::new(vec!["*".into()]);
440 let result = tool.execute(serde_json::json!({"path": "/nonexistent_file_xyz"})).await;
441 assert!(result.is_error);
442 assert!(result.content.contains("cannot resolve"));
443 }
444
445 #[tokio::test]
446 async fn test_file_read_missing_path_field() {
447 let tool = FileReadTool::new(vec!["*".into()]);
448 let result = tool.execute(serde_json::json!({"file": "test.txt"})).await;
449 assert!(result.is_error);
450 assert!(result.content.contains("missing"));
451 }
452
453 #[tokio::test]
454 async fn test_file_read_path_restricted() {
455 let dir = TempDir::new().unwrap();
456 let path = temp_file(dir.path(), "secret.txt", "secret data");
457 let tool = FileReadTool::new(vec!["/nonexistent_allowed_prefix".into()]);
458
459 let result = tool.execute(serde_json::json!({"path": path.to_str().unwrap()})).await;
460 assert!(result.is_error);
461 assert!(result.content.contains("outside allowed"));
462 }
463
464 #[tokio::test]
467 async fn test_file_write_create() {
468 let dir = TempDir::new().unwrap();
469 let path = dir.path().join("new_file.txt");
470 let tool = FileWriteTool::new(vec!["*".into()]);
471
472 let result = tool
473 .execute(serde_json::json!({"path": path.to_str().unwrap(), "content": "hello world"}))
474 .await;
475 assert!(!result.is_error, "error: {}", result.content);
476 assert!(result.content.contains("11 bytes"));
477 assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world");
478 }
479
480 #[tokio::test]
481 async fn test_file_write_overwrite() {
482 let dir = TempDir::new().unwrap();
483 let path = temp_file(dir.path(), "existing.txt", "old content");
484 let tool = FileWriteTool::new(vec!["*".into()]);
485
486 let result = tool
487 .execute(serde_json::json!({"path": path.to_str().unwrap(), "content": "new content"}))
488 .await;
489 assert!(!result.is_error);
490 assert_eq!(std::fs::read_to_string(&path).unwrap(), "new content");
491 }
492
493 #[tokio::test]
494 async fn test_file_write_path_restricted() {
495 let tool = FileWriteTool::new(vec!["/nonexistent_allowed_prefix".into()]);
496 let result =
497 tool.execute(serde_json::json!({"path": "/tmp/evil.txt", "content": "bad"})).await;
498 assert!(result.is_error);
499 assert!(result.content.contains("outside allowed"));
500 }
501
502 #[tokio::test]
503 async fn test_file_write_missing_content() {
504 let tool = FileWriteTool::new(vec!["*".into()]);
505 let result = tool.execute(serde_json::json!({"path": "/tmp/test.txt"})).await;
506 assert!(result.is_error);
507 assert!(result.content.contains("missing"));
508 }
509
510 #[tokio::test]
513 async fn test_file_edit_unique_match() {
514 let dir = TempDir::new().unwrap();
515 let path = temp_file(dir.path(), "code.rs", "fn main() {\n println!(\"hello\");\n}\n");
516 let tool = FileEditTool::new(vec!["*".into()]);
517
518 let result = tool
519 .execute(serde_json::json!({
520 "path": path.to_str().unwrap(),
521 "old_string": "println!(\"hello\")",
522 "new_string": "println!(\"world\")"
523 }))
524 .await;
525 assert!(!result.is_error, "error: {}", result.content);
526 assert!(result.content.contains("Replaced 1 occurrence"));
527
528 let content = std::fs::read_to_string(&path).unwrap();
529 assert!(content.contains("println!(\"world\")"));
530 assert!(!content.contains("println!(\"hello\")"));
531 }
532
533 #[tokio::test]
534 async fn test_file_edit_no_match() {
535 let dir = TempDir::new().unwrap();
536 let path = temp_file(dir.path(), "code.rs", "fn main() {}\n");
537 let tool = FileEditTool::new(vec!["*".into()]);
538
539 let result = tool
540 .execute(serde_json::json!({
541 "path": path.to_str().unwrap(),
542 "old_string": "nonexistent string",
543 "new_string": "replacement"
544 }))
545 .await;
546 assert!(result.is_error);
547 assert!(result.content.contains("not found"));
548 }
549
550 #[tokio::test]
551 async fn test_file_edit_multiple_matches() {
552 let dir = TempDir::new().unwrap();
553 let path = temp_file(dir.path(), "code.rs", "let x = 1;\nlet y = 1;\n");
554 let tool = FileEditTool::new(vec!["*".into()]);
555
556 let result = tool
557 .execute(serde_json::json!({
558 "path": path.to_str().unwrap(),
559 "old_string": "= 1",
560 "new_string": "= 2"
561 }))
562 .await;
563 assert!(result.is_error);
564 assert!(result.content.contains("2 times"));
565 }
566
567 #[tokio::test]
568 async fn test_file_edit_identical_strings() {
569 let dir = TempDir::new().unwrap();
570 let path = temp_file(dir.path(), "code.rs", "hello\n");
571 let tool = FileEditTool::new(vec!["*".into()]);
572
573 let result = tool
574 .execute(serde_json::json!({
575 "path": path.to_str().unwrap(),
576 "old_string": "hello",
577 "new_string": "hello"
578 }))
579 .await;
580 assert!(result.is_error);
581 assert!(result.content.contains("identical"));
582 }
583
584 #[tokio::test]
585 async fn test_file_edit_path_restricted() {
586 let dir = TempDir::new().unwrap();
587 let path = temp_file(dir.path(), "code.rs", "hello\n");
588 let tool = FileEditTool::new(vec!["/nonexistent_allowed_prefix".into()]);
589
590 let result = tool
591 .execute(serde_json::json!({
592 "path": path.to_str().unwrap(),
593 "old_string": "hello",
594 "new_string": "world"
595 }))
596 .await;
597 assert!(result.is_error);
598 assert!(result.content.contains("outside allowed"));
599 }
600
601 #[test]
604 fn test_file_read_capability() {
605 let tool = FileReadTool::new(vec!["/home".into()]);
606 match tool.required_capability() {
607 Capability::FileRead { allowed_paths } => {
608 assert_eq!(allowed_paths, vec!["/home".to_string()]);
609 }
610 other => panic!("expected FileRead, got: {other:?}"),
611 }
612 }
613
614 #[test]
615 fn test_file_write_capability() {
616 let tool = FileWriteTool::new(vec!["/tmp".into()]);
617 match tool.required_capability() {
618 Capability::FileWrite { allowed_paths } => {
619 assert_eq!(allowed_paths, vec!["/tmp".to_string()]);
620 }
621 other => panic!("expected FileWrite, got: {other:?}"),
622 }
623 }
624
625 #[test]
626 fn test_file_edit_capability() {
627 let tool = FileEditTool::new(vec!["/project".into()]);
628 match tool.required_capability() {
629 Capability::FileWrite { allowed_paths } => {
630 assert_eq!(allowed_paths, vec!["/project".to_string()]);
631 }
632 other => panic!("expected FileWrite, got: {other:?}"),
633 }
634 }
635
636 #[test]
637 fn test_tool_names() {
638 assert_eq!(FileReadTool::new(vec![]).name(), "file_read");
639 assert_eq!(FileWriteTool::new(vec![]).name(), "file_write");
640 assert_eq!(FileEditTool::new(vec![]).name(), "file_edit");
641 }
642
643 #[test]
644 fn test_tool_schemas() {
645 let tools: Vec<Box<dyn Tool>> = vec![
646 Box::new(FileReadTool::new(vec![])),
647 Box::new(FileWriteTool::new(vec![])),
648 Box::new(FileEditTool::new(vec![])),
649 ];
650 for tool in &tools {
651 let def = tool.definition();
652 assert_eq!(def.input_schema["type"], "object");
653 assert!(def.input_schema["required"].as_array().unwrap().iter().any(|v| v == "path"));
654 }
655 }
656}