use crate::tools::types::{Tool, ToolContext, ToolOutput};
use crate::tools::MAX_OUTPUT_SIZE;
use anyhow::Result;
use async_trait::async_trait;
use ignore::WalkBuilder;
use regex::Regex;
use std::path::PathBuf;
pub struct GrepTool;
#[async_trait]
impl Tool for GrepTool {
fn name(&self) -> &str {
"grep"
}
fn description(&self) -> &str {
"Search for a pattern in files using ripgrep. Returns matching lines with file paths and line numbers."
}
fn parameters(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"additionalProperties": false,
"properties": {
"pattern": {
"type": "string",
"description": "Required. Regular expression pattern to search for. Always provide this exact field name: 'pattern'."
},
"path": {
"type": "string",
"description": "Optional. Directory or file to search in. Default: workspace root."
},
"glob": {
"type": "string",
"description": "Optional. Glob pattern to filter files, for example '*.rs' or '*.{ts,tsx}'."
},
"context": {
"type": "integer",
"description": "Optional. Number of context lines to show before and after matches."
},
"-i": {
"type": "boolean",
"description": "Optional. Case insensitive search."
}
},
"required": ["pattern"],
"examples": [
{
"pattern": "TODO"
},
{
"pattern": "fn main",
"path": "src",
"glob": "*.rs",
"context": 2
}
]
})
}
async fn execute(&self, args: &serde_json::Value, ctx: &ToolContext) -> Result<ToolOutput> {
let pattern_str = match args.get("pattern").and_then(|v| v.as_str()) {
Some(p) => p,
None => return Ok(ToolOutput::error("pattern parameter is required")),
};
let case_insensitive = args.get("-i").and_then(|v| v.as_bool()).unwrap_or(false);
let regex_pattern = if case_insensitive {
format!("(?i){}", pattern_str)
} else {
pattern_str.to_string()
};
let regex = match Regex::new(®ex_pattern) {
Ok(r) => r,
Err(e) => {
return Ok(ToolOutput::error(format!(
"Invalid regex pattern '{}': {}",
pattern_str, e
)))
}
};
let search_path = match args.get("path").and_then(|v| v.as_str()) {
Some(p) => {
if std::path::Path::new(p).is_absolute() {
PathBuf::from(p)
} else {
ctx.workspace.join(p)
}
}
None => ctx.workspace.clone(),
};
let glob_filter = args.get("glob").and_then(|v| v.as_str());
let context_lines = args.get("context").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let mut builder = WalkBuilder::new(&search_path);
builder.hidden(false).git_ignore(true).git_global(true);
if let Some(glob_pat) = glob_filter {
let mut types = ignore::types::TypesBuilder::new();
types.add("custom", glob_pat).ok();
types.select("custom");
if let Ok(built) = types.build() {
builder.types(built);
}
}
let mut output = String::new();
let mut match_count = 0;
let mut file_count = 0;
let mut total_size = 0;
for entry in builder.build().flatten() {
if !entry.file_type().map(|ft| ft.is_file()).unwrap_or(false) {
continue;
}
let file_path = entry.path();
let content = match std::fs::read_to_string(file_path) {
Ok(c) => c,
Err(_) => continue, };
let lines: Vec<&str> = content.lines().collect();
let mut file_matches = Vec::new();
for (line_idx, line) in lines.iter().enumerate() {
if regex.is_match(line) {
file_matches.push(line_idx);
}
}
if file_matches.is_empty() {
continue;
}
file_count += 1;
let file_display = file_path.to_string_lossy().replace('\\', "/");
let workspace_display = ctx.workspace.to_string_lossy().replace('\\', "/");
let rel_path = if let Some(stripped) = file_display.strip_prefix(&workspace_display) {
stripped.trim_start_matches('/').to_string()
} else {
file_path
.strip_prefix(&ctx.workspace)
.unwrap_or(file_path)
.to_string_lossy()
.to_string()
};
for &match_idx in &file_matches {
if total_size > MAX_OUTPUT_SIZE {
output.push_str("\n... (output truncated)\n");
return Ok(ToolOutput::success(format!(
"{}Found {} matches in {} files (output truncated)",
output, match_count, file_count
)));
}
match_count += 1;
let start = match_idx.saturating_sub(context_lines);
let end = (match_idx + context_lines + 1).min(lines.len());
for (i, line) in lines[start..end].iter().enumerate() {
let abs_i = start + i;
let prefix = if abs_i == match_idx { ">" } else { " " };
let line_str = format!("{}{}:{}: {}\n", prefix, rel_path, abs_i + 1, line);
total_size += line_str.len();
output.push_str(&line_str);
}
if context_lines > 0 {
output.push_str("--\n");
total_size += 3;
}
}
}
if match_count == 0 {
Ok(ToolOutput::success(format!(
"No matches found for pattern: {}",
pattern_str
)))
} else {
output.push_str(&format!(
"\n{} match(es) in {} file(s)",
match_count, file_count
));
Ok(ToolOutput::success(output))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_grep_find_pattern() {
let temp = tempfile::tempdir().unwrap();
std::fs::write(
temp.path().join("a.txt"),
"hello world\nfoo bar\nhello again",
)
.unwrap();
std::fs::write(temp.path().join("b.txt"), "no match here").unwrap();
let tool = GrepTool;
let ctx = ToolContext::new(temp.path().to_path_buf());
let result = tool
.execute(&serde_json::json!({"pattern": "hello"}), &ctx)
.await
.unwrap();
assert!(result.success);
assert!(result.content.contains("hello world"));
assert!(result.content.contains("hello again"));
assert!(result.content.contains("2 match(es)"));
}
#[tokio::test]
async fn test_grep_no_match() {
let temp = tempfile::tempdir().unwrap();
std::fs::write(temp.path().join("a.txt"), "hello").unwrap();
let tool = GrepTool;
let ctx = ToolContext::new(temp.path().to_path_buf());
let result = tool
.execute(&serde_json::json!({"pattern": "xyz"}), &ctx)
.await
.unwrap();
assert!(result.success);
assert!(result.content.contains("No matches found"));
}
#[tokio::test]
async fn test_grep_case_insensitive() {
let temp = tempfile::tempdir().unwrap();
std::fs::write(temp.path().join("a.txt"), "Hello World\nhello world").unwrap();
let tool = GrepTool;
let ctx = ToolContext::new(temp.path().to_path_buf());
let result = tool
.execute(&serde_json::json!({"pattern": "hello", "-i": true}), &ctx)
.await
.unwrap();
assert!(result.success);
assert!(result.content.contains("2 match(es)"));
}
#[tokio::test]
async fn test_grep_invalid_regex() {
let tool = GrepTool;
let ctx = ToolContext::new(PathBuf::from("/tmp"));
let result = tool
.execute(&serde_json::json!({"pattern": "[invalid"}), &ctx)
.await
.unwrap();
assert!(!result.success);
assert!(result.content.contains("Invalid regex"));
}
#[tokio::test]
async fn test_grep_missing_pattern() {
let tool = GrepTool;
let ctx = ToolContext::new(PathBuf::from("/tmp"));
let result = tool.execute(&serde_json::json!({}), &ctx).await.unwrap();
assert!(!result.success);
}
#[test]
fn test_grep_schema_is_canonical() {
let tool = GrepTool;
let params = tool.parameters();
assert_eq!(params["additionalProperties"], false);
assert_eq!(params["required"], serde_json::json!(["pattern"]));
let examples = params["examples"].as_array().unwrap();
assert_eq!(examples[0]["pattern"], "TODO");
assert!(examples[0].get("query").is_none());
}
}