Skip to main content

codetether_agent/tool/
codesearch.rs

1//! Code Search Tool - Search code in the workspace using ripgrep-style patterns.
2
3use super::{Tool, ToolResult};
4use crate::workspace_scan::path_has_pruned_component;
5use anyhow::{Context, Result};
6use async_trait::async_trait;
7use serde::Deserialize;
8use serde_json::{Value, json};
9use std::path::PathBuf;
10use walkdir::WalkDir;
11
12const MAX_RESULTS: usize = 50;
13const MAX_CONTEXT_LINES: usize = 3;
14
15pub struct CodeSearchTool {
16    root: PathBuf,
17}
18
19impl Default for CodeSearchTool {
20    fn default() -> Self {
21        Self::new()
22    }
23}
24
25#[allow(dead_code)]
26impl CodeSearchTool {
27    pub fn new() -> Self {
28        Self {
29            root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
30        }
31    }
32
33    pub fn with_root(root: PathBuf) -> Self {
34        Self { root }
35    }
36
37    fn should_skip(&self, path: &std::path::Path) -> bool {
38        path_has_pruned_component(path)
39    }
40
41    fn is_text_file(&self, path: &std::path::Path) -> bool {
42        let text_exts = [
43            "rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", "md",
44            "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss",
45        ];
46        path.extension()
47            .and_then(|e| e.to_str())
48            .map(|e| text_exts.contains(&e))
49            .unwrap_or(false)
50    }
51
52    fn search_file(
53        &self,
54        path: &std::path::Path,
55        pattern: &regex::Regex,
56        context: usize,
57    ) -> Result<Vec<Match>> {
58        let content = std::fs::read_to_string(path)?;
59        let lines: Vec<&str> = content.lines().collect();
60        let mut matches = Vec::new();
61
62        for (idx, line) in lines.iter().enumerate() {
63            if pattern.is_match(line) {
64                let start = idx.saturating_sub(context);
65                let end = (idx + context + 1).min(lines.len());
66                let context_lines: Vec<String> = lines[start..end]
67                    .iter()
68                    .enumerate()
69                    .map(|(i, l)| {
70                        let line_num = start + i + 1;
71                        let marker = if start + i == idx { ">" } else { " " };
72                        format!("{} {:4}: {}", marker, line_num, l)
73                    })
74                    .collect();
75
76                matches.push(Match {
77                    path: path
78                        .strip_prefix(&self.root)
79                        .unwrap_or(path)
80                        .to_string_lossy()
81                        .to_string(),
82                    line: idx + 1,
83                    matched_line: line.to_string(),
84                    context: context_lines.join("\n"),
85                });
86            }
87        }
88        Ok(matches)
89    }
90}
91
92#[derive(Debug)]
93struct Match {
94    path: String,
95    line: usize,
96    matched_line: String,
97    context: String,
98}
99
100impl std::fmt::Display for Match {
101    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102        write!(f, "{}:{}: {}", self.path, self.line, self.matched_line)
103    }
104}
105
106#[derive(Deserialize)]
107struct Params {
108    pattern: String,
109    #[serde(default)]
110    path: Option<String>,
111    #[serde(default)]
112    file_pattern: Option<String>,
113    #[serde(default = "default_context")]
114    context_lines: usize,
115    #[serde(default)]
116    case_sensitive: bool,
117}
118
119fn default_context() -> usize {
120    2
121}
122
123#[async_trait]
124impl Tool for CodeSearchTool {
125    fn id(&self) -> &str {
126        "codesearch"
127    }
128    fn name(&self) -> &str {
129        "Code Search"
130    }
131    fn description(&self) -> &str {
132        "Search for code patterns in the workspace. Supports regex."
133    }
134    fn parameters(&self) -> Value {
135        json!({
136            "type": "object",
137            "properties": {
138                "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
139                "path": {"type": "string", "description": "Subdirectory to search in"},
140                "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
141                "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
142                "case_sensitive": {"type": "boolean", "default": false}
143            },
144            "required": ["pattern"]
145        })
146    }
147
148    async fn execute(&self, params: Value) -> Result<ToolResult> {
149        let p: Params = serde_json::from_value(params).context("Invalid params")?;
150
151        let regex = regex::RegexBuilder::new(&p.pattern)
152            .case_insensitive(!p.case_sensitive)
153            .build()
154            .context("Invalid regex pattern")?;
155
156        let search_root = match &p.path {
157            Some(subpath) => self.root.join(subpath),
158            None => self.root.clone(),
159        };
160
161        let file_glob = p
162            .file_pattern
163            .as_ref()
164            .and_then(|pat| glob::Pattern::new(pat).ok());
165
166        let mut all_matches = Vec::new();
167
168        for entry in WalkDir::new(&search_root)
169            .into_iter()
170            .filter_entry(|entry| !self.should_skip(entry.path()))
171            .filter_map(|e| e.ok())
172        {
173            let path = entry.path();
174            if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
175                continue;
176            }
177
178            if let Some(ref glob) = file_glob
179                && !glob.matches_path(path)
180            {
181                continue;
182            }
183
184            if let Ok(matches) =
185                self.search_file(path, &regex, p.context_lines.min(MAX_CONTEXT_LINES))
186            {
187                all_matches.extend(matches);
188                if all_matches.len() >= MAX_RESULTS {
189                    break;
190                }
191            }
192        }
193
194        if all_matches.is_empty() {
195            return Ok(ToolResult::success(format!(
196                "No matches found for pattern: {}",
197                p.pattern
198            )));
199        }
200
201        let output = all_matches
202            .iter()
203            .take(MAX_RESULTS)
204            .map(|m| format!("{}:{}\n{}", m.path, m.line, m.context))
205            .collect::<Vec<_>>()
206            .join("\n\n");
207
208        let truncated = all_matches.len() > MAX_RESULTS;
209        let msg = if truncated {
210            format!(
211                "Found {} matches (showing first {}):\n\n{}",
212                all_matches.len(),
213                MAX_RESULTS,
214                output
215            )
216        } else {
217            format!("Found {} matches:\n\n{}", all_matches.len(), output)
218        };
219
220        Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
221    }
222}