Skip to main content

codetether_agent/tool/
codesearch.rs

1//! Code Search Tool - Search code in the workspace using ripgrep-style patterns.
2
3use super::{Tool, ToolResult};
4use anyhow::{Context, Result};
5use async_trait::async_trait;
6use serde::Deserialize;
7use serde_json::{Value, json};
8use std::path::PathBuf;
9use walkdir::WalkDir;
10
11const MAX_RESULTS: usize = 50;
12const MAX_CONTEXT_LINES: usize = 3;
13
14pub struct CodeSearchTool {
15    root: PathBuf,
16}
17
18impl Default for CodeSearchTool {
19    fn default() -> Self {
20        Self::new()
21    }
22}
23
24#[allow(dead_code)]
25impl CodeSearchTool {
26    pub fn new() -> Self {
27        Self {
28            root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
29        }
30    }
31
32    pub fn with_root(root: PathBuf) -> Self {
33        Self { root }
34    }
35
36    fn should_skip(&self, path: &std::path::Path) -> bool {
37        let skip_dirs = [
38            ".git",
39            "node_modules",
40            "target",
41            "dist",
42            ".next",
43            "__pycache__",
44            ".venv",
45            "vendor",
46        ];
47        path.components()
48            .any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
49    }
50
51    fn is_text_file(&self, path: &std::path::Path) -> bool {
52        let text_exts = [
53            "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
54            "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
55        ];
56        path.extension()
57            .and_then(|e| e.to_str())
58            .map(|e| text_exts.contains(&e))
59            .unwrap_or(false)
60    }
61
62    fn search_file(
63        &self,
64        path: &std::path::Path,
65        pattern: &regex::Regex,
66        context: usize,
67    ) -> Result<Vec<Match>> {
68        let content = std::fs::read_to_string(path)?;
69        let lines: Vec<&str> = content.lines().collect();
70        let mut matches = Vec::new();
71
72        for (idx, line) in lines.iter().enumerate() {
73            if pattern.is_match(line) {
74                let start = idx.saturating_sub(context);
75                let end = (idx + context + 1).min(lines.len());
76                let context_lines: Vec<String> = lines[start..end]
77                    .iter()
78                    .enumerate()
79                    .map(|(i, l)| {
80                        let line_num = start + i + 1;
81                        let marker = if start + i == idx { ">" } else { " " };
82                        format!("{} {:4}: {}", marker, line_num, l)
83                    })
84                    .collect();
85
86                matches.push(Match {
87                    path: path
88                        .strip_prefix(&self.root)
89                        .unwrap_or(path)
90                        .to_string_lossy()
91                        .to_string(),
92                    line: idx + 1,
93                    content: line.to_string(),
94                    context: context_lines.join("\n"),
95                });
96            }
97        }
98        Ok(matches)
99    }
100}
101
102#[derive(Debug)]
103struct Match {
104    path: String,
105    line: usize,
106    #[allow(dead_code)]
107    content: String,
108    context: String,
109}
110
111#[derive(Deserialize)]
112struct Params {
113    pattern: String,
114    #[serde(default)]
115    path: Option<String>,
116    #[serde(default)]
117    file_pattern: Option<String>,
118    #[serde(default = "default_context")]
119    context_lines: usize,
120    #[serde(default)]
121    case_sensitive: bool,
122}
123
124fn default_context() -> usize {
125    2
126}
127
128#[async_trait]
129impl Tool for CodeSearchTool {
130    fn id(&self) -> &str {
131        "codesearch"
132    }
133    fn name(&self) -> &str {
134        "Code Search"
135    }
136    fn description(&self) -> &str {
137        "Search for code patterns in the workspace. Supports regex."
138    }
139    fn parameters(&self) -> Value {
140        json!({
141            "type": "object",
142            "properties": {
143                "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
144                "path": {"type": "string", "description": "Subdirectory to search in"},
145                "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
146                "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
147                "case_sensitive": {"type": "boolean", "default": false}
148            },
149            "required": ["pattern"]
150        })
151    }
152
153    async fn execute(&self, params: Value) -> Result<ToolResult> {
154        let p: Params = serde_json::from_value(params).context("Invalid params")?;
155
156        let regex = regex::RegexBuilder::new(&p.pattern)
157            .case_insensitive(!p.case_sensitive)
158            .build()
159            .context("Invalid regex pattern")?;
160
161        let search_root = match &p.path {
162            Some(subpath) => self.root.join(subpath),
163            None => self.root.clone(),
164        };
165
166        let file_glob = p
167            .file_pattern
168            .as_ref()
169            .and_then(|pat| glob::Pattern::new(pat).ok());
170
171        let mut all_matches = Vec::new();
172
173        for entry in WalkDir::new(&search_root)
174            .into_iter()
175            .filter_map(|e| e.ok())
176        {
177            let path = entry.path();
178            if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
179                continue;
180            }
181
182            if let Some(ref glob) = file_glob {
183                if !glob.matches_path(path) {
184                    continue;
185                }
186            }
187
188            if let Ok(matches) =
189                self.search_file(path, &regex, p.context_lines.min(MAX_CONTEXT_LINES))
190            {
191                all_matches.extend(matches);
192                if all_matches.len() >= MAX_RESULTS {
193                    break;
194                }
195            }
196        }
197
198        if all_matches.is_empty() {
199            return Ok(ToolResult::success(format!(
200                "No matches found for pattern: {}",
201                p.pattern
202            )));
203        }
204
205        let output = all_matches
206            .iter()
207            .take(MAX_RESULTS)
208            .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
209            .collect::<Vec<_>>()
210            .join("\n\n");
211
212        let truncated = all_matches.len() > MAX_RESULTS;
213        let msg = if truncated {
214            format!(
215                "Found {} matches (showing first {}):\n\n{}",
216                all_matches.len(),
217                MAX_RESULTS,
218                output
219            )
220        } else {
221            format!("Found {} matches:\n\n{}", all_matches.len(), output)
222        };
223
224        Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
225    }
226}