Skip to main content

capo_agent/tools/
write.rs

1#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
2
3use std::future::Future;
4use std::path::PathBuf;
5use std::pin::Pin;
6use std::sync::Arc;
7
8use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
9use serde_json::{json, Value};
10
11use crate::tools::ToolCtx;
12
13pub struct WriteTool {
14    ctx: Arc<ToolCtx>,
15}
16
17impl WriteTool {
18    pub fn new(ctx: Arc<ToolCtx>) -> Self {
19        Self { ctx }
20    }
21}
22
23impl Tool for WriteTool {
24    fn def(&self) -> ToolDef {
25        ToolDef {
26            name: "write".into(),
27            description: "Create or overwrite a file with the given contents.".into(),
28            input_schema: json!({
29                "type": "object",
30                "properties": {
31                    "path": { "type": "string" },
32                    "content": { "type": "string" }
33                },
34                "required": ["path", "content"]
35            }),
36        }
37    }
38
39    fn call(
40        &self,
41        args: Value,
42        _ctx: &ToolContext,
43    ) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
44        let ctx = Arc::clone(&self.ctx);
45        Box::pin(async move {
46            let path = match args.get("path").and_then(|v| v.as_str()) {
47                Some(path) => PathBuf::from(path),
48                None => return ToolResult::error("missing 'path'"),
49            };
50            let content = match args.get("content").and_then(|v| v.as_str()) {
51                Some(content) => content.to_string(),
52                None => return ToolResult::error("missing 'content'"),
53            };
54            let abs = if path.is_absolute() {
55                path
56            } else {
57                ctx.cwd.join(&path)
58            };
59
60            if is_hard_blocked(&abs) {
61                return ToolResult::error(format!(
62                    "write blocked: {} is inside a protected directory",
63                    abs.display()
64                ));
65            }
66
67            if abs.exists() {
68                let canonical = tokio::fs::canonicalize(&abs)
69                    .await
70                    .unwrap_or_else(|_| abs.clone());
71                if !ctx.has_been_read(&canonical).await && !ctx.has_been_read(&abs).await {
72                    return ToolResult::error(format!(
73                        "refusing to overwrite {} without reading it first",
74                        abs.display()
75                    ));
76                }
77            }
78
79            if let Err(err) = tokio::fs::write(&abs, &content).await {
80                return ToolResult::error(format!("write failed: {err}"));
81            }
82            let canonical = tokio::fs::canonicalize(&abs)
83                .await
84                .unwrap_or_else(|_| abs.clone());
85            ctx.mark_read(&canonical).await;
86            ToolResult::text(format!(
87                "{{\"path\":\"{}\",\"bytes\":{}}}",
88                abs.display(),
89                content.len()
90            ))
91        })
92    }
93}
94
95pub(super) fn is_hard_blocked(path: &std::path::Path) -> bool {
96    let text = path.to_string_lossy();
97    let patterns = [".git/", "node_modules/", "target/", ".ssh/"];
98    for pattern in patterns {
99        if text.contains(pattern) {
100            return true;
101        }
102    }
103
104    let name = path
105        .file_name()
106        .map(|name| name.to_string_lossy().to_string())
107        .unwrap_or_default();
108    name == ".env" || name.starts_with(".env.")
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use crate::permissions::NoOpPermissionGate;
115    use std::path::Path;
116    use tempfile::tempdir;
117    use tokio::sync::mpsc;
118
119    fn test_ctx(cwd: &Path) -> Arc<ToolCtx> {
120        let (tx, _rx) = mpsc::channel(8);
121        Arc::new(ToolCtx::new(cwd, Arc::new(NoOpPermissionGate), tx))
122    }
123
124    #[tokio::test]
125    async fn writes_new_file() {
126        let dir = tempdir().unwrap();
127        let tool = WriteTool::new(test_ctx(dir.path()));
128        let result = tool
129            .call(
130                json!({ "path": "hello.txt", "content": "hi" }),
131                &ToolContext::default(),
132            )
133            .await;
134
135        assert!(!result.is_error, "{result:?}");
136        let body = tokio::fs::read_to_string(dir.path().join("hello.txt"))
137            .await
138            .unwrap();
139        assert_eq!(body, "hi");
140    }
141
142    #[tokio::test]
143    async fn refuses_env_file() {
144        let dir = tempdir().unwrap();
145        let tool = WriteTool::new(test_ctx(dir.path()));
146        let result = tool
147            .call(
148                json!({ "path": ".env", "content": "SECRET=1" }),
149                &ToolContext::default(),
150            )
151            .await;
152        let debug = format!("{result:?}");
153        assert!(debug.to_lowercase().contains("protected"), "{debug}");
154    }
155
156    #[tokio::test]
157    async fn refuses_path_inside_git() {
158        let dir = tempdir().unwrap();
159        std::fs::create_dir_all(dir.path().join(".git")).unwrap();
160        let tool = WriteTool::new(test_ctx(dir.path()));
161        let result = tool
162            .call(
163                json!({ "path": ".git/config", "content": "x" }),
164                &ToolContext::default(),
165            )
166            .await;
167        let debug = format!("{result:?}");
168        assert!(debug.to_lowercase().contains("protected"), "{debug}");
169    }
170
171    #[tokio::test]
172    async fn refuses_overwrite_without_prior_read() {
173        let dir = tempdir().unwrap();
174        let file = dir.path().join("doc.md");
175        tokio::fs::write(&file, "old").await.unwrap();
176
177        let tool = WriteTool::new(test_ctx(dir.path()));
178        let result = tool
179            .call(
180                json!({ "path": "doc.md", "content": "new" }),
181                &ToolContext::default(),
182            )
183            .await;
184        let debug = format!("{result:?}");
185        assert!(
186            debug.to_lowercase().contains("without reading"),
187            "expected refusal, got: {debug}"
188        );
189    }
190
191    #[tokio::test]
192    async fn permits_overwrite_when_read_first() {
193        let dir = tempdir().unwrap();
194        let file = dir.path().join("doc.md");
195        tokio::fs::write(&file, "old").await.unwrap();
196
197        let ctx = test_ctx(dir.path());
198        let canonical = tokio::fs::canonicalize(&file).await.unwrap();
199        ctx.read_files.lock().await.insert(canonical);
200
201        let tool = WriteTool::new(ctx);
202        let result = tool
203            .call(
204                json!({ "path": "doc.md", "content": "new" }),
205                &ToolContext::default(),
206            )
207            .await;
208        assert!(!result.is_error, "{result:?}");
209        assert_eq!(tokio::fs::read_to_string(&file).await.unwrap(), "new");
210    }
211}