use crate::tool::Tool;
use async_trait::async_trait;
use serde_json::json;
use std::path::Path;
const MAX_FILE_SIZE_BYTES: u64 = 10 * 1024 * 1024;
pub struct FileReadTool;
impl FileReadTool {
pub fn new() -> Self {
Self
}
}
impl Default for FileReadTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileReadTool {
fn name(&self) -> &str {
"file_read"
}
fn description(&self) -> &str {
"Read the contents of a file in the workspace"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the file within the workspace"
}
},
"required": ["path"]
})
}
fn requires_network(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let path = Path::new(path);
if path.components().any(|c| matches!(c, std::path::Component::ParentDir)) {
anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
}
if !path.exists() {
anyhow::bail!("File not found: {}", path.display());
}
if !path.is_file() {
anyhow::bail!("Path is not a file: {}", path.display());
}
let metadata = tokio::fs::metadata(path).await?;
if metadata.len() > MAX_FILE_SIZE_BYTES {
anyhow::bail!(
"File too large: {} bytes (max: {} bytes)",
metadata.len(),
MAX_FILE_SIZE_BYTES
);
}
let content = tokio::fs::read_to_string(path).await?;
Ok(json!({
"success": true,
"content": content,
"path": path.to_string_lossy().to_string(),
"size": metadata.len()
}))
}
}
pub struct FileWriteTool;
impl FileWriteTool {
pub fn new() -> Self {
Self
}
}
impl Default for FileWriteTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileWriteTool {
fn name(&self) -> &str {
"file_write"
}
fn description(&self) -> &str {
"Write content to a file in the workspace (creates or overwrites)"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the file within the workspace"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
})
}
fn requires_network(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
let path = Path::new(path);
if path.components().any(|c| matches!(c, std::path::Component::ParentDir)) {
anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
}
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(path, content).await?;
let metadata = tokio::fs::metadata(path).await?;
Ok(json!({
"success": true,
"path": path.to_string_lossy().to_string(),
"size": metadata.len(),
"message": "File written successfully"
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_file_read_success() {
let tool = FileReadTool::new();
let test_content = "Hello, World!";
tokio::fs::write("/tmp/test_read.txt", test_content).await.unwrap();
let result = tool.execute(json!({"path": "/tmp/test_read.txt"})).await.unwrap();
assert_eq!(result["success"], true);
assert_eq!(result["content"], test_content);
tokio::fs::remove_file("/tmp/test_read.txt").await.ok();
}
#[tokio::test]
async fn test_file_write_success() {
let tool = FileWriteTool::new();
let result = tool.execute(json!({
"path": "/tmp/test_write.txt",
"content": "Test content"
})).await.unwrap();
assert_eq!(result["success"], true);
let content = tokio::fs::read_to_string("/tmp/test_write.txt").await.unwrap();
assert_eq!(content, "Test content");
tokio::fs::remove_file("/tmp/test_write.txt").await.ok();
}
#[tokio::test]
async fn test_file_read_not_found() {
let tool = FileReadTool::new();
let result = tool.execute(json!({"path": "/tmp/nonexistent_file_xyz.txt"})).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_read_traversal_prevention() {
let tool = FileReadTool::new();
let result = tool.execute(json!({"path": "../etc/passwd"})).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("directory traversal"));
}
}