1use super::CodeTool;
7use super::ToolError;
8use ast_grep_language::SupportLang;
9use dashmap::DashMap;
10use std::path::Path;
11use std::path::PathBuf;
12use std::sync::Arc;
13
14#[derive(Debug, Clone)]
16pub struct AstGrep {
17 result_cache: Arc<DashMap<String, Vec<AgMatch>>>,
19}
20
21#[derive(Debug, Clone)]
23pub enum AgQuery {
24 Pattern {
26 language: Option<String>,
27 pattern: String,
28 paths: Vec<PathBuf>,
29 },
30 Rule {
32 yaml_rule: String,
33 paths: Vec<PathBuf>,
34 },
35}
36
37#[derive(Debug, Clone)]
39pub struct AgMatch {
40 pub file: PathBuf,
41 pub line: u32,
42 pub column: u32,
43 pub end_line: u32,
44 pub end_column: u32,
45 pub matched_text: String,
46 pub context_before: Vec<String>,
47 pub context_after: Vec<String>,
48 pub metavariables: std::collections::HashMap<String, String>,
50}
51
52impl Default for AstGrep {
53 fn default() -> Self {
54 Self::new()
55 }
56}
57
58impl AstGrep {
59 pub fn new() -> Self {
61 Self {
62 result_cache: Arc::new(DashMap::new()),
63 }
64 }
65
66 fn detect_language(
68 &self,
69 lang_hint: Option<&str>,
70 file_path: &Path,
71 ) -> Result<SupportLang, ToolError> {
72 if let Some(lang) = lang_hint {
74 return self.parse_language_string(lang);
75 }
76
77 let extension = file_path
79 .extension()
80 .and_then(|e| e.to_str())
81 .ok_or_else(|| {
82 ToolError::InvalidQuery("Cannot detect language from file path".to_string())
83 })?;
84
85 match extension {
86 "rs" => Ok(SupportLang::Rust),
87 "py" | "pyi" => Ok(SupportLang::Python),
88 "js" | "mjs" | "cjs" => Ok(SupportLang::JavaScript),
89 "ts" | "mts" | "cts" | "tsx" | "jsx" => Ok(SupportLang::TypeScript),
90 "go" => Ok(SupportLang::Go),
91 "java" => Ok(SupportLang::Java),
92 "c" | "h" => Ok(SupportLang::C),
93 "cpp" | "cc" | "cxx" | "hpp" | "hxx" | "c++" => Ok(SupportLang::Cpp),
94 "cs" => Ok(SupportLang::CSharp),
95 "sh" | "bash" | "zsh" => Ok(SupportLang::Bash),
96 "rb" => Ok(SupportLang::Ruby),
97 "php" => Ok(SupportLang::Php),
98 "lua" => Ok(SupportLang::Lua),
99 "hs" | "lhs" => Ok(SupportLang::Haskell),
100 "ex" | "exs" => Ok(SupportLang::Elixir),
101 "scala" | "sc" => Ok(SupportLang::Scala),
102 "swift" => Ok(SupportLang::Swift),
103 "kt" | "kts" => Ok(SupportLang::Kotlin),
104 "html" | "htm" => Ok(SupportLang::Html),
105 "css" | "scss" | "sass" => Ok(SupportLang::Css),
106 "json" => Ok(SupportLang::Json),
107 _ => Err(ToolError::UnsupportedLanguage(extension.to_string())),
108 }
109 }
110
111 fn parse_language_string(&self, lang: &str) -> Result<SupportLang, ToolError> {
113 match lang.to_lowercase().as_str() {
114 "rust" | "rs" => Ok(SupportLang::Rust),
115 "python" | "py" => Ok(SupportLang::Python),
116 "javascript" | "js" => Ok(SupportLang::JavaScript),
117 "typescript" | "ts" => Ok(SupportLang::TypeScript),
118 "go" | "golang" => Ok(SupportLang::Go),
119 "java" => Ok(SupportLang::Java),
120 "c" => Ok(SupportLang::C),
121 "cpp" | "c++" | "cxx" => Ok(SupportLang::Cpp),
122 "csharp" | "c#" | "cs" => Ok(SupportLang::CSharp),
123 "bash" | "sh" => Ok(SupportLang::Bash),
124 "ruby" | "rb" => Ok(SupportLang::Ruby),
125 "php" => Ok(SupportLang::Php),
126 "lua" => Ok(SupportLang::Lua),
127 "haskell" | "hs" => Ok(SupportLang::Haskell),
128 "elixir" | "ex" => Ok(SupportLang::Elixir),
129 "scala" => Ok(SupportLang::Scala),
130 "swift" => Ok(SupportLang::Swift),
131 "kotlin" | "kt" => Ok(SupportLang::Kotlin),
132 "html" => Ok(SupportLang::Html),
133 "css" => Ok(SupportLang::Css),
134 "json" => Ok(SupportLang::Json),
135 _ => Err(ToolError::UnsupportedLanguage(lang.to_string())),
136 }
137 }
138
139 const fn language_to_string(&self, lang: SupportLang) -> &'static str {
141 match lang {
142 SupportLang::Rust => "rust",
143 SupportLang::Python => "python",
144 SupportLang::JavaScript => "javascript",
145 SupportLang::TypeScript => "typescript",
146 SupportLang::Go => "go",
147 SupportLang::Java => "java",
148 SupportLang::C => "c",
149 SupportLang::Cpp => "cpp",
150 SupportLang::CSharp => "csharp",
151 SupportLang::Bash => "bash",
152 SupportLang::Ruby => "ruby",
153 SupportLang::Php => "php",
154 SupportLang::Lua => "lua",
155 SupportLang::Haskell => "haskell",
156 SupportLang::Elixir => "elixir",
157 SupportLang::Scala => "scala",
158 SupportLang::Swift => "swift",
159 SupportLang::Kotlin => "kotlin",
160 SupportLang::Html => "html",
161 SupportLang::Css => "css",
162 SupportLang::Json => "json",
163 _ => "text", }
165 }
166
167 fn search_with_pattern(
169 &self,
170 pattern: &str,
171 language: SupportLang,
172 paths: &[PathBuf],
173 ) -> Result<Vec<AgMatch>, ToolError> {
174 if paths.is_empty() {
175 return Ok(Vec::new());
176 }
177
178 let mut matches = Vec::new();
181 let _lang_str = self.language_to_string(language);
182
183 for path in paths {
184 if !path.exists() || !path.is_file() {
185 continue;
186 }
187
188 let content = std::fs::read_to_string(path).map_err(ToolError::Io)?;
190
191 if self.simple_pattern_match(&content, pattern) {
194 let lines: Vec<&str> = content.lines().collect();
195
196 for (line_idx, line) in lines.iter().enumerate() {
198 if line.contains(&pattern.replace("$_", "")) {
199 let line_num = (line_idx + 1) as u32;
201
202 let context_before = if line_idx > 0 {
204 lines[line_idx.saturating_sub(3)..line_idx]
205 .iter()
206 .map(|s| (*s).to_string())
207 .collect()
208 } else {
209 Vec::new()
210 };
211
212 let context_after = if line_idx < lines.len() - 1 {
213 lines[line_idx + 1..std::cmp::min(line_idx + 4, lines.len())]
214 .iter()
215 .map(|s| (*s).to_string())
216 .collect()
217 } else {
218 Vec::new()
219 };
220
221 matches.push(AgMatch {
222 file: path.clone(),
223 line: line_num,
224 column: 1, end_line: line_num,
226 end_column: line.len() as u32,
227 matched_text: (*line).to_string(),
228 context_before,
229 context_after,
230 metavariables: std::collections::HashMap::new(),
231 });
232 }
233 }
234 }
235 }
236
237 Ok(matches)
238 }
239
240 fn simple_pattern_match(&self, content: &str, pattern: &str) -> bool {
242 let simplified_pattern = pattern.replace("$_", "").replace("($_)", "()");
245 content.contains(&simplified_pattern)
246 }
247
248 const fn search_with_rule(
250 &self,
251 _yaml_rule: &str,
252 _paths: &[PathBuf],
253 ) -> Result<Vec<AgMatch>, ToolError> {
254 Err(ToolError::NotImplemented(
257 "YAML rule support - use simple patterns instead",
258 ))
259 }
260
261 pub fn cache_stats(&self) -> usize {
263 self.result_cache.len()
264 }
265
266 pub fn clear_cache(&self) {
268 self.result_cache.clear();
269 }
270}
271
272impl CodeTool for AstGrep {
273 type Query = AgQuery;
274 type Output = Vec<AgMatch>;
275
276 fn search(&self, query: Self::Query) -> Result<Self::Output, ToolError> {
277 match query {
278 AgQuery::Pattern {
279 language,
280 pattern,
281 paths,
282 } => {
283 let cache_key = format!(
285 "{}:{}:{}",
286 language.as_deref().unwrap_or("auto"),
287 pattern,
288 paths.len()
289 );
290
291 if let Some(cached) = self.result_cache.get(&cache_key) {
293 return Ok(cached.clone());
294 }
295
296 let lang = if let Some(lang_str) = language.as_deref() {
298 self.parse_language_string(lang_str)?
299 } else if let Some(first_path) = paths.first() {
300 self.detect_language(None, first_path)?
301 } else {
302 return Err(ToolError::InvalidQuery(
303 "No language specified and no files provided".to_string(),
304 ));
305 };
306
307 let results = self.search_with_pattern(&pattern, lang, &paths)?;
309
310 self.result_cache.insert(cache_key, results.clone());
312
313 Ok(results)
314 }
315 AgQuery::Rule { yaml_rule, paths } => self.search_with_rule(&yaml_rule, &paths),
316 }
317 }
318}
319
320#[cfg(test)]
321mod tests {
322 use super::*;
323 use std::fs;
324 use tempfile::tempdir;
325
326 #[test]
327 fn test_language_detection() {
328 let ast_grep = AstGrep::new();
329
330 assert_eq!(
332 ast_grep
333 .detect_language(None, Path::new("test.rs"))
334 .unwrap(),
335 SupportLang::Rust
336 );
337 assert_eq!(
338 ast_grep
339 .detect_language(None, Path::new("test.py"))
340 .unwrap(),
341 SupportLang::Python
342 );
343 assert_eq!(
344 ast_grep
345 .detect_language(None, Path::new("test.js"))
346 .unwrap(),
347 SupportLang::JavaScript
348 );
349
350 assert_eq!(
352 ast_grep
353 .detect_language(Some("rust"), Path::new("unknown.txt"))
354 .unwrap(),
355 SupportLang::Rust
356 );
357 }
358
359 #[test]
360 fn test_simple_pattern_matching() {
361 let ast_grep = AstGrep::new();
362
363 assert!(ast_grep.simple_pattern_match("console.log('hello')", "console.log"));
365 assert!(ast_grep.simple_pattern_match("println!(\"test\")", "println!"));
366 assert!(!ast_grep.simple_pattern_match("print('test')", "console.log"));
367 }
368
369 #[test]
370 fn test_search_pattern() {
371 let ast_grep = AstGrep::new();
372 let dir = tempdir().unwrap();
373 let file_path = dir.path().join("test.js");
374
375 fs::write(
377 &file_path,
378 r#"
379function test() {
380 console.log("hello");
381 console.error("error");
382 alert("world");
383}
384"#,
385 )
386 .unwrap();
387
388 let query = AgQuery::Pattern {
390 language: Some("javascript".to_string()),
391 pattern: "console.log".to_string(),
392 paths: vec![file_path.clone()],
393 };
394
395 let results = ast_grep.search(query).unwrap();
396 assert!(!results.is_empty());
397
398 let match_result = &results[0];
399 assert_eq!(match_result.file, file_path);
400 assert!(match_result.matched_text.contains("console.log"));
401 }
402
403 #[test]
404 fn test_unsupported_language() {
405 let ast_grep = AstGrep::new();
406
407 let result = ast_grep.detect_language(None, Path::new("test.unknown"));
408 assert!(result.is_err());
409
410 if let Err(ToolError::UnsupportedLanguage(_)) = result {
411 } else {
413 panic!("Expected UnsupportedLanguage error");
414 }
415 }
416
417 #[test]
418 fn test_yaml_rule_placeholder() {
419 let ast_grep = AstGrep::new();
420 let dir = tempdir().unwrap();
421 let file_path = dir.path().join("test.rs");
422
423 fs::write(&file_path, "fn main() {}").unwrap();
424
425 let query = AgQuery::Rule {
426 yaml_rule: "id: test\nlanguage: rust\nrule:\n pattern: 'fn $_() {}'".to_string(),
427 paths: vec![file_path],
428 };
429
430 let result = ast_grep.search(query);
431 assert!(result.is_err());
432
433 if let Err(ToolError::NotImplemented(_)) = result {
434 } else {
436 panic!("Expected NotImplemented error for YAML rules");
437 }
438 }
439
440 #[test]
441 fn test_cache_management() {
442 let ast_grep = AstGrep::new();
443 let dir = tempdir().unwrap();
444 let file_path = dir.path().join("test.rs");
445
446 fs::write(&file_path, "fn main() {}").unwrap();
447
448 let query1 = AgQuery::Pattern {
450 language: Some("rust".to_string()),
451 pattern: "fn".to_string(),
452 paths: vec![file_path.clone()],
453 };
454
455 let _ = ast_grep.search(query1).unwrap();
456 assert_eq!(ast_grep.cache_stats(), 1);
457
458 let query2 = AgQuery::Pattern {
460 language: Some("rust".to_string()),
461 pattern: "fn".to_string(),
462 paths: vec![file_path],
463 };
464
465 let _ = ast_grep.search(query2).unwrap();
466 assert_eq!(ast_grep.cache_stats(), 1); ast_grep.clear_cache();
470 assert_eq!(ast_grep.cache_stats(), 0);
471 }
472
473 #[test]
474 fn test_language_conversion() {
475 let ast_grep = AstGrep::new();
476
477 assert_eq!(ast_grep.language_to_string(SupportLang::Rust), "rust");
478 assert_eq!(ast_grep.language_to_string(SupportLang::Python), "python");
479 assert_eq!(
480 ast_grep.language_to_string(SupportLang::JavaScript),
481 "javascript"
482 );
483 }
484}