use std::path::Path;
use walkdir::WalkDir;
#[derive(Debug, Clone)]
pub struct FileTests {
pub file: String,
pub tests: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct FileDocTests {
pub file: String,
pub count: usize,
}
#[derive(Debug, Default, Clone)]
pub struct TestInventory {
pub unit: Vec<FileTests>,
pub integration: Vec<FileTests>,
pub doc: Vec<FileDocTests>,
}
impl TestInventory {
pub fn unit_total(&self) -> usize {
self.unit.iter().map(|f| f.tests.len()).sum()
}
pub fn integration_total(&self) -> usize {
self.integration.iter().map(|f| f.tests.len()).sum()
}
pub fn doc_total(&self) -> usize {
self.doc.iter().map(|f| f.count).sum()
}
}
pub fn is_benchmark_file(stem: &str) -> bool {
let s = stem.to_ascii_lowercase();
s.ends_with("_bench")
|| s.ends_with("_benches")
|| s.starts_with("perf_")
|| s.starts_with("bench_")
|| s.contains("throughput")
|| s.contains("benchmark")
}
pub fn scan(repo_root: &Path, crate_filter: Option<&str>) -> TestInventory {
scan_opts(repo_root, crate_filter, false)
}
pub fn scan_opts(repo_root: &Path, crate_filter: Option<&str>, include_benches: bool) -> TestInventory {
let mut inv = TestInventory::default();
for entry in WalkDir::new(repo_root)
.into_iter()
.filter_entry(|e| !is_skipped_dir(&e.file_name().to_string_lossy()))
.flatten()
{
if entry.file_name() != "Cargo.toml" {
continue;
}
let Ok(text) = std::fs::read_to_string(entry.path()) else { continue };
let Ok(doc) = text.parse::<toml::Value>() else { continue };
let Some(crate_name) = doc
.get("package")
.and_then(|p| p.get("name"))
.and_then(|n| n.as_str())
else {
continue;
};
if let Some(pat) = crate_filter {
if !crate_name.contains(pat) {
continue;
}
}
let Some(crate_dir) = entry.path().parent() else { continue };
scan_crate(repo_root, crate_dir, &mut inv, include_benches);
}
inv.unit.sort_by(|a, b| a.file.cmp(&b.file));
inv.integration.sort_by(|a, b| a.file.cmp(&b.file));
inv.doc.sort_by(|a, b| a.file.cmp(&b.file));
inv
}
fn scan_crate(repo_root: &Path, crate_dir: &Path, inv: &mut TestInventory, include_benches: bool) {
for (sub, integration) in [("src", false), ("tests", true)] {
let dir = crate_dir.join(sub);
if !dir.is_dir() {
continue;
}
for entry in WalkDir::new(&dir)
.into_iter()
.filter_entry(|e| !is_skipped_dir(&e.file_name().to_string_lossy()))
.flatten()
{
let path = entry.path();
if !entry.file_type().is_file()
|| path.extension().and_then(|e| e.to_str()) != Some("rs")
{
continue;
}
if !include_benches {
if let Some(stem) = path.file_stem().and_then(|s| s.to_str()) {
if is_benchmark_file(stem) {
continue;
}
}
}
let Ok(src) = std::fs::read_to_string(path) else { continue };
let rel = rel_path(repo_root, path);
let names = test_fn_names(&src);
if !names.is_empty() {
let ft = FileTests { file: rel.clone(), tests: names };
if integration {
inv.integration.push(ft);
} else {
inv.unit.push(ft);
}
}
if !integration {
let count = count_doc_tests(&src);
if count > 0 {
inv.doc.push(FileDocTests { file: rel, count });
}
}
}
}
}
fn rel_path(repo_root: &Path, path: &Path) -> String {
path.strip_prefix(repo_root)
.unwrap_or(path)
.to_string_lossy()
.replace('\\', "/")
}
fn test_fn_names(src: &str) -> Vec<String> {
use syn::visit::{self, Visit};
struct Collector {
names: Vec<String>,
}
impl<'ast> Visit<'ast> for Collector {
fn visit_item_fn(&mut self, f: &'ast syn::ItemFn) {
if has_test_attr(&f.attrs) {
self.names.push(f.sig.ident.to_string());
}
visit::visit_item_fn(self, f);
}
}
fn has_test_attr(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|a| {
a.path()
.segments
.last()
.is_some_and(|s| s.ident == "test")
})
}
let Ok(file) = syn::parse_file(src) else { return Vec::new() };
let mut c = Collector { names: Vec::new() };
c.visit_file(&file);
c.names
}
fn count_doc_tests(src: &str) -> usize {
let mut count = 0;
let mut in_fence = false;
for raw in src.lines() {
let line = raw.trim_start();
let Some(doc) = line
.strip_prefix("///")
.or_else(|| line.strip_prefix("//!"))
else {
continue;
};
let doc = doc.trim_start();
let Some(info) = doc.strip_prefix("```").or_else(|| doc.strip_prefix("~~~")) else {
continue;
};
if in_fence {
in_fence = false; continue;
}
in_fence = true;
let info = info.trim().to_ascii_lowercase();
let non_runnable = info.split([',', ' ']).any(|t| t == "ignore" || t == "text");
if !non_runnable {
count += 1;
}
}
count
}
fn is_skipped_dir(name: &str) -> bool {
matches!(name, "target" | ".git" | "node_modules" | ".nornir")
}
#[cfg(test)]
mod tests {
use super::*;
fn write(p: &Path, s: &str) {
std::fs::create_dir_all(p.parent().unwrap()).unwrap();
std::fs::write(p, s).unwrap();
}
#[test]
fn scans_unit_integration_and_doc() {
let t = tempfile::tempdir().unwrap();
let root = t.path();
write(&root.join("Cargo.toml"), "[package]\nname='demo'\nversion='0.1.0'\n");
write(
&root.join("src/lib.rs"),
r#"
/// Adds.
/// ```
/// assert_eq!(demo::add(1,2), 3);
/// ```
/// ```ignore
/// not_run();
/// ```
pub fn add(a: i32, b: i32) -> i32 { a + b }
#[cfg(test)]
mod tests {
#[test]
fn unit_a() {}
#[tokio::test]
async fn unit_b() {}
#[bench]
fn not_a_test() {}
}
"#,
);
write(
&root.join("tests/it.rs"),
"#[test]\nfn integration_one() {}\n",
);
let inv = scan(root, None);
assert_eq!(inv.unit_total(), 2, "unit: {:?}", inv.unit);
assert_eq!(inv.integration_total(), 1);
assert_eq!(inv.doc_total(), 1, "only the runnable fence counts");
assert!(inv.unit[0].tests.contains(&"unit_a".to_string()));
assert!(!inv.unit[0].tests.contains(&"not_a_test".to_string()));
}
#[test]
fn skips_benchmark_files_by_default() {
let t = tempfile::tempdir().unwrap();
let root = t.path();
write(&root.join("Cargo.toml"), "[package]\nname='demo'\nversion='0.1.0'\n");
write(&root.join("tests/integration_test.rs"), "#[test]\nfn real_one() {}\n");
write(&root.join("tests/compress_dir_bench.rs"), "#[test]\nfn compress_dir_real_jars() {}\n");
write(&root.join("tests/perf_bench.rs"), "#[test]\nfn perf_suite() {}\n");
let inv = scan(root, None);
assert_eq!(inv.integration_total(), 1, "only the real test counts: {:?}", inv.integration);
assert_eq!(inv.integration[0].file, "tests/integration_test.rs");
let inv2 = scan_opts(root, None, true);
assert_eq!(inv2.integration_total(), 3);
}
#[test]
fn benchmark_file_classifier() {
assert!(is_benchmark_file("compress_dir_bench"));
assert!(is_benchmark_file("perf_bench"));
assert!(is_benchmark_file("maven_bench"));
assert!(is_benchmark_file("jar_throughput"));
assert!(is_benchmark_file("decompress_benchmark"));
assert!(!is_benchmark_file("integration_test"));
assert!(!is_benchmark_file("maven_artifacts"));
assert!(!is_benchmark_file("codec"));
}
#[test]
fn crate_filter_limits_scope() {
let t = tempfile::tempdir().unwrap();
let root = t.path();
write(&root.join("a/Cargo.toml"), "[package]\nname='alpha'\nversion='0.1.0'\n");
write(&root.join("a/src/lib.rs"), "#[test]\nfn ta() {}\n");
write(&root.join("b/Cargo.toml"), "[package]\nname='beta'\nversion='0.1.0'\n");
write(&root.join("b/src/lib.rs"), "#[test]\nfn tb() {}\n");
let inv = scan(root, Some("alph"));
assert_eq!(inv.unit_total(), 1);
assert_eq!(inv.unit[0].file, "a/src/lib.rs");
}
}