Skip to main content

rab/builtin/
write.rs

1use crate::agent::extension::{AgentTool, Cancel, Extension, ToolOutput};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::tui::Theme;
4use anyhow::Context;
5use async_trait::async_trait;
6use std::borrow::Cow;
7use tokio::sync::mpsc::UnboundedSender;
8
9pub struct WriteExtension {
10    cwd: std::path::PathBuf,
11}
12
13impl WriteExtension {
14    pub fn new(cwd: std::path::PathBuf) -> Self {
15        Self { cwd }
16    }
17}
18
19impl Extension for WriteExtension {
20    fn name(&self) -> Cow<'static, str> {
21        "write".into()
22    }
23
24    fn tools(&self) -> Vec<Box<dyn AgentTool>> {
25        vec![Box::new(WriteTool {
26            cwd: self.cwd.clone(),
27        })]
28    }
29}
30
31struct WriteTool {
32    cwd: std::path::PathBuf,
33}
34
35#[async_trait]
36impl AgentTool for WriteTool {
37    fn name(&self) -> &str {
38        "write"
39    }
40
41    fn description(&self) -> &str {
42        "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. \
43         Automatically creates parent directories."
44    }
45
46    fn parameters(&self) -> serde_json::Value {
47        serde_json::json!({
48            "type": "object",
49            "required": ["path", "content"],
50            "properties": {
51                "path": {
52                    "type": "string",
53                    "description": "Path to the file to write (relative or absolute)"
54                },
55                "content": {
56                    "type": "string",
57                    "description": "Content to write to the file"
58                }
59            }
60        })
61    }
62
63    fn prompt_guidelines(&self) -> Vec<String> {
64        vec!["Use write only for new files or complete rewrites.".into()]
65    }
66
67    fn label(&self) -> &str {
68        "Create or overwrite files"
69    }
70
71    fn renderer(&self) -> Option<Box<dyn ToolRenderer>> {
72        Some(Box::new(WriteRenderer::new()))
73    }
74
75    async fn execute(
76        &self,
77        tool_call_id: String,
78        args: serde_json::Value,
79        cancel: Cancel,
80        _on_update: Option<UnboundedSender<ToolOutput>>,
81    ) -> anyhow::Result<ToolOutput> {
82        let _ = tool_call_id;
83        let path = args["path"]
84            .as_str()
85            .ok_or_else(|| anyhow::anyhow!("Missing 'path' argument"))?;
86        let content = args["content"]
87            .as_str()
88            .ok_or_else(|| anyhow::anyhow!("Missing 'content' argument"))?;
89
90        cancel.check()?;
91
92        let cwd = self.cwd.clone();
93        let path_for_queue = path.to_owned();
94        let cwd_for_closure = cwd.clone();
95        let path_for_closure = path.to_owned();
96        let content_owned = content.to_owned();
97
98        let result = crate::builtin::file_mutation_queue::with_file_mutation_queue(
99            &path_for_queue,
100            &cwd,
101            || async move {
102                let abs_path = {
103                    let p = std::path::Path::new(&path_for_closure);
104                    if p.is_absolute() {
105                        p.to_path_buf()
106                    } else {
107                        cwd_for_closure.join(p)
108                    }
109                };
110
111                // Create parent directories
112                if let Some(parent) = abs_path.parent() {
113                    std::fs::create_dir_all(parent).with_context(|| {
114                        format!("Failed to create directory {}", parent.display())
115                    })?;
116                }
117
118                // Write to temp file, then atomic rename
119                let tmp_path = abs_path.with_extension(format!("tmp{}", uuid::Uuid::new_v4()));
120                std::fs::write(&tmp_path, &content_owned)
121                    .with_context(|| format!("Failed to write {}", tmp_path.display()))?;
122                std::fs::rename(&tmp_path, &abs_path).with_context(|| {
123                    format!(
124                        "Failed to rename {} → {}",
125                        tmp_path.display(),
126                        abs_path.display()
127                    )
128                })?;
129
130                Ok::<_, anyhow::Error>(format!(
131                    "Successfully wrote {} bytes to {}",
132                    content_owned.len(),
133                    path_for_closure
134                ))
135            },
136        )
137        .await?;
138
139        Ok(ToolOutput::ok(result))
140    }
141}
142
143/// Tool renderer for the `write` tool.
144/// Shows the file path and a content preview in the call, empty result on success.
145/// Includes incremental caching for syntax-highlighted content.
146struct WriteRenderer {
147    /// Cache state using RwLock for thread safety.
148    cache: std::sync::RwLock<WriteCache>,
149}
150
151struct WriteCache {
152    /// Cache key: (content_hash, expanded, preview_lines_count)
153    key: Option<(u64, bool, usize)>,
154    /// Cached highlighted lines (without the leading \n prefix)
155    lines: Vec<String>,
156    /// Cached remaining count
157    remaining: usize,
158}
159
160impl WriteRenderer {
161    fn new() -> Self {
162        Self {
163            cache: std::sync::RwLock::new(WriteCache {
164                key: None,
165                lines: Vec::new(),
166                remaining: 0,
167            }),
168        }
169    }
170
171    /// Compute a hash of the content for cache invalidation.
172    fn content_hash(content: &str) -> u64 {
173        use std::collections::hash_map::DefaultHasher;
174        use std::hash::{Hash, Hasher};
175        let mut hasher = DefaultHasher::new();
176        content.hash(&mut hasher);
177        hasher.finish()
178    }
179
180    /// Get or compute highlighted lines, using cache when possible.
181    fn get_highlighted_lines(
182        &self,
183        content: &str,
184        path: &str,
185        expanded: bool,
186    ) -> (Vec<String>, usize) {
187        let hash = Self::content_hash(content);
188        let max_preview = if expanded { usize::MAX } else { 5 };
189        let content_lines: Vec<&str> = content.lines().collect();
190        let preview_count = content_lines.len().min(max_preview);
191        let remaining = content_lines.len().saturating_sub(preview_count);
192
193        let key = (hash, expanded, preview_count);
194
195        // Check cache (read lock)
196        {
197            let cache = self.cache.read().unwrap();
198            if let Some(ref cached_key) = cache.key
199                && *cached_key == key
200                && !cache.lines.is_empty()
201            {
202                return (cache.lines.clone(), cache.remaining);
203            }
204        }
205
206        // Compute highlighted lines
207        let display: Vec<&str> = content_lines.iter().copied().take(preview_count).collect();
208        let lang = if !path.is_empty() {
209            crate::tui::components::path_to_language(path)
210        } else {
211            None
212        };
213
214        let mut highlighted = Vec::new();
215
216        #[cfg(feature = "syntect")]
217        if let Some(lang) = lang {
218            let text = display.join("\n");
219            let hl = crate::tui::components::highlight_code(&text, Some(lang));
220            if !hl.is_empty() {
221                highlighted = hl;
222            }
223        }
224
225        // Fallback: no highlighting
226        if highlighted.is_empty() {
227            highlighted = display.iter().map(|l| l.to_string()).collect();
228        }
229
230        // Update cache (write lock)
231        {
232            let mut cache = self.cache.write().unwrap();
233            cache.key = Some(key);
234            cache.lines = highlighted.clone();
235            cache.remaining = remaining;
236        }
237
238        (highlighted, remaining)
239    }
240}
241
242impl ToolRenderer for WriteRenderer {
243    fn render_call(
244        &self,
245        args: &serde_json::Value,
246        _width: usize,
247        theme: &dyn Theme,
248        ctx: &ToolRenderContext,
249    ) -> Vec<String> {
250        let path = args
251            .get("file_path")
252            .or_else(|| args.get("path"))
253            .and_then(|v| v.as_str())
254            .unwrap_or("");
255        let content = args.get("content").and_then(|v| v.as_str()).unwrap_or("");
256
257        let short = if let Ok(home) = std::env::var("HOME") {
258            path.replacen(&home, "~", 1)
259        } else {
260            path.to_string()
261        };
262        let path_disp = if short.is_empty() {
263            String::new()
264        } else {
265            theme.fg("accent", &short)
266        };
267
268        let header = format!(
269            "{} {}",
270            theme.fg("toolTitle", &theme.bold("write")),
271            path_disp
272        );
273
274        let mut lines = vec![header];
275
276        // Show content preview (first few lines) when not expanded
277        if !content.is_empty() {
278            let (display, remaining) = self.get_highlighted_lines(content, path, ctx.expanded);
279
280            for line in &display {
281                lines.push(format!("\n{}", theme.fg("toolOutput", line)));
282            }
283
284            if remaining > 0 {
285                lines.push(theme.fg(
286                    "muted",
287                    &format!(
288                        "... ({} more lines, {} total, {} to expand)",
289                        remaining,
290                        content.lines().count(),
291                        ctx.expand_key
292                    ),
293                ));
294            }
295        }
296
297        lines
298    }
299
300    fn render_result(
301        &self,
302        content: &str,
303        _width: usize,
304        theme: &dyn Theme,
305        ctx: &ToolRenderContext,
306    ) -> Vec<String> {
307        // On success, pi shows no result output (just the background color transition).
308        // On error, show the error text.
309        if !ctx.is_error || content.is_empty() {
310            return vec![];
311        }
312        vec![theme.fg("error", content)]
313    }
314}