1use anyhow::Result;
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11use std::process::Command;
12
13use super::common::parameters_schema;
14use crate::define_tool_error;
15
16define_tool_error!(CodeSearchError);
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct CodeSearch;
21
22impl Default for CodeSearch {
23 fn default() -> Self {
24 Self
25 }
26}
27
28impl CodeSearch {
29 pub fn new() -> Self {
30 Self
31 }
32
33 fn execute_ripgrep_search(
35 query: &str,
36 repo_path: &Path,
37 file_pattern: Option<&str>,
38 search_type: &str,
39 max_results: usize,
40 ) -> Result<Vec<SearchResult>> {
41 let mut cmd = Command::new("rg");
42
43 match search_type {
45 "function" => {
46 cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
47 cmd.args([
48 "-e",
49 &format!(r"fn\s+{query}|function\s+{query}|def\s+{query}"),
50 ]);
51 }
52 "class" => {
53 cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
54 cmd.args(["-e", &format!(r"struct\s+{query}|class\s+{query}")]);
55 }
56 "variable" => {
57 cmd.args(["-e", &format!(r"let\s+{query}|var\s+{query}|{query}\s*=")]);
58 }
59 "pattern" => {
60 cmd.args(["-e", query]);
61 }
62 _ => {
63 cmd.args(["-i", query]); }
65 }
66
67 if let Some(pattern) = file_pattern {
69 cmd.args(["-g", pattern]);
70 }
71
72 cmd.args(["-n", "--color", "never", "-A", "3", "-B", "1"]);
74 cmd.current_dir(repo_path);
75
76 let output = cmd.output()?;
77 let stdout = String::from_utf8_lossy(&output.stdout);
78
79 let mut results = Vec::new();
80 let mut current_file = String::new();
81 let mut line_number = 0;
82 let mut content_lines = Vec::new();
83
84 for line in stdout.lines().take(max_results * 4) {
85 if line.contains(':') && !line.starts_with('-') {
87 let parts: Vec<&str> = line.splitn(3, ':').collect();
89 if parts.len() >= 3 {
90 let file_path = parts[0].to_string();
91 if let Ok(line_num) = parts[1].parse::<usize>() {
92 let content = parts[2].to_string();
93
94 if file_path != current_file && !current_file.is_empty() {
95 results.push(SearchResult {
97 file_path: current_file.clone(),
98 line_number,
99 content: content_lines.join("\n"),
100 match_type: search_type.to_string(),
101 context_lines: content_lines.len(),
102 });
103 content_lines.clear();
104 }
105
106 current_file = file_path;
107 line_number = line_num;
108 content_lines.push(content);
109
110 if results.len() >= max_results {
111 break;
112 }
113 }
114 }
115 } else if !line.starts_with('-') && !current_file.is_empty() {
116 content_lines.push(line.to_string());
118 }
119 }
120
121 if !current_file.is_empty() && results.len() < max_results {
123 results.push(SearchResult {
124 file_path: current_file,
125 line_number,
126 content: content_lines.join("\n"),
127 match_type: search_type.to_string(),
128 context_lines: content_lines.len(),
129 });
130 }
131
132 Ok(results)
133 }
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137pub struct SearchResult {
138 pub file_path: String,
139 pub line_number: usize,
140 pub content: String,
141 pub match_type: String,
142 pub context_lines: usize,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
147#[serde(rename_all = "lowercase")]
148pub enum SearchType {
149 Function,
151 Class,
153 Variable,
155 #[default]
157 Text,
158 Pattern,
160}
161
162impl SearchType {
163 fn as_str(&self) -> &'static str {
164 match self {
165 SearchType::Function => "function",
166 SearchType::Class => "class",
167 SearchType::Variable => "variable",
168 SearchType::Text => "text",
169 SearchType::Pattern => "pattern",
170 }
171 }
172}
173
174#[derive(Debug, Deserialize, Serialize, schemars::JsonSchema)]
175pub struct CodeSearchArgs {
176 pub query: String,
178 #[serde(default)]
180 pub search_type: SearchType,
181 #[serde(default)]
183 pub file_pattern: Option<String>,
184 #[serde(default = "default_max_results")]
186 pub max_results: usize,
187}
188
189fn default_max_results() -> usize {
190 20
191}
192
193impl Tool for CodeSearch {
194 const NAME: &'static str = "code_search";
195 type Error = CodeSearchError;
196 type Args = CodeSearchArgs;
197 type Output = String;
198
199 async fn definition(&self, _: String) -> ToolDefinition {
200 ToolDefinition {
201 name: "code_search".to_string(),
202 description: "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.".to_string(),
203 parameters: parameters_schema::<CodeSearchArgs>(),
204 }
205 }
206
207 async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
208 let current_dir = std::env::current_dir().map_err(CodeSearchError::from)?;
209 let max_results = args.max_results.min(100); let results = Self::execute_ripgrep_search(
212 &args.query,
213 ¤t_dir,
214 args.file_pattern.as_deref(),
215 args.search_type.as_str(),
216 max_results,
217 )
218 .map_err(CodeSearchError::from)?;
219
220 let result = serde_json::json!({
221 "query": args.query,
222 "search_type": args.search_type.as_str(),
223 "results": results,
224 "total_found": results.len(),
225 "max_results": max_results,
226 });
227
228 serde_json::to_string_pretty(&result).map_err(|e| CodeSearchError(e.to_string()))
229 }
230}