Skip to main content

bamboo_tools/tools/
write.rs

1use async_trait::async_trait;
2use bamboo_agent_core::{Tool, ToolError, ToolExecutionContext, ToolResult};
3use serde::Deserialize;
4use serde_json::json;
5use std::path::Path;
6
7use super::read_tracker::ReadState;
8use super::{content_diagnostics, file_change, read_tracker};
9
10#[derive(Debug, Deserialize)]
11struct WriteArgs {
12    file_path: String,
13    content: String,
14}
15
16pub struct WriteTool;
17
18impl WriteTool {
19    pub fn new() -> Self {
20        Self
21    }
22}
23
24impl Default for WriteTool {
25    fn default() -> Self {
26        Self::new()
27    }
28}
29
30#[async_trait]
31impl Tool for WriteTool {
32    fn name(&self) -> &str {
33        "Write"
34    }
35
36    fn description(&self) -> &str {
37        "Write a local file (create or replace full content). IMPORTANT: for existing files, call Read first in this session or Write will fail."
38    }
39
40    fn parameters_schema(&self) -> serde_json::Value {
41        json!({
42            "type": "object",
43            "properties": {
44                "file_path": {
45                    "type": "string",
46                    "description": "The absolute path to the file to write"
47                },
48                "content": {
49                    "type": "string",
50                    "description": "The content to write to the file"
51                }
52            },
53            "required": ["file_path", "content"],
54            "additionalProperties": false
55        })
56    }
57
58    async fn execute(&self, args: serde_json::Value) -> Result<ToolResult, ToolError> {
59        self.execute_with_context(args, ToolExecutionContext::none("Write"))
60            .await
61    }
62
63    async fn execute_with_context(
64        &self,
65        args: serde_json::Value,
66        ctx: ToolExecutionContext<'_>,
67    ) -> Result<ToolResult, ToolError> {
68        let parsed: WriteArgs = serde_json::from_value(args)
69            .map_err(|e| ToolError::InvalidArguments(format!("Invalid Write args: {}", e)))?;
70
71        let file_path = parsed.file_path.trim();
72        let path = Path::new(file_path);
73
74        if !path.is_absolute() {
75            return Err(ToolError::InvalidArguments(
76                "file_path must be an absolute path".to_string(),
77            ));
78        }
79
80        if path.exists() {
81            if let Some(session_id) = ctx.session_id {
82                match read_tracker::read_state(session_id, file_path).await {
83                    ReadState::Unread => {
84                        return Err(ToolError::Execution(
85                            "Write requires reading the target file first via Read".to_string(),
86                        ));
87                    }
88                    ReadState::Stale => {
89                        return Err(ToolError::Execution(
90                            "Target file changed after last Read; call Read again before Write"
91                                .to_string(),
92                        ));
93                    }
94                    ReadState::Fresh => {}
95                }
96            }
97        }
98
99        let previous_bytes = file_change::read_existing_bytes(path).await?;
100        let checkpoint = file_change::create_checkpoint(path, previous_bytes.as_deref()).await?;
101        let next_content = parsed.content;
102
103        file_change::atomic_write_text(path, &next_content).await?;
104
105        let previous_text = file_change::bytes_to_lossy_text(previous_bytes.as_deref());
106        let mut payload = file_change::build_file_change_payload_value(
107            "Write",
108            path,
109            format!("Wrote file: {}", file_path),
110            checkpoint,
111            &previous_text,
112            &next_content,
113        );
114        content_diagnostics::attach_file_diagnostics(&mut payload, path, &next_content);
115
116        Ok(ToolResult {
117            success: true,
118            result: payload.to_string(),
119            display_preference: Some("Default".to_string()),
120        })
121    }
122}
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127    use crate::tools::ReadTool;
128    use serde_json::json;
129
130    fn ctx<'a>(session_id: &'a str) -> ToolExecutionContext<'a> {
131        ToolExecutionContext {
132            session_id: Some(session_id),
133            tool_call_id: "call_1",
134            event_tx: None,
135            available_tool_schemas: None,
136        }
137    }
138
139    #[tokio::test]
140    async fn write_requires_fresh_read_for_existing_files() {
141        let file = tempfile::NamedTempFile::new().unwrap();
142        tokio::fs::write(file.path(), "v1").await.unwrap();
143        let write_tool = WriteTool::new();
144        let read_tool = ReadTool::new();
145
146        let denied = write_tool
147            .execute_with_context(
148                json!({"file_path": file.path(), "content": "v2"}),
149                ctx("session_a"),
150            )
151            .await;
152        assert!(matches!(denied, Err(ToolError::Execution(_))));
153
154        let _ = read_tool
155            .execute_with_context(json!({"file_path": file.path()}), ctx("session_a"))
156            .await
157            .unwrap();
158
159        tokio::fs::write(file.path(), "external change")
160            .await
161            .unwrap();
162
163        let stale = write_tool
164            .execute_with_context(
165                json!({"file_path": file.path(), "content": "v3"}),
166                ctx("session_a"),
167            )
168            .await;
169        assert!(matches!(stale, Err(ToolError::Execution(msg)) if msg.contains("changed")));
170
171        let _ = read_tool
172            .execute_with_context(json!({"file_path": file.path()}), ctx("session_a"))
173            .await
174            .unwrap();
175        let ok = write_tool
176            .execute_with_context(
177                json!({"file_path": file.path(), "content": "final"}),
178                ctx("session_a"),
179            )
180            .await
181            .unwrap();
182        assert!(ok.success);
183    }
184
185    #[cfg(unix)]
186    #[tokio::test]
187    async fn write_rejects_symlinked_path_components() {
188        use std::os::unix::fs::symlink;
189        let dir = tempfile::tempdir().unwrap();
190        let real = dir.path().join("real");
191        let link = dir.path().join("link");
192        tokio::fs::create_dir_all(&real).await.unwrap();
193        symlink(&real, &link).unwrap();
194
195        let write_tool = WriteTool::new();
196        let result = write_tool
197            .execute(json!({
198                "file_path": link.join("test.txt"),
199                "content": "hello"
200            }))
201            .await;
202        assert!(matches!(result, Err(ToolError::Execution(msg)) if msg.contains("symlinked")));
203    }
204
205    #[tokio::test]
206    async fn write_includes_json_diagnostics_for_invalid_content() {
207        let file = tempfile::Builder::new().suffix(".json").tempfile().unwrap();
208        let write_tool = WriteTool::new();
209
210        let result = write_tool
211            .execute(json!({
212                "file_path": file.path(),
213                "content": "{"
214            }))
215            .await
216            .unwrap();
217
218        let payload: serde_json::Value = serde_json::from_str(&result.result).unwrap();
219        assert_eq!(payload["diagnostics"]["format"], "json");
220        assert_eq!(payload["diagnostics"]["valid"], false);
221    }
222}