1use crate::agent::extension::{Extension, ToolDefinition};
2use crate::agent::extension::{ToolRenderContext, ToolRenderer};
3use crate::builtin;
4use crate::tui::Theme;
5use crate::tui::ThemeKey;
6
7use std::borrow::Cow;
8use std::path::Path;
9use std::sync::Arc;
10
11pub fn prepare_write_args(mut args: serde_json::Value) -> Result<serde_json::Value, String> {
14 if let Some(val) = args.get("path")
16 && !val.is_string()
17 {
18 if val.is_number() || val.is_boolean() {
19 args["path"] = serde_json::Value::String(match val {
20 serde_json::Value::Number(n) => n.to_string(),
21 serde_json::Value::Bool(b) => b.to_string(),
22 _ => unreachable!(),
23 });
24 } else if val.is_null() {
25 return Err("Missing 'path' argument".to_string());
26 }
27 }
28
29 if let Some(val) = args.get("content")
31 && !val.is_string()
32 {
33 if val.is_number() || val.is_boolean() || val.is_null() {
34 args["content"] = serde_json::Value::String(match val {
35 serde_json::Value::Number(n) => n.to_string(),
36 serde_json::Value::Bool(b) => b.to_string(),
37 serde_json::Value::Null => String::new(),
38 _ => unreachable!(),
39 });
40 } else if val.is_array() || val.is_object() {
41 args["content"] =
43 serde_json::Value::String(serde_json::to_string(val).unwrap_or_default());
44 }
45 }
46
47 Ok(args)
48}
49
50const PREVIEW_LINES: usize = 10;
52
53const PARTIAL_FULL_HIGHLIGHT_LINES: usize = 50;
56
57pub trait WriteOperations: Send + Sync {
62 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()>;
64 fn mkdir(&self, dir: &Path) -> anyhow::Result<()>;
66}
67
68impl<F1, F2> WriteOperations for (F1, F2)
69where
70 F1: Send + Sync + Fn(&Path, &str) -> anyhow::Result<()>,
71 F2: Send + Sync + Fn(&Path) -> anyhow::Result<()>,
72{
73 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()> {
74 (self.0)(absolute_path, content)
75 }
76 fn mkdir(&self, dir: &Path) -> anyhow::Result<()> {
77 (self.1)(dir)
78 }
79}
80
81struct DefaultWriteOperations;
82
83impl WriteOperations for DefaultWriteOperations {
84 fn write_file(&self, absolute_path: &Path, content: &str) -> anyhow::Result<()> {
85 Ok(std::fs::write(absolute_path, content)?)
86 }
87 fn mkdir(&self, dir: &Path) -> anyhow::Result<()> {
88 Ok(std::fs::create_dir_all(dir)?)
89 }
90}
91
92pub struct WriteExtension {
95 cwd: std::path::PathBuf,
96 operations: Arc<dyn WriteOperations>,
97}
98
99impl WriteExtension {
100 pub fn new(cwd: std::path::PathBuf) -> Self {
101 Self {
102 cwd,
103 operations: Arc::new(DefaultWriteOperations),
104 }
105 }
106
107 pub fn with_operations(mut self, operations: Arc<dyn WriteOperations>) -> Self {
109 self.operations = operations;
110 self
111 }
112}
113
114impl Extension for WriteExtension {
115 fn name(&self) -> Cow<'static, str> {
116 "write".into()
117 }
118
119 fn as_any(&self) -> &dyn std::any::Any {
120 self
121 }
122
123 fn tools(&self) -> Vec<ToolDefinition> {
124 vec![ToolDefinition {
125 tool: Box::new(WriteTool {
126 cwd: self.cwd.clone(),
127 operations: self.operations.clone(),
128 }),
129 snippet: "Create or overwrite files",
130 guidelines: &["Use write only for new files or complete rewrites."],
131 prepare_arguments: Some(prepare_write_args),
132 before_tool_call: None,
133 after_tool_call: None,
134 renderer: Some(std::sync::Arc::new(WriteRenderer::new())),
135 }]
136 }
137}
138
139struct WriteTool {
142 cwd: std::path::PathBuf,
143 operations: Arc<dyn WriteOperations>,
144}
145
146struct WriteHighlightCache {
151 raw_path: Option<String>,
152 lang: String,
153 raw_content: String,
154 normalized_lines: Vec<String>,
155 highlighted_lines: Vec<String>,
156}
157
158fn highlight_single_line(line: &str, lang: &str) -> String {
160 #[cfg(feature = "syntect")]
161 {
162 let hl = crate::tui::components::highlight_code(line, Some(lang));
163 if !hl.is_empty() {
164 return hl[0].clone();
165 }
166 }
167 line.to_string()
168}
169
170fn refresh_highlight_prefix(cache: &mut WriteHighlightCache) {
173 let prefix_count = PARTIAL_FULL_HIGHLIGHT_LINES.min(cache.normalized_lines.len());
174 if prefix_count == 0 {
175 return;
176 }
177 let prefix_source: Vec<&str> = cache.normalized_lines[..prefix_count]
178 .iter()
179 .map(|s| s.as_str())
180 .collect();
181 let prefix_text = prefix_source.join("\n");
182 #[cfg(feature = "syntect")]
183 {
184 let prefix_highlighted =
185 crate::tui::components::highlight_code(&prefix_text, Some(&cache.lang));
186 for i in 0..prefix_count {
187 cache.highlighted_lines[i] = prefix_highlighted
188 .get(i)
189 .cloned()
190 .unwrap_or_else(|| highlight_single_line(&cache.normalized_lines[i], &cache.lang));
191 }
192 }
193 #[cfg(not(feature = "syntect"))]
194 {
195 let _ = prefix_text;
196 for i in 0..prefix_count {
197 cache.highlighted_lines[i] = cache.normalized_lines[i].clone();
198 }
199 }
200}
201
202fn rebuild_highlight_cache(
204 raw_path: Option<&str>,
205 file_content: &str,
206) -> Option<WriteHighlightCache> {
207 let lang = raw_path
208 .and_then(crate::tui::components::path_to_language)
209 .map(|s| s.to_string());
210 let lang = lang?;
211
212 let display_content = file_content.replace('\r', "");
213 let normalized = display_content.replace('\t', " ");
214 let normalized_lines: Vec<String> = normalized.lines().map(|l| l.to_string()).collect();
215
216 #[cfg(feature = "syntect")]
217 let highlighted_lines = crate::tui::components::highlight_code(&normalized, Some(&lang));
218 #[cfg(not(feature = "syntect"))]
219 let highlighted_lines = normalized_lines.clone();
220
221 Some(WriteHighlightCache {
222 raw_path: raw_path.map(|s| s.to_string()),
223 lang,
224 raw_content: file_content.to_string(),
225 normalized_lines,
226 highlighted_lines,
227 })
228}
229
230fn update_highlight_cache_incremental(
233 cache: Option<WriteHighlightCache>,
234 raw_path: Option<&str>,
235 file_content: &str,
236) -> Option<WriteHighlightCache> {
237 let lang = raw_path
238 .and_then(crate::tui::components::path_to_language)
239 .map(|s| s.to_string());
240 let lang = lang?;
241
242 let mut cache = match cache {
243 Some(c) => c,
244 None => return rebuild_highlight_cache(raw_path, file_content),
245 };
246
247 if cache.lang != lang || cache.raw_path.as_deref() != raw_path {
249 return rebuild_highlight_cache(raw_path, file_content);
250 }
251
252 if !file_content.starts_with(&cache.raw_content) {
254 return rebuild_highlight_cache(raw_path, file_content);
255 }
256
257 if file_content.len() == cache.raw_content.len() {
259 return Some(cache);
260 }
261
262 let delta_raw = &file_content[cache.raw_content.len()..];
264 let delta_display = delta_raw.replace('\r', "");
265 let delta_normalized = delta_display.replace('\t', " ");
266
267 cache.raw_content = file_content.to_string();
268
269 if cache.normalized_lines.is_empty() {
270 cache.normalized_lines.push(String::new());
271 cache.highlighted_lines.push(String::new());
272 }
273
274 let segments: Vec<&str> = delta_normalized.split('\n').collect();
275 if segments.is_empty() {
276 return Some(cache);
277 }
278
279 let last_idx = cache.normalized_lines.len() - 1;
281 cache.normalized_lines[last_idx].push_str(segments[0]);
282 cache.highlighted_lines[last_idx] =
283 highlight_single_line(&cache.normalized_lines[last_idx], &cache.lang);
284
285 for &seg in &segments[1..] {
287 cache.normalized_lines.push(seg.to_string());
288 cache
289 .highlighted_lines
290 .push(highlight_single_line(seg, &cache.lang));
291 }
292
293 refresh_highlight_prefix(&mut cache);
295
296 Some(cache)
297}
298
299fn trim_trailing_empty_lines(lines: &[String]) -> &[String] {
301 let mut end = lines.len();
302 while end > 0 && lines[end - 1].is_empty() {
303 end -= 1;
304 }
305 &lines[..end]
306}
307
308#[async_trait::async_trait]
309impl yoagent::types::AgentTool for WriteTool {
310 fn name(&self) -> &str {
311 "write"
312 }
313 fn label(&self) -> &str {
314 "write"
315 }
316 fn description(&self) -> &str {
317 "Write content to a file. Creates the file if it doesn't exist, overwrites if it does. \
318 Automatically creates parent directories."
319 }
320 fn parameters_schema(&self) -> serde_json::Value {
321 serde_json::json!({
322 "type": "object",
323 "required": ["path", "content"],
324 "properties": {
325 "path": {
326 "type": "string",
327 "description": "Path to the file to write"
328 },
329 "content": {
330 "type": "string",
331 "description": "Content to write to the file"
332 }
333 }
334 })
335 }
336 async fn execute(
337 &self,
338 params: serde_json::Value,
339 ctx: yoagent::types::ToolContext,
340 ) -> std::result::Result<yoagent::types::ToolResult, yoagent::types::ToolError> {
341 let path = params["path"]
342 .as_str()
343 .ok_or_else(|| {
344 yoagent::types::ToolError::InvalidArgs("Missing 'path' argument".into())
345 })?
346 .to_string();
347 let content = params["content"]
348 .as_str()
349 .ok_or_else(|| {
350 yoagent::types::ToolError::InvalidArgs("Missing 'content' argument".into())
351 })?
352 .to_string();
353
354 if ctx.cancel.is_cancelled() {
355 return Err(yoagent::types::ToolError::Cancelled);
356 }
357
358 let cwd = self.cwd.clone();
359 let cancel = ctx.cancel.clone();
360 let ops = self.operations.clone();
361 let path_for_queue = path.clone();
362 let cwd_for_closure = cwd.clone();
363 let content_for_closure = content.clone();
364
365 let result = crate::builtin::file_mutation_queue::with_file_mutation_queue(
366 &path_for_queue,
367 &cwd,
368 || async move {
369 let abs_path = builtin::resolve_path(&path, &cwd_for_closure);
370
371 if let Some(parent) = abs_path.parent() {
373 ops.mkdir(parent).map_err(|e| {
374 anyhow::anyhow!("Failed to create dir {}: {}", parent.display(), e)
375 })?;
376 }
377
378 if cancel.is_cancelled() {
379 anyhow::bail!("Operation cancelled");
380 }
381
382 ops.write_file(&abs_path, &content_for_closure)
384 .map_err(|e| {
385 anyhow::anyhow!("Failed to write {}: {}", abs_path.display(), e)
386 })?;
387
388 Ok::<_, anyhow::Error>(format!(
389 "Successfully wrote {} bytes to {}",
390 content_for_closure.len(),
391 path
392 ))
393 },
394 )
395 .await
396 .map_err(|e| yoagent::types::ToolError::Failed(e.to_string()))?;
397
398 Ok(yoagent::types::ToolResult {
399 content: vec![yoagent::types::Content::Text { text: result }],
400 details: serde_json::Value::Null,
401 })
402 }
403}
404
405struct WriteRenderer {
411 cache: std::sync::Mutex<Option<WriteHighlightCache>>,
412}
413
414impl WriteRenderer {
415 fn new() -> Self {
416 Self {
417 cache: std::sync::Mutex::new(None),
418 }
419 }
420}
421
422impl ToolRenderer for WriteRenderer {
423 fn render_call(
424 &self,
425 args: &serde_json::Value,
426 _width: usize,
427 theme: &dyn Theme,
428 ctx: &ToolRenderContext,
429 ) -> Vec<String> {
430 let raw_path = args
431 .get("file_path")
432 .or_else(|| args.get("path"))
433 .and_then(|v| v.as_str());
434 let content = args.get("content");
435
436 let path_display = if let Some(p) = raw_path {
439 let short = builtin::shorten_path(p);
440 let cwd = if ctx.cwd.is_empty() {
441 std::path::Path::new(".")
442 } else {
443 std::path::Path::new(&ctx.cwd)
444 };
445 builtin::link_path(&theme.fg_key(ThemeKey::Accent, &short), p, cwd)
446 } else {
447 String::new()
448 };
449
450 let header = format!(
451 "{} {}",
452 theme.fg_key(ThemeKey::ToolTitle, &theme.bold("write")),
453 path_display
454 );
455
456 let mut lines = vec![header];
457
458 let content_str = match content {
460 Some(content_val) => content_val.as_str(),
461 None => Some(""),
462 };
463
464 match content_str {
465 None => {
466 lines.push(String::new());
467 lines
468 .push(theme.fg_key(ThemeKey::Error, "[invalid content arg - expected string]"));
469 }
470 Some("") => {}
471 Some(text) => {
472 let mut cache_guard = self.cache.lock().unwrap();
474 *cache_guard =
475 update_highlight_cache_incremental(cache_guard.take(), raw_path, text);
476
477 let lang = raw_path.and_then(crate::tui::components::path_to_language);
478
479 let rendered_lines: Vec<String> = if let Some(ref cache) = *cache_guard {
481 cache.highlighted_lines.clone()
482 } else if lang.is_some() {
483 let normalized = text.replace('\r', "").replace('\t', " ");
485 #[cfg(feature = "syntect")]
486 {
487 let hl = crate::tui::components::highlight_code(&normalized, lang);
488 if !hl.is_empty() {
489 hl
490 } else {
491 normalized.lines().map(|l| l.to_string()).collect()
492 }
493 }
494 #[cfg(not(feature = "syntect"))]
495 {
496 normalized.lines().map(|l| l.to_string()).collect()
497 }
498 } else {
499 text.replace('\r', "")
501 .split('\n')
502 .map(|l| l.to_string())
503 .collect()
504 };
505
506 let trimmed = trim_trailing_empty_lines(&rendered_lines);
508 let total_lines = trimmed.len();
509 let max_lines = if ctx.expanded {
510 total_lines
511 } else {
512 PREVIEW_LINES
513 };
514 let display_lines = if total_lines > max_lines {
515 &trimmed[..max_lines]
516 } else {
517 trimmed
518 };
519 let remaining = total_lines.saturating_sub(max_lines);
520
521 let has_highlighting = cache_guard.is_some();
522
523 lines.push(String::new());
525
526 for line in display_lines {
527 let styled = if has_highlighting {
528 line.clone()
529 } else {
530 theme.fg_key(ThemeKey::ToolOutput, &line.replace('\t', " "))
531 };
532 lines.push(styled);
533 }
534
535 if remaining > 0 {
540 let dim_key = theme.fg_key(ThemeKey::Dim, &ctx.expand_key);
541 let muted_rest = theme.fg_key(
542 ThemeKey::Muted,
543 &format!("... ({} more lines, {} total, ", remaining, total_lines),
544 );
545 let muted_to_expand = theme.fg_key(ThemeKey::Muted, " to expand");
546 let muted_paren = theme.fg_key(ThemeKey::Muted, ")");
547 lines.push(format!(
548 "{}{}{}{}",
549 muted_rest, dim_key, muted_to_expand, muted_paren
550 ));
551 }
552 }
553 }
554
555 lines
556 }
557
558 fn render_result(
559 &self,
560 content: &str,
561 _width: usize,
562 theme: &dyn Theme,
563 ctx: &ToolRenderContext,
564 ) -> Vec<String> {
565 if !ctx.is_error || content.is_empty() {
568 return vec![];
569 }
570 vec![theme.fg_key(ThemeKey::Error, content)]
571 }
572}
573
574#[cfg(test)]
575mod tests {
576 use super::*;
577 use yoagent::AgentTool;
578 use yoagent::types::ToolContext;
579
580 fn tool_ctx() -> ToolContext {
581 ToolContext {
582 tool_call_id: "id".into(),
583 tool_name: "write".into(),
584 cancel: tokio_util::sync::CancellationToken::new(),
585 on_update: None,
586 on_progress: None,
587 }
588 }
589
590 fn tmp_dir() -> std::path::PathBuf {
591 let d = std::env::temp_dir().join(format!("rab-write-test-{}", uuid::Uuid::new_v4()));
592 std::fs::create_dir_all(&d).unwrap();
593 d
594 }
595
596 fn make_tool() -> (WriteTool, std::path::PathBuf) {
597 let tmp = tmp_dir();
598 let tool = WriteTool {
599 cwd: tmp.clone(),
600 operations: Arc::new(DefaultWriteOperations),
601 };
602 (tool, tmp)
603 }
604
605 async fn exec_ok(tool: &WriteTool, args: serde_json::Value) -> String {
606 let result = tool.execute(args, tool_ctx()).await.unwrap();
607 yo_msg_text(&result.content)
608 }
609
610 fn yo_msg_text(content: &[yoagent::types::Content]) -> String {
611 content
612 .iter()
613 .filter_map(|c| {
614 if let yoagent::types::Content::Text { text } = c {
615 Some(text.as_str())
616 } else {
617 None
618 }
619 })
620 .collect::<Vec<_>>()
621 .join("")
622 }
623
624 #[tokio::test]
627 async fn writes_file_content() {
628 let (tool, tmp) = make_tool();
629 let path = tmp.join("test.txt");
630 let result = exec_ok(
631 &tool,
632 serde_json::json!({"path": path.to_str().unwrap(), "content": "hello world\n"}),
633 )
634 .await;
635
636 assert!(result.contains("Successfully wrote"));
637 assert!(result.contains("12 bytes"));
638 assert_eq!(std::fs::read_to_string(&path).unwrap(), "hello world\n");
639 }
640
641 #[tokio::test]
642 async fn creates_parent_directories() {
643 let (tool, tmp) = make_tool();
644 let path = tmp.join("subdir/nested/file.txt");
645 let result = exec_ok(
646 &tool,
647 serde_json::json!({"path": path.to_str().unwrap(), "content": "nested\n"}),
648 )
649 .await;
650
651 assert!(result.contains("Successfully wrote"));
652 assert!(path.exists());
653 assert_eq!(std::fs::read_to_string(&path).unwrap(), "nested\n");
654 }
655
656 #[tokio::test]
657 async fn missing_path_errors() {
658 let (tool, _tmp) = make_tool();
659 let result = tool
660 .execute(serde_json::json!({"content": "hello"}), tool_ctx())
661 .await;
662 assert!(result.is_err());
663 }
664
665 #[tokio::test]
666 async fn missing_content_errors() {
667 let (tool, tmp) = make_tool();
668 let result = tool
669 .execute(
670 serde_json::json!({"path": tmp.join("test.txt").to_str().unwrap()}),
671 tool_ctx(),
672 )
673 .await;
674 assert!(result.is_err());
675 }
676
677 #[tokio::test]
678 async fn handles_empty_content() {
679 let (tool, tmp) = make_tool();
680 let path = tmp.join("empty.txt");
681 let result = exec_ok(
682 &tool,
683 serde_json::json!({"path": path.to_str().unwrap(), "content": ""}),
684 )
685 .await;
686
687 assert!(result.contains("Successfully wrote"));
688 assert_eq!(result.contains("0 bytes"), true);
689 }
690
691 #[tokio::test]
692 async fn cancel_aborts_write() {
693 let (tool, tmp) = make_tool();
694 let path = tmp.join("cancelled.txt");
695 let cancel = tokio_util::sync::CancellationToken::new();
696 cancel.cancel();
697
698 let result = tool
699 .execute(
700 serde_json::json!({"path": path.to_str().unwrap(), "content": "hello"}),
701 ToolContext {
702 tool_call_id: "id".into(),
703 tool_name: "write".into(),
704 cancel,
705 on_update: None,
706 on_progress: None,
707 },
708 )
709 .await;
710 assert!(result.is_err());
711 }
712
713 #[test]
714 fn test_highlight_single_line_empty() {
715 let result = highlight_single_line("", "rust");
716 assert_eq!(result, "");
717 }
718
719 #[test]
720 fn test_trim_trailing_empty_lines() {
721 let lines = vec![
722 "a".to_string(),
723 "b".to_string(),
724 "".to_string(),
725 "".to_string(),
726 ];
727 let trimmed = trim_trailing_empty_lines(&lines);
728 assert_eq!(trimmed, &["a".to_string(), "b".to_string()]);
729 }
730
731 #[test]
732 fn test_trim_no_trailing_empty_lines() {
733 let lines = vec!["a".to_string(), "b".to_string()];
734 let trimmed = trim_trailing_empty_lines(&lines);
735 assert_eq!(trimmed, &["a".to_string(), "b".to_string()]);
736 }
737
738 #[test]
739 fn test_trim_all_empty() {
740 let lines = vec!["".to_string(), "".to_string()];
741 let trimmed = trim_trailing_empty_lines(&lines);
742 assert!(trimmed.is_empty());
743 }
744
745 #[test]
746 fn test_trim_empty_input() {
747 let lines: Vec<String> = vec![];
748 let trimmed = trim_trailing_empty_lines(&lines);
749 assert!(trimmed.is_empty());
750 }
751
752 #[test]
753 fn test_rebuild_cache_unknown_lang() {
754 let result = rebuild_highlight_cache(Some("foo.unknown"), "hello");
755 assert!(result.is_none());
756 }
757
758 #[test]
759 fn test_rebuild_cache_known_lang() {
760 let result = rebuild_highlight_cache(Some("foo.rs"), "fn main() {}");
761 assert!(result.is_some());
762 let cache = result.unwrap();
763 assert_eq!(cache.lang, "rust");
764 assert_eq!(cache.raw_content, "fn main() {}");
765 }
766
767 #[test]
768 fn test_incremental_update_extends_content() {
769 let cache = rebuild_highlight_cache(Some("foo.rs"), "fn main()");
770 assert!(cache.is_some());
771 let cache = cache.unwrap();
772 assert_eq!(cache.normalized_lines.len(), 1);
773
774 let updated =
775 update_highlight_cache_incremental(Some(cache), Some("foo.rs"), "fn main() {}");
776 assert!(updated.is_some());
777 let updated = updated.unwrap();
778 assert_eq!(updated.raw_content, "fn main() {}");
779 }
780
781 #[tokio::test]
782 async fn relative_path_resolves_to_cwd() {
783 let (tool, tmp) = make_tool();
784 let result = exec_ok(
785 &tool,
786 serde_json::json!({"path": "relative.txt", "content": "hello\n"}),
787 )
788 .await;
789
790 assert!(result.contains("Successfully wrote"));
791 let abs_path = tmp.join("relative.txt");
792 assert!(abs_path.exists());
793 }
794
795 #[tokio::test]
796 async fn absolute_path_is_resolved_correctly() {
797 let (tool, _tmp) = make_tool();
798 let tmp2 = tmp_dir();
799 let path = tmp2.join("abs.txt");
800 let result = exec_ok(
801 &tool,
802 serde_json::json!({"path": path.to_str().unwrap(), "content": "absolute\n"}),
803 )
804 .await;
805
806 assert!(result.contains("Successfully wrote"));
807 assert!(path.exists());
808 }
809}