mod completions;
mod config;
mod lang;
mod lsp;
mod new;
mod print;
mod run;
mod scan;
mod utils;
mod verify;
use anyhow::Result;
use clap::{Parser, Subcommand};
use std::{path::PathBuf, process::ExitCode};
use completions::{run_shell_completion, CompletionsArg};
use config::ProjectConfig;
use lsp::{run_language_server, LspArg};
use new::{run_create_new, NewArg};
use run::{run_with_pattern, RunArg};
use scan::{run_with_config, ScanArg};
use utils::exit_with_error;
use verify::{run_test_rule, TestArg};
const LOGO: &str = r#"
Search and Rewrite code at large scale using AST pattern.
__
____ ______/ /_ ____ _________ ____
/ __ `/ ___/ __/_____/ __ `/ ___/ _ \/ __ \
/ /_/ (__ ) /_/_____/ /_/ / / / __/ /_/ /
\__,_/____/\__/ \__, /_/ \___/ .___/
/____/ /_/
"#;
#[derive(Parser)]
#[clap(author, version, about, long_about = LOGO)]
struct App {
#[clap(subcommand)]
command: Commands,
#[clap(short, long, global = true, value_name = "CONFIG_FILE")]
config: Option<PathBuf>,
}
#[derive(Subcommand)]
enum Commands {
Run(RunArg),
Scan(ScanArg),
Test(TestArg),
New(NewArg),
Lsp(LspArg),
Completions(CompletionsArg),
#[cfg(debug_assertions)]
Docs,
}
pub fn execute_main() -> Result<ExitCode> {
match main_with_args(std::env::args()) {
Err(error) => exit_with_error(error),
ok => ok,
}
}
fn is_command(arg: &str, command: &str) -> bool {
let arg = arg.split('=').next().unwrap_or(arg);
if arg.starts_with("--") {
let arg = arg.trim_start_matches("--");
arg == command
} else if arg.starts_with('-') {
let arg = arg.trim_start_matches('-');
arg == &command[..1]
} else {
false
}
}
fn try_default_run(args: &[String]) -> Result<Option<RunArg>> {
let should_use_default_run_command =
args.iter().skip(1).any(|p| is_command(p, "pattern")) && args[1].starts_with('-');
if should_use_default_run_command {
let arg = RunArg::try_parse_from(args)?;
Ok(Some(arg))
} else {
Ok(None)
}
}
fn setup_project_is_possible(args: &[String]) -> Result<Result<ProjectConfig>> {
let mut config = None;
for i in 0..args.len() {
let arg = &args[i];
if !is_command(arg, "config") {
continue;
}
if arg.contains('=') {
let config_file = arg.split('=').nth(1).unwrap().into();
config = Some(config_file);
break;
}
if i + 1 >= args.len() || args[i + 1].starts_with('-') {
return Err(anyhow::anyhow!("missing config file after -c"));
}
let config_file = (&args[i + 1]).into();
config = Some(config_file);
}
ProjectConfig::setup(config)
}
pub fn main_with_args(args: impl Iterator<Item = String>) -> Result<ExitCode> {
let args: Vec<_> = args.collect();
let project = setup_project_is_possible(&args);
if let Some(arg) = try_default_run(&args)? {
return run_with_pattern(arg, project?);
}
let app = App::try_parse_from(args)?;
let project = project?; match app.command {
Commands::Run(arg) => run_with_pattern(arg, project),
Commands::Scan(arg) => run_with_config(arg, project),
Commands::Test(arg) => run_test_rule(arg, project),
Commands::New(arg) => run_create_new(arg, project),
Commands::Lsp(arg) => run_language_server(arg, project).map(|_| ExitCode::SUCCESS),
Commands::Completions(arg) => run_shell_completion::<App>(arg),
#[cfg(debug_assertions)]
Commands::Docs => todo!("todo, generate rule docs based on current config"),
}
}
#[cfg(test)]
mod test_cli {
use super::*;
fn sg(args: &str) -> Result<App> {
let app = App::try_parse_from(
std::iter::once("sg".into()).chain(args.split(' ').map(|s| s.to_string())),
)?;
Ok(app)
}
fn ok(args: &str) -> App {
sg(args).expect("should parse")
}
fn error(args: &str) -> clap::Error {
let Err(err) = sg(args) else {
panic!("app parsing should fail!")
};
err
.downcast::<clap::Error>()
.expect("should have clap::Error")
}
#[test]
fn test_wrong_usage() {
error("");
error("Some($A) -l rs");
error("-l rs");
}
#[test]
fn test_version_and_help() {
let version = error("--version");
assert!(version.to_string().starts_with("ast-grep"));
let version = error("-V");
assert!(version.to_string().starts_with("ast-grep"));
let help = error("--help");
assert!(help.to_string().contains("Search and Rewrite code"));
}
fn default_run(args: &str) {
let args: Vec<_> = std::iter::once("sg".into())
.chain(args.split(' ').map(|s| s.to_string()))
.collect();
assert!(matches!(try_default_run(&args), Ok(Some(_))));
}
#[test]
fn test_no_arg_run() {
let ret = main_with_args(["sg".to_owned()].into_iter());
let err = ret.unwrap_err();
assert!(err.to_string().contains("sg [OPTIONS] <COMMAND>"));
}
#[test]
fn test_default_subcommand() {
default_run("-p Some($A) -l rs");
default_run("-p Some($A)");
default_run("-p Some($A) -l rs -r $A.unwrap()");
}
#[test]
fn test_run() {
ok("run -p test -i");
ok("run -p test --interactive dir");
ok("run -p test -r Test dir");
ok("run -p test -l rs --debug-query");
ok("run -p test -l rs --debug-query not");
ok("run -p test -l rs --debug-query=ast");
ok("run -p test -l rs --debug-query=cst");
ok("run -p test -l rs --color always");
ok("run -p test -l rs --heading always");
ok("run -p test dir1 dir2 dir3"); ok("run -p testm -r restm -U"); ok("run -p testm -r restm --update-all"); ok("run -p test --json compact"); ok("run -p test --json=pretty dir");
ok("run -p test --json dir"); ok("run -p test --strictness ast");
ok("run -p test --strictness relaxed");
ok("run -p test --selector identifier"); ok("run -p test --selector identifier -l js");
ok("run -p test --follow");
ok("run -p test --globs '*.js'");
ok("run -p test --globs '*.{js, ts}'");
ok("run -p test --globs '*.js' --globs '*.ts'");
ok("run -p fubuki -j8");
ok("run -p test --threads 12");
ok("run -p test --files-with-matches");
ok("run -p test -l rs -c config.yml"); error("run test");
error("run --debug-query test"); error("run -r Test dir");
error("run -p test -i --json dir"); error("run -p test -U");
error("run -p test --update-all");
error("run -p test --strictness not");
error("run -p test -l rs --debug-query=not");
error("run -p test --selector");
error("run -p test --threads");
error("run -p test --files-with-matches -r test -U");
error("run -p test --files-with-matches --json");
}
#[test]
fn test_scan() {
ok("scan");
ok("scan dir");
ok("scan -r test-rule.yml dir");
ok("scan -c test-rule.yml dir");
ok("scan -c test-rule.yml");
ok("scan --report-style short"); ok("scan dir1 dir2 dir3"); ok("scan -r test.yml --format github");
ok("scan --format github");
ok("scan --interactive");
ok("scan --follow");
ok("scan --json --include-metadata");
ok("scan -r test.yml -c test.yml --json dir"); ok("scan --globs '*.js'");
ok("scan --globs '*.{js, ts}'");
ok("scan --globs '*.js' --globs '*.ts'");
ok("scan -j 12");
ok("scan --threads 12");
ok("scan -A 12");
ok("scan --after 12");
ok("scan --context 1");
ok("scan --max-results=10");
ok("scan --max-results 5");
error("scan --max-results=10 --interactive"); error("scan -i --json dir"); error("scan --report-style rich --json dir"); error("scan -r test.yml --inline-rules '{}'"); error("scan --format gitlab");
error("scan --format github -i");
error("scan --format local");
error("scan --json=dir"); error("scan --json= not-pretty"); error("scan -j");
error("scan --include-metadata"); error("scan --threads");
error("scan --files-with-matches -U");
error("scan --files-with-matches --json");
}
#[test]
fn test_test() {
ok("test");
ok("test -c sgconfig.yml");
ok("test --skip-snapshot-tests");
ok("test -U");
ok("test --update-all");
error("test --update-all --skip-snapshot-tests");
}
#[test]
fn test_new() {
ok("new");
ok("new project");
ok("new -c sgconfig.yml rule");
ok("new rule -y");
ok("new test -y");
ok("new util -y");
ok("new rule -c sgconfig.yml");
error("new --base-dir");
}
#[test]
fn test_shell() {
ok("completions");
ok("completions zsh");
ok("completions fish");
error("completions not-shell");
error("completions --shell fish");
}
}