Skip to main content

git_iris/agents/tools/
code_search.rs

1//! Code search tool
2//!
3//! This tool provides Iris with the ability to search for code patterns,
4//! functions, classes, and related files in the repository.
5
6use anyhow::Result;
7use regex::escape as regex_escape;
8use rig::completion::ToolDefinition;
9use rig::tool::Tool;
10use serde::{Deserialize, Serialize};
11use std::path::Path;
12use std::process::Command;
13
14use super::common::{get_current_repo, parameters_schema};
15use crate::define_tool_error;
16
17define_tool_error!(CodeSearchError);
18
19/// Code search tool for finding related files and functions
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct CodeSearch;
22
23impl Default for CodeSearch {
24    fn default() -> Self {
25        Self
26    }
27}
28
29impl CodeSearch {
30    #[must_use]
31    pub fn new() -> Self {
32        Self
33    }
34
35    /// Execute a ripgrep search for patterns
36    fn execute_ripgrep_search(
37        query: &str,
38        repo_path: &Path,
39        file_pattern: Option<&str>,
40        search_type: &str,
41        max_results: usize,
42    ) -> Result<Vec<SearchResult>> {
43        let mut cmd = Command::new("rg");
44
45        // Configure ripgrep based on search type
46        // For structured types (function/class/variable), escape the query to prevent
47        // regex injection — the surrounding pattern provides the regex structure.
48        // For "pattern" type, the user explicitly wants regex, so we add a timeout instead.
49        match search_type {
50            "function" => {
51                let escaped = regex_escape(query);
52                cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
53                cmd.args([
54                    "-e",
55                    &format!(r"fn\s+{escaped}|function\s+{escaped}|def\s+{escaped}"),
56                ]);
57            }
58            "class" => {
59                let escaped = regex_escape(query);
60                cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
61                cmd.args(["-e", &format!(r"struct\s+{escaped}|class\s+{escaped}")]);
62            }
63            "variable" => {
64                let escaped = regex_escape(query);
65                cmd.args([
66                    "-e",
67                    &format!(r"let\s+{escaped}|var\s+{escaped}|{escaped}\s*="),
68                ]);
69            }
70            "pattern" => {
71                // User-supplied regex — enforce a timeout to prevent ReDoS
72                cmd.args(["--regex-size-limit", "1M", "--dfa-size-limit", "1M"]);
73                cmd.args(["-e", query]);
74            }
75            _ => {
76                // Fixed-string literal search — -F disables regex interpretation
77                cmd.args(["-F", "-i", query]);
78            }
79        }
80
81        // Add file pattern if specified (reject path traversal)
82        if let Some(pattern) = file_pattern {
83            if pattern.contains("..") {
84                return Err(anyhow::anyhow!(
85                    "File pattern must not contain '..' path traversal"
86                ));
87            }
88            cmd.args(["-g", pattern]);
89        }
90
91        // Limit results and add context
92        cmd.args(["-n", "--color", "never", "-A", "3", "-B", "1"]);
93        cmd.current_dir(repo_path);
94
95        let output = cmd.output()?;
96        let stdout = String::from_utf8_lossy(&output.stdout);
97
98        let mut results = Vec::new();
99        let mut current_file = String::new();
100        let mut line_number = 0;
101        let mut content_lines = Vec::new();
102
103        for line in stdout.lines().take(max_results * 4) {
104            // rough estimate with context
105            if line.contains(':') && !line.starts_with('-') {
106                // Parse file:line:content format
107                let parts: Vec<&str> = line.splitn(3, ':').collect();
108                if parts.len() >= 3 {
109                    let file_path = parts[0].to_string();
110                    if let Ok(line_num) = parts[1].parse::<usize>() {
111                        let content = parts[2].to_string();
112
113                        if file_path != current_file && !current_file.is_empty() {
114                            // Finalize previous result
115                            results.push(SearchResult {
116                                file_path: current_file.clone(),
117                                line_number,
118                                content: content_lines.join("\n"),
119                                match_type: search_type.to_string(),
120                                context_lines: content_lines.len(),
121                            });
122                            content_lines.clear();
123                        }
124
125                        current_file = file_path;
126                        line_number = line_num;
127                        content_lines.push(content);
128
129                        if results.len() >= max_results {
130                            break;
131                        }
132                    }
133                }
134            } else if !line.starts_with('-') && !current_file.is_empty() {
135                // Context line
136                content_lines.push(line.to_string());
137            }
138        }
139
140        // Add final result
141        if !current_file.is_empty() && results.len() < max_results {
142            results.push(SearchResult {
143                file_path: current_file,
144                line_number,
145                content: content_lines.join("\n"),
146                match_type: search_type.to_string(),
147                context_lines: content_lines.len(),
148            });
149        }
150
151        Ok(results)
152    }
153}
154
155#[derive(Debug, Serialize, Deserialize)]
156pub struct SearchResult {
157    pub file_path: String,
158    pub line_number: usize,
159    pub content: String,
160    pub match_type: String,
161    pub context_lines: usize,
162}
163
164/// Search type for code search
165#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
166#[serde(rename_all = "lowercase")]
167pub enum SearchType {
168    /// Search for function definitions
169    Function,
170    /// Search for class/struct definitions
171    Class,
172    /// Search for variable assignments
173    Variable,
174    /// General text search (case-insensitive)
175    #[default]
176    Text,
177    /// Regex pattern search
178    Pattern,
179}
180
181impl SearchType {
182    fn as_str(&self) -> &'static str {
183        match self {
184            SearchType::Function => "function",
185            SearchType::Class => "class",
186            SearchType::Variable => "variable",
187            SearchType::Text => "text",
188            SearchType::Pattern => "pattern",
189        }
190    }
191}
192
193#[derive(Debug, Deserialize, Serialize, schemars::JsonSchema)]
194pub struct CodeSearchArgs {
195    /// Search query - function name, class name, variable, or pattern
196    pub query: String,
197    /// Type of search to perform
198    #[serde(default)]
199    pub search_type: SearchType,
200    /// Optional file glob pattern to limit scope (e.g., "*.rs", "*.js")
201    #[serde(default)]
202    pub file_pattern: Option<String>,
203    /// Maximum results to return (default: 20, max: 100)
204    #[serde(default = "default_max_results")]
205    pub max_results: usize,
206}
207
208fn default_max_results() -> usize {
209    20
210}
211
212impl Tool for CodeSearch {
213    const NAME: &'static str = "code_search";
214    type Error = CodeSearchError;
215    type Args = CodeSearchArgs;
216    type Output = String;
217
218    async fn definition(&self, _: String) -> ToolDefinition {
219        ToolDefinition {
220            name: "code_search".to_string(),
221            description: "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.".to_string(),
222            parameters: parameters_schema::<CodeSearchArgs>(),
223        }
224    }
225
226    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
227        let repo = get_current_repo().map_err(CodeSearchError::from)?;
228        let repo_path = repo.repo_path().clone();
229        let max_results = args.max_results.min(100); // Cap at 100
230
231        let results = Self::execute_ripgrep_search(
232            &args.query,
233            &repo_path,
234            args.file_pattern.as_deref(),
235            args.search_type.as_str(),
236            max_results,
237        )
238        .map_err(CodeSearchError::from)?;
239
240        let result = serde_json::json!({
241            "query": args.query,
242            "search_type": args.search_type.as_str(),
243            "results": results,
244            "total_found": results.len(),
245            "max_results": max_results,
246        });
247
248        serde_json::to_string_pretty(&result).map_err(|e| CodeSearchError(e.to_string()))
249    }
250}