Skip to main content

rho_tools/
write.rs

1use async_trait::async_trait;
2use rho_core::tool::{AgentTool, ToolError};
3use rho_core::types::{Content, ToolResult};
4use serde_json::Value;
5use tokio_util::sync::CancellationToken;
6
7pub struct WriteTool {
8    working_dir: Option<std::path::PathBuf>,
9}
10
11impl WriteTool {
12    pub fn new() -> Self {
13        WriteTool { working_dir: None }
14    }
15
16    pub fn with_cwd(cwd: std::path::PathBuf) -> Self {
17        WriteTool {
18            working_dir: Some(cwd),
19        }
20    }
21
22    fn resolve_path(&self, path: &str) -> std::path::PathBuf {
23        let p = std::path::Path::new(path);
24        if p.is_absolute() {
25            p.to_path_buf()
26        } else if let Some(ref cwd) = self.working_dir {
27            cwd.join(p)
28        } else {
29            p.to_path_buf()
30        }
31    }
32}
33
34impl Default for WriteTool {
35    fn default() -> Self {
36        Self::new()
37    }
38}
39
40#[async_trait]
41impl AgentTool for WriteTool {
42    fn name(&self) -> &str {
43        "write"
44    }
45
46    fn label(&self) -> String {
47        "Write File".to_string()
48    }
49
50    fn description(&self) -> String {
51        "Write content to a file at the specified path. Creates parent directories if they don't exist.".to_string()
52    }
53
54    fn parameters_schema(&self) -> Value {
55        serde_json::json!({
56            "type": "object",
57            "properties": {
58                "path": {
59                    "type": "string",
60                    "description": "The absolute path to the file to write"
61                },
62                "content": {
63                    "type": "string",
64                    "description": "The content to write to the file"
65                }
66            },
67            "required": ["path", "content"]
68        })
69    }
70
71    async fn execute(
72        &self,
73        _tool_call_id: &str,
74        params: Value,
75        _cancel: CancellationToken,
76    ) -> Result<ToolResult, ToolError> {
77        let path = params
78            .get("path")
79            .and_then(|v| v.as_str())
80            .ok_or_else(|| ToolError::InvalidParameters("missing or invalid 'path' parameter".into()))?;
81
82        let content = params
83            .get("content")
84            .and_then(|v| v.as_str())
85            .ok_or_else(|| {
86                ToolError::InvalidParameters("missing or invalid 'content' parameter".into())
87            })?;
88
89        let file_path = self.resolve_path(path);
90
91        if let Some(parent) = file_path.parent() {
92            tokio::fs::create_dir_all(parent)
93                .await
94                .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
95        }
96
97        let bytes = content.len();
98        tokio::fs::write(&file_path, content)
99            .await
100            .map_err(|e| ToolError::ExecutionFailed(e.to_string()))?;
101
102        crate::git_helpers::auto_commit_file(&file_path, "write").await;
103
104        Ok(ToolResult {
105            content: vec![Content::Text {
106                text: format!("Successfully wrote {} bytes to {}", bytes, path),
107            }],
108            details: serde_json::json!({}),
109        })
110    }
111}
112
113#[cfg(test)]
114mod tests {
115    use super::*;
116    use std::path::Path;
117
118    #[tokio::test]
119    async fn write_to_temp_file() {
120        let tool = WriteTool::new();
121        let dir = tempfile::tempdir().unwrap();
122        let file_path = dir.path().join("test.txt");
123
124        let params = serde_json::json!({
125            "path": file_path.to_str().unwrap(),
126            "content": "hello world"
127        });
128
129        let result = tool
130            .execute("call_1", params, CancellationToken::new())
131            .await
132            .unwrap();
133
134        assert_eq!(result.content.len(), 1);
135        match &result.content[0] {
136            Content::Text { text } => {
137                assert!(text.contains("11 bytes"));
138                assert!(text.contains("test.txt"));
139            }
140            _ => panic!("expected Text content"),
141        }
142
143        let written = std::fs::read_to_string(&file_path).unwrap();
144        assert_eq!(written, "hello world");
145    }
146
147    #[tokio::test]
148    async fn missing_path_parameter() {
149        let tool = WriteTool::new();
150        let params = serde_json::json!({
151            "content": "hello"
152        });
153
154        let err = tool
155            .execute("call_2", params, CancellationToken::new())
156            .await
157            .unwrap_err();
158
159        match err {
160            ToolError::InvalidParameters(msg) => assert!(msg.contains("path")),
161            _ => panic!("expected InvalidParameters"),
162        }
163    }
164
165    #[tokio::test]
166    async fn creates_parent_directories() {
167        let tool = WriteTool::new();
168        let dir = tempfile::tempdir().unwrap();
169        let file_path = dir.path().join("a").join("b").join("c").join("test.txt");
170
171        assert!(!Path::new(dir.path().join("a").to_str().unwrap()).exists());
172
173        let params = serde_json::json!({
174            "path": file_path.to_str().unwrap(),
175            "content": "nested content"
176        });
177
178        let result = tool
179            .execute("call_3", params, CancellationToken::new())
180            .await
181            .unwrap();
182
183        assert_eq!(result.content.len(), 1);
184        let written = std::fs::read_to_string(&file_path).unwrap();
185        assert_eq!(written, "nested content");
186    }
187}