Skip to main content

codetether_agent/tool/
codesearch.rs

1//! Code Search Tool - Search code in the workspace using ripgrep-style patterns.
2
3use anyhow::{Context, Result};
4use async_trait::async_trait;
5use serde::Deserialize;
6use serde_json::{json, Value};
7use std::path::PathBuf;
8use walkdir::WalkDir;
9use super::{Tool, ToolResult};
10
11const MAX_RESULTS: usize = 50;
12const MAX_CONTEXT_LINES: usize = 3;
13
14pub struct CodeSearchTool {
15    root: PathBuf,
16}
17
18impl Default for CodeSearchTool {
19    fn default() -> Self { Self::new() }
20}
21
22#[allow(dead_code)]
23impl CodeSearchTool {
24    pub fn new() -> Self {
25        Self { root: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")) }
26    }
27
28    pub fn with_root(root: PathBuf) -> Self {
29        Self { root }
30    }
31
32    fn should_skip(&self, path: &std::path::Path) -> bool {
33        let skip_dirs = [".git", "node_modules", "target", "dist", ".next", "__pycache__", ".venv", "vendor"];
34        path.components().any(|c| skip_dirs.contains(&c.as_os_str().to_str().unwrap_or("")))
35    }
36
37    fn is_text_file(&self, path: &std::path::Path) -> bool {
38        let text_exts = ["rs", "ts", "js", "tsx", "jsx", "py", "go", "java", "c", "cpp", "h", "hpp", 
39                         "md", "txt", "json", "yaml", "yml", "toml", "sh", "bash", "zsh", "html", "css", "scss"];
40        path.extension().and_then(|e| e.to_str()).map(|e| text_exts.contains(&e)).unwrap_or(false)
41    }
42
43    fn search_file(&self, path: &std::path::Path, pattern: &regex::Regex, context: usize) -> Result<Vec<Match>> {
44        let content = std::fs::read_to_string(path)?;
45        let lines: Vec<&str> = content.lines().collect();
46        let mut matches = Vec::new();
47
48        for (idx, line) in lines.iter().enumerate() {
49            if pattern.is_match(line) {
50                let start = idx.saturating_sub(context);
51                let end = (idx + context + 1).min(lines.len());
52                let context_lines: Vec<String> = lines[start..end].iter().enumerate().map(|(i, l)| {
53                    let line_num = start + i + 1;
54                    let marker = if start + i == idx { ">" } else { " " };
55                    format!("{} {:4}: {}", marker, line_num, l)
56                }).collect();
57                
58                matches.push(Match {
59                    path: path.strip_prefix(&self.root).unwrap_or(path).to_string_lossy().to_string(),
60                    line: idx + 1,
61                    content: line.to_string(),
62                    context: context_lines.join("\n"),
63                });
64            }
65        }
66        Ok(matches)
67    }
68}
69
70#[derive(Debug)]
71struct Match {
72    path: String,
73    line: usize,
74    #[allow(dead_code)]
75    content: String,
76    context: String,
77}
78
79#[derive(Deserialize)]
80struct Params {
81    pattern: String,
82    #[serde(default)]
83    path: Option<String>,
84    #[serde(default)]
85    file_pattern: Option<String>,
86    #[serde(default = "default_context")]
87    context_lines: usize,
88    #[serde(default)]
89    case_sensitive: bool,
90}
91
92fn default_context() -> usize { 2 }
93
94#[async_trait]
95impl Tool for CodeSearchTool {
96    fn id(&self) -> &str { "codesearch" }
97    fn name(&self) -> &str { "Code Search" }
98    fn description(&self) -> &str { "Search for code patterns in the workspace. Supports regex." }
99    fn parameters(&self) -> Value {
100        json!({
101            "type": "object",
102            "properties": {
103                "pattern": {"type": "string", "description": "Search pattern (regex supported)"},
104                "path": {"type": "string", "description": "Subdirectory to search in"},
105                "file_pattern": {"type": "string", "description": "Glob pattern for files (e.g., *.rs)"},
106                "context_lines": {"type": "integer", "default": 2, "description": "Lines of context"},
107                "case_sensitive": {"type": "boolean", "default": false}
108            },
109            "required": ["pattern"]
110        })
111    }
112
113    async fn execute(&self, params: Value) -> Result<ToolResult> {
114        let p: Params = serde_json::from_value(params).context("Invalid params")?;
115        
116        let regex = regex::RegexBuilder::new(&p.pattern)
117            .case_insensitive(!p.case_sensitive)
118            .build()
119            .context("Invalid regex pattern")?;
120        
121        let search_root = match &p.path {
122            Some(subpath) => self.root.join(subpath),
123            None => self.root.clone(),
124        };
125        
126        let file_glob = p.file_pattern.as_ref().and_then(|pat| glob::Pattern::new(pat).ok());
127        
128        let mut all_matches = Vec::new();
129        
130        for entry in WalkDir::new(&search_root).into_iter().filter_map(|e| e.ok()) {
131            let path = entry.path();
132            if !path.is_file() || self.should_skip(path) || !self.is_text_file(path) {
133                continue;
134            }
135            
136            if let Some(ref glob) = file_glob {
137                if !glob.matches_path(path) {
138                    continue;
139                }
140            }
141            
142            if let Ok(matches) = self.search_file(path, &regex, p.context_lines.min(MAX_CONTEXT_LINES)) {
143                all_matches.extend(matches);
144                if all_matches.len() >= MAX_RESULTS {
145                    break;
146                }
147            }
148        }
149        
150        if all_matches.is_empty() {
151            return Ok(ToolResult::success(format!("No matches found for pattern: {}", p.pattern)));
152        }
153        
154        let output = all_matches.iter().take(MAX_RESULTS).map(|m| {
155            format!("{}:{}\n{}", m.path, m.line, m.context)
156        }).collect::<Vec<_>>().join("\n\n");
157        
158        let truncated = all_matches.len() > MAX_RESULTS;
159        let msg = if truncated {
160            format!("Found {} matches (showing first {}):\n\n{}", all_matches.len(), MAX_RESULTS, output)
161        } else {
162            format!("Found {} matches:\n\n{}", all_matches.len(), output)
163        };
164        
165        Ok(ToolResult::success(msg).with_metadata("match_count", json!(all_matches.len())))
166    }
167}