1use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9use walkdir::WalkDir;
10
11const MAX_RESULTS: usize = 50;
12const MAX_CONTEXT_LINES: usize = 3;
13
14pub struct CodeSearchTool {
15 root: PathBuf,
16}
17
18impl Default for CodeSearchTool {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24#[allow(dead_code)]
25impl CodeSearchTool {
26 pub fn new() -> Self {
27 Self {
28 root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
29 }
30 }
31
32 pub fn with_root(root: PathBuf) -> Self {
33 Self { root }
34 }
35
36 fn should_skip(&self, path: &std::path::Path) -> bool {
37 let skip_dirs = [
38 ".git",
39 "node_modules",
40 "target",
41 "dist",
42 ".next",
43 "__pycache__",
44 ".venv",
45 "vendor",
46 ];
47 path.components()
48 .any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
49 }
50
51 fn is_text_file(&self, path: &std::path::Path) -> bool {
52 let text_exts = [
53 "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
54 "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
55 ];
56 path.extension()
57 .and_then(|e| e.to_str())
58 .map(|e| text_exts.contains(&e))
59 .unwrap_or(false)
60 }
61
62 fn search_file(
63 &self,
64 path: &std::path::Path,
65 pattern: ®ex::Regex,
66 context: usize,
67 ) -> Result<Vec<Match>> {
68 let content = std::fs::read_to_string(path)?;
69 let lines: Vec<&str> = content.lines().collect();
70 let mut matches = Vec::new();
71
72 for (idx, line) in lines.iter().enumerate() {
73 if pattern.is_match(line) {
74 let start = idx.saturating_sub(context);
75 let end = (idx + context + 1).min(lines.len());
76 let context_lines: Vec<String> = lines[start..end]
77 .iter()
78 .enumerate()
79 .map(|(i, l)| {
80 let line_num = start + i + 1;
81 let marker = if start + i == idx { ">" } else { " " };
82 format!("{} {:4}: {}", marker, line_num, l)
83 })
84 .collect();
85
86 matches.push(Match {
87 path: path
88 .strip_prefix(&self.root)
89 .unwrap_or(path)
90 .to_string_lossy()
91 .to_string(),
92 line: idx + 1,
93 content: line.to_string(),
94 context: context_lines.join("\n"),
95 });
96 }
97 }
98 Ok(matches)
99 }
100}
101
102#[derive(Debug)]
103struct Match {
104 path: String,
105 line: usize,
106 #[allow(dead_code)]
107 content: String,
108 context: String,
109}
110
111#[derive(Deserialize)]
112struct Params {
113 pattern: String,
114 #[serde(default)]
115 path: Option<String>,
116 #[serde(default)]
117 file_pattern: Option<String>,
118 #[serde(default = "default_context")]
119 context_lines: usize,
120 #[serde(default)]
121 case_sensitive: bool,
122}
123
124fn default_context() -> usize {
125 2
126}
127
128#[async_trait]
129impl Tool for CodeSearchTool {
130 fn id(&self) -> &str {
131 "codesearch"
132 }
133 fn name(&self) -> &str {
134 "Code Search"
135 }
136 fn description(&self) -> &str {
137 "Search for code patterns in the workspace. Supports regex."
138 }
139 fn parameters(&self) -> Value {
140 json!({
141 "type": "object",
142 "properties": {
143 "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
144 "path": {"type": "string", "description": "Subdirectory to search in"},
145 "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
146 "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
147 "case_sensitive": {"type": "boolean", "default": false}
148 },
149 "required": ["pattern"]
150 })
151 }
152
153 async fn execute(&self, params: Value) -> Result<ToolResult> {
154 let p: Params = serde_json::from_value(params).context("Invalid params")?;
155
156 let regex = regex::RegexBuilder::new(&p.pattern)
157 .case_insensitive(!p.case_sensitive)
158 .build()
159 .context("Invalid regex pattern")?;
160
161 let search_root = match &p.path {
162 Some(subpath) => self.root.join(subpath),
163 None => self.root.clone(),
164 };
165
166 let file_glob = p
167 .file_pattern
168 .as_ref()
169 .and_then(|pat| glob::Pattern::new(pat).ok());
170
171 let mut all_matches = Vec::new();
172
173 for entry in WalkDir::new(&search_root)
174 .into_iter()
175 .filter_map(|e| e.ok())
176 {
177 let path = entry.path();
178 if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
179 continue;
180 }
181
182 if let Some(ref glob) = file_glob {
183 if !glob.matches_path(path) {
184 continue;
185 }
186 }
187
188 if let Ok(matches) =
189 self.search_file(path, ®ex, p.context_lines.min(MAX_CONTEXT_LINES))
190 {
191 all_matches.extend(matches);
192 if all_matches.len() >= MAX_RESULTS {
193 break;
194 }
195 }
196 }
197
198 if all_matches.is_empty() {
199 return Ok(ToolResult::success(format!(
200 "No matches found for pattern: {}",
201 p.pattern
202 )));
203 }
204
205 let output = all_matches
206 .iter()
207 .take(MAX_RESULTS)
208 .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
209 .collect::<Vec<_>>()
210 .join("\n\n");
211
212 let truncated = all_matches.len() > MAX_RESULTS;
213 let msg = if truncated {
214 format!(
215 "Found {} matches (showing first {}):\n\n{}",
216 all_matches.len(),
217 MAX_RESULTS,
218 output
219 )
220 } else {
221 format!("Found {} matches:\n\n{}", all_matches.len(), output)
222 };
223
224 Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
225 }
226}