use super::{AgentTool, AgentToolResult, ToolError};
use async_trait::async_trait;
use regex::RegexBuilder;
use serde_json::{json, Value};
use std::path::Path;
use tokio::fs;
use tokio::sync::oneshot;
const GREP_MAX_LINE_LENGTH: usize = 500;
fn truncate_line(line: &str) -> (String, bool) {
if line.len() <= GREP_MAX_LINE_LENGTH {
(line.to_string(), false)
} else {
(
format!("{}... [truncated]", &line[..GREP_MAX_LINE_LENGTH]),
true,
)
}
}
pub struct GrepTool;
impl GrepTool {
pub fn new() -> Self {
Self
}
fn matches_glob(file_name: &str, pattern: &str) -> bool {
if pattern.starts_with("*.") {
let ext = &pattern[2..];
file_name.ends_with(ext)
} else if pattern.contains('*') {
let parts: Vec<&str> = pattern.split('*').collect();
if parts.len() == 2 {
file_name.starts_with(parts[0]) && file_name.ends_with(parts[1])
} else {
file_name == pattern
}
} else {
file_name == pattern
}
}
async fn grep_impl(
pattern: &str,
path: &str,
case_insensitive: bool,
literal: bool,
context_before: usize,
context_after: usize,
include: Option<&str>,
max_results: usize,
) -> Result<(String, bool), ToolError> {
let root = Path::new(path);
if root.components().any(|c| c.as_os_str() == "..") {
return Err("Path traversal not allowed".to_string());
}
if !root.exists() {
return Err(format!("Path not found: {}", path));
}
let pattern = if literal {
regex::escape(pattern)
} else {
pattern.to_string()
};
let re = RegexBuilder::new(&pattern)
.case_insensitive(case_insensitive)
.build()
.map_err(|e| format!("Invalid pattern '{}': {}", pattern, e))?;
let mut matches: Vec<String> = Vec::new();
let mut lines_truncated = false;
Self::grep_walk(
root,
root,
&re,
include,
context_before,
context_after,
max_results,
&mut matches,
&mut lines_truncated,
)
.await?;
if matches.is_empty() {
Ok(("No matches found".to_string(), false))
} else {
let header = format!("Found {} matches:\n", matches.len());
Ok((header + &matches.join("\n"), lines_truncated))
}
}
async fn read_file_lines(path: &Path) -> Result<Vec<String>, ToolError> {
match fs::read_to_string(path).await {
Ok(content) => {
let normalized = content.replace("\r\n", "\n").replace('\r', "\n");
Ok(normalized.lines().map(|s| s.to_string()).collect())
}
Err(e) => Err(format!("Cannot read file: {}", e)),
}
}
async fn grep_walk(
root: &Path,
current: &Path,
re: ®ex::Regex,
include: Option<&str>,
context_before: usize,
context_after: usize,
max_results: usize,
matches: &mut Vec<String>,
lines_truncated: &mut bool,
) -> Result<(), ToolError> {
if matches.len() >= max_results {
return Ok(());
}
if current.is_file() {
if let Some(glob) = include {
let file_name = current
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
if !Self::matches_glob(&file_name, glob) {
return Ok(());
}
}
match Self::read_file_lines(current).await {
Ok(lines) => {
let relative = current
.strip_prefix(root)
.unwrap_or(current)
.display();
for (i, line) in lines.iter().enumerate() {
if re.is_match(line) {
let context_lines_count = if context_before > 0 || context_after > 0 {
let start = if context_before > 0 {
i.saturating_sub(context_before)
} else {
i
};
let end = std::cmp::min(lines.len(), i + context_after + 1);
end - start
} else {
1
};
if matches.len() + context_lines_count > max_results {
return Ok(());
}
if context_before > 0 && i > 0 {
let start = i.saturating_sub(context_before);
for j in start..i {
let (truncated_text, was_truncated) = truncate_line(&lines[j]);
if was_truncated {
*lines_truncated = true;
}
matches.push(format!("{}-{}- {}", relative, j + 1, truncated_text));
}
}
let (truncated_text, was_truncated) = truncate_line(line);
if was_truncated {
*lines_truncated = true;
}
matches.push(format!("{}:{}: {}", relative, i + 1, truncated_text));
if context_after > 0 {
let end = std::cmp::min(lines.len(), i + context_after + 1);
for j in (i + 1)..end {
let (truncated_text, was_truncated) = truncate_line(&lines[j]);
if was_truncated {
*lines_truncated = true;
}
matches.push(format!("{}-{}- {}", relative, j + 1, truncated_text));
}
}
if matches.len() >= max_results {
return Ok(());
}
}
}
}
Err(_) => {
}
}
return Ok(());
}
let mut entries = fs::read_dir(current)
.await
.map_err(|e| format!("Cannot read directory {}: {}", current.display(), e))?;
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| format!("Error reading entry: {}", e))?
{
let entry_path = entry.path();
if entry_path
.file_name()
.map(|n| n.to_string_lossy().starts_with('.'))
.unwrap_or(false)
{
continue;
}
if entry_path.is_dir() {
let dir_name = entry_path
.file_name()
.map(|n| n.to_string_lossy().to_string())
.unwrap_or_default();
if matches!(
dir_name.as_str(),
"node_modules"
| "target"
| ".git"
| "dist"
| "build"
| "__pycache__"
| ".venv"
| "venv"
) {
continue;
}
}
Box::pin(Self::grep_walk(
root,
&entry_path,
re,
include,
context_before,
context_after,
max_results,
matches,
lines_truncated,
))
.await?;
}
Ok(())
}
}
impl Default for GrepTool {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl AgentTool for GrepTool {
fn name(&self) -> &str {
"grep"
}
fn label(&self) -> &str {
"Grep"
}
fn description(&self) -> &str {
"Search files for a pattern. Returns matching lines with file paths and line numbers. Use literal=true to treat pattern as a literal string. Use context=n to show n lines before and after matches. Long lines are truncated to 500 chars."
}
fn parameters_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"pattern": {
"type": "string",
"description": "The pattern to search for (regex by default, or literal string if literal=true)"
},
"path": {
"type": "string",
"description": "The directory or file to search in",
"default": "."
},
"case_insensitive": {
"type": "boolean",
"description": "If true, perform case-insensitive search",
"default": false
},
"literal": {
"type": "boolean",
"description": "If true, treat pattern as a literal string instead of regex",
"default": false
},
"context": {
"type": "integer",
"description": "Number of lines to show before and after each match",
"default": 0
},
"include": {
"type": "string",
"description": "Glob pattern to filter files (e.g., '*.rs', '*.ts')"
},
"max_results": {
"type": "integer",
"description": "Maximum number of results to return",
"default": 100
}
},
"required": ["pattern"]
})
}
async fn execute(
&self,
_tool_call_id: &str,
params: Value,
_signal: Option<oneshot::Receiver<()>>,
) -> Result<AgentToolResult, ToolError> {
let pattern = params
.get("pattern")
.and_then(|v: &Value| v.as_str())
.ok_or_else(|| "Missing required parameter: pattern".to_string())?;
let path = params
.get("path")
.and_then(|v: &Value| v.as_str())
.unwrap_or(".");
let case_insensitive = params
.get("case_insensitive")
.and_then(|v: &Value| v.as_bool())
.unwrap_or(false);
let literal = params
.get("literal")
.and_then(|v: &Value| v.as_bool())
.unwrap_or(false);
let context = params
.get("context")
.and_then(|v: &Value| v.as_u64())
.unwrap_or(0) as usize;
let include = params.get("include").and_then(|v: &Value| v.as_str());
let max_results = params
.get("max_results")
.and_then(|v: &Value| v.as_u64())
.unwrap_or(100) as usize;
match Self::grep_impl(
pattern,
path,
case_insensitive,
literal,
context,
context,
include,
max_results,
)
.await
{
Ok((output, lines_truncated)) => {
let mut result = AgentToolResult::success(output);
if lines_truncated {
result.metadata = Some(json!({
"lines_truncated": true,
"message": "Some lines truncated to 500 chars. Use read tool to see full lines."
}));
}
Ok(result)
}
Err(e) => Ok(AgentToolResult::error(e)),
}
}
}