1use anyhow::{Context, Result};
2use ast_grep_core::AstGrep;
3use ast_grep_language::SupportLang;
4use colored::*;
5use ignore::Walk;
6use probe_code::path_resolver::resolve_path;
7use rayon::prelude::*; use std::fs;
9use std::path::{Path, PathBuf};
10use std::time::Instant;
11
12pub struct AstMatch {
14 pub file_path: PathBuf,
15 pub line_start: usize,
16 pub line_end: usize,
17 pub column_start: usize,
18 pub column_end: usize,
19 pub matched_text: String,
20}
21
22pub struct QueryOptions<'a> {
24 pub path: &'a Path,
25 pub pattern: &'a str,
26 pub language: Option<&'a str>,
27 pub ignore: &'a [String],
28 pub allow_tests: bool,
29 pub max_results: Option<usize>,
30 #[allow(dead_code)]
31 pub format: &'a str,
32}
33
34fn get_language(lang: &str) -> Option<SupportLang> {
36 match lang.to_lowercase().as_str() {
37 "rust" => Some(SupportLang::Rust),
38 "javascript" => Some(SupportLang::JavaScript),
39 "typescript" => Some(SupportLang::TypeScript),
40 "python" => Some(SupportLang::Python),
41 "go" => Some(SupportLang::Go),
42 "c" => Some(SupportLang::C),
43 "cpp" => Some(SupportLang::Cpp),
44 "java" => Some(SupportLang::Java),
45 "ruby" => Some(SupportLang::Ruby),
46 "php" => Some(SupportLang::Php),
47 "swift" => Some(SupportLang::Swift),
48 "csharp" => Some(SupportLang::CSharp),
49 _ => None,
50 }
51}
52
53fn get_file_extension(lang: &str) -> Vec<&str> {
55 match lang.to_lowercase().as_str() {
56 "rust" => vec![".rs"],
57 "javascript" => vec![".js", ".jsx", ".mjs"],
58 "typescript" => vec![".ts", ".tsx"],
59 "python" => vec![".py"],
60 "go" => vec![".go"],
61 "c" => vec![".c", ".h"],
62 "cpp" => vec![".cpp", ".hpp", ".cc", ".hh", ".cxx", ".hxx"],
63 "java" => vec![".java"],
64 "ruby" => vec![".rb"],
65 "php" => vec![".php"],
66 "swift" => vec![".swift"],
67 "csharp" => vec![".cs"],
68 _ => vec![],
69 }
70}
71
72fn should_ignore_file(file_path: &Path, options: &QueryOptions) -> bool {
74 let path_str = file_path.to_string_lossy();
75
76 if !options.allow_tests
78 && (path_str.contains("/test/")
79 || path_str.contains("/tests/")
80 || path_str.contains("_test.")
81 || path_str.contains("_spec.")
82 || path_str.contains(".test.")
83 || path_str.contains(".spec."))
84 {
85 return true;
86 }
87
88 for pattern in options.ignore {
90 if path_str.contains(pattern) {
91 return true;
92 }
93 }
94
95 false
96}
97
98fn query_file(file_path: &Path, options: &QueryOptions) -> Result<Vec<AstMatch>> {
100 let file_ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
102
103 if let Some(language) = options.language {
105 let extensions = get_file_extension(language);
106 let has_matching_ext = extensions
107 .iter()
108 .any(|&ext| file_path.to_string_lossy().ends_with(ext));
109
110 if !has_matching_ext {
111 return Ok(vec![]);
112 }
113 }
114
115 let content = fs::read_to_string(file_path)
117 .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
118
119 let lang = if let Some(language) = options.language {
121 match get_language(language) {
123 Some(lang) => lang,
124 None => return Ok(vec![]),
125 }
126 } else {
127 let inferred_lang = match file_ext {
129 "rs" => Some(SupportLang::Rust),
130 "js" | "jsx" | "mjs" => Some(SupportLang::JavaScript),
131 "ts" | "tsx" => Some(SupportLang::TypeScript),
132 "py" => Some(SupportLang::Python),
133 "go" => Some(SupportLang::Go),
134 "c" | "h" => Some(SupportLang::C),
135 "cpp" | "hpp" | "cc" | "hh" | "cxx" | "hxx" => Some(SupportLang::Cpp),
136 "java" => Some(SupportLang::Java),
137 "rb" => Some(SupportLang::Ruby),
138 "php" => Some(SupportLang::Php),
139 "swift" => Some(SupportLang::Swift),
140 "cs" => Some(SupportLang::CSharp),
141 _ => None, };
143
144 match inferred_lang {
145 Some(lang) => lang,
146 None => return Ok(vec![]), }
148 };
149
150 let grep = AstGrep::new(&content, lang);
152
153 let matches = match std::panic::catch_unwind(|| {
155 grep.root().find_all(options.pattern).collect::<Vec<_>>()
156 }) {
157 Ok(matches) => matches,
158 Err(_) => {
159 if options.language.is_some() {
162 eprintln!(
163 "Error parsing pattern: '{}' is not a valid ast-grep pattern",
164 options.pattern
165 );
166 }
167 return Ok(vec![]);
168 }
169 };
170
171 let mut ast_matches = Vec::new();
173 for node in matches {
174 let range = node.range();
175
176 let mut line_start = 1;
178 let mut column_start = 1;
179 let mut line_end = 1;
180 let mut column_end = 1;
181
182 let mut current_line = 1;
183 let mut current_column = 1;
184
185 for (i, c) in content.char_indices() {
186 if i == range.start {
187 line_start = current_line;
188 column_start = current_column;
189 }
190 if i == range.end {
191 line_end = current_line;
192 column_end = current_column;
193 break;
194 }
195
196 if c == '\n' {
197 current_line += 1;
198 current_column = 1;
199 } else {
200 current_column += 1;
201 }
202 }
203
204 ast_matches.push(AstMatch {
205 file_path: file_path.to_path_buf(),
206 line_start,
207 line_end,
208 column_start,
209 column_end,
210 matched_text: node.text().to_string(),
211 });
212 }
213
214 Ok(ast_matches)
215}
216
217pub fn perform_query(options: &QueryOptions) -> Result<Vec<AstMatch>> {
218 let suppress_output = options.language.is_none();
220
221 let original_hook = if suppress_output {
223 let original_hook = std::panic::take_hook();
224 std::panic::set_hook(Box::new(|_| {
225 }));
227 Some(original_hook)
228 } else {
229 None
230 };
231
232 let resolved_path = if let Some(path_str) = options.path.to_str() {
234 match resolve_path(path_str) {
235 Ok(resolved_path) => {
236 if std::env::var("DEBUG").unwrap_or_default() == "1" {
237 println!(
238 "DEBUG: Resolved path '{}' to '{}'",
239 path_str,
240 resolved_path.display()
241 );
242 }
243 resolved_path
244 }
245 Err(err) => {
246 if std::env::var("DEBUG").unwrap_or_default() == "1" {
247 println!("DEBUG: Failed to resolve path '{path_str}': {err}");
248 }
249 options.path.to_path_buf()
251 }
252 }
253 } else {
254 options.path.to_path_buf()
256 };
257
258 let file_paths: Vec<PathBuf> = Walk::new(&resolved_path)
260 .filter_map(|entry| entry.ok())
261 .filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file()))
262 .filter(|entry| !should_ignore_file(entry.path(), options))
263 .map(|entry| entry.path().to_path_buf())
264 .collect();
265
266 let all_matches: Vec<AstMatch> = file_paths
268 .par_iter()
269 .flat_map(|path| {
270 std::panic::catch_unwind(|| query_file(path, options))
271 .unwrap_or_else(|_| {
272 Ok(vec![])
274 })
275 .unwrap_or_else(|_| {
276 vec![]
278 })
279 })
280 .collect();
281
282 if let Some(hook) = original_hook {
284 std::panic::set_hook(hook);
285 }
286
287 let mut all_matches = all_matches;
289 if let Some(max) = options.max_results {
290 all_matches.truncate(max);
291 }
292
293 Ok(all_matches)
294}
295
296fn escape_xml(s: &str) -> String {
298 s.replace("&", "&")
299 .replace("<", "<")
300 .replace(">", ">")
301 .replace("\"", """)
302 .replace("'", "'")
303}
304
305pub fn format_and_print_query_results(matches: &[AstMatch], format: &str) -> Result<()> {
307 match format {
308 "color" | "terminal" => {
309 for m in matches {
310 println!(
311 "{}",
312 format!(
313 "{}:{}:{}",
314 m.file_path.display(),
315 m.line_start,
316 m.column_start
317 )
318 .cyan()
319 );
320 println!("{}", m.matched_text.trim());
321 println!();
322 }
323 }
324 "plain" => {
325 for m in matches {
326 println!(
327 "{}:{}:{}",
328 m.file_path.display(),
329 m.line_start,
330 m.column_start
331 );
332 println!("{}", m.matched_text.trim());
333 println!();
334 }
335 }
336 "markdown" => {
337 for m in matches {
338 println!(
339 "**{}:{}:{}**",
340 m.file_path.display(),
341 m.line_start,
342 m.column_start
343 );
344
345 let lang = m
347 .file_path
348 .extension()
349 .and_then(|e| e.to_str())
350 .unwrap_or("");
351
352 println!("```{lang}");
353 println!("{}", m.matched_text.trim());
354 println!("```");
355 println!();
356 }
357 }
358 "json" => {
359 use probe_code::search::search_tokens::count_tokens;
361 let total_tokens = matches
362 .iter()
363 .map(|m| count_tokens(&m.matched_text))
364 .sum::<usize>();
365
366 let json_matches_standardized: Vec<_> = matches
368 .iter()
369 .map(|m| {
370 serde_json::json!({
371 "file": m.file_path.to_string_lossy(),
372 "lines": [m.line_start, m.line_end],
373 "node_type": "match",
374 "content": m.matched_text,
375 "column_start": m.column_start,
376 "column_end": m.column_end
377 })
378 })
379 .collect();
380
381 let wrapper = serde_json::json!({
383 "results": json_matches_standardized,
384 "summary": {
385 "count": matches.len(),
386 "total_bytes": matches.iter().map(|m| m.matched_text.len()).sum::<usize>(),
387 "total_tokens": total_tokens
388 }
389 });
390
391 println!("{}", serde_json::to_string_pretty(&wrapper)?);
392 }
393 "xml" => {
394 println!("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
395 println!("<probe_results>");
396
397 for m in matches {
398 println!(" <result>");
399 println!(
400 " <file>{}</file>",
401 escape_xml(&m.file_path.to_string_lossy())
402 );
403 println!(" <lines>{}-{}</lines>", m.line_start, m.line_end);
404 println!(" <node_type>match</node_type>");
405 println!(" <column_start>{}</column_start>", m.column_start);
406 println!(" <column_end>{}</column_end>", m.column_end);
407 println!(" <code><![CDATA[{}]]></code>", m.matched_text.trim());
408 println!(" </result>");
409 }
410
411 println!(" <summary>");
413 println!(" <count>{}</count>", matches.len());
414 println!(
415 " <total_bytes>{}",
416 matches.iter().map(|m| m.matched_text.len()).sum::<usize>()
417 );
418
419 use probe_code::search::search_tokens::count_tokens;
421 println!(
422 " <total_tokens>{}",
423 matches
424 .iter()
425 .map(|m| count_tokens(&m.matched_text))
426 .sum::<usize>()
427 );
428 println!(" </summary>");
429
430 println!("</probe_results>");
431 }
432 _ => {
433 format_and_print_query_results(matches, "color")?;
435 }
436 }
437
438 Ok(())
439}
440
441pub fn handle_query(
443 pattern: &str,
444 path: &Path,
445 language: Option<&str>,
446 ignore: &[String],
447 allow_tests: bool,
448 max_results: Option<usize>,
449 format: &str,
450) -> Result<()> {
451 if format != "json" && format != "xml" {
453 println!("{} {}", "Pattern:".bold().green(), pattern);
454 println!("{} {}", "Path:".bold().green(), path.display());
455
456 if let Some(lang) = language {
458 println!("{} {}", "Language:".bold().green(), lang);
459 } else {
460 println!("{} auto-detect", "Language:".bold().green());
461 }
462
463 let mut advanced_options = Vec::<String>::new();
465 if allow_tests {
466 advanced_options.push("Including tests".to_string());
467 }
468 if let Some(max) = max_results {
469 advanced_options.push(format!("Max results: {max}"));
470 }
471
472 if !advanced_options.is_empty() {
473 println!(
474 "{} {}",
475 "Options:".bold().green(),
476 advanced_options.join(", ")
477 );
478 }
479 }
480
481 let start_time = Instant::now();
482
483 let options = QueryOptions {
484 path,
485 pattern,
486 language,
487 ignore,
488 allow_tests,
489 max_results,
490 format,
491 };
492
493 let matches = perform_query(&options)?;
494
495 let duration = start_time.elapsed();
497
498 if matches.is_empty() {
499 if format == "json" || format == "xml" {
501 format_and_print_query_results(&matches, format)?;
502 } else {
503 println!("{}", "No results found.".yellow().bold());
505 println!("Search completed in {duration:.2?}");
506 }
507 } else {
508 if format != "json" && format != "xml" {
510 println!("Found {} matches in {:.2?}", matches.len(), duration);
511 println!();
512 }
513
514 format_and_print_query_results(&matches, format)?;
515
516 if format != "json" && format != "xml" {
518 let total_bytes: usize = matches.iter().map(|m| m.matched_text.len()).sum();
520 let total_tokens: usize = matches
521 .iter()
522 .map(|m| {
523 use probe_code::search::search_tokens::count_tokens;
525 count_tokens(&m.matched_text)
526 })
527 .sum();
528
529 println!("Total bytes returned: {total_bytes}");
530 println!("Total tokens returned: {total_tokens}");
531 }
532 }
533
534 Ok(())
535}