mod case_result;
mod find_file;
mod reporter;
mod snapshot;
mod test_case;
use crate::config::ProjectConfig;
use crate::lang::SgLang;
use crate::print::ColorArg;
use crate::utils::{ErrorContext, RuleOverwrite};
use crate::verify::reporter::TestReportStyle;
use anyhow::{anyhow, Result};
use ast_grep_config::RuleCollection;
use clap::Args;
use regex::Regex;
use serde_yaml::to_string;
use std::collections::HashMap;
use std::io::Write;
use std::path::PathBuf;
use std::process::ExitCode;
use std::sync::{Arc, Mutex};
use std::thread;
use case_result::{CaseResult, CaseStatus};
use find_file::TestHarness;
use reporter::{DefaultReporter, InteractiveReporter, Reporter};
use snapshot::{SnapshotAction, SnapshotCollection, TestSnapshots};
use test_case::TestCase;
fn parallel_collect<'a, T, R, F>(cases: &'a [T], filter_mapper: F) -> Vec<R>
where
T: Sync,
R: Send,
F: FnMut(&'a T) -> Option<R> + Send + Copy,
{
let threads = std::thread::available_parallelism()
.map_or(1, |n| n.get())
.min(12);
let chunk_size = (cases.len() + threads) / threads;
thread::scope(|s| {
cases
.chunks(chunk_size)
.map(|chunk| {
s.spawn(move || {
chunk
.iter()
.filter_map(filter_mapper) .collect::<Vec<_>>() })
})
.collect::<Vec<_>>() .into_iter()
.flat_map(|sc| sc.join().unwrap())
.collect()
})
}
fn run_test_rule_impl<R: Reporter + Send>(
arg: TestArg,
reporter: R,
project: ProjectConfig,
) -> Result<ExitCode> {
let filter = arg.filter.as_ref();
let overwrite = RuleOverwrite::new_for_verify(filter, arg.include_off);
let collections = &project.find_rules(overwrite)?.0;
let TestHarness {
test_cases,
snapshots,
path_map,
} = if let Some(test_dirname) = arg.test_dir {
let snapshot_dirname = arg.snapshot_dir.as_deref();
TestHarness::from_dir(&test_dirname, snapshot_dirname, filter)?
} else {
TestHarness::from_config(project, filter)?
};
let snapshots = (!arg.skip_snapshot_tests).then_some(snapshots);
let reporter = &Arc::new(Mutex::new(reporter));
{
reporter.lock().unwrap().before_report(&test_cases)?;
}
let check_one_case = |case| {
let result = verify_test_case_simple(case, collections, snapshots.as_ref());
if result.is_none() {
let mut reporter = reporter.lock().unwrap();
let output = reporter.get_output();
writeln!(output, "Configuration not found! {}", case.id).unwrap();
}
result
};
let mut results = parallel_collect(&test_cases, check_one_case);
let mut reporter = reporter.lock().unwrap();
reporter.report_failed_cases(&mut results)?;
let action = reporter.collect_snapshot_action();
apply_snapshot_action(action, &results, snapshots, path_map)?;
reporter.report_summaries(&results)?;
let test_result = reporter.after_report(&results)?;
match test_result {
reporter::TestResult::Success { message } => {
writeln!(reporter.get_output(), "{message}")?;
Ok(ExitCode::SUCCESS)
}
reporter::TestResult::RuleFail { message } => {
let error_context = ErrorContext::TestFail(message);
Err(anyhow!(error_context))
}
reporter::TestResult::MismatchSnapshotOnly { message } => {
let error_context = ErrorContext::TestSnapshotMismatch(message);
Err(anyhow!(error_context))
}
}
}
fn apply_snapshot_action(
action: SnapshotAction,
results: &[CaseResult],
snapshots: Option<SnapshotCollection>,
path_map: HashMap<String, PathBuf>,
) -> Result<()> {
let Some(snapshots) = snapshots else {
return Ok(());
};
let Some(merged) = action.update_snapshot_collection(snapshots, results) else {
return Ok(());
};
write_merged_to_disk(merged, path_map)
}
fn write_merged_to_disk(
merged: SnapshotCollection,
path_map: HashMap<String, PathBuf>,
) -> Result<()> {
for (id, snaps) in merged {
let path = &path_map[&id];
if !path.exists() {
std::fs::create_dir(path)?;
}
let file = path.join(format!("{id}-snapshot.yml"));
std::fs::write(file, to_string(&snaps)?)?;
}
Ok(())
}
fn verify_test_case_simple<'a>(
test_case: &'a TestCase,
rules: &RuleCollection<SgLang>,
snapshots: Option<&SnapshotCollection>,
) -> Option<CaseResult<'a>> {
let rule_config = rules.get_rule(&test_case.id)?;
let test_case = if let Some(snapshots) = snapshots {
let snaps = snapshots.get(&test_case.id);
test_case.verify_with_snapshot(rule_config, snaps)
} else {
test_case.verify_rule(rule_config)
};
Some(test_case)
}
#[derive(Args)]
pub struct TestArg {
#[clap(short, long)]
test_dir: Option<PathBuf>,
#[clap(long)]
snapshot_dir: Option<PathBuf>,
#[clap(long, conflicts_with = "update_all")]
skip_snapshot_tests: bool,
#[clap(short = 'U', long)]
update_all: bool,
#[clap(short, long)]
interactive: bool,
#[clap(short, long, value_name = "REGEX")]
filter: Option<Regex>,
#[clap(long)]
include_off: bool,
#[clap(long, default_value = "auto", value_name = "WHEN")]
color: ColorArg,
}
pub fn run_test_rule(arg: TestArg, project: Result<ProjectConfig>) -> Result<ExitCode> {
let project = project?;
let style = if arg.color.should_use_color() {
TestReportStyle::colored()
} else {
TestReportStyle::default()
};
if arg.interactive {
let reporter = InteractiveReporter {
output: std::io::stdout(),
should_accept_all: false,
style,
};
run_test_rule_impl(arg, reporter, project)
} else {
let reporter = DefaultReporter {
output: std::io::stdout(),
update_all: arg.update_all,
style,
};
run_test_rule_impl(arg, reporter, project)
}
}
#[cfg(test)]
pub mod test {
use super::*;
use ast_grep_config::{from_str, GlobalRules, RuleConfig};
pub const TEST_RULE: &str = "test-rule";
fn get_rule_text(rule: &str) -> String {
format!(
"
id: {TEST_RULE}
message: test
severity: hint
language: TypeScript
rule:
{rule}
"
)
}
pub fn get_rule_config(rule: &str) -> RuleConfig<SgLang> {
let globals = GlobalRules::default();
let inner = from_str(&get_rule_text(rule)).unwrap();
RuleConfig::try_from(inner, &globals).unwrap()
}
fn always_report_rule() -> RuleCollection<SgLang> {
let rule = get_rule_config("all: [kind: number]");
RuleCollection::try_new(vec![rule]).expect("RuleCollection must be valid")
}
fn never_report_rule() -> RuleCollection<SgLang> {
let rule = get_rule_config("any: [kind: string]");
RuleCollection::try_new(vec![rule]).expect("RuleCollection must be valid")
}
fn valid_case() -> TestCase {
TestCase {
id: TEST_RULE.into(),
valid: vec!["123".into()],
invalid: vec![],
}
}
fn invalid_case() -> TestCase {
TestCase {
id: TEST_RULE.into(),
valid: vec![],
invalid: vec!["123".into()],
}
}
fn test_case_result(status: CaseStatus) -> Option<CaseResult> {
Some(CaseResult {
id: TEST_RULE,
cases: vec![status],
})
}
#[test]
fn test_validated() {
let rule = never_report_rule();
let case = valid_case();
let ret = verify_test_case_simple(&case, &rule, None);
assert_eq!(ret, test_case_result(CaseStatus::Validated),);
}
#[test]
fn test_reported() {
let case = invalid_case();
let rule = always_report_rule();
let ret = verify_test_case_simple(&case, &rule, None);
assert_eq!(ret, test_case_result(CaseStatus::Reported),);
}
#[test]
fn test_noisy() {
let case = valid_case();
let rule = always_report_rule();
let ret = verify_test_case_simple(&case, &rule, None);
assert_eq!(ret, test_case_result(CaseStatus::Noisy("123")),);
}
#[test]
fn test_missing() {
let case = invalid_case();
let rule = never_report_rule();
let ret = verify_test_case_simple(&case, &rule, None);
assert_eq!(ret, test_case_result(CaseStatus::Missing("123")),);
}
#[test]
fn test_no_such_rule() {
let case = TestCase {
id: "no-such-rule".into(),
valid: vec![],
invalid: vec![],
};
let rule = never_report_rule();
let ret = verify_test_case_simple(&case, &rule, None);
assert!(ret.is_none());
}
#[test]
fn test_run_verify_error() {
let arg = TestArg {
interactive: false,
skip_snapshot_tests: true,
snapshot_dir: None,
test_dir: None,
update_all: false,
filter: None,
include_off: false,
color: ColorArg::Never,
};
assert!(run_test_rule(arg, Err(anyhow!("error"))).is_err());
}
const TRANSFORM_TEXT: &str = "
transform:
B:
substring:
source: $A
startChar: 1
endChar: -1
fix: 'log($B)'";
#[test]
fn test_verify_transform() {
let globals = GlobalRules::default();
let inner = from_str(&get_rule_text(&format!(
"pattern: console.log($A)\n{TRANSFORM_TEXT}"
)))
.unwrap();
let rule = RuleConfig::try_from(inner, &globals).unwrap();
let rule = RuleCollection::try_new(vec![rule]).expect("RuleCollection must be valid");
let case = TestCase {
id: TEST_RULE.into(),
valid: vec![],
invalid: vec!["console.log(123)".to_string()],
};
let snapshots = SnapshotCollection::new();
let mut ret = verify_test_case_simple(&case, &rule, Some(&snapshots)).unwrap();
let case = ret.cases.pop().unwrap();
match case {
CaseStatus::Wrong { actual, .. } => {
assert_eq!(actual.fixed.unwrap(), "log(2)");
}
_ => {
panic!("wrong case status");
}
}
}
}