bamboo_tools/tools/
write.rs1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use serde::Deserialize;
4use serde_json::json;
5use std::path::Path;
6
7use super::read_tracker::ReadState;
8use super::{content_diagnostics, file_change, read_tracker};
9
10#[derive(Debug, Deserialize)]
11struct WriteArgs {
12 file_path: String,
13 content: String,
14}
15
16pub struct WriteTool;
17
18impl WriteTool {
19 pub fn new() -> Self {
20 Self
21 }
22}
23
24impl Default for WriteTool {
25 fn default() -> Self {
26 Self::new()
27 }
28}
29
30#[async_trait]
31impl Tool for WriteTool {
32 fn name(&self) -> &str {
33 "Write"
34 }
35
36 fn description(&self) -> &str {
37 "Write a local file (create or replace full content). IMPORTANT: for existing files, call Read first in this session or Write will fail."
38 }
39
40 fn parameters_schema(&self) -> serde_json::Value {
41 json!({
42 "type": "object",
43 "properties": {
44 "file_path": {
45 "type": "string",
46 "description": "The absolute path to the file to write"
47 },
48 "content": {
49 "type": "string",
50 "description": "The content to write to the file"
51 }
52 },
53 "required": ["file_path", "content"],
54 "additionalProperties": false
55 })
56 }
57
58 async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
59 self.execute_with_context(args, ToolExecutionContext::none("Write"))
60 .await
61 }
62
63 async fn execute_with_context(
64 &self,
65 args: serde_json::Value,
66 ctx: ToolExecutionContext<'_>,
67 ) -> Result<ToolResult, ToolError> {
68 let parsed: WriteArgs = serde_json::from_value(args)
69 .map_err(|e| ToolError::InvalidArguments(format!("Invalid Write args: {}", e)))?;
70
71 let file_path = parsed.file_path.trim();
72 let path = Path::new(file_path);
73
74 if !path.is_absolute() {
75 return Err(ToolError::InvalidArguments(
76 "file_path must be an absolute path".to_string(),
77 ));
78 }
79
80 if path.exists() {
81 if let Some(session_id) = ctx.session_id {
82 match read_tracker::read_state(session_id, file_path).await {
83 ReadState::Unread => {
84 return Err(ToolError::Execution(
85 "Write requires reading the target file first via Read".to_string(),
86 ));
87 }
88 ReadState::Stale => {
89 return Err(ToolError::Execution(
90 "Target file changed after last Read; call Read again before Write"
91 .to_string(),
92 ));
93 }
94 ReadState::Fresh => {}
95 }
96 }
97 }
98
99 let previous_bytes = file_change::read_existing_bytes(path).await?;
100 let checkpoint = file_change::create_checkpoint(path, previous_bytes.as_deref()).await?;
101 let next_content = parsed.content;
102
103 file_change::atomic_write_text(path, &next_content).await?;
104
105 let previous_text = file_change::bytes_to_lossy_text(previous_bytes.as_deref());
106 let mut payload = file_change::build_file_change_payload_value(
107 "Write",
108 path,
109 format!("Wrote file: {}", file_path),
110 checkpoint,
111 &previous_text,
112 &next_content,
113 );
114 content_diagnostics::attach_file_diagnostics(&mut payload, path, &next_content);
115
116 Ok(ToolResult {
117 success: true,
118 result: payload.to_string(),
119 display_preference: Some("Default".to_string()),
120 images: Vec::new(),
121 })
122 }
123}
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128 use crate::tools::ReadTool;
129 use serde_json::json;
130
131 fn ctx<'a>(session_id: &'a str) -> ToolExecutionContext<'a> {
132 ToolExecutionContext {
133 session_id: Some(session_id),
134 tool_call_id: "call_1",
135 event_tx: None,
136 available_tool_schemas: None,
137 }
138 }
139
140 #[tokio::test]
141 async fn write_requires_fresh_read_for_existing_files() {
142 let file = tempfile::NamedTempFile::new().unwrap();
143 tokio::fs::write(file.path(), "v1").await.unwrap();
144 let write_tool = WriteTool::new();
145 let read_tool = ReadTool::new();
146
147 let denied = write_tool
148 .execute_with_context(
149 json!({"file_path": file.path(), "content": "v2"}),
150 ctx("session_a"),
151 )
152 .await;
153 assert!(matches!(denied, Err(ToolError::Execution(_))));
154
155 let _ = read_tool
156 .execute_with_context(json!({"file_path": file.path()}), ctx("session_a"))
157 .await
158 .unwrap();
159
160 tokio::fs::write(file.path(), "external change")
161 .await
162 .unwrap();
163
164 let stale = write_tool
165 .execute_with_context(
166 json!({"file_path": file.path(), "content": "v3"}),
167 ctx("session_a"),
168 )
169 .await;
170 assert!(matches!(stale, Err(ToolError::Execution(msg)) if msg.contains("changed")));
171
172 let _ = read_tool
173 .execute_with_context(json!({"file_path": file.path()}), ctx("session_a"))
174 .await
175 .unwrap();
176 let ok = write_tool
177 .execute_with_context(
178 json!({"file_path": file.path(), "content": "final"}),
179 ctx("session_a"),
180 )
181 .await
182 .unwrap();
183 assert!(ok.success);
184 }
185
186 #[cfg(unix)]
187 #[tokio::test]
188 async fn write_rejects_symlinked_path_components() {
189 use std::os::unix::fs::symlink;
190 let dir = tempfile::tempdir().unwrap();
191 let real = dir.path().join("real");
192 let link = dir.path().join("link");
193 tokio::fs::create_dir_all(&real).await.unwrap();
194 symlink(&real, &link).unwrap();
195
196 let write_tool = WriteTool::new();
197 let result = write_tool
198 .execute(json!({
199 "file_path": link.join("test.txt"),
200 "content": "hello"
201 }))
202 .await;
203 assert!(matches!(result, Err(ToolError::Execution(msg)) if msg.contains("symlinked")));
204 }
205
206 #[tokio::test]
207 async fn write_includes_json_diagnostics_for_invalid_content() {
208 let file = tempfile::Builder::new().suffix(".json").tempfile().unwrap();
209 let write_tool = WriteTool::new();
210
211 let result = write_tool
212 .execute(json!({
213 "file_path": file.path(),
214 "content": "{"
215 }))
216 .await
217 .unwrap();
218
219 let payload: serde_json::Value = serde_json::from_str(&result.result).unwrap();
220 assert_eq!(payload["diagnostics"]["format"], "json");
221 assert_eq!(payload["diagnostics"]["valid"], false);
222 }
223}