use super::RustTestType;
use quote::ToTokens;
use std::path::Path;
use syn::ItemFn;
pub struct TestClassifier;
impl TestClassifier {
pub fn new() -> Self {
Self
}
pub fn is_test_function(&self, func: &ItemFn) -> bool {
self.has_test_attribute(func)
|| self.is_benchmark_test(func)
|| self.has_test_name_pattern(&func.sig.ident.to_string())
}
pub fn classify_test_type(&self, func: &ItemFn, file_path: &Path) -> Option<RustTestType> {
if !self.is_test_function(func) {
return None;
}
if self.is_benchmark_test(func) {
return Some(RustTestType::BenchmarkTest);
}
if self.is_property_test(func) {
return Some(RustTestType::PropertyTest);
}
if self.is_integration_test_path(file_path) {
return Some(RustTestType::IntegrationTest);
}
Some(RustTestType::UnitTest)
}
fn has_test_attribute(&self, func: &ItemFn) -> bool {
func.attrs.iter().any(|attr| {
if attr.path().is_ident("test") {
return true;
}
if let Some(last_segment) = attr.path().segments.last() {
if last_segment.ident == "test" {
return true;
}
}
if attr.path().is_ident("cfg") {
let tokens = attr.meta.to_token_stream().to_string();
if tokens.contains("test") {
return true;
}
}
false
})
}
fn has_test_name_pattern(&self, name: &str) -> bool {
const TEST_PREFIXES: &[&str] = &["test_", "it_", "should_"];
const MOCK_PATTERNS: &[&str] = &["mock", "stub", "fake"];
let name_lower = name.to_lowercase();
if TEST_PREFIXES.iter().any(|prefix| name.starts_with(prefix)) {
return true;
}
if MOCK_PATTERNS
.iter()
.any(|pattern| name_lower.contains(pattern))
{
return true;
}
false
}
fn is_benchmark_test(&self, func: &ItemFn) -> bool {
func.attrs.iter().any(|attr| {
attr.path().is_ident("bench")
|| attr.path().segments.iter().any(|seg| seg.ident == "bench")
})
}
fn is_property_test(&self, func: &ItemFn) -> bool {
let has_proptest = func.attrs.iter().any(|attr| {
attr.path().segments.iter().any(|seg| {
let ident_str = seg.ident.to_string();
ident_str.contains("proptest") || ident_str.contains("quickcheck")
})
});
if has_proptest {
return true;
}
let tokens = quote::quote!(#func).to_string();
tokens.contains("proptest!") || tokens.contains("quickcheck!")
}
fn is_integration_test_path(&self, path: &Path) -> bool {
let path_str = path.to_string_lossy();
path_str.contains("/tests/")
|| path_str.contains("\\tests\\")
|| path_str.starts_with("tests/")
|| path_str.starts_with("tests\\")
}
}
impl Default for TestClassifier {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::path::PathBuf;
use syn::parse_quote;
#[test]
fn test_detect_standard_test() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
#[test]
fn test_something() {
assert_eq!(1, 1);
}
};
assert!(classifier.is_test_function(&func));
}
#[test]
fn test_detect_tokio_test() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
#[tokio::test]
async fn test_async() {
assert_eq!(1, 1);
}
};
assert!(classifier.is_test_function(&func));
}
#[test]
fn test_detect_test_by_name() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
fn test_something() {
assert_eq!(1, 1);
}
};
assert!(classifier.is_test_function(&func));
}
#[test]
fn test_classify_unit_test() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
#[test]
fn test_unit() {
assert_eq!(1, 1);
}
};
let path = PathBuf::from("src/lib.rs");
assert_eq!(
classifier.classify_test_type(&func, &path),
Some(RustTestType::UnitTest)
);
}
#[test]
fn test_classify_integration_test() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
#[test]
fn test_integration() {
assert_eq!(1, 1);
}
};
let path = PathBuf::from("tests/integration_test.rs");
assert_eq!(
classifier.classify_test_type(&func, &path),
Some(RustTestType::IntegrationTest)
);
}
#[test]
fn test_classify_benchmark_test() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
#[bench]
fn bench_something(b: &mut Bencher) {
b.iter(|| 2 + 2);
}
};
let path = PathBuf::from("benches/bench.rs");
assert_eq!(
classifier.classify_test_type(&func, &path),
Some(RustTestType::BenchmarkTest)
);
}
#[test]
fn test_non_test_function() {
let classifier = TestClassifier::new();
let func: ItemFn = parse_quote! {
fn regular_function() {
println!("not a test");
}
};
assert!(!classifier.is_test_function(&func));
}
}