1use crate::agent::extension::{Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::builtin;
4use crate::tui::Theme;
5use crate::tui::ThemeKey;
6
7use std::borrow::Cow;
8use std::path::Path;
9use std::sync::Arc;
10
11pub fn prepare_write_args(mut args: serde_json::Value) -> Result<serde_json::Value, String> {
14 if let Some(val) = args.get("path")
16 && !val.is_string()
17 {
18 if val.is_number() || val.is_boolean() {
19 args["path"] = serde_json::Value::String(match val {
20 serde_json::Value::Number(n) => n.to_string(),
21 serde_json::Value::Bool(b) => b.to_string(),
22 _ => unreachable!(),
23 });
24 } else if val.is_null() {
25 return Err("Missing 'path' argument".to_string());
26 }
27 }
28
29 if let Some(val) = args.get("content")
31 && !val.is_string()
32 {
33 if val.is_number() || val.is_boolean() || val.is_null() {
34 args["content"] = serde_json::Value::String(match val {
35 serde_json::Value::Number(n) => n.to_string(),
36 serde_json::Value::Bool(b) => b.to_string(),
37 serde_json::Value::Null => String::new(),
38 _ => unreachable!(),
39 });
40 } else if val.is_array() || val.is_object() {
41 args["content"] =
43 serde_json::Value::String(serde_json::to_string(val).unwrap_or_default());
44 }
45 }
46
47 Ok(args)
48}
49
50const PREVIEW_LINES: usize = 10;
52
53const PARTIAL_FULL_HIGHLIGHT_LINES: usize = 50;
56
57pub trait WriteOperations: Send + Sync {
62 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()>;
64 fn mkdir(&self, dir: &Path) -> anyhow::Result<()>;
66}
67
68impl<F1, F2> WriteOperations for (F1, F2)
69where
70 F1: Send + Sync + Fn(&Path, &str) -> anyhow::Result<()>,
71 F2: Send + Sync + Fn(&Path) -> anyhow::Result<()>,
72{
73 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()> {
74 (self.0)(absolute_path, content)
75 }
76 fn mkdir(&self, dir: &Path) -> anyhow::Result<()> {
77 (self.1)(dir)
78 }
79}
80
81struct DefaultWriteOperations;
82
83impl WriteOperations for DefaultWriteOperations {
84 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()> {
85 Ok(std::fs::write(absolute_path, content)?)
86 }
87 fn mkdir(&self, dir: &Path) -> anyhow::Result<()> {
88 Ok(std::fs::create_dir_all(dir)?)
89 }
90}
91
92pub struct WriteExtension {
95 cwd: std::path::PathBuf,
96 operations: Arc<dyn WriteOperations>,
97}
98
99impl WriteExtension {
100 pub fn new(cwd: std::path::PathBuf) -> Self {
101 Self {
102 cwd,
103 operations: Arc::new(DefaultWriteOperations),
104 }
105 }
106
107 pub fn with_operations(mut self, operations: Arc<dyn WriteOperations>) -> Self {
109 self.operations = operations;
110 self
111 }
112}
113
114impl Extension for WriteExtension {
115 fn name(&self) -> Cow<'static, str> {
116 "write".into()
117 }
118
119 fn tools(&self) -> Vec<ToolDefinition> {
120 vec![ToolDefinition {
121 tool: Box::new(WriteTool {
122 cwd: self.cwd.clone(),
123 operations: self.operations.clone(),
124 }),
125 snippet: "Create or overwrite files",
126 guidelines: &["Use write only for new files or complete rewrites."],
127 prepare_arguments: Some(prepare_write_args),
128 before_tool_call: None,
129 after_tool_call: None,
130 renderer: Some(std::sync::Arc::new(WriteRenderer::new())),
131 }]
132 }
133}
134
135struct WriteTool {
138 cwd: std::path::PathBuf,
139 operations: Arc<dyn WriteOperations>,
140}
141
142struct WriteHighlightCache {
147 raw_path: Option<String>,
148 lang: String,
149 raw_content: String,
150 normalized_lines: Vec<String>,
151 highlighted_lines: Vec<String>,
152}
153
154fn highlight_single_line(line: &str, lang: &str) -> String {
156 #[cfg(feature = "syntect")]
157 {
158 let hl = crate::tui::components::highlight_code(line, Some(lang));
159 if !hl.is_empty() {
160 return hl[0].clone();
161 }
162 }
163 line.to_string()
164}
165
166fn refresh_highlight_prefix(cache: &mut WriteHighlightCache) {
169 let prefix_count = PARTIAL_FULL_HIGHLIGHT_LINES.min(cache.normalized_lines.len());
170 if prefix_count == 0 {
171 return;
172 }
173 let prefix_source: Vec<&str> = cache.normalized_lines[..prefix_count]
174 .iter()
175 .map(|s| s.as_str())
176 .collect();
177 let prefix_text = prefix_source.join("\n");
178 #[cfg(feature = "syntect")]
179 {
180 let prefix_highlighted =
181 crate::tui::components::highlight_code(&prefix_text, Some(&cache.lang));
182 for i in 0..prefix_count {
183 cache.highlighted_lines[i] = prefix_highlighted
184 .get(i)
185 .cloned()
186 .unwrap_or_else(|| highlight_single_line(&cache.normalized_lines[i], &cache.lang));
187 }
188 }
189 #[cfg(not(feature = "syntect"))]
190 {
191 let _ = prefix_text;
192 for i in 0..prefix_count {
193 cache.highlighted_lines[i] = cache.normalized_lines[i].clone();
194 }
195 }
196}
197
198fn rebuild_highlight_cache(
200 raw_path: Option<&str>,
201 file_content: &str,
202) -> Option<WriteHighlightCache> {
203 let lang = raw_path
204 .and_then(crate::tui::components::path_to_language)
205 .map(|s| s.to_string());
206 let lang = lang?;
207
208 let display_content = file_content.replace('\r', "");
209 let normalized = display_content.replace('\t', " ");
210 let normalized_lines: Vec<String> = normalized.lines().map(|l| l.to_string()).collect();
211
212 #[cfg(feature = "syntect")]
213 let highlighted_lines = crate::tui::components::highlight_code(&normalized, Some(&lang));
214 #[cfg(not(feature = "syntect"))]
215 let highlighted_lines = normalized_lines.clone();
216
217 Some(WriteHighlightCache {
218 raw_path: raw_path.map(|s| s.to_string()),
219 lang,
220 raw_content: file_content.to_string(),
221 normalized_lines,
222 highlighted_lines,
223 })
224}
225
226fn update_highlight_cache_incremental(
229 cache: Option<WriteHighlightCache>,
230 raw_path: Option<&str>,
231 file_content: &str,
232) -> Option<WriteHighlightCache> {
233 let lang = raw_path
234 .and_then(crate::tui::components::path_to_language)
235 .map(|s| s.to_string());
236 let lang = lang?;
237
238 let mut cache = match cache {
239 Some(c) => c,
240 None => return rebuild_highlight_cache(raw_path, file_content),
241 };
242
243 if cache.lang != lang || cache.raw_path.as_deref() != raw_path {
245 return rebuild_highlight_cache(raw_path, file_content);
246 }
247
248 if !file_content.starts_with(&cache.raw_content) {
250 return rebuild_highlight_cache(raw_path, file_content);
251 }
252
253 if file_content.len() == cache.raw_content.len() {
255 return Some(cache);
256 }
257
258 let delta_raw = &file_content[cache.raw_content.len()..];
260 let delta_display = delta_raw.replace('\r', "");
261 let delta_normalized = delta_display.replace('\t', " ");
262
263 cache.raw_content = file_content.to_string();
264
265 if cache.normalized_lines.is_empty() {
266 cache.normalized_lines.push(String::new());
267 cache.highlighted_lines.push(String::new());
268 }
269
270 let segments: Vec<&str> = delta_normalized.split('\n').collect();
271 if segments.is_empty() {
272 return Some(cache);
273 }
274
275 let last_idx = cache.normalized_lines.len() - 1;
277 cache.normalized_lines[last_idx].push_str(segments[0]);
278 cache.highlighted_lines[last_idx] =
279 highlight_single_line(&cache.normalized_lines[last_idx], &cache.lang);
280
281 for &seg in &segments[1..] {
283 cache.normalized_lines.push(seg.to_string());
284 cache
285 .highlighted_lines
286 .push(highlight_single_line(seg, &cache.lang));
287 }
288
289 refresh_highlight_prefix(&mut cache);
291
292 Some(cache)
293}
294
295fn trim_trailing_empty_lines(lines: &[String]) -> &[String] {
297 let mut end = lines.len();
298 while end > 0 && lines[end - 1].is_empty() {
299 end -= 1;
300 }
301 &lines[..end]
302}
303
304#[async_trait::async_trait]
305impl yoagent::types::AgentTool for WriteTool {
306 fn name(&self) -> &str {
307 "write"
308 }
309 fn label(&self) -> &str {
310 "write"
311 }
312 fn description(&self) -> &str {
313 "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. \
314 Automatically creates parent directories."
315 }
316 fn parameters_schema(&self) -> serde_json::Value {
317 serde_json::json!({
318 "type": "object",
319 "required": ["path", "content"],
320 "properties": {
321 "path": {
322 "type": "string",
323 "description": "Path to the file to write"
324 },
325 "content": {
326 "type": "string",
327 "description": "Content to write to the file"
328 }
329 }
330 })
331 }
332 async fn execute(
333 &self,
334 params: serde_json::Value,
335 ctx: yoagent::types::ToolContext,
336 ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
337 let path = params["path"]
338 .as_str()
339 .ok_or_else(|| {
340 yoagent::types::ToolError::InvalidArgs("Missing 'path' argument".into())
341 })?
342 .to_string();
343 let content = params["content"]
344 .as_str()
345 .ok_or_else(|| {
346 yoagent::types::ToolError::InvalidArgs("Missing 'content' argument".into())
347 })?
348 .to_string();
349
350 if ctx.cancel.is_cancelled() {
351 return Err(yoagent::types::ToolError::Cancelled);
352 }
353
354 let cwd = self.cwd.clone();
355 let cancel = ctx.cancel.clone();
356 let ops = self.operations.clone();
357 let path_for_queue = path.clone();
358 let cwd_for_closure = cwd.clone();
359 let content_for_closure = content.clone();
360
361 let result = crate::builtin::file_mutation_queue::with_file_mutation_queue(
362 &path_for_queue,
363 &cwd,
364 || async move {
365 let abs_path = builtin::resolve_path(&path, &cwd_for_closure);
366
367 if let Some(parent) = abs_path.parent() {
369 ops.mkdir(parent).map_err(|e| {
370 anyhow::anyhow!("Failed to create dir {}: {}", parent.display(), e)
371 })?;
372 }
373
374 if cancel.is_cancelled() {
375 anyhow::bail!("Operation cancelled");
376 }
377
378 ops.write_file(&abs_path, &content_for_closure)
380 .map_err(|e| {
381 anyhow::anyhow!("Failed to write {}: {}", abs_path.display(), e)
382 })?;
383
384 Ok::<_, anyhow::Error>(format!(
385 "Successfully wrote {} bytes to {}",
386 content_for_closure.len(),
387 path
388 ))
389 },
390 )
391 .await
392 .map_err(|e| yoagent::types::ToolError::Failed(e.to_string()))?;
393
394 Ok(yoagent::types::ToolResult {
395 content: vec![yoagent::types::Content::Text { text: result }],
396 details: serde_json::Value::Null,
397 })
398 }
399}
400
401struct WriteRenderer {
407 cache: std::sync::Mutex<Option<WriteHighlightCache>>,
408}
409
410impl WriteRenderer {
411 fn new() -> Self {
412 Self {
413 cache: std::sync::Mutex::new(None),
414 }
415 }
416}
417
418impl ToolRenderer for WriteRenderer {
419 fn render_call(
420 &self,
421 args: &serde_json::Value,
422 _width: usize,
423 theme: &dyn Theme,
424 ctx: &ToolRenderContext,
425 ) -> Vec<String> {
426 let raw_path = args
427 .get("file_path")
428 .or_else(|| args.get("path"))
429 .and_then(|v| v.as_str());
430 let content = args.get("content");
431
432 let path_display = if let Some(p) = raw_path {
435 let short = builtin::shorten_path(p);
436 let cwd = if ctx.cwd.is_empty() {
437 std::path::Path::new(".")
438 } else {
439 std::path::Path::new(&ctx.cwd)
440 };
441 builtin::link_path(&theme.fg_key(ThemeKey::Accent, &short), p, cwd)
442 } else {
443 String::new()
444 };
445
446 let header = format!(
447 "{} {}",
448 theme.fg_key(ThemeKey::ToolTitle, &theme.bold("write")),
449 path_display
450 );
451
452 let mut lines = vec![header];
453
454 let content_str = match content {
456 Some(content_val) => content_val.as_str(),
457 None => Some(""),
458 };
459
460 match content_str {
461 None => {
462 lines.push(String::new());
463 lines
464 .push(theme.fg_key(ThemeKey::Error, "[invalid content arg - expected string]"));
465 }
466 Some("") => {}
467 Some(text) => {
468 let mut cache_guard = self.cache.lock().unwrap();
470 *cache_guard =
471 update_highlight_cache_incremental(cache_guard.take(), raw_path, text);
472
473 let lang = raw_path.and_then(crate::tui::components::path_to_language);
474
475 let rendered_lines: Vec<String> = if let Some(ref cache) = *cache_guard {
477 cache.highlighted_lines.clone()
478 } else if lang.is_some() {
479 let normalized = text.replace('\r', "").replace('\t', " ");
481 #[cfg(feature = "syntect")]
482 {
483 let hl = crate::tui::components::highlight_code(&normalized, lang);
484 if !hl.is_empty() {
485 hl
486 } else {
487 normalized.lines().map(|l| l.to_string()).collect()
488 }
489 }
490 #[cfg(not(feature = "syntect"))]
491 {
492 normalized.lines().map(|l| l.to_string()).collect()
493 }
494 } else {
495 text.replace('\r', "")
497 .split('\n')
498 .map(|l| l.to_string())
499 .collect()
500 };
501
502 let trimmed = trim_trailing_empty_lines(&rendered_lines);
504 let total_lines = trimmed.len();
505 let max_lines = if ctx.expanded {
506 total_lines
507 } else {
508 PREVIEW_LINES
509 };
510 let display_lines = if total_lines > max_lines {
511 &trimmed[..max_lines]
512 } else {
513 trimmed
514 };
515 let remaining = total_lines.saturating_sub(max_lines);
516
517 let has_highlighting = cache_guard.is_some();
518
519 lines.push(String::new());
521
522 for line in display_lines {
523 let styled = if has_highlighting {
524 line.clone()
525 } else {
526 theme.fg_key(ThemeKey::ToolOutput, &line.replace('\t', " "))
527 };
528 lines.push(styled);
529 }
530
531 if remaining > 0 {
536 let dim_key = theme.fg_key(ThemeKey::Dim, &ctx.expand_key);
537 let muted_rest = theme.fg_key(
538 ThemeKey::Muted,
539 &format!("... ({} more lines, {} total, ", remaining, total_lines),
540 );
541 let muted_to_expand = theme.fg_key(ThemeKey::Muted, " to expand");
542 let muted_paren = theme.fg_key(ThemeKey::Muted, ")");
543 lines.push(format!(
544 "{}{}{}{}",
545 muted_rest, dim_key, muted_to_expand, muted_paren
546 ));
547 }
548 }
549 }
550
551 lines
552 }
553
554 fn render_result(
555 &self,
556 content: &str,
557 _width: usize,
558 theme: &dyn Theme,
559 ctx: &ToolRenderContext,
560 ) -> Vec<String> {
561 if !ctx.is_error || content.is_empty() {
564 return vec![];
565 }
566 vec![theme.fg_key(ThemeKey::Error, content)]
567 }
568}
569
570#[cfg(test)]
571mod tests {
572 use super::*;
573 use yoagent::AgentTool;
574 use yoagent::types::ToolContext;
575
576 fn tool_ctx() -> ToolContext {
577 ToolContext {
578 tool_call_id: "id".into(),
579 tool_name: "write".into(),
580 cancel: tokio_util::sync::CancellationToken::new(),
581 on_update: None,
582 on_progress: None,
583 }
584 }
585
586 fn tmp_dir() -> std::path::PathBuf {
587 let d = std::env::temp_dir().join(format!("rab-write-test-{}", uuid::Uuid::new_v4()));
588 std::fs::create_dir_all(&d).unwrap();
589 d
590 }
591
592 fn make_tool() -> (WriteTool, std::path::PathBuf) {
593 let tmp = tmp_dir();
594 let tool = WriteTool {
595 cwd: tmp.clone(),
596 operations: Arc::new(DefaultWriteOperations),
597 };
598 (tool, tmp)
599 }
600
601 async fn exec_ok(tool: &WriteTool, args: serde_json::Value) -> String {
602 let result = tool.execute(args, tool_ctx()).await.unwrap();
603 yo_msg_text(&result.content)
604 }
605
606 fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
607 content
608 .iter()
609 .filter_map(|c| {
610 if let yoagent::types::Content::Text { text } = c {
611 Some(text.as_str())
612 } else {
613 None
614 }
615 })
616 .collect::<Vec<_>>()
617 .join("")
618 }
619
620 #[tokio::test]
623 async fn writes_file_content() {
624 let (tool, tmp) = make_tool();
625 let path = tmp.join("test.txt");
626 let result = exec_ok(
627 &tool,
628 serde_json::json!({"path": path.to_str().unwrap(), "content": "hello world\n"}),
629 )
630 .await;
631
632 assert!(result.contains("Successfully wrote"));
633 assert!(result.contains("12 bytes"));
634 assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world\n");
635 }
636
637 #[tokio::test]
638 async fn creates_parent_directories() {
639 let (tool, tmp) = make_tool();
640 let path = tmp.join("subdir/nested/file.txt");
641 let result = exec_ok(
642 &tool,
643 serde_json::json!({"path": path.to_str().unwrap(), "content": "nested\n"}),
644 )
645 .await;
646
647 assert!(result.contains("Successfully wrote"));
648 assert!(path.exists());
649 assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n");
650 }
651
652 #[tokio::test]
653 async fn missing_path_errors() {
654 let (tool, _tmp) = make_tool();
655 let result = tool
656 .execute(serde_json::json!({"content": "hello"}), tool_ctx())
657 .await;
658 assert!(result.is_err());
659 }
660
661 #[tokio::test]
662 async fn missing_content_errors() {
663 let (tool, tmp) = make_tool();
664 let result = tool
665 .execute(
666 serde_json::json!({"path": tmp.join("test.txt").to_str().unwrap()}),
667 tool_ctx(),
668 )
669 .await;
670 assert!(result.is_err());
671 }
672
673 #[tokio::test]
674 async fn handles_empty_content() {
675 let (tool, tmp) = make_tool();
676 let path = tmp.join("empty.txt");
677 let result = exec_ok(
678 &tool,
679 serde_json::json!({"path": path.to_str().unwrap(), "content": ""}),
680 )
681 .await;
682
683 assert!(result.contains("Successfully wrote"));
684 assert_eq!(result.contains("0 bytes"), true);
685 }
686
687 #[tokio::test]
688 async fn cancel_aborts_write() {
689 let (tool, tmp) = make_tool();
690 let path = tmp.join("cancelled.txt");
691 let cancel = tokio_util::sync::CancellationToken::new();
692 cancel.cancel();
693
694 let result = tool
695 .execute(
696 serde_json::json!({"path": path.to_str().unwrap(), "content": "hello"}),
697 ToolContext {
698 tool_call_id: "id".into(),
699 tool_name: "write".into(),
700 cancel,
701 on_update: None,
702 on_progress: None,
703 },
704 )
705 .await;
706 assert!(result.is_err());
707 }
708
709 #[test]
710 fn test_highlight_single_line_empty() {
711 let result = highlight_single_line("", "rust");
712 assert_eq!(result, "");
713 }
714
715 #[test]
716 fn test_trim_trailing_empty_lines() {
717 let lines = vec![
718 "a".to_string(),
719 "b".to_string(),
720 "".to_string(),
721 "".to_string(),
722 ];
723 let trimmed = trim_trailing_empty_lines(&lines);
724 assert_eq!(trimmed, &["a".to_string(), "b".to_string()]);
725 }
726
727 #[test]
728 fn test_trim_no_trailing_empty_lines() {
729 let lines = vec!["a".to_string(), "b".to_string()];
730 let trimmed = trim_trailing_empty_lines(&lines);
731 assert_eq!(trimmed, &["a".to_string(), "b".to_string()]);
732 }
733
734 #[test]
735 fn test_trim_all_empty() {
736 let lines = vec!["".to_string(), "".to_string()];
737 let trimmed = trim_trailing_empty_lines(&lines);
738 assert!(trimmed.is_empty());
739 }
740
741 #[test]
742 fn test_trim_empty_input() {
743 let lines: Vec<String> = vec![];
744 let trimmed = trim_trailing_empty_lines(&lines);
745 assert!(trimmed.is_empty());
746 }
747
748 #[test]
749 fn test_rebuild_cache_unknown_lang() {
750 let result = rebuild_highlight_cache(Some("foo.unknown"), "hello");
751 assert!(result.is_none());
752 }
753
754 #[test]
755 fn test_rebuild_cache_known_lang() {
756 let result = rebuild_highlight_cache(Some("foo.rs"), "fn main() {}");
757 assert!(result.is_some());
758 let cache = result.unwrap();
759 assert_eq!(cache.lang, "rust");
760 assert_eq!(cache.raw_content, "fn main() {}");
761 }
762
763 #[test]
764 fn test_incremental_update_extends_content() {
765 let cache = rebuild_highlight_cache(Some("foo.rs"), "fn main()");
766 assert!(cache.is_some());
767 let cache = cache.unwrap();
768 assert_eq!(cache.normalized_lines.len(), 1);
769
770 let updated =
771 update_highlight_cache_incremental(Some(cache), Some("foo.rs"), "fn main() {}");
772 assert!(updated.is_some());
773 let updated = updated.unwrap();
774 assert_eq!(updated.raw_content, "fn main() {}");
775 }
776
777 #[tokio::test]
778 async fn relative_path_resolves_to_cwd() {
779 let (tool, tmp) = make_tool();
780 let result = exec_ok(
781 &tool,
782 serde_json::json!({"path": "relative.txt", "content": "hello\n"}),
783 )
784 .await;
785
786 assert!(result.contains("Successfully wrote"));
787 let abs_path = tmp.join("relative.txt");
788 assert!(abs_path.exists());
789 }
790
791 #[tokio::test]
792 async fn absolute_path_is_resolved_correctly() {
793 let (tool, _tmp) = make_tool();
794 let tmp2 = tmp_dir();
795 let path = tmp2.join("abs.txt");
796 let result = exec_ok(
797 &tool,
798 serde_json::json!({"path": path.to_str().unwrap(), "content": "absolute\n"}),
799 )
800 .await;
801
802 assert!(result.contains("Successfully wrote"));
803 assert!(path.exists());
804 }
805}