git_iris/agents/tools/
code_search.rs

1//! Code search tool
2//!
3//! This tool provides Iris with the ability to search for code patterns,
4//! functions, classes, and related files in the repository.
5
6use anyhow::Result;
7use rig::completion::ToolDefinition;
8use rig::tool::Tool;
9use serde::{Deserialize, Serialize};
10use std::path::Path;
11use std::process::Command;
12
13use super::common::parameters_schema;
14use crate::define_tool_error;
15
16define_tool_error!(CodeSearchError);
17
18/// Code search tool for finding related files and functions
19#[derive(Debug, Clone, Serialize, Deserialize)]
20pub struct CodeSearch;
21
22impl Default for CodeSearch {
23    fn default() -> Self {
24        Self
25    }
26}
27
28impl CodeSearch {
29    pub fn new() -> Self {
30        Self
31    }
32
33    /// Execute a ripgrep search for patterns
34    fn execute_ripgrep_search(
35        query: &str,
36        repo_path: &Path,
37        file_pattern: Option<&str>,
38        search_type: &str,
39        max_results: usize,
40    ) -> Result<Vec<SearchResult>> {
41        let mut cmd = Command::new("rg");
42
43        // Configure ripgrep based on search type
44        match search_type {
45            "function" => {
46                cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
47                cmd.args([
48                    "-e",
49                    &format!(r"fn\s+{query}|function\s+{query}|def\s+{query}"),
50                ]);
51            }
52            "class" => {
53                cmd.args(["--type", "rust", "--type", "javascript", "--type", "python"]);
54                cmd.args(["-e", &format!(r"struct\s+{query}|class\s+{query}")]);
55            }
56            "variable" => {
57                cmd.args(["-e", &format!(r"let\s+{query}|var\s+{query}|{query}\s*=")]);
58            }
59            "pattern" => {
60                cmd.args(["-e", query]);
61            }
62            _ => {
63                cmd.args(["-i", query]); // case-insensitive text search
64            }
65        }
66
67        // Add file pattern if specified
68        if let Some(pattern) = file_pattern {
69            cmd.args(["-g", pattern]);
70        }
71
72        // Limit results and add context
73        cmd.args(["-n", "--color", "never", "-A", "3", "-B", "1"]);
74        cmd.current_dir(repo_path);
75
76        let output = cmd.output()?;
77        let stdout = String::from_utf8_lossy(&output.stdout);
78
79        let mut results = Vec::new();
80        let mut current_file = String::new();
81        let mut line_number = 0;
82        let mut content_lines = Vec::new();
83
84        for line in stdout.lines().take(max_results * 4) {
85            // rough estimate with context
86            if line.contains(':') && !line.starts_with('-') {
87                // Parse file:line:content format
88                let parts: Vec<&str> = line.splitn(3, ':').collect();
89                if parts.len() >= 3 {
90                    let file_path = parts[0].to_string();
91                    if let Ok(line_num) = parts[1].parse::<usize>() {
92                        let content = parts[2].to_string();
93
94                        if file_path != current_file && !current_file.is_empty() {
95                            // Finalize previous result
96                            results.push(SearchResult {
97                                file_path: current_file.clone(),
98                                line_number,
99                                content: content_lines.join("\n"),
100                                match_type: search_type.to_string(),
101                                context_lines: content_lines.len(),
102                            });
103                            content_lines.clear();
104                        }
105
106                        current_file = file_path;
107                        line_number = line_num;
108                        content_lines.push(content);
109
110                        if results.len() >= max_results {
111                            break;
112                        }
113                    }
114                }
115            } else if !line.starts_with('-') && !current_file.is_empty() {
116                // Context line
117                content_lines.push(line.to_string());
118            }
119        }
120
121        // Add final result
122        if !current_file.is_empty() && results.len() < max_results {
123            results.push(SearchResult {
124                file_path: current_file,
125                line_number,
126                content: content_lines.join("\n"),
127                match_type: search_type.to_string(),
128                context_lines: content_lines.len(),
129            });
130        }
131
132        Ok(results)
133    }
134}
135
136#[derive(Debug, Serialize, Deserialize)]
137pub struct SearchResult {
138    pub file_path: String,
139    pub line_number: usize,
140    pub content: String,
141    pub match_type: String,
142    pub context_lines: usize,
143}
144
145/// Search type for code search
146#[derive(Debug, Clone, Serialize, Deserialize, schemars::JsonSchema, Default)]
147#[serde(rename_all = "lowercase")]
148pub enum SearchType {
149    /// Search for function definitions
150    Function,
151    /// Search for class/struct definitions
152    Class,
153    /// Search for variable assignments
154    Variable,
155    /// General text search (case-insensitive)
156    #[default]
157    Text,
158    /// Regex pattern search
159    Pattern,
160}
161
162impl SearchType {
163    fn as_str(&self) -> &'static str {
164        match self {
165            SearchType::Function => "function",
166            SearchType::Class => "class",
167            SearchType::Variable => "variable",
168            SearchType::Text => "text",
169            SearchType::Pattern => "pattern",
170        }
171    }
172}
173
174#[derive(Debug, Deserialize, Serialize, schemars::JsonSchema)]
175pub struct CodeSearchArgs {
176    /// Search query - function name, class name, variable, or pattern
177    pub query: String,
178    /// Type of search to perform
179    #[serde(default)]
180    pub search_type: SearchType,
181    /// Optional file glob pattern to limit scope (e.g., "*.rs", "*.js")
182    #[serde(default)]
183    pub file_pattern: Option<String>,
184    /// Maximum results to return (default: 20, max: 100)
185    #[serde(default = "default_max_results")]
186    pub max_results: usize,
187}
188
189fn default_max_results() -> usize {
190    20
191}
192
193impl Tool for CodeSearch {
194    const NAME: &'static str = "code_search";
195    type Error = CodeSearchError;
196    type Args = CodeSearchArgs;
197    type Output = String;
198
199    async fn definition(&self, _: String) -> ToolDefinition {
200        ToolDefinition {
201            name: "code_search".to_string(),
202            description: "Search for code patterns, functions, classes, and related files in the repository using ripgrep. Supports multiple search types and file filtering.".to_string(),
203            parameters: parameters_schema::<CodeSearchArgs>(),
204        }
205    }
206
207    async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
208        let current_dir = std::env::current_dir().map_err(CodeSearchError::from)?;
209        let max_results = args.max_results.min(100); // Cap at 100
210
211        let results = Self::execute_ripgrep_search(
212            &args.query,
213            &current_dir,
214            args.file_pattern.as_deref(),
215            args.search_type.as_str(),
216            max_results,
217        )
218        .map_err(CodeSearchError::from)?;
219
220        let result = serde_json::json!({
221            "query": args.query,
222            "search_type": args.search_type.as_str(),
223            "results": results,
224            "total_found": results.len(),
225            "max_results": max_results,
226        });
227
228        serde_json::to_string_pretty(&result).map_err(|e| CodeSearchError(e.to_string()))
229    }
230}