1use super::{Tool, ToolResult};
4use crate::workspace_scan::path_has_pruned_component;
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use serde::Deserialize;
8use serde_json::{Value, json};
9use std::path::PathBuf;
10use walkdir::WalkDir;
11
12const MAX_RESULTS: usize = 50;
13const MAX_CONTEXT_LINES: usize = 3;
14
15pub struct CodeSearchTool {
16 root: PathBuf,
17}
18
19impl Default for CodeSearchTool {
20 fn default() -> Self {
21 Self::new()
22 }
23}
24
25#[allow(dead_code)]
26impl CodeSearchTool {
27 pub fn new() -> Self {
28 Self {
29 root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
30 }
31 }
32
33 pub fn with_root(root: PathBuf) -> Self {
34 Self { root }
35 }
36
37 fn should_skip(&self, path: &std::path::Path) -> bool {
38 path_has_pruned_component(path)
39 }
40
41 fn is_text_file(&self, path: &std::path::Path) -> bool {
42 let text_exts = [
43 "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
44 "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
45 ];
46 path.extension()
47 .and_then(|e| e.to_str())
48 .map(|e| text_exts.contains(&e))
49 .unwrap_or(false)
50 }
51
52 fn search_file(
53 &self,
54 path: &std::path::Path,
55 pattern: ®ex::Regex,
56 context: usize,
57 ) -> Result<Vec<Match>> {
58 let content = std::fs::read_to_string(path)?;
59 let lines: Vec<&str> = content.lines().collect();
60 let mut matches = Vec::new();
61
62 for (idx, line) in lines.iter().enumerate() {
63 if pattern.is_match(line) {
64 let start = idx.saturating_sub(context);
65 let end = (idx + context + 1).min(lines.len());
66 let context_lines: Vec<String> = lines[start..end]
67 .iter()
68 .enumerate()
69 .map(|(i, l)| {
70 let line_num = start + i + 1;
71 let marker = if start + i == idx { ">" } else { " " };
72 format!("{} {:4}: {}", marker, line_num, l)
73 })
74 .collect();
75
76 matches.push(Match {
77 path: path
78 .strip_prefix(&self.root)
79 .unwrap_or(path)
80 .to_string_lossy()
81 .to_string(),
82 line: idx + 1,
83 matched_line: line.to_string(),
84 context: context_lines.join("\n"),
85 });
86 }
87 }
88 Ok(matches)
89 }
90}
91
92#[derive(Debug)]
93struct Match {
94 path: String,
95 line: usize,
96 matched_line: String,
97 context: String,
98}
99
100impl std::fmt::Display for Match {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 write!(f, "{}:{}: {}", self.path, self.line, self.matched_line)
103 }
104}
105
106#[derive(Deserialize)]
107struct Params {
108 pattern: String,
109 #[serde(default)]
110 path: Option<String>,
111 #[serde(default)]
112 file_pattern: Option<String>,
113 #[serde(default = "default_context")]
114 context_lines: usize,
115 #[serde(default)]
116 case_sensitive: bool,
117}
118
119fn default_context() -> usize {
120 2
121}
122
123#[async_trait]
124impl Tool for CodeSearchTool {
125 fn id(&self) -> &str {
126 "codesearch"
127 }
128 fn name(&self) -> &str {
129 "Code Search"
130 }
131 fn description(&self) -> &str {
132 "Search for code patterns in the workspace. Supports regex."
133 }
134 fn parameters(&self) -> Value {
135 json!({
136 "type": "object",
137 "properties": {
138 "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
139 "path": {"type": "string", "description": "Subdirectory to search in"},
140 "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
141 "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
142 "case_sensitive": {"type": "boolean", "default": false}
143 },
144 "required": ["pattern"]
145 })
146 }
147
148 async fn execute(&self, params: Value) -> Result<ToolResult> {
149 let p: Params = serde_json::from_value(params).context("Invalid params")?;
150
151 let regex = regex::RegexBuilder::new(&p.pattern)
152 .case_insensitive(!p.case_sensitive)
153 .build()
154 .context("Invalid regex pattern")?;
155
156 let search_root = match &p.path {
157 Some(subpath) => self.root.join(subpath),
158 None => self.root.clone(),
159 };
160
161 let file_glob = p
162 .file_pattern
163 .as_ref()
164 .and_then(|pat| glob::Pattern::new(pat).ok());
165
166 let mut all_matches = Vec::new();
167
168 for entry in WalkDir::new(&search_root)
169 .into_iter()
170 .filter_entry(|entry| !self.should_skip(entry.path()))
171 .filter_map(|e| e.ok())
172 {
173 let path = entry.path();
174 if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
175 continue;
176 }
177
178 if let Some(ref glob) = file_glob
179 && !glob.matches_path(path)
180 {
181 continue;
182 }
183
184 if let Ok(matches) =
185 self.search_file(path, ®ex, p.context_lines.min(MAX_CONTEXT_LINES))
186 {
187 all_matches.extend(matches);
188 if all_matches.len() >= MAX_RESULTS {
189 break;
190 }
191 }
192 }
193
194 if all_matches.is_empty() {
195 return Ok(ToolResult::success(format!(
196 "No matches found for pattern: {}",
197 p.pattern
198 )));
199 }
200
201 let output = all_matches
202 .iter()
203 .take(MAX_RESULTS)
204 .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
205 .collect::<Vec<_>>()
206 .join("\n\n");
207
208 let truncated = all_matches.len() > MAX_RESULTS;
209 let msg = if truncated {
210 format!(
211 "Found {} matches (showing first {}):\n\n{}",
212 all_matches.len(),
213 MAX_RESULTS,
214 output
215 )
216 } else {
217 format!("Found {} matches:\n\n{}", all_matches.len(), output)
218 };
219
220 Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
221 }
222}