use async_trait::async_trait;
use serde_json::{json, Value};
#[cfg(unix)]
use std::os::unix::fs::OpenOptionsExt;
use std::path::Path;
#[cfg(unix)]
use std::{fs::OpenOptions, io::Write as _, os::unix::fs::MetadataExt};
use unicode_normalization::UnicodeNormalization;
use crate::error::{Result, ZeptoError};
#[cfg(not(unix))]
use crate::security::check_hardlink_write;
use crate::security::{ensure_directory_chain_secure, revalidate_path, validate_path_in_workspace};
use crate::tools::diff::apply_unified_diff;
use super::output::{truncate_tool_output, DEFAULT_MAX_BYTES, DEFAULT_MAX_LINES};
use super::{Tool, ToolCategory, ToolContext, ToolOutput};
fn resolve_path(path: &str, ctx: &ToolContext) -> Result<(String, String)> {
let workspace = ctx.workspace.as_ref().ok_or_else(|| {
ZeptoError::SecurityViolation(
"Workspace not configured; filesystem tools require a workspace for safety".to_string(),
)
})?;
let safe_path = validate_path_in_workspace(path, workspace)?;
Ok((
safe_path.as_path().to_string_lossy().to_string(),
workspace.clone(),
))
}
#[cfg(unix)]
fn write_file_secure_blocking(path: &Path, workspace: &str, content: &[u8]) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
ensure_directory_chain_secure(parent, workspace)?;
revalidate_path(parent, workspace)?;
}
}
revalidate_path(path, workspace)?;
let mut options = OpenOptions::new();
options
.write(true)
.create(true)
.custom_flags(libc::O_NOFOLLOW);
let mut file = options.open(path).map_err(|e| {
ZeptoError::Tool(format!(
"Failed to securely open file '{}': {}",
path.display(),
e
))
})?;
let metadata = file.metadata().map_err(|e| {
ZeptoError::Tool(format!(
"Failed to inspect opened file '{}': {}",
path.display(),
e
))
})?;
if metadata.is_file() && metadata.nlink() > 1 {
return Err(ZeptoError::SecurityViolation(format!(
"Write blocked: '{}' has {} hard links and may alias content outside workspace",
path.display(),
metadata.nlink()
)));
}
file.set_len(0).map_err(|e| {
ZeptoError::Tool(format!(
"Failed to truncate file '{}': {}",
path.display(),
e
))
})?;
file.write_all(content).map_err(|e| {
ZeptoError::Tool(format!("Failed to write file '{}': {}", path.display(), e))
})?;
Ok(())
}
#[cfg(not(unix))]
fn write_file_secure_blocking(path: &Path, workspace: &str, content: &[u8]) -> Result<()> {
if let Some(parent) = path.parent() {
if !parent.as_os_str().is_empty() {
ensure_directory_chain_secure(parent, workspace)?;
revalidate_path(parent, workspace)?;
}
}
revalidate_path(path, workspace)?;
check_hardlink_write(path)?;
std::fs::write(path, content).map_err(|e| {
ZeptoError::Tool(format!("Failed to write file '{}': {}", path.display(), e))
})?;
Ok(())
}
async fn write_file_secure(path: &Path, workspace: &str, content: &[u8]) -> Result<()> {
let path = path.to_path_buf();
let workspace = workspace.to_string();
let content = content.to_vec();
tokio::task::spawn_blocking(move || write_file_secure_blocking(&path, &workspace, &content))
.await
.map_err(|e| ZeptoError::Tool(format!("Secure write task failed: {}", e)))?
}
pub struct ReadFileTool;
#[async_trait]
impl Tool for ReadFileTool {
fn name(&self) -> &str {
"read_file"
}
fn description(&self) -> &str {
"Read the contents of a file at the specified path"
}
fn compact_description(&self) -> &str {
"Read file"
}
fn category(&self) -> ToolCategory {
ToolCategory::FilesystemRead
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the file to read"
}
},
"required": ["path"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| ZeptoError::Tool("Missing 'path' argument".into()))?;
let (full_path, workspace) = resolve_path(path, ctx)?;
revalidate_path(Path::new(&full_path), &workspace)?;
let content = tokio::fs::read_to_string(&full_path)
.await
.map_err(|e| ZeptoError::Tool(format!("Failed to read file '{}': {}", full_path, e)))?;
Ok(ToolOutput::llm_only(truncate_tool_output(
&content,
DEFAULT_MAX_LINES,
DEFAULT_MAX_BYTES,
)))
}
}
pub struct WriteFileTool;
#[async_trait]
impl Tool for WriteFileTool {
fn name(&self) -> &str {
"write_file"
}
fn description(&self) -> &str {
"Write content to a file at the specified path, creating it if necessary"
}
fn compact_description(&self) -> &str {
"Write file"
}
fn category(&self) -> ToolCategory {
ToolCategory::FilesystemWrite
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the file to write"
},
"content": {
"type": "string",
"description": "The content to write to the file"
}
},
"required": ["path", "content"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| ZeptoError::Tool("Missing 'path' argument".into()))?;
let content = args
.get("content")
.and_then(|v| v.as_str())
.ok_or_else(|| ZeptoError::Tool("Missing 'content' argument".into()))?;
let (full_path, workspace) = resolve_path(path, ctx)?;
let full_path_ref = Path::new(&full_path);
write_file_secure(full_path_ref, &workspace, content.as_bytes()).await?;
Ok(ToolOutput::llm_only(format!(
"Successfully wrote {} bytes to {}",
content.len(),
full_path
)))
}
}
pub struct ListDirTool;
#[async_trait]
impl Tool for ListDirTool {
fn name(&self) -> &str {
"list_dir"
}
fn description(&self) -> &str {
"List the contents of a directory at the specified path"
}
fn compact_description(&self) -> &str {
"List directory"
}
fn category(&self) -> ToolCategory {
ToolCategory::FilesystemRead
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the directory to list"
}
},
"required": ["path"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| ZeptoError::Tool("Missing 'path' argument".into()))?;
let (full_path, workspace) = resolve_path(path, ctx)?;
revalidate_path(Path::new(&full_path), &workspace)?;
let mut entries = tokio::fs::read_dir(&full_path).await.map_err(|e| {
ZeptoError::Tool(format!("Failed to read directory '{}': {}", full_path, e))
})?;
let mut items = Vec::new();
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| ZeptoError::Tool(format!("Failed to read directory entry: {}", e)))?
{
let file_name = entry.file_name().to_string_lossy().to_string();
let file_type = entry.file_type().await.ok();
let type_indicator = match file_type {
Some(ft) if ft.is_dir() => "/",
Some(ft) if ft.is_symlink() => "@",
_ => "",
};
items.push(format!("{}{}", file_name, type_indicator));
}
items.sort();
let joined = items.join("\n");
Ok(ToolOutput::llm_only(truncate_tool_output(
&joined,
DEFAULT_MAX_LINES,
DEFAULT_MAX_BYTES,
)))
}
}
pub struct EditFileTool;
#[async_trait]
impl Tool for EditFileTool {
fn name(&self) -> &str {
"edit_file"
}
fn description(&self) -> &str {
"Edit a file using either exact string replacement (old_text/new_text) or a unified diff patch (diff). String replacements must resolve to a single match unless expected_replacements is provided."
}
fn compact_description(&self) -> &str {
"Edit file"
}
fn category(&self) -> ToolCategory {
ToolCategory::FilesystemWrite
}
fn parameters(&self) -> Value {
json!({
"type": "object",
"properties": {
"path": {
"type": "string",
"description": "The path to the file to edit"
},
"old_text": {
"type": "string",
"description": "The text to search for and replace. Must resolve to a single match unless expected_replacements is provided."
},
"new_text": {
"type": "string",
"description": "The text to replace it with"
},
"expected_replacements": {
"type": "integer",
"description": "Optional exact number of matches required before applying the replacement"
},
"diff": {
"type": "string",
"description": "A unified diff patch to apply. Use standard @@ hunk headers with +/- lines. Mutually exclusive with old_text/new_text."
},
"expected_replacements": {
"type": "integer",
"description": "Exact number of occurrences to replace. When provided, all exact matches are replaced with count validation. When omitted, the match must be unique (fuzzy matching is used as fallback)."
}
},
"required": ["path"]
})
}
async fn execute(&self, args: Value, ctx: &ToolContext) -> Result<ToolOutput> {
let path = args
.get("path")
.and_then(|v| v.as_str())
.ok_or_else(|| ZeptoError::Tool("Missing 'path' argument".into()))?;
let diff_param = args.get("diff").and_then(|v| v.as_str());
let old_text = args.get("old_text").and_then(|v| v.as_str());
let new_text = args.get("new_text").and_then(|v| v.as_str());
let expected_replacements = args
.get("expected_replacements")
.and_then(|v| v.as_u64())
.map(|n| n as usize);
if diff_param.is_some() && (old_text.is_some() || new_text.is_some()) {
return Err(ZeptoError::Tool(
"Provide either 'diff' or 'old_text'/'new_text', not both.".into(),
));
}
if diff_param.is_none() && (old_text.is_none() || new_text.is_none()) {
return Err(ZeptoError::Tool(
"Provide either 'diff' or 'old_text'/'new_text'".into(),
));
}
let (full_path, workspace) = resolve_path(path, ctx)?;
let full_path_ref = Path::new(&full_path);
if let Some(diff_str) = diff_param {
revalidate_path(full_path_ref, &workspace)?;
let content = tokio::fs::read_to_string(&full_path).await.map_err(|e| {
ZeptoError::Tool(format!("Failed to read file '{}': {}", full_path, e))
})?;
let (new_content, summary) = apply_unified_diff(&content, diff_str)
.map_err(|e| ZeptoError::Tool(format!("Diff apply failed: {}", e)))?;
write_file_secure(full_path_ref, &workspace, new_content.as_bytes()).await?;
Ok(ToolOutput::llm_only(format!(
"Applied {} hunk(s): +{} -{} in {}",
summary.hunks_applied, summary.lines_added, summary.lines_removed, full_path
)))
} else if let (Some(old_text), Some(new_text)) = (old_text, new_text) {
revalidate_path(full_path_ref, &workspace)?;
if old_text.is_empty() {
return Err(ZeptoError::Tool("'old_text' must not be empty".into()));
}
let content = tokio::fs::read_to_string(&full_path).await.map_err(|e| {
ZeptoError::Tool(format!("Failed to read file '{}': {}", full_path, e))
})?;
if let Some(expected) = expected_replacements {
let replacements = content.matches(old_text).count();
if replacements == 0 {
return Err(ZeptoError::Tool(format!(
"Text '{}' not found in file '{}'",
crate::utils::string::preview(old_text, 50),
full_path
)));
}
if replacements != expected {
return Err(ZeptoError::Tool(format!(
"Expected {} replacement(s) for '{}' in '{}', found {}",
expected,
crate::utils::string::preview(old_text, 50),
full_path,
replacements
)));
}
let new_content = content.replace(old_text, new_text);
write_file_secure(full_path_ref, &workspace, new_content.as_bytes()).await?;
Ok(ToolOutput::llm_only(format!(
"Successfully replaced {} occurrence(s) in {}",
replacements, full_path
)))
} else {
match find_unique_match(&content, old_text) {
Ok(m) => {
let mut new_content = String::with_capacity(content.len());
new_content.push_str(&content[..m.start]);
new_content.push_str(new_text);
new_content.push_str(&content[m.end..]);
write_file_secure(full_path_ref, &workspace, new_content.as_bytes())
.await?;
Ok(ToolOutput::llm_only(format!(
"Successfully replaced 1 occurrence ({} match) in {}",
m.tier, full_path
)))
}
Err(EditMatchError::MultipleMatches(n)) => {
Err(ZeptoError::Tool(format!(
"Found {} occurrences of text in '{}'. Provide more surrounding context to uniquely identify the location.",
n, full_path
)))
}
Err(EditMatchError::NotFound) => {
Err(ZeptoError::Tool(format!(
"Text '{}' not found in file '{}'",
crate::utils::string::preview(old_text, 50),
full_path
)))
}
}
}
} else {
Err(ZeptoError::Tool(
"Provide either 'diff' or 'old_text'/'new_text'".into(),
))
}
}
}
#[derive(Debug)]
enum MatchTier {
Exact,
UnicodeNormalized,
WhitespaceNormalized,
}
impl std::fmt::Display for MatchTier {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MatchTier::Exact => write!(f, "exact"),
MatchTier::UnicodeNormalized => write!(f, "unicode-normalized"),
MatchTier::WhitespaceNormalized => write!(f, "whitespace-normalized"),
}
}
}
#[derive(Debug)]
struct UniqueMatch {
start: usize,
end: usize,
tier: MatchTier,
}
#[derive(Debug)]
enum EditMatchError {
NotFound,
MultipleMatches(usize),
}
impl std::fmt::Display for EditMatchError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
EditMatchError::NotFound => write!(f, "not found"),
EditMatchError::MultipleMatches(n) => write!(f, "Found {} occurrences", n),
}
}
}
fn normalize_whitespace(s: &str) -> String {
s.replace("\r\n", "\n")
.lines()
.map(|line| {
let mut result = String::new();
let mut prev_ws = false;
for ch in line.chars() {
if ch == ' ' || ch == '\t' {
if !prev_ws {
result.push(' ');
}
prev_ws = true;
} else {
result.push(ch);
prev_ws = false;
}
}
result.trim_end().to_string()
})
.collect::<Vec<_>>()
.join("\n")
}
fn find_all_occurrences(haystack: &str, needle: &str) -> Vec<usize> {
if needle.is_empty() {
return Vec::new();
}
let mut positions = Vec::new();
let mut start = 0;
while let Some(pos) = haystack[start..].find(needle) {
positions.push(start + pos);
start += pos + needle.len();
}
positions
}
fn map_nfc_range_to_original(
original: &str,
_nfc: &str,
nfc_start: usize,
nfc_end: usize,
) -> (usize, usize) {
use unicode_normalization::UnicodeNormalization;
let orig_indices: Vec<(usize, usize)> = original
.char_indices()
.map(|(i, ch)| (i, i + ch.len_utf8()))
.collect();
let mut nfc_to_orig_start: Vec<usize> = Vec::new();
let mut nfc_to_orig_end: Vec<usize> = Vec::new();
let mut orig_idx = 0;
for nfc_ch in original.nfc() {
let orig_start_byte = orig_indices
.get(orig_idx)
.map(|&(s, _)| s)
.unwrap_or(original.len());
loop {
orig_idx += 1;
let scan_end = orig_indices
.get(orig_idx)
.map(|&(s, _)| s)
.unwrap_or(original.len());
let slice = &original[orig_start_byte..scan_end];
let normalized: String = slice.nfc().collect();
if (normalized.len() == nfc_ch.len_utf8() && normalized.starts_with(nfc_ch))
|| orig_idx >= orig_indices.len()
{
break;
}
}
let orig_end_byte = orig_indices
.get(orig_idx)
.map(|&(s, _)| s)
.unwrap_or(original.len());
for _ in 0..nfc_ch.len_utf8() {
nfc_to_orig_start.push(orig_start_byte);
nfc_to_orig_end.push(orig_end_byte);
}
}
let orig_start = nfc_to_orig_start.get(nfc_start).copied().unwrap_or(0);
let orig_end = if nfc_end > 0 {
nfc_to_orig_end
.get(nfc_end - 1)
.copied()
.unwrap_or(original.len())
} else {
0
};
(orig_start, orig_end)
}
fn map_ws_range_to_original(
original: &str,
normalized: &str,
norm_start: usize,
norm_end: usize,
) -> (usize, usize) {
let orig_bytes = original.as_bytes();
let norm_bytes = normalized.as_bytes();
let mut orig_i = 0;
let mut norm_i = 0;
let mut result_start = 0;
let mut result_end = original.len();
while norm_i < norm_bytes.len() && orig_i < orig_bytes.len() {
if norm_i == norm_start {
result_start = orig_i;
}
if norm_i == norm_end {
result_end = orig_i;
break;
}
if orig_bytes[orig_i] == b'\r'
&& orig_i + 1 < orig_bytes.len()
&& orig_bytes[orig_i + 1] == b'\n'
&& norm_bytes[norm_i] == b'\n'
{
orig_i += 2;
norm_i += 1;
continue;
}
if (orig_bytes[orig_i] == b' ' || orig_bytes[orig_i] == b'\t') && norm_bytes[norm_i] == b' '
{
orig_i += 1;
norm_i += 1;
while orig_i < orig_bytes.len()
&& (orig_bytes[orig_i] == b' ' || orig_bytes[orig_i] == b'\t')
{
orig_i += 1;
}
continue;
}
if (orig_bytes[orig_i] == b' ' || orig_bytes[orig_i] == b'\t')
&& (norm_i >= norm_bytes.len() || norm_bytes[norm_i] == b'\n')
{
orig_i += 1;
continue;
}
orig_i += 1;
norm_i += 1;
}
if norm_end >= norm_bytes.len() && result_end == original.len() {
result_end = original.len();
}
(result_start, result_end)
}
fn find_unique_match(
content: &str,
old_text: &str,
) -> std::result::Result<UniqueMatch, EditMatchError> {
let positions = find_all_occurrences(content, old_text);
match positions.len() {
1 => {
return Ok(UniqueMatch {
start: positions[0],
end: positions[0] + old_text.len(),
tier: MatchTier::Exact,
});
}
n if n > 1 => return Err(EditMatchError::MultipleMatches(n)),
_ => {}
}
let content_nfc: String = content.nfc().collect();
let search_nfc: String = old_text.nfc().collect();
if content_nfc != content || search_nfc != old_text {
let positions = find_all_occurrences(&content_nfc, &search_nfc);
match positions.len() {
1 => {
let (orig_start, orig_end) = map_nfc_range_to_original(
content,
&content_nfc,
positions[0],
positions[0] + search_nfc.len(),
);
return Ok(UniqueMatch {
start: orig_start,
end: orig_end,
tier: MatchTier::UnicodeNormalized,
});
}
n if n > 1 => return Err(EditMatchError::MultipleMatches(n)),
_ => {}
}
}
let content_ws = normalize_whitespace(content);
let search_ws = normalize_whitespace(old_text);
let positions = find_all_occurrences(&content_ws, &search_ws);
match positions.len() {
1 => {
let (orig_start, orig_end) = map_ws_range_to_original(
content,
&content_ws,
positions[0],
positions[0] + search_ws.len(),
);
Ok(UniqueMatch {
start: orig_start,
end: orig_end,
tier: MatchTier::WhitespaceNormalized,
})
}
n if n > 1 => Err(EditMatchError::MultipleMatches(n)),
_ => Err(EditMatchError::NotFound),
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[tokio::test]
async fn test_read_file_tool() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("zeptoclaw_test_read.txt");
fs::write(&file_path, "test content").unwrap();
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(json!({"path": "zeptoclaw_test_read.txt"}), &ctx)
.await;
assert!(result.is_ok());
assert_eq!(result.unwrap().for_llm, "test content");
}
#[tokio::test]
async fn test_read_file_tool_not_found() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace(canonical.to_str().unwrap());
let result = tool
.execute(json!({"path": "nonexistent_file.txt"}), &ctx)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Failed to read file"));
}
#[tokio::test]
async fn test_read_file_tool_missing_path() {
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace("/tmp");
let result = tool.execute(json!({}), &ctx).await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("Missing 'path'"));
}
#[tokio::test]
async fn test_read_file_tool_rejects_no_workspace() {
let tool = ReadFileTool;
let ctx = ToolContext::new();
let result = tool
.execute(json!({"path": "/tmp/some_file.txt"}), &ctx)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Workspace not configured"));
}
#[tokio::test]
async fn test_read_file_tool_with_workspace() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
fs::write(&file_path, "workspace content").unwrap();
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool.execute(json!({"path": "test.txt"}), &ctx).await;
assert!(result.is_ok());
assert_eq!(result.unwrap().for_llm, "workspace content");
}
#[tokio::test]
async fn test_write_file_tool() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace(canonical.to_str().unwrap());
let result = tool
.execute(
json!({"path": "write_test.txt", "content": "written content"}),
&ctx,
)
.await;
assert!(result.is_ok());
assert!(result.unwrap().for_llm.contains("Successfully wrote"));
assert_eq!(
fs::read_to_string(canonical.join("write_test.txt")).unwrap(),
"written content"
);
}
#[tokio::test]
async fn test_write_file_tool_creates_parent_dirs() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace(canonical.to_str().unwrap());
let result = tool
.execute(json!({"path": "a/b/c/test.txt", "content": "nested"}), &ctx)
.await;
assert!(result.is_ok());
assert_eq!(
fs::read_to_string(canonical.join("a/b/c/test.txt")).unwrap(),
"nested"
);
}
#[tokio::test]
async fn test_write_file_tool_missing_content() {
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace("/tmp");
let result = tool.execute(json!({"path": "test.txt"}), &ctx).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Missing 'content'"));
}
#[tokio::test]
async fn test_list_dir_tool() {
let dir = tempdir().unwrap();
fs::write(dir.path().join("file1.txt"), "").unwrap();
fs::write(dir.path().join("file2.txt"), "").unwrap();
fs::create_dir(dir.path().join("subdir")).unwrap();
let tool = ListDirTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool.execute(json!({"path": "."}), &ctx).await;
assert!(result.is_ok());
let output = result.unwrap().for_llm;
assert!(output.contains("file1.txt"));
assert!(output.contains("file2.txt"));
assert!(output.contains("subdir/"));
}
#[tokio::test]
async fn test_list_dir_tool_not_found() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let tool = ListDirTool;
let ctx = ToolContext::new().with_workspace(canonical.to_str().unwrap());
let result = tool.execute(json!({"path": "nonexistent_dir"}), &ctx).await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Failed to read directory"));
}
#[tokio::test]
async fn test_list_dir_tool_with_workspace() {
let dir = tempdir().unwrap();
let subdir = dir.path().join("mydir");
fs::create_dir(&subdir).unwrap();
fs::write(subdir.join("inner.txt"), "").unwrap();
let tool = ListDirTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool.execute(json!({"path": "mydir"}), &ctx).await;
assert!(result.is_ok());
assert!(result.unwrap().for_llm.contains("inner.txt"));
}
#[tokio::test]
async fn test_edit_file_tool() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_test.txt");
fs::write(&file_path, "Hello World").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_test.txt",
"old_text": "World",
"new_text": "Rust"
}),
&ctx,
)
.await;
assert!(result.is_ok());
assert!(result
.unwrap()
.for_llm
.contains("Successfully replaced 1 occurrence"));
assert_eq!(fs::read_to_string(&file_path).unwrap(), "Hello Rust");
}
#[tokio::test]
async fn test_edit_file_tool_multiple_occurrences() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_multi.txt");
fs::write(&file_path, "foo bar foo baz foo").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_multi.txt",
"old_text": "foo",
"new_text": "qux",
"expected_replacements": 3
}),
&ctx,
)
.await;
assert!(result.is_ok());
assert!(result.unwrap().for_llm.contains("replaced 3 occurrence"));
assert_eq!(
fs::read_to_string(&file_path).unwrap(),
"qux bar qux baz qux"
);
}
#[tokio::test]
async fn test_edit_file_tool_text_not_found() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_notfound.txt");
fs::write(&file_path, "Hello World").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_notfound.txt",
"old_text": "NotPresent",
"new_text": "Replacement"
}),
&ctx,
)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("not found in file"));
}
#[tokio::test]
async fn test_edit_file_tool_rejects_empty_old_text() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_empty_old.txt");
fs::write(&file_path, "Hello World").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_empty_old.txt",
"old_text": "",
"new_text": "Replacement"
}),
&ctx,
)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("must not be empty"));
assert_eq!(fs::read_to_string(&file_path).unwrap(), "Hello World");
}
#[tokio::test]
async fn test_edit_file_tool_expected_replacements_mismatch() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_expected_count.txt");
fs::write(&file_path, "foo bar foo").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_expected_count.txt",
"old_text": "foo",
"new_text": "qux",
"expected_replacements": 1
}),
&ctx,
)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Expected 1 replacement(s)"));
assert_eq!(fs::read_to_string(&file_path).unwrap(), "foo bar foo");
}
#[tokio::test]
async fn test_edit_file_tool_expected_replacements_match() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("edit_expected_ok.txt");
fs::write(&file_path, "foo bar foo").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "edit_expected_ok.txt",
"old_text": "foo",
"new_text": "qux",
"expected_replacements": 2
}),
&ctx,
)
.await;
assert!(result.is_ok());
assert_eq!(fs::read_to_string(&file_path).unwrap(), "qux bar qux");
}
#[tokio::test]
async fn test_edit_file_tool_missing_args() {
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace("/tmp");
let result = tool
.execute(json!({"path": "test.txt", "new_text": "new"}), &ctx)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Provide either 'diff' or 'old_text'/'new_text'"));
let result = tool
.execute(json!({"path": "test.txt", "old_text": "old"}), &ctx)
.await;
assert!(result.is_err());
assert!(result
.unwrap_err()
.to_string()
.contains("Provide either 'diff' or 'old_text'/'new_text'"));
}
#[test]
fn test_resolve_path_rejects_without_workspace() {
let ctx = ToolContext::new();
let result = resolve_path("relative/path", &ctx);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Workspace not configured"));
}
#[test]
fn test_resolve_path_relative_with_workspace() {
let dir = tempdir().unwrap();
std::fs::create_dir_all(dir.path().join("relative")).unwrap();
std::fs::write(dir.path().join("relative/path"), "").unwrap();
let workspace = dir.path().to_str().unwrap();
let ctx = ToolContext::new().with_workspace(workspace);
let result = resolve_path("relative/path", &ctx);
assert!(result.is_ok());
let (resolved, _ws) = result.unwrap();
assert!(resolved.contains("relative/path") || resolved.ends_with("relative/path"));
}
#[test]
fn test_resolve_path_blocks_absolute_outside_workspace() {
let dir = tempdir().unwrap();
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = resolve_path("/etc/passwd", &ctx);
assert!(result.is_err());
}
#[test]
fn test_tool_names() {
assert_eq!(ReadFileTool.name(), "read_file");
assert_eq!(WriteFileTool.name(), "write_file");
assert_eq!(ListDirTool.name(), "list_dir");
assert_eq!(EditFileTool.name(), "edit_file");
}
#[test]
fn test_tool_descriptions() {
assert!(!ReadFileTool.description().is_empty());
assert!(!WriteFileTool.description().is_empty());
assert!(!ListDirTool.description().is_empty());
assert!(!EditFileTool.description().is_empty());
}
#[test]
fn test_tool_parameters() {
for tool in [
&ReadFileTool as &dyn Tool,
&WriteFileTool,
&ListDirTool,
&EditFileTool,
] {
let params = tool.parameters();
assert!(params.is_object());
assert_eq!(params["type"], "object");
assert!(params["properties"].is_object());
assert!(params["required"].is_array());
}
}
#[tokio::test]
async fn test_path_traversal_blocked() {
let dir = tempdir().unwrap();
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(json!({"path": "../../../etc/passwd"}), &ctx)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("Security violation") || err.contains("escapes workspace"));
}
#[tokio::test]
async fn test_absolute_path_outside_workspace_blocked() {
let dir = tempdir().unwrap();
let tool = ReadFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool.execute(json!({"path": "/etc/passwd"}), &ctx).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_write_tool_rejects_traversal_outside_workspace() {
let dir = tempdir().unwrap();
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({"path": "../../etc/shadow", "content": "pwned"}),
&ctx,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Security violation") || err.contains("traversal"),
"Expected security error, got: {}",
err
);
}
#[tokio::test]
async fn test_list_dir_rejects_absolute_outside_workspace() {
let dir = tempdir().unwrap();
let tool = ListDirTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool.execute(json!({"path": "/etc"}), &ctx).await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Security violation") || err.contains("escapes workspace"),
"Expected security error, got: {}",
err
);
}
#[tokio::test]
async fn test_edit_tool_rejects_no_workspace() {
let tool = EditFileTool;
let ctx = ToolContext::new();
let result = tool
.execute(
json!({
"path": "/tmp/test.txt",
"old_text": "a",
"new_text": "b"
}),
&ctx,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Workspace not configured"),
"Expected workspace error, got: {}",
err
);
}
#[test]
fn test_resolve_path_blocks_url_encoded_traversal() {
let dir = tempdir().unwrap();
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = resolve_path("%2e%2e/etc/passwd", &ctx);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Security violation") || err.contains("traversal"),
"Expected security error for URL-encoded traversal, got: {}",
err
);
}
#[test]
fn test_resolve_path_blocks_double_encoded_traversal() {
let dir = tempdir().unwrap();
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = resolve_path("%252e%252e/etc/passwd", &ctx);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("Security violation") || err.contains("traversal"),
"Expected security error for double-encoded traversal, got: {}",
err
);
}
#[tokio::test]
async fn test_write_blocks_hardlinked_file() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let workspace = canonical.to_str().unwrap();
let original = canonical.join("original.txt");
fs::write(&original, "original content").unwrap();
let hardlink = canonical.join("hardlink.txt");
fs::hard_link(&original, &hardlink).unwrap();
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace(workspace);
let result = tool
.execute(
json!({"path": "hardlink.txt", "content": "malicious"}),
&ctx,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("hard links"),
"Expected hardlink error, got: {}",
err
);
}
#[tokio::test]
async fn test_edit_blocks_hardlinked_file() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let workspace = canonical.to_str().unwrap();
let original = canonical.join("editable.txt");
fs::write(&original, "Hello World").unwrap();
let hardlink = canonical.join("edit_link.txt");
fs::hard_link(&original, &hardlink).unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(workspace);
let result = tool
.execute(
json!({
"path": "edit_link.txt",
"old_text": "Hello",
"new_text": "Goodbye"
}),
&ctx,
)
.await;
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(
err.contains("hard links"),
"Expected hardlink error, got: {}",
err
);
}
#[tokio::test]
async fn test_write_allows_single_link_file() {
let dir = tempdir().unwrap();
let canonical = dir.path().canonicalize().unwrap();
let workspace = canonical.to_str().unwrap();
fs::write(canonical.join("normal.txt"), "original").unwrap();
let tool = WriteFileTool;
let ctx = ToolContext::new().with_workspace(workspace);
let result = tool
.execute(json!({"path": "normal.txt", "content": "updated"}), &ctx)
.await;
assert!(result.is_ok());
assert_eq!(
fs::read_to_string(canonical.join("normal.txt")).unwrap(),
"updated"
);
}
#[tokio::test]
async fn test_edit_file_diff_mode_simple() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("diff_test.txt");
fs::write(&file_path, "line one\nline two\nline three\n").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "diff_test.txt",
"diff": "@@ -1,3 +1,3 @@\n line one\n-line two\n+LINE TWO\n line three"
}),
&ctx,
)
.await;
assert!(result.is_ok());
let output = result.unwrap().for_llm;
assert!(output.contains("Applied 1 hunk"));
assert_eq!(
fs::read_to_string(&file_path).unwrap(),
"line one\nLINE TWO\nline three\n"
);
}
#[tokio::test]
async fn test_edit_file_diff_mode_context_mismatch() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("diff_mismatch.txt");
fs::write(&file_path, "foo\nbar\nbaz\n").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "diff_mismatch.txt",
"diff": "@@ -1,3 +1,3 @@\n foo\n WRONG\n baz"
}),
&ctx,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("context mismatch"));
}
#[tokio::test]
async fn test_edit_file_diff_and_old_text_mutually_exclusive() {
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace("/tmp");
let result = tool
.execute(
json!({
"path": "test.txt",
"diff": "@@ -1,1 +1,1 @@\n-a\n+b",
"old_text": "a",
"new_text": "b"
}),
&ctx,
)
.await;
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not both"));
}
#[tokio::test]
async fn test_edit_file_tool_multi_match_without_expected_errors() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
fs::write(&file_path, "foo bar foo baz foo").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "test.txt",
"old_text": "foo",
"new_text": "qux"
}),
&ctx,
)
.await;
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("3 occurrences"));
assert!(err.contains("more surrounding context"));
}
#[tokio::test]
async fn test_edit_file_tool_fuzzy_whitespace_match() {
let dir = tempdir().unwrap();
let file_path = dir.path().join("test.txt");
fs::write(&file_path, "fn\tmain()\t{}").unwrap();
let tool = EditFileTool;
let ctx = ToolContext::new().with_workspace(dir.path().to_str().unwrap());
let result = tool
.execute(
json!({
"path": "test.txt",
"old_text": "fn main() {}",
"new_text": "fn run() {}"
}),
&ctx,
)
.await;
assert!(result.is_ok());
let content = fs::read_to_string(&file_path).unwrap();
assert_eq!(content, "fn run() {}");
}
use super::{find_unique_match, MatchTier};
#[test]
fn test_exact_single_match() {
let content = "fn main() {}";
let result = find_unique_match(content, "main").unwrap();
assert_eq!(&content[result.start..result.end], "main");
assert!(matches!(result.tier, MatchTier::Exact));
}
#[test]
fn test_exact_multi_match_errors() {
let content = "foo bar foo baz foo";
let result = find_unique_match(content, "foo");
assert!(result.is_err());
let err = format!("{}", result.unwrap_err());
assert!(err.contains("3"));
}
#[test]
fn test_not_found_errors() {
let content = "fn main() {}";
let result = find_unique_match(content, "nonexistent");
assert!(result.is_err());
}
#[test]
fn test_unicode_nfc_fallback() {
let content = "caf\u{0065}\u{0301}"; let search = "caf\u{00E9}"; let result = find_unique_match(content, search).unwrap();
assert!(matches!(result.tier, MatchTier::UnicodeNormalized));
assert_eq!(result.start, 0);
assert_eq!(result.end, content.len());
}
#[test]
fn test_unicode_nfc_mid_string() {
let content = "hello caf\u{0065}\u{0301} world";
let search = "caf\u{00E9}";
let result = find_unique_match(content, search).unwrap();
assert!(matches!(result.tier, MatchTier::UnicodeNormalized));
assert_eq!(result.start, 6);
assert_eq!(&content[result.start..result.end], "caf\u{0065}\u{0301}");
}
#[test]
fn test_whitespace_tabs_vs_spaces() {
let content = "fn\tmain()\t{}";
let search = "fn main() {}";
let result = find_unique_match(content, search).unwrap();
assert!(matches!(result.tier, MatchTier::WhitespaceNormalized));
assert_eq!(result.start, 0);
assert_eq!(result.end, content.len());
}
#[test]
fn test_whitespace_trailing() {
let content = "hello \nworld";
let search = "hello\nworld";
let result = find_unique_match(content, search).unwrap();
assert!(matches!(result.tier, MatchTier::WhitespaceNormalized));
assert_eq!(result.start, 0);
assert_eq!(result.end, content.len());
}
#[test]
fn test_whitespace_crlf_normalization() {
let content = "line1\r\nline2";
let search = "line1\nline2";
let result = find_unique_match(content, search).unwrap();
assert!(matches!(result.tier, MatchTier::WhitespaceNormalized));
assert_eq!(result.start, 0);
assert_eq!(result.end, content.len());
}
#[test]
fn test_fuzzy_multi_match_errors() {
let content = "fn\tmain() {}\nfn\t\tmain() {}";
let search = "fn main() {}";
let result = find_unique_match(content, search);
assert!(result.is_err());
}
#[test]
fn test_empty_content() {
let result = find_unique_match("", "search");
assert!(result.is_err());
}
#[test]
fn test_nfc_offset_map_combining_accent() {
use unicode_normalization::UnicodeNormalization;
let original = "e\u{0301}";
let nfc: String = original.nfc().collect();
assert_eq!(nfc, "\u{00E9}");
let (start, end) = map_nfc_range_to_original(original, &nfc, 0, nfc.len());
assert_eq!(start, 0);
assert_eq!(end, original.len()); }
#[test]
fn test_nfc_offset_map_no_drift_after_composition() {
use unicode_normalization::UnicodeNormalization;
let original = "cafe\u{0301} x nai\u{0308}ve";
let nfc: String = original.nfc().collect();
assert_eq!(nfc, "café x naïve");
let naive_nfc_start = nfc.find("naïve").unwrap();
let naive_nfc_end = naive_nfc_start + "naïve".len();
let (start, end) =
map_nfc_range_to_original(original, &nfc, naive_nfc_start, naive_nfc_end);
let expected_start = original.find("nai").unwrap();
let expected_end = original.len();
assert_eq!(start, expected_start);
assert_eq!(end, expected_end);
}
#[test]
fn test_nfc_offset_map_ascii_identity() {
use unicode_normalization::UnicodeNormalization;
let original = "hello world";
let nfc: String = original.nfc().collect();
let (start, end) = map_nfc_range_to_original(original, &nfc, 6, 11);
assert_eq!(start, 6);
assert_eq!(end, 11);
}
}