use crate::tool::Tool;
use async_trait::async_trait;
use serde_json::json;
use std::path::PathBuf;
const MAX_FILE_SIZE_BYTES: u64 = 10 * 1024 * 1024;
fn expand_tilde(path: &str) -> PathBuf {
let path = path.trim();
if path.is_empty() {
return PathBuf::from(path);
}
let home = std::env::var("HOME")
.ok()
.or_else(|| std::env::var("USERPROFILE").ok());
match home {
Some(home) if path == "~" => PathBuf::from(home),
Some(home) if path.starts_with("~/") => {
PathBuf::from(home).join(path.trim_start_matches("~/"))
}
Some(home) if path.starts_with("~\\") => {
PathBuf::from(home).join(path.trim_start_matches("~\\"))
}
_ => PathBuf::from(path),
}
}
pub struct FileReadTool;
impl FileReadTool {
pub fn new() -> Self {
Self
}
}
impl Default for FileReadTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileReadTool {
fn name(&self) -> &str {
"file_read"
}
fn description(&self) -> &str {
"Read the contents of a file in the workspace"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Path to the file. Supports ~ for home directory (e.g. ~/.enact/config.yaml). Relative or absolute paths allowed."
}
},
"required": ["path"]
})
}
fn requires_network(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let path_str = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let path = expand_tilde(path_str);
if path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
}
if !path.exists() {
anyhow::bail!("File not found: {}", path.display());
}
if !path.is_file() {
anyhow::bail!("Path is not a file: {}", path.display());
}
let metadata = tokio::fs::metadata(&path).await?;
if metadata.len() > MAX_FILE_SIZE_BYTES {
anyhow::bail!(
"File too large: {} bytes (max: {} bytes)",
metadata.len(),
MAX_FILE_SIZE_BYTES
);
}
let content = tokio::fs::read_to_string(&path).await?;
Ok(json!({
"success": true,
"content": content,
"path": path.to_string_lossy().to_string(),
"size": metadata.len()
}))
}
}
pub struct FileWriteTool;
impl FileWriteTool {
pub fn new() -> Self {
Self
}
}
impl Default for FileWriteTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Tool for FileWriteTool {
fn name(&self) -> &str {
"file_write"
}
fn description(&self) -> &str {
"Write content to a file in the workspace (creates or overwrites)"
}
fn parameters_schema(&self) -> serde_json::Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "Relative path to the file within the workspace"
},
"content": {
"type": "string",
"description": "Content to write to the file"
}
},
"required": ["path", "content"]
})
}
fn requires_network(&self) -> bool {
false
}
async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
let path_str = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'path' parameter"))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| anyhow::anyhow!("Missing 'content' parameter"))?;
let path = expand_tilde(path_str);
if path
.components()
.any(|c| matches!(c, std::path::Component::ParentDir))
{
anyhow::bail!("Path cannot contain '..' (directory traversal not allowed)");
}
if let Some(parent) = path.parent() {
tokio::fs::create_dir_all(parent).await?;
}
tokio::fs::write(&path, content).await?;
let metadata = tokio::fs::metadata(&path).await?;
Ok(json!({
"success": true,
"path": path.to_string_lossy().to_string(),
"size": metadata.len(),
"message": "File written successfully"
}))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_file_read_success() {
let tool = FileReadTool::new();
let test_content = "Hello, World!";
tokio::fs::write("/tmp/test_read.txt", test_content)
.await
.unwrap();
let result = tool
.execute(json!({"path": "/tmp/test_read.txt"}))
.await
.unwrap();
assert_eq!(result["success"], true);
assert_eq!(result["content"], test_content);
tokio::fs::remove_file("/tmp/test_read.txt").await.ok();
}
#[tokio::test]
async fn test_file_write_success() {
let tool = FileWriteTool::new();
let result = tool
.execute(json!({
"path": "/tmp/test_write.txt",
"content": "Test content"
}))
.await
.unwrap();
assert_eq!(result["success"], true);
let content = tokio::fs::read_to_string("/tmp/test_write.txt")
.await
.unwrap();
assert_eq!(content, "Test content");
tokio::fs::remove_file("/tmp/test_write.txt").await.ok();
}
#[tokio::test]
async fn test_file_read_not_found() {
let tool = FileReadTool::new();
let result = tool
.execute(json!({"path": "/tmp/nonexistent_file_xyz.txt"}))
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_file_read_traversal_prevention() {
let tool = FileReadTool::new();
let result = tool.execute(json!({"path": "../etc/passwd"})).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("directory traversal"));
}
#[tokio::test]
async fn test_file_read_expands_tilde() {
let home = std::env::var("HOME")
.or_else(|_| std::env::var("USERPROFILE"))
.expect("HOME or USERPROFILE");
let test_file = std::path::PathBuf::from(&home).join(".enact_file_read_tilde_test");
let test_content = "tilde expansion works";
tokio::fs::write(&test_file, test_content).await.unwrap();
let tool = FileReadTool::new();
let result = tool
.execute(json!({"path": "~/.enact_file_read_tilde_test"}))
.await
.unwrap();
assert_eq!(result["success"], true);
assert_eq!(result["content"], test_content);
tokio::fs::remove_file(&test_file).await.ok();
}
}