probe_code/
query.rs

1use anyhow::{Context, Result};
2use ast_grep_core::AstGrep;
3use ast_grep_language::SupportLang;
4use colored::*;
5use ignore::Walk;
6use probe_code::path_resolver::resolve_path;
7use rayon::prelude::*; // Added import
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::time::Instant;
11
12/// Represents a match found by ast-grep
13pub struct AstMatch {
14    pub file_path: PathBuf,
15    pub line_start: usize,
16    pub line_end: usize,
17    pub column_start: usize,
18    pub column_end: usize,
19    pub matched_text: String,
20}
21
22/// Options for the ast-grep query
23pub struct QueryOptions<'a> {
24    pub path: &'a Path,
25    pub pattern: &'a str,
26    pub language: Option<&'a str>,
27    pub ignore: &'a [String],
28    pub allow_tests: bool,
29    pub max_results: Option<usize>,
30    #[allow(dead_code)]
31    pub format: &'a str,
32}
33
34/// Convert a language string to the corresponding SupportLang
35fn get_language(lang: &str) -> Option<SupportLang> {
36    match lang.to_lowercase().as_str() {
37        "rust" => Some(SupportLang::Rust),
38        "javascript" => Some(SupportLang::JavaScript),
39        "typescript" => Some(SupportLang::TypeScript),
40        "python" => Some(SupportLang::Python),
41        "go" => Some(SupportLang::Go),
42        "c" => Some(SupportLang::C),
43        "cpp" => Some(SupportLang::Cpp),
44        "java" => Some(SupportLang::Java),
45        "ruby" => Some(SupportLang::Ruby),
46        "php" => Some(SupportLang::Php),
47        "swift" => Some(SupportLang::Swift),
48        "csharp" => Some(SupportLang::CSharp),
49        _ => None,
50    }
51}
52
53/// Get the file extension for a language
54fn get_file_extension(lang: &str) -> Vec<&str> {
55    match lang.to_lowercase().as_str() {
56        "rust" => vec![".rs"],
57        "javascript" => vec![".js", ".jsx", ".mjs"],
58        "typescript" => vec![".ts", ".tsx"],
59        "python" => vec![".py"],
60        "go" => vec![".go"],
61        "c" => vec![".c", ".h"],
62        "cpp" => vec![".cpp", ".hpp", ".cc", ".hh", ".cxx", ".hxx"],
63        "java" => vec![".java"],
64        "ruby" => vec![".rb"],
65        "php" => vec![".php"],
66        "swift" => vec![".swift"],
67        "csharp" => vec![".cs"],
68        _ => vec![],
69    }
70}
71
72/// Check if a file should be ignored based on its path
73fn should_ignore_file(file_path: &Path, options: &QueryOptions) -> bool {
74    let path_str = file_path.to_string_lossy();
75
76    // Skip test files if allow_tests is false
77    if !options.allow_tests
78        && (path_str.contains("/test/")
79            || path_str.contains("/tests/")
80            || path_str.contains("_test.")
81            || path_str.contains("_spec.")
82            || path_str.contains(".test.")
83            || path_str.contains(".spec."))
84    {
85        return true;
86    }
87
88    // Skip files that match custom ignore patterns
89    for pattern in options.ignore {
90        if path_str.contains(pattern) {
91            return true;
92        }
93    }
94
95    false
96}
97
98/// Perform an ast-grep query on a single file
99fn query_file(file_path: &Path, options: &QueryOptions) -> Result<Vec<AstMatch>> {
100    // Get the file extension
101    let file_ext = file_path.extension().and_then(|e| e.to_str()).unwrap_or("");
102
103    // If language is provided, check if the file has the correct extension
104    if let Some(language) = options.language {
105        let extensions = get_file_extension(language);
106        let has_matching_ext = extensions
107            .iter()
108            .any(|&ext| file_path.to_string_lossy().ends_with(ext));
109
110        if !has_matching_ext {
111            return Ok(vec![]);
112        }
113    }
114
115    // Read the file content
116    let content = fs::read_to_string(file_path)
117        .with_context(|| format!("Failed to read file: {}", file_path.display()))?;
118
119    // Get the language for ast-grep
120    let lang = if let Some(language) = options.language {
121        // If language is specified, use it
122        match get_language(language) {
123            Some(lang) => lang,
124            None => return Ok(vec![]),
125        }
126    } else {
127        // If language is not specified, try to infer from file extension
128        let inferred_lang = match file_ext {
129            "rs" => Some(SupportLang::Rust),
130            "js" | "jsx" | "mjs" => Some(SupportLang::JavaScript),
131            "ts" | "tsx" => Some(SupportLang::TypeScript),
132            "py" => Some(SupportLang::Python),
133            "go" => Some(SupportLang::Go),
134            "c" | "h" => Some(SupportLang::C),
135            "cpp" | "hpp" | "cc" | "hh" | "cxx" | "hxx" => Some(SupportLang::Cpp),
136            "java" => Some(SupportLang::Java),
137            "rb" => Some(SupportLang::Ruby),
138            "php" => Some(SupportLang::Php),
139            "swift" => Some(SupportLang::Swift),
140            "cs" => Some(SupportLang::CSharp),
141            _ => None, // Unsupported extension
142        };
143
144        match inferred_lang {
145            Some(lang) => lang,
146            None => return Ok(vec![]), // Skip files with unsupported extensions
147        }
148    };
149
150    // Create the document and grep instance
151    let grep = AstGrep::new(&content, lang);
152
153    // Create the pattern and find all matches
154    let matches = match std::panic::catch_unwind(|| {
155        grep.root().find_all(options.pattern).collect::<Vec<_>>()
156    }) {
157        Ok(matches) => matches,
158        Err(_) => {
159            // Only print error if language is explicitly specified
160            // This suppresses errors during auto-detection
161            if options.language.is_some() {
162                eprintln!(
163                    "Error parsing pattern: '{}' is not a valid ast-grep pattern",
164                    options.pattern
165                );
166            }
167            return Ok(vec![]);
168        }
169    };
170
171    // Convert matches to AstMatch structs
172    let mut ast_matches = Vec::new();
173    for node in matches {
174        let range = node.range();
175
176        // Convert byte offsets to line and column numbers
177        let mut line_start = 1;
178        let mut column_start = 1;
179        let mut line_end = 1;
180        let mut column_end = 1;
181
182        let mut current_line = 1;
183        let mut current_column = 1;
184
185        for (i, c) in content.char_indices() {
186            if i == range.start {
187                line_start = current_line;
188                column_start = current_column;
189            }
190            if i == range.end {
191                line_end = current_line;
192                column_end = current_column;
193                break;
194            }
195
196            if c == '\n' {
197                current_line += 1;
198                current_column = 1;
199            } else {
200                current_column += 1;
201            }
202        }
203
204        ast_matches.push(AstMatch {
205            file_path: file_path.to_path_buf(),
206            line_start,
207            line_end,
208            column_start,
209            column_end,
210            matched_text: node.text().to_string(),
211        });
212    }
213
214    Ok(ast_matches)
215}
216
217pub fn perform_query(options: &QueryOptions) -> Result<Vec<AstMatch>> {
218    // Suppress panic output if language is not specified
219    let suppress_output = options.language.is_none();
220
221    // Set a custom panic hook to suppress panic messages if needed
222    let original_hook = if suppress_output {
223        let original_hook = std::panic::take_hook();
224        std::panic::set_hook(Box::new(|_| {
225            // Do nothing, effectively suppressing the panic message
226        }));
227        Some(original_hook)
228    } else {
229        None
230    };
231
232    // Resolve the path if it's a special format (e.g., "go:github.com/user/repo")
233    let resolved_path = if let Some(path_str) = options.path.to_str() {
234        match resolve_path(path_str) {
235            Ok(resolved_path) => {
236                if std::env::var("DEBUG").unwrap_or_default() == "1" {
237                    println!(
238                        "DEBUG: Resolved path '{}' to '{}'",
239                        path_str,
240                        resolved_path.display()
241                    );
242                }
243                resolved_path
244            }
245            Err(err) => {
246                if std::env::var("DEBUG").unwrap_or_default() == "1" {
247                    println!("DEBUG: Failed to resolve path '{path_str}': {err}");
248                }
249                // Fall back to the original path
250                options.path.to_path_buf()
251            }
252        }
253    } else {
254        // If we can't convert the path to a string, use it as is
255        options.path.to_path_buf()
256    };
257
258    // Collect file paths
259    let file_paths: Vec<PathBuf> = Walk::new(&resolved_path)
260        .filter_map(|entry| entry.ok())
261        .filter(|entry| entry.file_type().is_some_and(|ft| ft.is_file()))
262        .filter(|entry| !should_ignore_file(entry.path(), options))
263        .map(|entry| entry.path().to_path_buf())
264        .collect();
265
266    // Process files in parallel
267    let all_matches: Vec<AstMatch> = file_paths
268        .par_iter()
269        .flat_map(|path| {
270            std::panic::catch_unwind(|| query_file(path, options))
271                .unwrap_or_else(|_| {
272                    // Panic was caught, return empty results
273                    Ok(vec![])
274                })
275                .unwrap_or_else(|_| {
276                    // Error was caught, return empty results
277                    vec![]
278                })
279        })
280        .collect();
281
282    // Restore the original panic hook if we changed it
283    if let Some(hook) = original_hook {
284        std::panic::set_hook(hook);
285    }
286
287    // Apply max_results limit
288    let mut all_matches = all_matches;
289    if let Some(max) = options.max_results {
290        all_matches.truncate(max);
291    }
292
293    Ok(all_matches)
294}
295
296/// Helper function to escape XML special characters
297fn escape_xml(s: &str) -> String {
298    s.replace("&", "&amp;")
299        .replace("<", "&lt;")
300        .replace(">", "&gt;")
301        .replace("\"", "&quot;")
302        .replace("'", "&apos;")
303}
304
305/// Format and print the query results
306pub fn format_and_print_query_results(matches: &[AstMatch], format: &str) -> Result<()> {
307    match format {
308        "color" | "terminal" => {
309            for m in matches {
310                println!(
311                    "{}",
312                    format!(
313                        "{}:{}:{}",
314                        m.file_path.display(),
315                        m.line_start,
316                        m.column_start
317                    )
318                    .cyan()
319                );
320                println!("{}", m.matched_text.trim());
321                println!();
322            }
323        }
324        "plain" => {
325            for m in matches {
326                println!(
327                    "{}:{}:{}",
328                    m.file_path.display(),
329                    m.line_start,
330                    m.column_start
331                );
332                println!("{}", m.matched_text.trim());
333                println!();
334            }
335        }
336        "markdown" => {
337            for m in matches {
338                println!(
339                    "**{}:{}:{}**",
340                    m.file_path.display(),
341                    m.line_start,
342                    m.column_start
343                );
344
345                // Determine language for code block
346                let lang = m
347                    .file_path
348                    .extension()
349                    .and_then(|e| e.to_str())
350                    .unwrap_or("");
351
352                println!("```{lang}");
353                println!("{}", m.matched_text.trim());
354                println!("```");
355                println!();
356            }
357        }
358        "json" => {
359            // Import the count_tokens function locally
360            use probe_code::search::search_tokens::count_tokens;
361            let total_tokens = matches
362                .iter()
363                .map(|m| count_tokens(&m.matched_text))
364                .sum::<usize>();
365
366            // Create standardized results
367            let json_matches_standardized: Vec<_> = matches
368                .iter()
369                .map(|m| {
370                    serde_json::json!({
371                        "file": m.file_path.to_string_lossy(),
372                        "lines": [m.line_start, m.line_end],
373                        "node_type": "match",
374                        "content": m.matched_text,
375                        "column_start": m.column_start,
376                        "column_end": m.column_end
377                    })
378                })
379                .collect();
380
381            // Create the wrapper object
382            let wrapper = serde_json::json!({
383                "results": json_matches_standardized,
384                "summary": {
385                    "count": matches.len(),
386                    "total_bytes": matches.iter().map(|m| m.matched_text.len()).sum::<usize>(),
387                    "total_tokens": total_tokens
388                }
389            });
390
391            println!("{}", serde_json::to_string_pretty(&wrapper)?);
392        }
393        "xml" => {
394            println!("<?xml version=\"1.0\" encoding=\"UTF-8\"?>");
395            println!("<probe_results>");
396
397            for m in matches {
398                println!("  <result>");
399                println!(
400                    "    <file>{}</file>",
401                    escape_xml(&m.file_path.to_string_lossy())
402                );
403                println!("    <lines>{}-{}</lines>", m.line_start, m.line_end);
404                println!("    <node_type>match</node_type>");
405                println!("    <column_start>{}</column_start>", m.column_start);
406                println!("    <column_end>{}</column_end>", m.column_end);
407                println!("    <code><![CDATA[{}]]></code>", m.matched_text.trim());
408                println!("  </result>");
409            }
410
411            // Add summary section
412            println!("  <summary>");
413            println!("    <count>{}</count>", matches.len());
414            println!(
415                "    <total_bytes>{}",
416                matches.iter().map(|m| m.matched_text.len()).sum::<usize>()
417            );
418
419            // Import the count_tokens function locally to avoid unused import warning
420            use probe_code::search::search_tokens::count_tokens;
421            println!(
422                "    <total_tokens>{}",
423                matches
424                    .iter()
425                    .map(|m| count_tokens(&m.matched_text))
426                    .sum::<usize>()
427            );
428            println!("  </summary>");
429
430            println!("</probe_results>");
431        }
432        _ => {
433            // Default to color format
434            format_and_print_query_results(matches, "color")?;
435        }
436    }
437
438    Ok(())
439}
440
441/// Handle the query command
442pub fn handle_query(
443    pattern: &str,
444    path: &Path,
445    language: Option<&str>,
446    ignore: &[String],
447    allow_tests: bool,
448    max_results: Option<usize>,
449    format: &str,
450) -> Result<()> {
451    // Only print information for non-JSON/XML formats
452    if format != "json" && format != "xml" {
453        println!("{} {}", "Pattern:".bold().green(), pattern);
454        println!("{} {}", "Path:".bold().green(), path.display());
455
456        // Print language if provided, otherwise show auto-detect
457        if let Some(lang) = language {
458            println!("{} {}", "Language:".bold().green(), lang);
459        } else {
460            println!("{} auto-detect", "Language:".bold().green());
461        }
462
463        // Show advanced options if they differ from defaults
464        let mut advanced_options = Vec::<String>::new();
465        if allow_tests {
466            advanced_options.push("Including tests".to_string());
467        }
468        if let Some(max) = max_results {
469            advanced_options.push(format!("Max results: {max}"));
470        }
471
472        if !advanced_options.is_empty() {
473            println!(
474                "{} {}",
475                "Options:".bold().green(),
476                advanced_options.join(", ")
477            );
478        }
479    }
480
481    let start_time = Instant::now();
482
483    let options = QueryOptions {
484        path,
485        pattern,
486        language,
487        ignore,
488        allow_tests,
489        max_results,
490        format,
491    };
492
493    let matches = perform_query(&options)?;
494
495    // Calculate search time
496    let duration = start_time.elapsed();
497
498    if matches.is_empty() {
499        // For JSON and XML formats, still call format_and_print_query_results
500        if format == "json" || format == "xml" {
501            format_and_print_query_results(&matches, format)?;
502        } else {
503            // For other formats, print the "No results found" message
504            println!("{}", "No results found.".yellow().bold());
505            println!("Search completed in {duration:.2?}");
506        }
507    } else {
508        // For non-JSON/XML formats, print search time
509        if format != "json" && format != "xml" {
510            println!("Found {} matches in {:.2?}", matches.len(), duration);
511            println!();
512        }
513
514        format_and_print_query_results(&matches, format)?;
515
516        // Skip summary for JSON and XML formats
517        if format != "json" && format != "xml" {
518            // Calculate and display total bytes and tokens
519            let total_bytes: usize = matches.iter().map(|m| m.matched_text.len()).sum();
520            let total_tokens: usize = matches
521                .iter()
522                .map(|m| {
523                    // Import the count_tokens function locally to avoid unused import warning
524                    use probe_code::search::search_tokens::count_tokens;
525                    count_tokens(&m.matched_text)
526                })
527                .sum();
528
529            println!("Total bytes returned: {total_bytes}");
530            println!("Total tokens returned: {total_tokens}");
531        }
532    }
533
534    Ok(())
535}