1use anyhow::Result;
7use regex::escape as regex_escape;
8use rig::completion::ToolDefinition;
9use rig::tool::Tool;
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::process::Command;
13
14use super::common::{get_current_repo, parameters_schema};
15use crate::define_tool_error;
16
17define_tool_error!(CodeSearchError);
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct CodeSearch;
22
23impl Default for CodeSearch {
24 fn default() -> Self {
25 Self
26 }
27}
28
29impl CodeSearch {
30 #[must_use]
31 pub fn new() -> Self {
32 Self
33 }
34
35 fn execute_ripgrep_search(
37 query: &str,
38 repo_path: &Path,
39 file_pattern: Option<&str>,
40 search_type: &str,
41 max_results: usize,
42 ) -> Result<Vec<SearchResult>> {
43 let mut cmd = Command::new("rg");
44
45 match search_type {
50 "function" => {
51 let escaped = regex_escape(query);
52 cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
53 cmd.args([
54 "-e",
55 &format!(r"fn\s+{escaped}|function\s+{escaped}|def\s+{escaped}"),
56 ]);
57 }
58 "class" => {
59 let escaped = regex_escape(query);
60 cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
61 cmd.args(["-e", &format!(r"struct\s+{escaped}|class\s+{escaped}")]);
62 }
63 "variable" => {
64 let escaped = regex_escape(query);
65 cmd.args([
66 "-e",
67 &format!(r"let\s+{escaped}|var\s+{escaped}|{escaped}\s*="),
68 ]);
69 }
70 "pattern" => {
71 cmd.args(["--regex-size-limit", "1M", "--dfa-size-limit", "1M"]);
73 cmd.args(["-e", query]);
74 }
75 _ => {
76 cmd.args(["-F", "-i", query]);
78 }
79 }
80
81 if let Some(pattern) = file_pattern {
83 if pattern.contains("..") {
84 return Err(anyhow::anyhow!(
85 "File pattern must not contain '..' path traversal"
86 ));
87 }
88 cmd.args(["-g", pattern]);
89 }
90
91 cmd.args(["-n", "--color", "never", "-A", "3", "-B", "1"]);
93 cmd.current_dir(repo_path);
94
95 let output = cmd.output()?;
96 let stdout = String::from_utf8_lossy(&output.stdout);
97
98 let mut results = Vec::new();
99 let mut current_file = String::new();
100 let mut line_number = 0;
101 let mut content_lines = Vec::new();
102
103 for line in stdout.lines().take(max_results * 4) {
104 if line.contains(':') && !line.starts_with('-') {
106 let parts: Vec<&str> = line.splitn(3, ':').collect();
108 if parts.len() >= 3 {
109 let file_path = parts[0].to_string();
110 if let Ok(line_num) = parts[1].parse::<usize>() {
111 let content = parts[2].to_string();
112
113 if file_path != current_file && !current_file.is_empty() {
114 results.push(SearchResult {
116 file_path: current_file.clone(),
117 line_number,
118 content: content_lines.join("\n"),
119 match_type: search_type.to_string(),
120 context_lines: content_lines.len(),
121 });
122 content_lines.clear();
123 }
124
125 current_file = file_path;
126 line_number = line_num;
127 content_lines.push(content);
128
129 if results.len() >= max_results {
130 break;
131 }
132 }
133 }
134 } else if !line.starts_with('-') && !current_file.is_empty() {
135 content_lines.push(line.to_string());
137 }
138 }
139
140 if !current_file.is_empty() && results.len() < max_results {
142 results.push(SearchResult {
143 file_path: current_file,
144 line_number,
145 content: content_lines.join("\n"),
146 match_type: search_type.to_string(),
147 context_lines: content_lines.len(),
148 });
149 }
150
151 Ok(results)
152 }
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156pub struct SearchResult {
157 pub file_path: String,
158 pub line_number: usize,
159 pub content: String,
160 pub match_type: String,
161 pub context_lines: usize,
162}
163
164#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
166#[serde(rename_all = "lowercase")]
167pub enum SearchType {
168 Function,
170 Class,
172 Variable,
174 #[default]
176 Text,
177 Pattern,
179}
180
181impl SearchType {
182 fn as_str(&self) -> &'static str {
183 match self {
184 SearchType::Function => "function",
185 SearchType::Class => "class",
186 SearchType::Variable => "variable",
187 SearchType::Text => "text",
188 SearchType::Pattern => "pattern",
189 }
190 }
191}
192
193#[derive(Debug, Deserialize, Serialize, schemars::JsonSchema)]
194pub struct CodeSearchArgs {
195 pub query: String,
197 #[serde(default)]
199 pub search_type: SearchType,
200 #[serde(default)]
202 pub file_pattern: Option<String>,
203 #[serde(default = "default_max_results")]
205 pub max_results: usize,
206}
207
208fn default_max_results() -> usize {
209 20
210}
211
212impl Tool for CodeSearch {
213 const NAME: &'static str = "code_search";
214 type Error = CodeSearchError;
215 type Args = CodeSearchArgs;
216 type Output = String;
217
218 async fn definition(&self, _: String) -> ToolDefinition {
219 ToolDefinition {
220 name: "code_search".to_string(),
221 description: "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.".to_string(),
222 parameters: parameters_schema::<CodeSearchArgs>(),
223 }
224 }
225
226 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
227 let repo = get_current_repo().map_err(CodeSearchError::from)?;
228 let repo_path = repo.repo_path().clone();
229 let max_results = args.max_results.min(100); let results = Self::execute_ripgrep_search(
232 &args.query,
233 &repo_path,
234 args.file_pattern.as_deref(),
235 args.search_type.as_str(),
236 max_results,
237 )
238 .map_err(CodeSearchError::from)?;
239
240 let result = serde_json::json!({
241 "query": args.query,
242 "search_type": args.search_type.as_str(),
243 "results": results,
244 "total_found": results.len(),
245 "max_results": max_results,
246 });
247
248 serde_json::to_string_pretty(&result).map_err(|e| CodeSearchError(e.to_string()))
249 }
250}