#![cfg_attr(test, allow(clippy::expect_used, clippy::unwrap_used))]
use std::future::Future;
use std::path::PathBuf;
use std::pin::Pin;
use std::sync::Arc;
use motosan_agent_tool::{Tool, ToolContext, ToolDef, ToolResult};
use serde_json::{json, Value};
use crate::tools::ToolCtx;
pub struct WriteTool {
ctx: Arc<ToolCtx>,
}
impl WriteTool {
pub fn new(ctx: Arc<ToolCtx>) -> Self {
Self { ctx }
}
}
impl Tool for WriteTool {
fn def(&self) -> ToolDef {
ToolDef {
name: "write".into(),
description: "Create or overwrite a file with the given contents.".into(),
input_schema: json!({
"type": "object",
"properties": {
"path": { "type": "string" },
"content": { "type": "string" }
},
"required": ["path", "content"]
}),
}
}
fn call(
&self,
args: Value,
_ctx: &ToolContext,
) -> Pin<Box<dyn Future<Output = ToolResult> + Send + '_>> {
let ctx = Arc::clone(&self.ctx);
Box::pin(async move {
let path = match args.get("path").and_then(|v| v.as_str()) {
Some(path) => PathBuf::from(path),
None => return ToolResult::error("missing 'path'"),
};
let content = match args.get("content").and_then(|v| v.as_str()) {
Some(content) => content.to_string(),
None => return ToolResult::error("missing 'content'"),
};
let abs = if path.is_absolute() {
path
} else {
ctx.cwd.join(&path)
};
if is_hard_blocked(&abs) {
return ToolResult::error(format!(
"write blocked: {} is inside a protected directory",
abs.display()
));
}
if abs.exists() {
let canonical = tokio::fs::canonicalize(&abs)
.await
.unwrap_or_else(|_| abs.clone());
if !ctx.has_been_read(&canonical).await && !ctx.has_been_read(&abs).await {
return ToolResult::error(format!(
"refusing to overwrite {} without reading it first",
abs.display()
));
}
}
if let Err(err) = tokio::fs::write(&abs, &content).await {
return ToolResult::error(format!("write failed: {err}"));
}
let canonical = tokio::fs::canonicalize(&abs)
.await
.unwrap_or_else(|_| abs.clone());
ctx.mark_read(&canonical).await;
ToolResult::text(format!(
"{{\"path\":\"{}\",\"bytes\":{}}}",
abs.display(),
content.len()
))
})
}
}
pub(super) fn is_hard_blocked(path: &std::path::Path) -> bool {
let text = path.to_string_lossy();
let patterns = [".git/", "node_modules/", "target/", ".ssh/"];
for pattern in patterns {
if text.contains(pattern) {
return true;
}
}
let name = path
.file_name()
.map(|name| name.to_string_lossy().to_string())
.unwrap_or_default();
name == ".env" || name.starts_with(".env.")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::permissions::NoOpPermissionGate;
use std::path::Path;
use tempfile::tempdir;
use tokio::sync::mpsc;
fn test_ctx(cwd: &Path) -> Arc<ToolCtx> {
let (tx, _rx) = mpsc::channel(8);
Arc::new(ToolCtx::new(cwd, Arc::new(NoOpPermissionGate), tx))
}
#[tokio::test]
async fn writes_new_file() {
let dir = tempdir().unwrap();
let tool = WriteTool::new(test_ctx(dir.path()));
let result = tool
.call(
json!({ "path": "hello.txt", "content": "hi" }),
&ToolContext::default(),
)
.await;
assert!(!result.is_error, "{result:?}");
let body = tokio::fs::read_to_string(dir.path().join("hello.txt"))
.await
.unwrap();
assert_eq!(body, "hi");
}
#[tokio::test]
async fn refuses_env_file() {
let dir = tempdir().unwrap();
let tool = WriteTool::new(test_ctx(dir.path()));
let result = tool
.call(
json!({ "path": ".env", "content": "SECRET=1" }),
&ToolContext::default(),
)
.await;
let debug = format!("{result:?}");
assert!(debug.to_lowercase().contains("protected"), "{debug}");
}
#[tokio::test]
async fn refuses_path_inside_git() {
let dir = tempdir().unwrap();
std::fs::create_dir_all(dir.path().join(".git")).unwrap();
let tool = WriteTool::new(test_ctx(dir.path()));
let result = tool
.call(
json!({ "path": ".git/config", "content": "x" }),
&ToolContext::default(),
)
.await;
let debug = format!("{result:?}");
assert!(debug.to_lowercase().contains("protected"), "{debug}");
}
#[tokio::test]
async fn refuses_overwrite_without_prior_read() {
let dir = tempdir().unwrap();
let file = dir.path().join("doc.md");
tokio::fs::write(&file, "old").await.unwrap();
let tool = WriteTool::new(test_ctx(dir.path()));
let result = tool
.call(
json!({ "path": "doc.md", "content": "new" }),
&ToolContext::default(),
)
.await;
let debug = format!("{result:?}");
assert!(
debug.to_lowercase().contains("without reading"),
"expected refusal, got: {debug}"
);
}
#[tokio::test]
async fn permits_overwrite_when_read_first() {
let dir = tempdir().unwrap();
let file = dir.path().join("doc.md");
tokio::fs::write(&file, "old").await.unwrap();
let ctx = test_ctx(dir.path());
let canonical = tokio::fs::canonicalize(&file).await.unwrap();
ctx.read_files.lock().await.insert(canonical);
let tool = WriteTool::new(ctx);
let result = tool
.call(
json!({ "path": "doc.md", "content": "new" }),
&ToolContext::default(),
)
.await;
assert!(!result.is_error, "{result:?}");
assert_eq!(tokio::fs::read_to_string(&file).await.unwrap(), "new");
}
}