#![doc = include_str!("../README.md")]
use proc_macro2::{Ident, Span, TokenStream};
use quote::{format_ident, quote};
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
use std::env;
use std::ffi::OsString;
use std::path::{Path, PathBuf};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::token::Async;
use syn::{bracketed, parse_macro_input, Expr, LitStr, Meta, Token};
use unicode_ident::{is_xid_continue, is_xid_start};
struct TestEachArgs {
path: LitStr,
module: Option<Ident>,
function: Expr,
extensions: Vec<String>,
attributes: Vec<Meta>,
async_fn: Option<Async>,
ignore_patterns: IgnorePatterns,
}
macro_rules! abort {
($span:expr, $message:expr) => {
return Err(syn::Error::new($span, $message))
};
}
macro_rules! abort_token_stream {
($span:expr, $message:expr) => {
return syn::Error::new($span, $message).into_compile_error().into()
};
}
#[derive(Default)]
struct IgnorePatterns {
patterns: HashMap<String, String>,
}
impl Parse for IgnorePatterns {
fn parse(input: ParseStream) -> syn::Result<Self> {
if !input
.fork()
.parse::<Ident>()
.ok()
.is_some_and(|id| id == "ignore")
{
return Ok(IgnorePatterns::default());
}
let _: Ident = input.parse().unwrap();
let _ = input.parse::<Token![:]>();
let content;
syn::braced!(content in input);
let mut patterns = HashMap::new();
while !content.is_empty() {
let key: LitStr = match content.parse() {
Ok(k) => k,
Err(e) => abort!(
e.span(),
"Expected a string literal for ignore pattern name."
),
};
if let Err(e) = content.parse::<Token![=>]>() {
abort!(e.span(), "Expected `=>` after ignore pattern name.");
}
let reason: LitStr = match content.parse() {
Ok(r) => r,
Err(e) => abort!(e.span(), "Expected a string literal for ignore reason."),
};
patterns.insert(key.value(), reason.value());
let _ = content.parse::<Token![,]>();
}
Ok(IgnorePatterns { patterns })
}
}
impl IgnorePatterns {
fn should_ignore(&self, name: &str) -> Option<String> {
let lookup_name = name.strip_prefix("r#").unwrap_or(name);
self.patterns.get(lookup_name).cloned()
}
}
impl Parse for TestEachArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let attributes: Vec<Meta> = input
.parse::<Token![#]>()
.and_then(|_| {
let content;
bracketed!(content in input);
match Punctuated::<Meta, Token![,]>::parse_separated_nonempty(&content) {
Ok(attributes) => Ok(attributes.into_iter().collect()),
Err(e) => abort!(e.span(), "Expected at least one attribute to be given."),
}
})
.unwrap_or_default();
let async_span = input.span();
let async_fn = match input.parse::<Token![async]>() {
Ok(token) => {
if attributes.is_empty() {
abort!(async_span, "Expected at least one attribute (e.g., `#[tokio::test]`) when `async` is given.");
}
Some(token)
}
Err(_) => None,
};
let extensions = input
.parse::<Token![for]>()
.and_then(|_| {
let content;
bracketed!(content in input);
match Punctuated::<LitStr, Token![,]>::parse_separated_nonempty(&content) {
Ok(extensions) => Ok(extensions
.into_iter()
.map(|extension| extension.value())
.collect()),
Err(e) => abort!(e.span(), "Expected at least one extension to be given."),
}
})
.unwrap_or_default();
if let Err(e) = input.parse::<Token![in]>() {
abort!(e.span(), "Expected the keyword `in` before the path.");
};
let path = match input.parse::<LitStr>() {
Ok(path) => path,
Err(e) => abort!(e.span(), "Expected a path after the keyword 'in'."),
};
let module = input
.parse::<Token![as]>()
.and_then(|_| match input.parse::<Ident>() {
Ok(module) => Ok(module),
Err(e) => abort!(e.span(), "Expected a module to be given."),
})
.ok();
if let Err(e) = input.parse::<Token![=>]>() {
abort!(e.span(), "Expected `=>` before the function to call.");
};
let function = match input.parse::<Expr>() {
Ok(function) => function,
Err(e) => abort!(e.span(), "Expected a function to call after `=>`."),
};
let ignore_patterns = IgnorePatterns::parse(input)?;
Ok(Self {
path,
module,
function,
extensions,
attributes,
async_fn,
ignore_patterns,
})
}
}
#[derive(Default)]
struct Tree {
children: BTreeMap<PathBuf, Tree>,
here: BTreeSet<PathBuf>,
}
impl Tree {
fn new(base: &Path, extensions: &[String]) -> Result<Self, String> {
let mut tree = Self::default();
for entry in base.read_dir().unwrap() {
let mut entry = entry.unwrap().path();
if entry.is_file() {
if !extensions.is_empty() {
let Some(extension) = entry.extension() else {
continue;
};
if !extensions
.iter()
.any(|test_extension| test_extension == extension.to_str().unwrap())
{
continue;
}
entry.set_extension("");
}
tree.here.insert(entry);
} else if entry.is_dir() {
tree.children.insert(
entry.as_path().to_path_buf(),
Self::new(entry.as_path(), extensions)?,
);
} else {
return Err(format!("Unsupported path: {entry:#?}."));
}
}
Ok(tree)
}
}
enum Type {
File,
Path,
}
fn sanitize_ident(input: &str) -> Ident {
let name: String = input
.chars()
.map(|c| if is_xid_continue(c) { c } else { '_' })
.collect();
if !is_xid_start(name.chars().next().expect("Name is not empty")) {
format_ident!("test_{name}")
} else {
Ident::new_raw(&name, Span::call_site())
}
}
fn generate_name(starting_name: Ident, taken_names: &mut HashSet<Ident>) -> Ident {
if taken_names.insert(starting_name.clone()) {
return starting_name;
}
for i in 2.. {
let new_name = format_ident!("{starting_name}_{i}");
if taken_names.insert(new_name.clone()) {
return new_name;
}
}
unreachable!()
}
fn generate_from_tree(
tree: &Tree,
parsed: &TestEachArgs,
stream: &mut TokenStream,
invocation_type: &Type,
) -> Result<(), String> {
let mut taken_names_folders = HashSet::new();
for (name, directory) in tree.children.iter() {
let file_name = name.file_name().unwrap().to_str().unwrap();
let file_name = sanitize_ident(file_name);
let file_name = generate_name(file_name, &mut taken_names_folders);
let mut sub_stream = TokenStream::new();
generate_from_tree(directory, parsed, &mut sub_stream, invocation_type)?;
stream.extend(quote! {
mod #file_name {
use super::*;
#sub_stream
}
});
}
let mut taken_names_files = HashSet::new();
for file in tree.here.iter() {
let file_name = file.file_stem().unwrap().to_str().unwrap();
let file_name = sanitize_ident(file_name);
let file_name = generate_name(file_name, &mut taken_names_files);
let function = &parsed.function;
let arguments: TokenStream = if parsed.extensions.is_empty() {
let input = file.canonicalize().unwrap();
let input = input.to_str().unwrap();
match invocation_type {
Type::File => quote!(include_str!(#input)),
Type::Path => quote!(std::path::Path::new(#input)),
}
} else {
let mut arguments = TokenStream::new();
for extension in &parsed.extensions {
let mut file: OsString = file.clone().into();
file.push(".");
file.push(extension);
let file: PathBuf = file.into();
let input = match file.canonicalize() {
Ok(path) => path,
Err(e) => {
return Err(format!(
"Failed to read expected file {}.{extension}: {e}",
file.display()
))
}
};
let input = input.to_str().unwrap();
arguments.extend(match invocation_type {
Type::File => quote!(include_str!(#input),),
Type::Path => quote!(std::path::Path::new(#input),),
});
}
quote!([#arguments])
};
for attribute in &parsed.attributes {
stream.extend(quote! {
#[#attribute]
});
}
if let Some(reason) = parsed.ignore_patterns.should_ignore(&file_name.to_string()) {
stream.extend(quote! {
#[ignore = #reason]
});
}
if let Some(async_keyword) = &parsed.async_fn {
stream.extend(quote! {
#async_keyword fn #file_name() {
(#function)(#arguments).await
}
});
} else {
stream.extend(quote! {
#[test]
fn #file_name() {
(#function)(#arguments)
}
});
}
}
Ok(())
}
fn test_each(input: proc_macro::TokenStream, invocation_type: &Type) -> proc_macro::TokenStream {
let parsed = parse_macro_input!(input as TestEachArgs);
let path = parsed.path.value();
let path = Path::new(&path);
if !path.is_dir() {
let abs_path: PathBuf = env::current_dir()
.unwrap_or_default()
.join(path)
.components()
.collect();
abort_token_stream!(
parsed.path.span(),
format!("Given directory does not exist: {abs_path:?}")
);
}
let mut tokens = TokenStream::new();
let files = match Tree::new(parsed.path.value().as_ref(), &parsed.extensions) {
Ok(files) => files,
Err(e) => abort_token_stream!(parsed.path.span(), e),
};
if let Err(e) = generate_from_tree(&files, &parsed, &mut tokens, invocation_type) {
abort_token_stream!(parsed.path.span(), e)
}
if let Some(module) = parsed.module {
tokens = quote! {
#[cfg(test)]
mod #module {
use super::*;
#tokens
}
}
}
proc_macro::TokenStream::from(tokens)
}
#[proc_macro]
pub fn test_each_file(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
test_each(input, &Type::File)
}
#[proc_macro]
pub fn test_each_path(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
test_each(input, &Type::Path)
}