use std::path::{Path, PathBuf};
use std::process::ExitCode;
use anyhow::{Context, Result};
use ast_grep_config::{from_yaml_string, CombinedScan, RuleCollection, RuleConfig, Severity};
use ast_grep_core::{tree_sitter::StrDoc, NodeMatch};
use ast_grep_language::SupportLang;
use clap::Args;
use ignore::WalkParallel;
use crate::config::{read_rule_file, with_rule_stats, ProjectConfig};
use crate::lang::SgLang;
use crate::print::{
CloudPrinter, ColoredPrinter, Diff, FileNamePrinter, InteractivePrinter, JSONPrinter, Platform,
PrintProcessor, Printer, ReportStyle, SimpleFile,
};
use crate::utils::RuleOverwrite;
use crate::utils::{filter_file_rule, ContextArgs, InputArgs, OutputArgs, OverwriteArgs};
use crate::utils::{ErrorContext as EC, MaxItemCounter};
use crate::utils::{FileTrace, ScanTrace};
use crate::utils::{Items, PathWorker, StdInWorker, Worker};
use std::collections::HashSet;
use std::sync::atomic::{AtomicUsize, Ordering};
#[derive(Args)]
pub struct ScanArg {
#[clap(short, long, value_name = "RULE_FILE")]
rule: Option<PathBuf>,
#[clap(long, conflicts_with = "rule", value_name = "RULE_TEXT")]
inline_rules: Option<String>,
#[clap(long, conflicts_with = "json", conflicts_with = "interactive")]
format: Option<Platform>,
#[clap(long, default_value = "rich", conflicts_with = "json")]
report_style: ReportStyle,
#[clap(long, requires = "json")]
include_metadata: bool,
#[clap(flatten)]
overwrite: OverwriteArgs,
#[clap(flatten)]
input: InputArgs,
#[clap(flatten)]
output: OutputArgs,
#[clap(flatten)]
context: ContextArgs,
#[clap(long, conflicts_with = "interactive", value_name = "NUM")]
max_results: Option<u16>,
}
impl ScanArg {
fn include_all_rules(&self) -> bool {
self.overwrite.include_all_rules() && self.rule.is_none() && self.inline_rules.is_none()
}
}
pub fn run_with_config(arg: ScanArg, project: Result<ProjectConfig>) -> Result<ExitCode> {
let project_trace = arg.output.inspect.project_trace();
project_trace.print_project(&project)?;
let context = arg.context.get();
if arg.output.files_with_matches {
let printer = FileNamePrinter::stdout(arg.output.color);
return run_scan(arg, printer, project);
}
if let Some(format) = &arg.format {
let printer = CloudPrinter::stdout(format.clone());
return run_scan(arg, printer, project);
}
if let Some(json) = arg.output.json {
let printer = JSONPrinter::stdout(json).include_metadata(arg.include_metadata);
return run_scan(arg, printer, project);
}
let printer = ColoredPrinter::stdout(arg.output.color)
.style(arg.report_style)
.context(context);
let interactive = arg.output.needs_interactive();
if interactive {
let from_stdin = arg.input.stdin;
let printer = InteractivePrinter::new(printer, arg.output.update_all, from_stdin)?;
run_scan(arg, printer, project)
} else {
run_scan(arg, printer, project)
}
}
fn run_scan<P: Printer + 'static>(
arg: ScanArg,
printer: P,
project: Result<ProjectConfig>,
) -> Result<ExitCode> {
if arg.input.stdin {
let worker = ScanStdin::try_new(arg)?;
worker.run_std_in(printer)
} else {
let worker = ScanWithConfig::try_new(arg, project)?;
worker.run_path(printer)
}
}
struct ScanWithConfig {
arg: ScanArg,
configs: RuleCollection<SgLang>,
unused_suppression_rule: RuleConfig<SgLang>,
trace: ScanTrace,
proj_dir: PathBuf,
error_count: AtomicUsize,
max_item_counter: Option<MaxItemCounter>,
}
impl ScanWithConfig {
fn try_new(arg: ScanArg, project: Result<ProjectConfig>) -> Result<Self> {
let overwrite = RuleOverwrite::new(&arg.overwrite)?;
let unused_suppression_rule = unused_suppression_rule_config(&arg, &overwrite);
let mut proj_dir = PathBuf::from(".");
let (configs, rule_trace) = if let Some(path) = &arg.rule {
let rules = read_rule_file(path, None)?;
proj_dir = path.parent().unwrap_or(Path::new(".")).to_path_buf();
with_rule_stats(rules)?
} else if let Some(text) = &arg.inline_rules {
let rules = from_yaml_string(text, &Default::default())
.with_context(|| EC::ParseRule("INLINE_RULES".into()))?;
with_rule_stats(rules)?
} else {
let project_config = project?;
proj_dir = project_config.project_dir.clone();
project_config.find_rules(overwrite)?
};
let trace = arg.output.inspect.scan_trace(rule_trace);
trace.print_rules(&configs)?;
let absolute_proj_dir = proj_dir
.canonicalize()
.or_else(|_| std::env::current_dir())?;
let max_item_counter = arg.max_results.map(MaxItemCounter::new);
Ok(Self {
arg,
configs,
unused_suppression_rule,
trace,
proj_dir: absolute_proj_dir,
error_count: AtomicUsize::new(0),
max_item_counter,
})
}
}
impl Worker for ScanWithConfig {
fn consume_items<P: Printer>(
&self,
items: Items<P::Processed>,
mut printer: P,
) -> Result<ExitCode> {
printer.before_print()?;
for item in items {
printer.process(item)?;
}
printer.after_print()?;
self.trace.print()?;
let error_count = self.error_count.load(Ordering::Acquire);
if error_count > 0 {
Err(anyhow::anyhow!(EC::DiagnosticError(error_count)))
} else {
Ok(ExitCode::SUCCESS)
}
}
}
fn default_unused_suppression_rule_severity(arg: &ScanArg) -> Severity {
if arg.include_all_rules() {
Severity::Hint
} else {
Severity::Off
}
}
fn unused_suppression_rule_config(arg: &ScanArg, overwrite: &RuleOverwrite) -> RuleConfig<SgLang> {
let severity = overwrite
.find("unused-suppression")
.severity
.unwrap_or_else(|| default_unused_suppression_rule_severity(arg));
CombinedScan::unused_config(severity, SupportLang::Rust.into())
}
impl PathWorker for ScanWithConfig {
fn get_trace(&self) -> &FileTrace {
&self.trace.inner.file_trace
}
fn build_walk(&self) -> Result<WalkParallel> {
let mut langs = HashSet::new();
self.configs.for_each_rule(|rule| {
langs.insert(rule.language);
});
self.arg.input.walk_langs(langs.into_iter())
}
fn produce_item<P: Printer>(
&self,
path: &Path,
processor: &P::Processor,
) -> Result<Vec<P::Processed>> {
let items = filter_file_rule(path, &self.configs, &self.trace)?;
let mut error_count = 0usize;
let mut ret = vec![];
for grep in items {
let file_content = grep.source();
let abs_path = path.canonicalize()?;
let normalized_path = abs_path.strip_prefix(&self.proj_dir).unwrap_or(path);
let rules = self
.configs
.get_rule_from_lang(normalized_path, *grep.lang());
let mut combined = CombinedScan::new(rules);
combined.set_unused_suppression_rule(&self.unused_suppression_rule);
let interactive = self.arg.output.needs_interactive();
let scanned = combined.scan(&grep, interactive);
if interactive {
let diffs = scanned.diffs;
let processed = match_rule_diff_on_file(path, diffs, processor)?;
ret.push(processed);
}
for (rule, matches) in scanned.matches {
let matches: Vec<_> = if let Some(counter) = &self.max_item_counter {
let wanted = matches.len();
let claimed = counter.claim(wanted);
if claimed == 0 {
break;
}
matches.into_iter().take(claimed).collect()
} else {
matches
};
if matches.is_empty() {
continue;
}
let match_count = matches.len();
if matches!(rule.severity, Severity::Error) {
error_count = error_count.saturating_add(match_count);
}
let processed = match_rule_on_file(path, matches, rule, file_content, processor)?;
ret.push(processed);
}
}
self.error_count.fetch_add(error_count, Ordering::AcqRel);
Ok(ret)
}
fn should_stop(&self) -> bool {
match &self.max_item_counter {
Some(max) => max.reached_max(),
None => false,
}
}
}
struct ScanStdin {
rules: Vec<RuleConfig<SgLang>>,
error_count: AtomicUsize,
max_diagnostics_shown: Option<usize>,
}
impl ScanStdin {
fn try_new(arg: ScanArg) -> Result<Self> {
let rules = if let Some(path) = &arg.rule {
read_rule_file(path, None)?
} else if let Some(text) = &arg.inline_rules {
from_yaml_string(text, &Default::default())
.with_context(|| EC::ParseRule("INLINE_RULES".into()))?
} else {
return Err(anyhow::anyhow!(EC::RuleNotSpecified));
};
Ok(Self {
rules,
error_count: AtomicUsize::new(0),
max_diagnostics_shown: arg.max_results.map(usize::from),
})
}
}
impl Worker for ScanStdin {
fn consume_items<P: Printer>(
&self,
items: Items<P::Processed>,
mut printer: P,
) -> Result<ExitCode> {
printer.before_print()?;
for item in items {
printer.process(item)?;
}
printer.after_print()?;
let error_count = self.error_count.load(Ordering::Acquire);
if error_count > 0 {
Err(anyhow::anyhow!(EC::DiagnosticError(error_count)))
} else {
Ok(ExitCode::SUCCESS)
}
}
}
impl StdInWorker for ScanStdin {
fn parse_stdin<P: Printer>(
&self,
src: String,
processor: &P::Processor,
) -> Result<Vec<P::Processed>> {
use ast_grep_core::tree_sitter::LanguageExt;
let lang = self.rules[0].language;
let combined = CombinedScan::new(self.rules.iter().collect());
let grep = lang.ast_grep(src);
let path = Path::new("STDIN");
let file_content = grep.source();
let scanned = combined.scan(&grep, false);
let mut error_count = 0usize;
let mut diagnostic_count = 0usize;
let mut ret = vec![];
for (rule, matches) in scanned.matches {
let matches: Vec<_> = if let Some(max) = self.max_diagnostics_shown {
let remaining = max.saturating_sub(diagnostic_count);
if remaining == 0 {
break;
}
matches.into_iter().take(remaining).collect()
} else {
matches
};
if matches.is_empty() {
continue;
}
let match_count = matches.len();
diagnostic_count += match_count;
if matches!(rule.severity, Severity::Error) {
error_count = error_count.saturating_add(match_count);
}
let processed = match_rule_on_file(path, matches, rule, file_content, processor)?;
ret.push(processed);
}
self.error_count.fetch_add(error_count, Ordering::AcqRel);
Ok(ret)
}
}
fn match_rule_diff_on_file<T>(
path: &Path,
matches: Vec<(&RuleConfig<SgLang>, NodeMatch<StrDoc<SgLang>>)>,
processor: &impl PrintProcessor<T>,
) -> Result<T> {
let diffs = matches
.into_iter()
.filter_map(|(rule, m)| {
let fixers = &rule.matcher.fixer;
let diff = Diff::multiple(m, &rule.matcher, fixers)?;
Some((diff, rule))
})
.collect();
let processed = processor.print_rule_diffs(diffs, path)?;
Ok(processed)
}
fn match_rule_on_file<T>(
path: &Path,
matches: Vec<NodeMatch<StrDoc<SgLang>>>,
rule: &RuleConfig<SgLang>,
file_content: &str,
processor: &impl PrintProcessor<T>,
) -> Result<T> {
let file = SimpleFile::new(path.to_string_lossy(), file_content);
let processed = if let Some(fixer) = &rule.matcher.fixer.first() {
let diffs = matches
.into_iter()
.map(|m| (Diff::generate(m, &rule.matcher, fixer), rule))
.collect();
processor.print_rule_diffs(diffs, path)?
} else {
processor.print_rule(matches, file, rule)?
};
Ok(processed)
}
#[cfg(test)]
mod test {
use super::*;
use crate::print::ColorArg;
use std::fs::File;
use std::io::Write;
use tempfile::TempDir;
const RULE: &str = r#"
id: test
message: Add your rule message here....
severity: error # error, warning, hint, info
language: Rust
rule:
pattern: Some(123)
"#;
pub fn create_test_files<'a>(
names_and_contents: impl IntoIterator<Item = (&'a str, &'a str)>,
) -> TempDir {
let dir = TempDir::new().unwrap();
for (name, contents) in names_and_contents {
let path = dir.path().join(name);
let mut file = File::create(path.clone()).unwrap();
file.write_all(contents.as_bytes()).unwrap();
file.sync_all().unwrap();
}
dir
}
fn default_scan_arg() -> ScanArg {
ScanArg {
rule: None,
inline_rules: None,
report_style: ReportStyle::Rich,
include_metadata: false,
input: InputArgs {
no_ignore: vec![],
paths: vec![PathBuf::from(".")],
stdin: false,
follow: false,
globs: vec![],
threads: 0,
},
overwrite: OverwriteArgs {
filter: None,
error: None,
warning: None,
info: None,
hint: None,
off: None,
},
output: OutputArgs {
interactive: false,
json: None,
files_with_matches: false,
update_all: false,
color: ColorArg::Never,
inspect: Default::default(),
},
context: ContextArgs {
before: 0,
after: 0,
context: 0,
},
format: None,
max_results: None,
}
}
#[test]
fn test_run_with_config() {
let dir = create_test_files([("sgconfig.yml", "ruleDirs: [rules]")]);
std::fs::create_dir_all(dir.path().join("rules")).unwrap();
let mut file = File::create(dir.path().join("rules/test.yml")).unwrap();
file.write_all(RULE.as_bytes()).unwrap();
let mut file = File::create(dir.path().join("test.rs")).unwrap();
file
.write_all("fn test() { Some(123) }".as_bytes())
.unwrap();
file.sync_all().unwrap();
let project_config = ProjectConfig::setup(Some(dir.path().join("sgconfig.yml"))).unwrap();
let arg = default_scan_arg();
assert!(run_with_config(arg, project_config).is_ok());
}
#[test]
fn test_scan_with_inline_rules() {
let inline_rules = "{id: test, language: ts, rule: {pattern: readFileSync}}".to_string();
let arg = ScanArg {
inline_rules: Some(inline_rules),
..default_scan_arg()
};
assert!(run_with_config(arg, Err(anyhow::anyhow!("not found"))).is_ok());
}
#[test]
fn test_scan_with_inline_rules_diff() {
let inline_rules =
"{id: test, language: ts, rule: {pattern: readFileSync}, fix: 'nnn'}".to_string();
let arg = ScanArg {
inline_rules: Some(inline_rules),
..default_scan_arg()
};
assert!(run_with_config(arg, Err(anyhow::anyhow!("not found"))).is_ok());
}
#[test]
fn test_scan_with_inline_rules_error() {
let inline_rules = "nonsense".to_string();
let arg = ScanArg {
inline_rules: Some(inline_rules),
..default_scan_arg()
};
let err = run_with_config(arg, Err(anyhow::anyhow!("not found"))).expect_err("should error");
assert!(err.is::<EC>());
assert_eq!(err.to_string(), "Cannot parse rule INLINE_RULES");
}
}