1use async_trait::async_trait;
2use rho_core::tool::{AgentTool, ToolError};
3use rho_core::types::{Content, ToolResult};
4use serde_json::Value;
5use tokio_util::sync::CancellationToken;
6
7pub struct WriteTool {
8 working_dir: Option<std::path::PathBuf>,
9}
10
11impl WriteTool {
12 pub fn new() -> Self {
13 WriteTool { working_dir: None }
14 }
15
16 pub fn with_cwd(cwd: std::path::PathBuf) -> Self {
17 WriteTool {
18 working_dir: Some(cwd),
19 }
20 }
21
22 fn resolve_path(&self, path: &str) -> std::path::PathBuf {
23 let p = std::path::Path::new(path);
24 if p.is_absolute() {
25 p.to_path_buf()
26 } else if let Some(ref cwd) = self.working_dir {
27 cwd.join(p)
28 } else {
29 p.to_path_buf()
30 }
31 }
32}
33
34impl Default for WriteTool {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[async_trait]
41impl AgentTool for WriteTool {
42 fn name(&self) -> &str {
43 "write"
44 }
45
46 fn label(&self) -> String {
47 "Write File".to_string()
48 }
49
50 fn description(&self) -> String {
51 "Write content to a file at the specified path. Creates parent directories if they don't exist.".to_string()
52 }
53
54 fn parameters_schema(&self) -> Value {
55 serde_json::json!({
56 "type": "object",
57 "properties": {
58 "path": {
59 "type": "string",
60 "description": "The absolute path to the file to write"
61 },
62 "content": {
63 "type": "string",
64 "description": "The content to write to the file"
65 }
66 },
67 "required": ["path", "content"]
68 })
69 }
70
71 async fn execute(
72 &self,
73 _tool_call_id: &str,
74 params: Value,
75 _cancel: CancellationToken,
76 ) -> Result<ToolResult, ToolError> {
77 let path = params
78 .get("path")
79 .and_then(|v| v.as_str())
80 .ok_or_else(|| ToolError::InvalidParameters("missing or invalid 'path' parameter".into()))?;
81
82 let content = params
83 .get("content")
84 .and_then(|v| v.as_str())
85 .ok_or_else(|| {
86 ToolError::InvalidParameters("missing or invalid 'content' parameter".into())
87 })?;
88
89 let file_path = self.resolve_path(path);
90
91 if let Some(parent) = file_path.parent() {
92 tokio::fs::create_dir_all(parent)
93 .await
94 .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
95 }
96
97 let bytes = content.len();
98 tokio::fs::write(&file_path, content)
99 .await
100 .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
101
102 crate::git_helpers::auto_commit_file(&file_path, "write").await;
103
104 Ok(ToolResult {
105 content: vec![Content::Text {
106 text: format!("Successfully wrote {} bytes to {}", bytes, path),
107 }],
108 details: serde_json::json!({}),
109 })
110 }
111}
112
113#[cfg(test)]
114mod tests {
115 use super::*;
116 use std::path::Path;
117
118 #[tokio::test]
119 async fn write_to_temp_file() {
120 let tool = WriteTool::new();
121 let dir = tempfile::tempdir().unwrap();
122 let file_path = dir.path().join("test.txt");
123
124 let params = serde_json::json!({
125 "path": file_path.to_str().unwrap(),
126 "content": "hello world"
127 });
128
129 let result = tool
130 .execute("call_1", params, CancellationToken::new())
131 .await
132 .unwrap();
133
134 assert_eq!(result.content.len(), 1);
135 match &result.content[0] {
136 Content::Text { text } => {
137 assert!(text.contains("11 bytes"));
138 assert!(text.contains("test.txt"));
139 }
140 _ => panic!("expected Text content"),
141 }
142
143 let written = std::fs::read_to_string(&file_path).unwrap();
144 assert_eq!(written, "hello world");
145 }
146
147 #[tokio::test]
148 async fn missing_path_parameter() {
149 let tool = WriteTool::new();
150 let params = serde_json::json!({
151 "content": "hello"
152 });
153
154 let err = tool
155 .execute("call_2", params, CancellationToken::new())
156 .await
157 .unwrap_err();
158
159 match err {
160 ToolError::InvalidParameters(msg) => assert!(msg.contains("path")),
161 _ => panic!("expected InvalidParameters"),
162 }
163 }
164
165 #[tokio::test]
166 async fn creates_parent_directories() {
167 let tool = WriteTool::new();
168 let dir = tempfile::tempdir().unwrap();
169 let file_path = dir.path().join("a").join("b").join("c").join("test.txt");
170
171 assert!(!Path::new(dir.path().join("a").to_str().unwrap()).exists());
172
173 let params = serde_json::json!({
174 "path": file_path.to_str().unwrap(),
175 "content": "nested content"
176 });
177
178 let result = tool
179 .execute("call_3", params, CancellationToken::new())
180 .await
181 .unwrap();
182
183 assert_eq!(result.content.len(), 1);
184 let written = std::fs::read_to_string(&file_path).unwrap();
185 assert_eq!(written, "nested content");
186 }
187}