agcodex_core/code_tools/
ast_grep.rs

1//! AST-Grep integration for structural code search.
2//! Provides pattern-based AST search using ast-grep-core.
3//! Note: Tree-sitter is the primary structural engine; AST-Grep is offered as
4//! internal optional tooling.
5
6use super::CodeTool;
7use super::ToolError;
8use ast_grep_language::SupportLang;
9use dashmap::DashMap;
10use std::path::Path;
11use std::path::PathBuf;
12use std::sync::Arc;
13
14/// AST-Grep search tool using command-line interface
15#[derive(Debug, Clone)]
16pub struct AstGrep {
17    /// Cache for search results to avoid recomputation
18    result_cache: Arc<DashMap<String, Vec<AgMatch>>>,
19}
20
21/// Query types for AST-Grep
22#[derive(Debug, Clone)]
23pub enum AgQuery {
24    /// Simple pattern search (e.g., "console.log($_)")
25    Pattern {
26        language: Option<String>,
27        pattern: String,
28        paths: Vec<PathBuf>,
29    },
30    /// YAML rule-based search for complex patterns
31    Rule {
32        yaml_rule: String,
33        paths: Vec<PathBuf>,
34    },
35}
36
37/// Match result from AST-Grep search
38#[derive(Debug, Clone)]
39pub struct AgMatch {
40    pub file: PathBuf,
41    pub line: u32,
42    pub column: u32,
43    pub end_line: u32,
44    pub end_column: u32,
45    pub matched_text: String,
46    pub context_before: Vec<String>,
47    pub context_after: Vec<String>,
48    /// Captured metavariables (e.g., $_ -> captured value)
49    pub metavariables: std::collections::HashMap<String, String>,
50}
51
52impl Default for AstGrep {
53    fn default() -> Self {
54        Self::new()
55    }
56}
57
58impl AstGrep {
59    /// Create a new AST-Grep instance with empty caches
60    pub fn new() -> Self {
61        Self {
62            result_cache: Arc::new(DashMap::new()),
63        }
64    }
65
66    /// Detect ast-grep language from file extension or language string
67    fn detect_language(
68        &self,
69        lang_hint: Option<&str>,
70        file_path: &Path,
71    ) -> Result<SupportLang, ToolError> {
72        // If language hint is provided, try to use it
73        if let Some(lang) = lang_hint {
74            return self.parse_language_string(lang);
75        }
76
77        // Otherwise detect from file extension
78        let extension = file_path
79            .extension()
80            .and_then(|e| e.to_str())
81            .ok_or_else(|| {
82                ToolError::InvalidQuery("Cannot detect language from file path".to_string())
83            })?;
84
85        match extension {
86            "rs" => Ok(SupportLang::Rust),
87            "py" | "pyi" => Ok(SupportLang::Python),
88            "js" | "mjs" | "cjs" => Ok(SupportLang::JavaScript),
89            "ts" | "mts" | "cts" | "tsx" | "jsx" => Ok(SupportLang::TypeScript),
90            "go" => Ok(SupportLang::Go),
91            "java" => Ok(SupportLang::Java),
92            "c" | "h" => Ok(SupportLang::C),
93            "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "c++" => Ok(SupportLang::Cpp),
94            "cs" => Ok(SupportLang::CSharp),
95            "sh" | "bash" | "zsh" => Ok(SupportLang::Bash),
96            "rb" => Ok(SupportLang::Ruby),
97            "php" => Ok(SupportLang::Php),
98            "lua" => Ok(SupportLang::Lua),
99            "hs" | "lhs" => Ok(SupportLang::Haskell),
100            "ex" | "exs" => Ok(SupportLang::Elixir),
101            "scala" | "sc" => Ok(SupportLang::Scala),
102            "swift" => Ok(SupportLang::Swift),
103            "kt" | "kts" => Ok(SupportLang::Kotlin),
104            "html" | "htm" => Ok(SupportLang::Html),
105            "css" | "scss" | "sass" => Ok(SupportLang::Css),
106            "json" => Ok(SupportLang::Json),
107            _ => Err(ToolError::UnsupportedLanguage(extension.to_string())),
108        }
109    }
110
111    /// Parse language string to SupportLang
112    fn parse_language_string(&self, lang: &str) -> Result<SupportLang, ToolError> {
113        match lang.to_lowercase().as_str() {
114            "rust" | "rs" => Ok(SupportLang::Rust),
115            "python" | "py" => Ok(SupportLang::Python),
116            "javascript" | "js" => Ok(SupportLang::JavaScript),
117            "typescript" | "ts" => Ok(SupportLang::TypeScript),
118            "go" | "golang" => Ok(SupportLang::Go),
119            "java" => Ok(SupportLang::Java),
120            "c" => Ok(SupportLang::C),
121            "cpp" | "c++" | "cxx" => Ok(SupportLang::Cpp),
122            "csharp" | "c#" | "cs" => Ok(SupportLang::CSharp),
123            "bash" | "sh" => Ok(SupportLang::Bash),
124            "ruby" | "rb" => Ok(SupportLang::Ruby),
125            "php" => Ok(SupportLang::Php),
126            "lua" => Ok(SupportLang::Lua),
127            "haskell" | "hs" => Ok(SupportLang::Haskell),
128            "elixir" | "ex" => Ok(SupportLang::Elixir),
129            "scala" => Ok(SupportLang::Scala),
130            "swift" => Ok(SupportLang::Swift),
131            "kotlin" | "kt" => Ok(SupportLang::Kotlin),
132            "html" => Ok(SupportLang::Html),
133            "css" => Ok(SupportLang::Css),
134            "json" => Ok(SupportLang::Json),
135            _ => Err(ToolError::UnsupportedLanguage(lang.to_string())),
136        }
137    }
138
139    /// Convert SupportLang to string for CLI usage
140    const fn language_to_string(&self, lang: SupportLang) -> &'static str {
141        match lang {
142            SupportLang::Rust => "rust",
143            SupportLang::Python => "python",
144            SupportLang::JavaScript => "javascript",
145            SupportLang::TypeScript => "typescript",
146            SupportLang::Go => "go",
147            SupportLang::Java => "java",
148            SupportLang::C => "c",
149            SupportLang::Cpp => "cpp",
150            SupportLang::CSharp => "csharp",
151            SupportLang::Bash => "bash",
152            SupportLang::Ruby => "ruby",
153            SupportLang::Php => "php",
154            SupportLang::Lua => "lua",
155            SupportLang::Haskell => "haskell",
156            SupportLang::Elixir => "elixir",
157            SupportLang::Scala => "scala",
158            SupportLang::Swift => "swift",
159            SupportLang::Kotlin => "kotlin",
160            SupportLang::Html => "html",
161            SupportLang::Css => "css",
162            SupportLang::Json => "json",
163            _ => "text", // fallback
164        }
165    }
166
167    /// Search files using ast-grep command line tool
168    fn search_with_pattern(
169        &self,
170        pattern: &str,
171        language: SupportLang,
172        paths: &[PathBuf],
173    ) -> Result<Vec<AgMatch>, ToolError> {
174        if paths.is_empty() {
175            return Ok(Vec::new());
176        }
177
178        // For now, implement a simple pattern-based search
179        // In a full implementation, we would use ast-grep CLI or library directly
180        let mut matches = Vec::new();
181        let _lang_str = self.language_to_string(language);
182
183        for path in paths {
184            if !path.exists() || !path.is_file() {
185                continue;
186            }
187
188            // Read file content
189            let content = std::fs::read_to_string(path).map_err(ToolError::Io)?;
190
191            // Simple pattern matching for demonstration
192            // In a real implementation, this would use ast-grep-core properly
193            if self.simple_pattern_match(&content, pattern) {
194                let lines: Vec<&str> = content.lines().collect();
195
196                // Find the matching line(s)
197                for (line_idx, line) in lines.iter().enumerate() {
198                    if line.contains(&pattern.replace("$_", "")) {
199                        // Simple pattern matching
200                        let line_num = (line_idx + 1) as u32;
201
202                        // Extract context
203                        let context_before = if line_idx > 0 {
204                            lines[line_idx.saturating_sub(3)..line_idx]
205                                .iter()
206                                .map(|s| (*s).to_string())
207                                .collect()
208                        } else {
209                            Vec::new()
210                        };
211
212                        let context_after = if line_idx < lines.len() - 1 {
213                            lines[line_idx + 1..std::cmp::min(line_idx + 4, lines.len())]
214                                .iter()
215                                .map(|s| (*s).to_string())
216                                .collect()
217                        } else {
218                            Vec::new()
219                        };
220
221                        matches.push(AgMatch {
222                            file: path.clone(),
223                            line: line_num,
224                            column: 1, // Simplified column detection
225                            end_line: line_num,
226                            end_column: line.len() as u32,
227                            matched_text: (*line).to_string(),
228                            context_before,
229                            context_after,
230                            metavariables: std::collections::HashMap::new(),
231                        });
232                    }
233                }
234            }
235        }
236
237        Ok(matches)
238    }
239
240    /// Simple pattern matching (placeholder for real AST-based matching)
241    fn simple_pattern_match(&self, content: &str, pattern: &str) -> bool {
242        // This is a very simplified pattern matcher
243        // In a real implementation, this would use proper AST parsing
244        let simplified_pattern = pattern.replace("$_", "").replace("($_)", "()");
245        content.contains(&simplified_pattern)
246    }
247
248    /// Search using YAML rule (placeholder implementation)
249    const fn search_with_rule(
250        &self,
251        _yaml_rule: &str,
252        _paths: &[PathBuf],
253    ) -> Result<Vec<AgMatch>, ToolError> {
254        // YAML rule support would require parsing YAML and converting to ast-grep rules
255        // This is a complex feature that would need proper implementation
256        Err(ToolError::NotImplemented(
257            "YAML rule support - use simple patterns instead",
258        ))
259    }
260
261    /// Get cache statistics
262    pub fn cache_stats(&self) -> usize {
263        self.result_cache.len()
264    }
265
266    /// Clear all caches
267    pub fn clear_cache(&self) {
268        self.result_cache.clear();
269    }
270}
271
272impl CodeTool for AstGrep {
273    type Query = AgQuery;
274    type Output = Vec<AgMatch>;
275
276    fn search(&self, query: Self::Query) -> Result<Self::Output, ToolError> {
277        match query {
278            AgQuery::Pattern {
279                language,
280                pattern,
281                paths,
282            } => {
283                // Generate cache key
284                let cache_key = format!(
285                    "{}:{}:{}",
286                    language.as_deref().unwrap_or("auto"),
287                    pattern,
288                    paths.len()
289                );
290
291                // Check cache first
292                if let Some(cached) = self.result_cache.get(&cache_key) {
293                    return Ok(cached.clone());
294                }
295
296                // Determine language from first file if not specified
297                let lang = if let Some(lang_str) = language.as_deref() {
298                    self.parse_language_string(lang_str)?
299                } else if let Some(first_path) = paths.first() {
300                    self.detect_language(None, first_path)?
301                } else {
302                    return Err(ToolError::InvalidQuery(
303                        "No language specified and no files provided".to_string(),
304                    ));
305                };
306
307                // Search files
308                let results = self.search_with_pattern(&pattern, lang, &paths)?;
309
310                // Cache results
311                self.result_cache.insert(cache_key, results.clone());
312
313                Ok(results)
314            }
315            AgQuery::Rule { yaml_rule, paths } => self.search_with_rule(&yaml_rule, &paths),
316        }
317    }
318}
319
320#[cfg(test)]
321mod tests {
322    use super::*;
323    use std::fs;
324    use tempfile::tempdir;
325
326    #[test]
327    fn test_language_detection() {
328        let ast_grep = AstGrep::new();
329
330        // Test detection from file extension
331        assert_eq!(
332            ast_grep
333                .detect_language(None, Path::new("test.rs"))
334                .unwrap(),
335            SupportLang::Rust
336        );
337        assert_eq!(
338            ast_grep
339                .detect_language(None, Path::new("test.py"))
340                .unwrap(),
341            SupportLang::Python
342        );
343        assert_eq!(
344            ast_grep
345                .detect_language(None, Path::new("test.js"))
346                .unwrap(),
347            SupportLang::JavaScript
348        );
349
350        // Test detection from language hint
351        assert_eq!(
352            ast_grep
353                .detect_language(Some("rust"), Path::new("unknown.txt"))
354                .unwrap(),
355            SupportLang::Rust
356        );
357    }
358
359    #[test]
360    fn test_simple_pattern_matching() {
361        let ast_grep = AstGrep::new();
362
363        // Test simple pattern matching
364        assert!(ast_grep.simple_pattern_match("console.log('hello')", "console.log"));
365        assert!(ast_grep.simple_pattern_match("println!(\"test\")", "println!"));
366        assert!(!ast_grep.simple_pattern_match("print('test')", "console.log"));
367    }
368
369    #[test]
370    fn test_search_pattern() {
371        let ast_grep = AstGrep::new();
372        let dir = tempdir().unwrap();
373        let file_path = dir.path().join("test.js");
374
375        // Create test JavaScript file
376        fs::write(
377            &file_path,
378            r#"
379function test() {
380    console.log("hello");
381    console.error("error");
382    alert("world");
383}
384"#,
385        )
386        .unwrap();
387
388        // Search for console.log pattern
389        let query = AgQuery::Pattern {
390            language: Some("javascript".to_string()),
391            pattern: "console.log".to_string(),
392            paths: vec![file_path.clone()],
393        };
394
395        let results = ast_grep.search(query).unwrap();
396        assert!(!results.is_empty());
397
398        let match_result = &results[0];
399        assert_eq!(match_result.file, file_path);
400        assert!(match_result.matched_text.contains("console.log"));
401    }
402
403    #[test]
404    fn test_unsupported_language() {
405        let ast_grep = AstGrep::new();
406
407        let result = ast_grep.detect_language(None, Path::new("test.unknown"));
408        assert!(result.is_err());
409
410        if let Err(ToolError::UnsupportedLanguage(_)) = result {
411            // Expected error type
412        } else {
413            panic!("Expected UnsupportedLanguage error");
414        }
415    }
416
417    #[test]
418    fn test_yaml_rule_placeholder() {
419        let ast_grep = AstGrep::new();
420        let dir = tempdir().unwrap();
421        let file_path = dir.path().join("test.rs");
422
423        fs::write(&file_path, "fn main() {}").unwrap();
424
425        let query = AgQuery::Rule {
426            yaml_rule: "id: test\nlanguage: rust\nrule:\n  pattern: 'fn $_() {}'".to_string(),
427            paths: vec![file_path],
428        };
429
430        let result = ast_grep.search(query);
431        assert!(result.is_err());
432
433        if let Err(ToolError::NotImplemented(_)) = result {
434            // Expected - YAML rules are not yet implemented
435        } else {
436            panic!("Expected NotImplemented error for YAML rules");
437        }
438    }
439
440    #[test]
441    fn test_cache_management() {
442        let ast_grep = AstGrep::new();
443        let dir = tempdir().unwrap();
444        let file_path = dir.path().join("test.rs");
445
446        fs::write(&file_path, "fn main() {}").unwrap();
447
448        // First search
449        let query1 = AgQuery::Pattern {
450            language: Some("rust".to_string()),
451            pattern: "fn".to_string(),
452            paths: vec![file_path.clone()],
453        };
454
455        let _ = ast_grep.search(query1).unwrap();
456        assert_eq!(ast_grep.cache_stats(), 1);
457
458        // Second search with same parameters should use cache
459        let query2 = AgQuery::Pattern {
460            language: Some("rust".to_string()),
461            pattern: "fn".to_string(),
462            paths: vec![file_path],
463        };
464
465        let _ = ast_grep.search(query2).unwrap();
466        assert_eq!(ast_grep.cache_stats(), 1); // Still 1 because cached
467
468        // Clear cache
469        ast_grep.clear_cache();
470        assert_eq!(ast_grep.cache_stats(), 0);
471    }
472
473    #[test]
474    fn test_language_conversion() {
475        let ast_grep = AstGrep::new();
476
477        assert_eq!(ast_grep.language_to_string(SupportLang::Rust), "rust");
478        assert_eq!(ast_grep.language_to_string(SupportLang::Python), "python");
479        assert_eq!(
480            ast_grep.language_to_string(SupportLang::JavaScript),
481            "javascript"
482        );
483    }
484}