Skip to main content

batuta/agent/tool/
file.rs

1//! File system tools for agent code editing.
2//!
3//! Provides `file_read`, `file_write`, and `file_edit` tools for the
4//! `apr code` agentic coding assistant. Each tool enforces path
5//! restrictions via `Capability::FileRead` / `Capability::FileWrite`
6//! (Poka-Yoke: mistake-proofing).
7//!
8//! Design follows Claude Code's tool semantics:
9//! - `file_read`: Read file with optional line range
10//! - `file_write`: Create or overwrite a file
11//! - `file_edit`: Replace a unique string in a file
12//!
13//! Security: paths are canonicalized and checked against allowed
14//! prefixes before any I/O. Symlink traversal is blocked.
15
16use std::path::{Path, PathBuf};
17
18use async_trait::async_trait;
19
20use crate::agent::capability::Capability;
21use crate::agent::driver::ToolDefinition;
22
23use super::{Tool, ToolResult};
24
25/// Maximum file size that `file_read` will return (128 KB).
26const MAX_READ_BYTES: usize = 128 * 1024;
27
28/// Maximum number of lines `file_read` returns per call.
29const MAX_READ_LINES: usize = 2000;
30
31// ─── Path validation (shared) ───────────────────────────────
32
33/// Validate that a path is within allowed prefixes.
34/// Returns the canonicalized path on success.
35fn validate_path(raw: &str, allowed: &[String]) -> Result<PathBuf, String> {
36    if raw.is_empty() {
37        return Err("path is empty".into());
38    }
39    // Canonicalize to resolve symlinks (Poka-Yoke: block symlink traversal)
40    let canonical = PathBuf::from(raw)
41        .canonicalize()
42        .map_err(|e| format!("cannot resolve path '{}': {}", raw, e))?;
43    check_prefix(&canonical, &canonical, allowed)
44}
45
46/// Validate a path for writing. Parent directory must exist.
47/// For new files, we validate the parent directory instead.
48fn validate_write_path(raw: &str, allowed: &[String]) -> Result<PathBuf, String> {
49    if raw.is_empty() {
50        return Err("path is empty".into());
51    }
52
53    let path = PathBuf::from(raw);
54
55    // For existing files, canonicalize normally
56    if path.exists() {
57        return validate_path(raw, allowed);
58    }
59
60    // For new files, validate parent directory exists and is allowed
61    let parent = path.parent().ok_or_else(|| format!("cannot determine parent of '{}'", raw))?;
62
63    let parent_canon = parent
64        .canonicalize()
65        .map_err(|e| format!("parent directory '{}' not found: {}", parent.display(), e))?;
66
67    let target = parent_canon.join(path.file_name().unwrap_or_default());
68    check_prefix(&target, &parent_canon, allowed)
69}
70
71/// Check that a canonical path is within at least one allowed prefix.
72fn check_prefix(target: &Path, canonical: &Path, allowed: &[String]) -> Result<PathBuf, String> {
73    if allowed.iter().any(|p| p == "*") {
74        return Ok(target.to_path_buf());
75    }
76    for prefix in allowed {
77        if let Ok(prefix_canon) = PathBuf::from(prefix).canonicalize() {
78            if canonical.starts_with(&prefix_canon) {
79                return Ok(target.to_path_buf());
80            }
81        }
82    }
83    Err(format!("path '{}' outside allowed prefixes: {:?}", target.display(), allowed))
84}
85
86// ─── FileReadTool ───────────────────────────────────────────
87
88/// Read file contents with optional line range.
89///
90/// Returns numbered lines (like `cat -n`). Respects `MAX_READ_BYTES`
91/// and `MAX_READ_LINES` limits to prevent context overflow.
92pub struct FileReadTool {
93    allowed_paths: Vec<String>,
94}
95
96impl FileReadTool {
97    pub fn new(allowed_paths: Vec<String>) -> Self {
98        Self { allowed_paths }
99    }
100}
101
102#[async_trait]
103impl Tool for FileReadTool {
104    fn name(&self) -> &'static str {
105        "file_read"
106    }
107
108    fn definition(&self) -> ToolDefinition {
109        ToolDefinition {
110            name: "file_read".into(),
111            description: "Read a file's contents. Returns numbered lines.".into(),
112            input_schema: serde_json::json!({
113                "type": "object",
114                "required": ["path"],
115                "properties": {
116                    "path": {
117                        "type": "string",
118                        "description": "Absolute path to the file"
119                    },
120                    "offset": {
121                        "type": "integer",
122                        "description": "Line number to start from (1-based, default 1)"
123                    },
124                    "limit": {
125                        "type": "integer",
126                        "description": "Maximum lines to read (default 2000)"
127                    }
128                }
129            }),
130        }
131    }
132
133    async fn execute(&self, input: serde_json::Value) -> ToolResult {
134        let path_str = match input.get("path").and_then(|v| v.as_str()) {
135            Some(p) => p,
136            None => return ToolResult::error("missing required field 'path'"),
137        };
138
139        let offset = input.get("offset").and_then(|v| v.as_u64()).unwrap_or(1).max(1) as usize;
140        let limit = input
141            .get("limit")
142            .and_then(|v| v.as_u64())
143            .unwrap_or(MAX_READ_LINES as u64)
144            .min(MAX_READ_LINES as u64) as usize;
145
146        let path = match validate_path(path_str, &self.allowed_paths) {
147            Ok(p) => p,
148            Err(e) => return ToolResult::error(e),
149        };
150
151        // Check file size (Jidoka: don't read huge files)
152        match std::fs::metadata(&path) {
153            Ok(meta) if meta.len() > MAX_READ_BYTES as u64 => {
154                return ToolResult::error(format!(
155                    "file too large ({} bytes, max {}). Use offset/limit to read a portion.",
156                    meta.len(),
157                    MAX_READ_BYTES
158                ));
159            }
160            Err(e) => return ToolResult::error(format!("cannot stat '{}': {}", path.display(), e)),
161            _ => {}
162        }
163
164        match std::fs::read_to_string(&path) {
165            Ok(content) => {
166                let lines: Vec<&str> = content.lines().collect();
167                let start = (offset - 1).min(lines.len());
168                let end = (start + limit).min(lines.len());
169                let selected = &lines[start..end];
170
171                let mut result = String::with_capacity(selected.len() * 80);
172                for (i, line) in selected.iter().enumerate() {
173                    let line_num = start + i + 1;
174                    result.push_str(&format!("{line_num}\t{line}\n"));
175                }
176
177                if end < lines.len() {
178                    result.push_str(&format!(
179                        "\n[{} more lines, use offset={} to continue]",
180                        lines.len() - end,
181                        end + 1
182                    ));
183                }
184
185                ToolResult::success(result)
186            }
187            Err(e) => ToolResult::error(format!("cannot read '{}': {}", path.display(), e)),
188        }
189    }
190
191    fn required_capability(&self) -> Capability {
192        Capability::FileRead { allowed_paths: self.allowed_paths.clone() }
193    }
194}
195
196// ─── FileWriteTool ──────────────────────────────────────────
197
198/// Write content to a file. Creates or overwrites.
199///
200/// Creates parent directories if they don't exist.
201pub struct FileWriteTool {
202    allowed_paths: Vec<String>,
203}
204
205impl FileWriteTool {
206    pub fn new(allowed_paths: Vec<String>) -> Self {
207        Self { allowed_paths }
208    }
209}
210
211#[async_trait]
212impl Tool for FileWriteTool {
213    fn name(&self) -> &'static str {
214        "file_write"
215    }
216
217    fn definition(&self) -> ToolDefinition {
218        ToolDefinition {
219            name: "file_write".into(),
220            description: "Create or overwrite a file with the given content.".into(),
221            input_schema: serde_json::json!({
222                "type": "object",
223                "required": ["path", "content"],
224                "properties": {
225                    "path": {
226                        "type": "string",
227                        "description": "Absolute path to the file"
228                    },
229                    "content": {
230                        "type": "string",
231                        "description": "File content to write"
232                    }
233                }
234            }),
235        }
236    }
237
238    async fn execute(&self, input: serde_json::Value) -> ToolResult {
239        let path_str = match input.get("path").and_then(|v| v.as_str()) {
240            Some(p) => p,
241            None => return ToolResult::error("missing required field 'path'"),
242        };
243
244        let content = match input.get("content").and_then(|v| v.as_str()) {
245            Some(c) => c,
246            None => return ToolResult::error("missing required field 'content'"),
247        };
248
249        let path = match validate_write_path(path_str, &self.allowed_paths) {
250            Ok(p) => p,
251            Err(e) => return ToolResult::error(e),
252        };
253
254        // Create parent directories
255        if let Some(parent) = path.parent() {
256            if !parent.exists() {
257                if let Err(e) = std::fs::create_dir_all(parent) {
258                    return ToolResult::error(format!(
259                        "cannot create directory '{}': {}",
260                        parent.display(),
261                        e
262                    ));
263                }
264            }
265        }
266
267        match std::fs::write(&path, content) {
268            Ok(()) => {
269                ToolResult::success(format!("Wrote {} bytes to {}", content.len(), path.display()))
270            }
271            Err(e) => ToolResult::error(format!("cannot write '{}': {}", path.display(), e)),
272        }
273    }
274
275    fn required_capability(&self) -> Capability {
276        Capability::FileWrite { allowed_paths: self.allowed_paths.clone() }
277    }
278}
279
280// ─── FileEditTool ───────────────────────────────────────────
281
282/// Edit a file by replacing a unique string.
283///
284/// Semantics match Claude Code's Edit tool:
285/// - `old_string` must appear exactly once in the file
286/// - `new_string` replaces it
287/// - If `old_string` appears 0 or >1 times, the edit fails
288pub struct FileEditTool {
289    allowed_paths: Vec<String>,
290}
291
292impl FileEditTool {
293    pub fn new(allowed_paths: Vec<String>) -> Self {
294        Self { allowed_paths }
295    }
296}
297
298#[async_trait]
299impl Tool for FileEditTool {
300    fn name(&self) -> &'static str {
301        "file_edit"
302    }
303
304    fn definition(&self) -> ToolDefinition {
305        ToolDefinition {
306            name: "file_edit".into(),
307            description: "Replace a unique string in a file. old_string must appear exactly once."
308                .into(),
309            input_schema: serde_json::json!({
310                "type": "object",
311                "required": ["path", "old_string", "new_string"],
312                "properties": {
313                    "path": {
314                        "type": "string",
315                        "description": "Absolute path to the file"
316                    },
317                    "old_string": {
318                        "type": "string",
319                        "description": "Exact string to find (must be unique in the file)"
320                    },
321                    "new_string": {
322                        "type": "string",
323                        "description": "Replacement string"
324                    }
325                }
326            }),
327        }
328    }
329
330    async fn execute(&self, input: serde_json::Value) -> ToolResult {
331        let path_str = match input.get("path").and_then(|v| v.as_str()) {
332            Some(p) => p,
333            None => return ToolResult::error("missing required field 'path'"),
334        };
335
336        let old_string = match input.get("old_string").and_then(|v| v.as_str()) {
337            Some(s) => s,
338            None => return ToolResult::error("missing required field 'old_string'"),
339        };
340
341        let new_string = match input.get("new_string").and_then(|v| v.as_str()) {
342            Some(s) => s,
343            None => return ToolResult::error("missing required field 'new_string'"),
344        };
345
346        if old_string == new_string {
347            return ToolResult::error("old_string and new_string are identical");
348        }
349
350        let path = match validate_path(path_str, &self.allowed_paths) {
351            Ok(p) => p,
352            Err(e) => return ToolResult::error(e),
353        };
354
355        let content = match std::fs::read_to_string(&path) {
356            Ok(c) => c,
357            Err(e) => return ToolResult::error(format!("cannot read '{}': {}", path.display(), e)),
358        };
359
360        let count = content.matches(old_string).count();
361        match count {
362            0 => ToolResult::error(format!(
363                "old_string not found in {}. Provide more context to match.",
364                path.display()
365            )),
366            1 => {
367                let new_content = content.replacen(old_string, new_string, 1);
368                match std::fs::write(&path, &new_content) {
369                    Ok(()) => ToolResult::success(format!(
370                        "Edited {}. Replaced 1 occurrence ({} bytes → {} bytes).",
371                        path.display(),
372                        old_string.len(),
373                        new_string.len()
374                    )),
375                    Err(e) => {
376                        ToolResult::error(format!("cannot write '{}': {}", path.display(), e))
377                    }
378                }
379            }
380            n => ToolResult::error(format!(
381                "old_string found {} times in {}. Provide more context to make it unique.",
382                n,
383                path.display()
384            )),
385        }
386    }
387
388    fn required_capability(&self) -> Capability {
389        Capability::FileWrite { allowed_paths: self.allowed_paths.clone() }
390    }
391}
392
393#[cfg(test)]
394mod tests {
395    use super::*;
396    use std::io::Write;
397    use tempfile::TempDir;
398
399    fn temp_file(dir: &Path, name: &str, content: &str) -> PathBuf {
400        let path = dir.join(name);
401        let mut f = std::fs::File::create(&path).unwrap();
402        f.write_all(content.as_bytes()).unwrap();
403        path
404    }
405
406    // ─── FileReadTool tests ─────────────────────────────
407
408    #[tokio::test]
409    async fn test_file_read_basic() {
410        let dir = TempDir::new().unwrap();
411        let path = temp_file(dir.path(), "test.txt", "line1\nline2\nline3\n");
412        let tool = FileReadTool::new(vec!["*".into()]);
413
414        let result = tool.execute(serde_json::json!({"path": path.to_str().unwrap()})).await;
415        assert!(!result.is_error, "error: {}", result.content);
416        assert!(result.content.contains("1\tline1"));
417        assert!(result.content.contains("2\tline2"));
418        assert!(result.content.contains("3\tline3"));
419    }
420
421    #[tokio::test]
422    async fn test_file_read_with_offset_and_limit() {
423        let dir = TempDir::new().unwrap();
424        let content: String = (1..=100).map(|i| format!("line{i}\n")).collect();
425        let path = temp_file(dir.path(), "big.txt", &content);
426        let tool = FileReadTool::new(vec!["*".into()]);
427
428        let result = tool
429            .execute(serde_json::json!({"path": path.to_str().unwrap(), "offset": 50, "limit": 5}))
430            .await;
431        assert!(!result.is_error);
432        assert!(result.content.contains("50\tline50"));
433        assert!(result.content.contains("54\tline54"));
434        assert!(!result.content.contains("55\tline55"));
435    }
436
437    #[tokio::test]
438    async fn test_file_read_nonexistent() {
439        let tool = FileReadTool::new(vec!["*".into()]);
440        let result = tool.execute(serde_json::json!({"path": "/nonexistent_file_xyz"})).await;
441        assert!(result.is_error);
442        assert!(result.content.contains("cannot resolve"));
443    }
444
445    #[tokio::test]
446    async fn test_file_read_missing_path_field() {
447        let tool = FileReadTool::new(vec!["*".into()]);
448        let result = tool.execute(serde_json::json!({"file": "test.txt"})).await;
449        assert!(result.is_error);
450        assert!(result.content.contains("missing"));
451    }
452
453    #[tokio::test]
454    async fn test_file_read_path_restricted() {
455        let dir = TempDir::new().unwrap();
456        let path = temp_file(dir.path(), "secret.txt", "secret data");
457        let tool = FileReadTool::new(vec!["/nonexistent_allowed_prefix".into()]);
458
459        let result = tool.execute(serde_json::json!({"path": path.to_str().unwrap()})).await;
460        assert!(result.is_error);
461        assert!(result.content.contains("outside allowed"));
462    }
463
464    // ─── FileWriteTool tests ────────────────────────────
465
466    #[tokio::test]
467    async fn test_file_write_create() {
468        let dir = TempDir::new().unwrap();
469        let path = dir.path().join("new_file.txt");
470        let tool = FileWriteTool::new(vec!["*".into()]);
471
472        let result = tool
473            .execute(serde_json::json!({"path": path.to_str().unwrap(), "content": "hello world"}))
474            .await;
475        assert!(!result.is_error, "error: {}", result.content);
476        assert!(result.content.contains("11 bytes"));
477        assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world");
478    }
479
480    #[tokio::test]
481    async fn test_file_write_overwrite() {
482        let dir = TempDir::new().unwrap();
483        let path = temp_file(dir.path(), "existing.txt", "old content");
484        let tool = FileWriteTool::new(vec!["*".into()]);
485
486        let result = tool
487            .execute(serde_json::json!({"path": path.to_str().unwrap(), "content": "new content"}))
488            .await;
489        assert!(!result.is_error);
490        assert_eq!(std::fs::read_to_string(&path).unwrap(), "new content");
491    }
492
493    #[tokio::test]
494    async fn test_file_write_path_restricted() {
495        let tool = FileWriteTool::new(vec!["/nonexistent_allowed_prefix".into()]);
496        let result =
497            tool.execute(serde_json::json!({"path": "/tmp/evil.txt", "content": "bad"})).await;
498        assert!(result.is_error);
499        assert!(result.content.contains("outside allowed"));
500    }
501
502    #[tokio::test]
503    async fn test_file_write_missing_content() {
504        let tool = FileWriteTool::new(vec!["*".into()]);
505        let result = tool.execute(serde_json::json!({"path": "/tmp/test.txt"})).await;
506        assert!(result.is_error);
507        assert!(result.content.contains("missing"));
508    }
509
510    // ─── FileEditTool tests ─────────────────────────────
511
512    #[tokio::test]
513    async fn test_file_edit_unique_match() {
514        let dir = TempDir::new().unwrap();
515        let path = temp_file(dir.path(), "code.rs", "fn main() {\n    println!(\"hello\");\n}\n");
516        let tool = FileEditTool::new(vec!["*".into()]);
517
518        let result = tool
519            .execute(serde_json::json!({
520                "path": path.to_str().unwrap(),
521                "old_string": "println!(\"hello\")",
522                "new_string": "println!(\"world\")"
523            }))
524            .await;
525        assert!(!result.is_error, "error: {}", result.content);
526        assert!(result.content.contains("Replaced 1 occurrence"));
527
528        let content = std::fs::read_to_string(&path).unwrap();
529        assert!(content.contains("println!(\"world\")"));
530        assert!(!content.contains("println!(\"hello\")"));
531    }
532
533    #[tokio::test]
534    async fn test_file_edit_no_match() {
535        let dir = TempDir::new().unwrap();
536        let path = temp_file(dir.path(), "code.rs", "fn main() {}\n");
537        let tool = FileEditTool::new(vec!["*".into()]);
538
539        let result = tool
540            .execute(serde_json::json!({
541                "path": path.to_str().unwrap(),
542                "old_string": "nonexistent string",
543                "new_string": "replacement"
544            }))
545            .await;
546        assert!(result.is_error);
547        assert!(result.content.contains("not found"));
548    }
549
550    #[tokio::test]
551    async fn test_file_edit_multiple_matches() {
552        let dir = TempDir::new().unwrap();
553        let path = temp_file(dir.path(), "code.rs", "let x = 1;\nlet y = 1;\n");
554        let tool = FileEditTool::new(vec!["*".into()]);
555
556        let result = tool
557            .execute(serde_json::json!({
558                "path": path.to_str().unwrap(),
559                "old_string": "= 1",
560                "new_string": "= 2"
561            }))
562            .await;
563        assert!(result.is_error);
564        assert!(result.content.contains("2 times"));
565    }
566
567    #[tokio::test]
568    async fn test_file_edit_identical_strings() {
569        let dir = TempDir::new().unwrap();
570        let path = temp_file(dir.path(), "code.rs", "hello\n");
571        let tool = FileEditTool::new(vec!["*".into()]);
572
573        let result = tool
574            .execute(serde_json::json!({
575                "path": path.to_str().unwrap(),
576                "old_string": "hello",
577                "new_string": "hello"
578            }))
579            .await;
580        assert!(result.is_error);
581        assert!(result.content.contains("identical"));
582    }
583
584    #[tokio::test]
585    async fn test_file_edit_path_restricted() {
586        let dir = TempDir::new().unwrap();
587        let path = temp_file(dir.path(), "code.rs", "hello\n");
588        let tool = FileEditTool::new(vec!["/nonexistent_allowed_prefix".into()]);
589
590        let result = tool
591            .execute(serde_json::json!({
592                "path": path.to_str().unwrap(),
593                "old_string": "hello",
594                "new_string": "world"
595            }))
596            .await;
597        assert!(result.is_error);
598        assert!(result.content.contains("outside allowed"));
599    }
600
601    // ─── Capability tests ───────────────────────────────
602
603    #[test]
604    fn test_file_read_capability() {
605        let tool = FileReadTool::new(vec!["/home".into()]);
606        match tool.required_capability() {
607            Capability::FileRead { allowed_paths } => {
608                assert_eq!(allowed_paths, vec!["/home".to_string()]);
609            }
610            other => panic!("expected FileRead, got: {other:?}"),
611        }
612    }
613
614    #[test]
615    fn test_file_write_capability() {
616        let tool = FileWriteTool::new(vec!["/tmp".into()]);
617        match tool.required_capability() {
618            Capability::FileWrite { allowed_paths } => {
619                assert_eq!(allowed_paths, vec!["/tmp".to_string()]);
620            }
621            other => panic!("expected FileWrite, got: {other:?}"),
622        }
623    }
624
625    #[test]
626    fn test_file_edit_capability() {
627        let tool = FileEditTool::new(vec!["/project".into()]);
628        match tool.required_capability() {
629            Capability::FileWrite { allowed_paths } => {
630                assert_eq!(allowed_paths, vec!["/project".to_string()]);
631            }
632            other => panic!("expected FileWrite, got: {other:?}"),
633        }
634    }
635
636    #[test]
637    fn test_tool_names() {
638        assert_eq!(FileReadTool::new(vec![]).name(), "file_read");
639        assert_eq!(FileWriteTool::new(vec![]).name(), "file_write");
640        assert_eq!(FileEditTool::new(vec![]).name(), "file_edit");
641    }
642
643    #[test]
644    fn test_tool_schemas() {
645        let tools: Vec<Box<dyn Tool>> = vec![
646            Box::new(FileReadTool::new(vec![])),
647            Box::new(FileWriteTool::new(vec![])),
648            Box::new(FileEditTool::new(vec![])),
649        ];
650        for tool in &tools {
651            let def = tool.definition();
652            assert_eq!(def.input_schema["type"], "object");
653            assert!(def.input_schema["required"].as_array().unwrap().iter().any(|v| v == "path"));
654        }
655    }
656}