capo_agent/tools/
write.rs1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
2
3use std::future::Future;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
9use serde_json::{json, Value};
10
11use crate::tools::ToolCtx;
12
13pub struct WriteTool {
14 ctx: Arc<ToolCtx>,
15}
16
17impl WriteTool {
18 pub fn new(ctx: Arc<ToolCtx>) -> Self {
19 Self { ctx }
20 }
21}
22
23impl Tool for WriteTool {
24 fn def(&self) -> ToolDef {
25 ToolDef {
26 name: "write".into(),
27 description: "Create or overwrite a file with the given contents.".into(),
28 input_schema: json!({
29 "type": "object",
30 "properties": {
31 "path": { "type": "string" },
32 "content": { "type": "string" }
33 },
34 "required": ["path", "content"]
35 }),
36 }
37 }
38
39 fn call(
40 &self,
41 args: Value,
42 _ctx: &ToolContext,
43 ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
44 let ctx = Arc::clone(&self.ctx);
45 Box::pin(async move {
46 let path = match args.get("path").and_then(|v| v.as_str()) {
47 Some(path) => PathBuf::from(path),
48 None => return ToolResult::error("missing 'path'"),
49 };
50 let content = match args.get("content").and_then(|v| v.as_str()) {
51 Some(content) => content.to_string(),
52 None => return ToolResult::error("missing 'content'"),
53 };
54 let abs = if path.is_absolute() {
55 path
56 } else {
57 ctx.cwd.join(&path)
58 };
59
60 if is_hard_blocked(&abs) {
61 return ToolResult::error(format!(
62 "write blocked: {} is inside a protected directory",
63 abs.display()
64 ));
65 }
66
67 if abs.exists() {
68 let canonical = tokio::fs::canonicalize(&abs)
69 .await
70 .unwrap_or_else(|_| abs.clone());
71 if !ctx.has_been_read(&canonical).await && !ctx.has_been_read(&abs).await {
72 return ToolResult::error(format!(
73 "refusing to overwrite {} without reading it first",
74 abs.display()
75 ));
76 }
77 }
78
79 if let Err(err) = tokio::fs::write(&abs, &content).await {
80 return ToolResult::error(format!("write failed: {err}"));
81 }
82 let canonical = tokio::fs::canonicalize(&abs)
83 .await
84 .unwrap_or_else(|_| abs.clone());
85 ctx.mark_read(&canonical).await;
86 ToolResult::text(format!(
87 "{{\"path\":\"{}\",\"bytes\":{}}}",
88 abs.display(),
89 content.len()
90 ))
91 })
92 }
93}
94
95pub(super) fn is_hard_blocked(path: &std::path::Path) -> bool {
96 let text = path.to_string_lossy();
97 let patterns = [".git/", "node_modules/", "target/", ".ssh/"];
98 for pattern in patterns {
99 if text.contains(pattern) {
100 return true;
101 }
102 }
103
104 let name = path
105 .file_name()
106 .map(|name| name.to_string_lossy().to_string())
107 .unwrap_or_default();
108 name == ".env" || name.starts_with(".env.")
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use crate::permissions::NoOpPermissionGate;
115 use std::path::Path;
116 use tempfile::tempdir;
117 use tokio::sync::mpsc;
118
119 fn test_ctx(cwd: &Path) -> Arc<ToolCtx> {
120 let (tx, _rx) = mpsc::channel(8);
121 Arc::new(ToolCtx::new(cwd, Arc::new(NoOpPermissionGate), tx))
122 }
123
124 #[tokio::test]
125 async fn writes_new_file() {
126 let dir = tempdir().unwrap();
127 let tool = WriteTool::new(test_ctx(dir.path()));
128 let result = tool
129 .call(
130 json!({ "path": "hello.txt", "content": "hi" }),
131 &ToolContext::default(),
132 )
133 .await;
134
135 assert!(!result.is_error, "{result:?}");
136 let body = tokio::fs::read_to_string(dir.path().join("hello.txt"))
137 .await
138 .unwrap();
139 assert_eq!(body, "hi");
140 }
141
142 #[tokio::test]
143 async fn refuses_env_file() {
144 let dir = tempdir().unwrap();
145 let tool = WriteTool::new(test_ctx(dir.path()));
146 let result = tool
147 .call(
148 json!({ "path": ".env", "content": "SECRET=1" }),
149 &ToolContext::default(),
150 )
151 .await;
152 let debug = format!("{result:?}");
153 assert!(debug.to_lowercase().contains("protected"), "{debug}");
154 }
155
156 #[tokio::test]
157 async fn refuses_path_inside_git() {
158 let dir = tempdir().unwrap();
159 std::fs::create_dir_all(dir.path().join(".git")).unwrap();
160 let tool = WriteTool::new(test_ctx(dir.path()));
161 let result = tool
162 .call(
163 json!({ "path": ".git/config", "content": "x" }),
164 &ToolContext::default(),
165 )
166 .await;
167 let debug = format!("{result:?}");
168 assert!(debug.to_lowercase().contains("protected"), "{debug}");
169 }
170
171 #[tokio::test]
172 async fn refuses_overwrite_without_prior_read() {
173 let dir = tempdir().unwrap();
174 let file = dir.path().join("doc.md");
175 tokio::fs::write(&file, "old").await.unwrap();
176
177 let tool = WriteTool::new(test_ctx(dir.path()));
178 let result = tool
179 .call(
180 json!({ "path": "doc.md", "content": "new" }),
181 &ToolContext::default(),
182 )
183 .await;
184 let debug = format!("{result:?}");
185 assert!(
186 debug.to_lowercase().contains("without reading"),
187 "expected refusal, got: {debug}"
188 );
189 }
190
191 #[tokio::test]
192 async fn permits_overwrite_when_read_first() {
193 let dir = tempdir().unwrap();
194 let file = dir.path().join("doc.md");
195 tokio::fs::write(&file, "old").await.unwrap();
196
197 let ctx = test_ctx(dir.path());
198 let canonical = tokio::fs::canonicalize(&file).await.unwrap();
199 ctx.read_files.lock().await.insert(canonical);
200
201 let tool = WriteTool::new(ctx);
202 let result = tool
203 .call(
204 json!({ "path": "doc.md", "content": "new" }),
205 &ToolContext::default(),
206 )
207 .await;
208 assert!(!result.is_error, "{result:?}");
209 assert_eq!(tokio::fs::read_to_string(&file).await.unwrap(), "new");
210 }
211}