Skip to main content

atomcode_core/tool/
write.rs

1use anyhow::Result;
2use async_trait::async_trait;
3use serde::Deserialize;
4use serde_json::json;
5
6use super::{ApprovalRequirement, Tool, ToolContext, ToolDef, ToolResult};
7
8pub struct WriteFileTool;
9
10#[derive(Deserialize)]
11struct WriteFileArgs {
12    file_path: String,
13    content: String,
14}
15
16#[async_trait]
17impl Tool for WriteFileTool {
18    fn definition(&self) -> ToolDef {
19        ToolDef {
20            name: "write_file",
21            description:
22                "Write content to a file. Creates new files or overwrites existing ones.\n\
23                Use this for: creating new files, or rewriting an entire file from scratch.\n\
24                For small edits to existing files, prefer edit_file instead.\n\
25                Parent directories are auto-created if they don't exist."
26                    .to_string(),
27            parameters: json!({
28                "type": "object",
29                "properties": {
30                    "file_path": { "type": "string", "description": "Absolute path to the file" },
31                    "content": { "type": "string", "description": "The full content to write" }
32                },
33                "required": ["file_path", "content"]
34            }),
35        }
36    }
37
38    fn validate_args(&self, args: &str) -> std::result::Result<(), String> {
39        // Surface a model-friendly diagnostic (provided/missing keys + a
40        // one-line example) instead of the raw serde "line 1 column N"
41        // error which weak models read as a parser-position complaint and
42        // try to "fix" by switching to positional arguments. See
43        // `diagnose_args` doc for the failure mode this replaces.
44        super::diagnose_args(
45            "write_file",
46            args,
47            &[&["file_path", "content"]],
48            "write_file({\"file_path\": \"<absolute path>\", \"content\": \"<file body>\"})",
49        )?;
50        // Strict struct parse only AFTER the keys are known to be present
51        // — catches type mismatches (e.g. content sent as an array).
52        serde_json::from_str::<WriteFileArgs>(args)
53            .map(|_| ())
54            .map_err(|e| {
55                format!(
56                    "write_file: {e}. Re-issue with file_path as a string and content as a string."
57                )
58            })
59    }
60
61    fn approval(&self, args: &str) -> ApprovalRequirement {
62        let parsed = match serde_json::from_str::<WriteFileArgs>(args) {
63            Ok(p) => p,
64            Err(_) => {
65                // Fail-closed: if we can't parse args, require approval rather than auto-approving.
66                return ApprovalRequirement::RequireApproval(
67                    "Could not parse create_file arguments for safety check.".to_string(),
68                );
69            }
70        };
71        if super::is_sensitive_input_path(&parsed.file_path) {
72            return ApprovalRequirement::RequireApproval(
73                format!("Writing to sensitive system path: {}", parsed.file_path),
74            );
75        }
76        // Overwriting existing files is blocked in execute() — no need to
77        // RequireApproval here. Only new file creation is auto-approved.
78        ApprovalRequirement::AutoApprove
79    }
80
81    fn approval_with_context(&self, args: &str, ctx: &ToolContext) -> ApprovalRequirement {
82        let base = self.approval(args);
83        let parsed = match serde_json::from_str::<WriteFileArgs>(args) {
84            Ok(parsed) => parsed,
85            Err(_) => return base,
86        };
87        let working_dir = match ctx.working_dir.try_read() {
88            Ok(wd) => wd.clone(),
89            Err(_) => return base,
90        };
91        match super::approval_for_path(
92            &parsed.file_path,
93            &working_dir,
94            super::ExternalPathAction::Write,
95        ) {
96            Ok(ApprovalRequirement::RequireApprovalAlways(reason)) => {
97                ApprovalRequirement::RequireApprovalAlways(reason)
98            }
99            Ok(ApprovalRequirement::RequireApproval(reason)) => {
100                ApprovalRequirement::RequireApproval(reason)
101            }
102            Ok(ApprovalRequirement::AutoApprove) => match base {
103                ApprovalRequirement::RequireApproval(reason) => {
104                    ApprovalRequirement::RequireApprovalAlways(reason)
105                }
106                other => other,
107            },
108            Err(_) => base,
109        }
110    }
111
112    async fn execute(&self, args: &str, ctx: &ToolContext) -> Result<ToolResult> {
113        // Defense-in-depth: validate_args runs at the runner gate, but if
114        // it's bypassed (or args mutated between gate and execute), we fall
115        // back to the same diagnose_args path so the model sees a uniform
116        // recovery hint instead of a raw serde error.
117        if let Err(msg) = super::diagnose_args(
118            "write_file",
119            args,
120            &[&["file_path", "content"]],
121            "write_file({\"file_path\": \"<absolute path>\", \"content\": \"<file body>\"})",
122        ) {
123            return Ok(ToolResult {
124                call_id: String::new(),
125                output: msg,
126                success: false,
127            });
128        }
129        let parsed: WriteFileArgs = match serde_json::from_str(args) {
130            Ok(p) => p,
131            Err(e) => {
132                return Ok(ToolResult {
133                    call_id: String::new(),
134                    output: format!(
135                        "write_file: {e}. Re-issue with file_path as a string and content as a string."
136                    ),
137                    success: false,
138                });
139            }
140        };
141        let working_dir = ctx.working_dir.read().await.clone();
142        let path = match super::inspect_path_access(&parsed.file_path, &working_dir) {
143            Ok(access) => access.path,
144            Err(err) => {
145                return Ok(ToolResult {
146                    call_id: String::new(),
147                    output: err.to_string(),
148                    success: false,
149                });
150            }
151        };
152
153        // Backup before write (git checkpoint + file-level backup)
154        ctx.file_history
155            .lock()
156            .await
157            .backup_before_write(&path.to_string_lossy())
158            .await;
159
160        // Check if overwriting existing file — build appropriate output message
161        let overwrite_info = if path.exists() {
162            let old_lines = std::fs::read_to_string(&path)
163                .map(|c| c.lines().count())
164                .unwrap_or(0);
165            Some(old_lines)
166        } else {
167            None
168        };
169
170        if let Some(parent) = path.parent() {
171            tokio::fs::create_dir_all(parent).await?;
172        }
173
174        let new_lines = parsed.content.lines().count();
175        let bytes = parsed.content.len();
176        tokio::fs::write(&path, &parsed.content).await?;
177
178        // D3: drop any FileStore entry for this path. The next peek_file
179        // against the old store_id will report "stale" and route the
180        // model toward a fresh read_file. Without this invalidation a
181        // peek_file could hand the model pre-write content that no
182        // longer matches what just landed on disk.
183        ctx.file_store.write().await.invalidate(&path);
184        // Defense-in-depth: read_cache mtime gate is normally sufficient
185        // because tokio::fs::write bumps mtime, but on FS with coarse
186        // mtime granularity (ext4 1-second precision, NFS) a write within
187        // the same tick as the prior read keeps the same mtime and the
188        // gate stops protecting us. Explicit purge eliminates that
189        // corner case for any path we just wrote.
190        ctx.read_cache
191            .write()
192            .await
193            .retain(|(p, _, _), _| p != &path);
194
195        // Notify LSP that file changed (if LSP is enabled).
196        ctx.notify_lsp_file_changed(&path, &parsed.content).await;
197
198        let output = if let Some(old_lines) = overwrite_info {
199            let diff = new_lines as i64 - old_lines as i64;
200            let sign = if diff >= 0 { "+" } else { "" };
201            let mut msg = format!(
202                "Overwrote {} (was {} lines, now {} lines, {}{})",
203                path.display(),
204                old_lines,
205                new_lines,
206                sign,
207                diff
208            );
209            // Warn if significant content reduction (might have lost code)
210            if old_lines > 20 && new_lines < old_lines / 2 {
211                msg.push_str(&format!(
212                    "\n⚠ WARNING: File shrank by {}%. Verify no important code was lost. Use /undo to revert if needed.",
213                    100 - (new_lines * 100 / old_lines)
214                ));
215            }
216            msg
217        } else {
218            format!(
219                "Created new file {} ({} bytes, {} lines)",
220                path.display(),
221                bytes,
222                new_lines
223            )
224        };
225
226        Ok(ToolResult {
227            call_id: String::new(),
228            output,
229            success: true,
230        })
231    }
232}