1use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9use walkdir::WalkDir;
10
11const MAX_RESULTS: usize = 50;
12const MAX_CONTEXT_LINES: usize = 3;
13
14pub struct CodeSearchTool {
15 root: PathBuf,
16}
17
18impl Default for CodeSearchTool {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24#[allow(dead_code)]
25impl CodeSearchTool {
26 pub fn new() -> Self {
27 Self {
28 root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
29 }
30 }
31
32 pub fn with_root(root: PathBuf) -> Self {
33 Self { root }
34 }
35
36 fn should_skip(&self, path: &std::path::Path) -> bool {
37 let skip_dirs = [
38 ".git",
39 "node_modules",
40 "target",
41 "dist",
42 ".next",
43 "__pycache__",
44 ".venv",
45 "vendor",
46 ];
47 path.components()
48 .any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
49 }
50
51 fn is_text_file(&self, path: &std::path::Path) -> bool {
52 let text_exts = [
53 "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
54 "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
55 ];
56 path.extension()
57 .and_then(|e| e.to_str())
58 .map(|e| text_exts.contains(&e))
59 .unwrap_or(false)
60 }
61
62 fn search_file(
63 &self,
64 path: &std::path::Path,
65 pattern: ®ex::Regex,
66 context: usize,
67 ) -> Result<Vec<Match>> {
68 let content = std::fs::read_to_string(path)?;
69 let lines: Vec<&str> = content.lines().collect();
70 let mut matches = Vec::new();
71
72 for (idx, line) in lines.iter().enumerate() {
73 if pattern.is_match(line) {
74 let start = idx.saturating_sub(context);
75 let end = (idx + context + 1).min(lines.len());
76 let context_lines: Vec<String> = lines[start..end]
77 .iter()
78 .enumerate()
79 .map(|(i, l)| {
80 let line_num = start + i + 1;
81 let marker = if start + i == idx { ">" } else { " " };
82 format!("{} {:4}: {}", marker, line_num, l)
83 })
84 .collect();
85
86 matches.push(Match {
87 path: path
88 .strip_prefix(&self.root)
89 .unwrap_or(path)
90 .to_string_lossy()
91 .to_string(),
92 line: idx + 1,
93 matched_line: line.to_string(),
94 context: context_lines.join("\n"),
95 });
96 }
97 }
98 Ok(matches)
99 }
100}
101
102#[derive(Debug)]
103struct Match {
104 path: String,
105 line: usize,
106 matched_line: String,
107 context: String,
108}
109
110impl std::fmt::Display for Match {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 write!(f, "{}:{}: {}", self.path, self.line, self.matched_line)
113 }
114}
115
116#[derive(Deserialize)]
117struct Params {
118 pattern: String,
119 #[serde(default)]
120 path: Option<String>,
121 #[serde(default)]
122 file_pattern: Option<String>,
123 #[serde(default = "default_context")]
124 context_lines: usize,
125 #[serde(default)]
126 case_sensitive: bool,
127}
128
129fn default_context() -> usize {
130 2
131}
132
133#[async_trait]
134impl Tool for CodeSearchTool {
135 fn id(&self) -> &str {
136 "codesearch"
137 }
138 fn name(&self) -> &str {
139 "Code Search"
140 }
141 fn description(&self) -> &str {
142 "Search for code patterns in the workspace. Supports regex."
143 }
144 fn parameters(&self) -> Value {
145 json!({
146 "type": "object",
147 "properties": {
148 "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
149 "path": {"type": "string", "description": "Subdirectory to search in"},
150 "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
151 "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
152 "case_sensitive": {"type": "boolean", "default": false}
153 },
154 "required": ["pattern"]
155 })
156 }
157
158 async fn execute(&self, params: Value) -> Result<ToolResult> {
159 let p: Params = serde_json::from_value(params).context("Invalid params")?;
160
161 let regex = regex::RegexBuilder::new(&p.pattern)
162 .case_insensitive(!p.case_sensitive)
163 .build()
164 .context("Invalid regex pattern")?;
165
166 let search_root = match &p.path {
167 Some(subpath) => self.root.join(subpath),
168 None => self.root.clone(),
169 };
170
171 let file_glob = p
172 .file_pattern
173 .as_ref()
174 .and_then(|pat| glob::Pattern::new(pat).ok());
175
176 let mut all_matches = Vec::new();
177
178 for entry in WalkDir::new(&search_root)
179 .into_iter()
180 .filter_map(|e| e.ok())
181 {
182 let path = entry.path();
183 if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
184 continue;
185 }
186
187 if let Some(ref glob) = file_glob
188 && !glob.matches_path(path)
189 {
190 continue;
191 }
192
193 if let Ok(matches) =
194 self.search_file(path, ®ex, p.context_lines.min(MAX_CONTEXT_LINES))
195 {
196 all_matches.extend(matches);
197 if all_matches.len() >= MAX_RESULTS {
198 break;
199 }
200 }
201 }
202
203 if all_matches.is_empty() {
204 return Ok(ToolResult::success(format!(
205 "No matches found for pattern: {}",
206 p.pattern
207 )));
208 }
209
210 let output = all_matches
211 .iter()
212 .take(MAX_RESULTS)
213 .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
214 .collect::<Vec<_>>()
215 .join("\n\n");
216
217 let truncated = all_matches.len() > MAX_RESULTS;
218 let msg = if truncated {
219 format!(
220 "Found {} matches (showing first {}):\n\n{}",
221 all_matches.len(),
222 MAX_RESULTS,
223 output
224 )
225 } else {
226 format!("Found {} matches:\n\n{}", all_matches.len(), output)
227 };
228
229 Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
230 }
231}