use proc_macro::TokenStream;
use quote::quote;
use syn::{
Ident, ItemFn, ReturnType, Token,
parse::{Parse, ParseStream},
parse_macro_input,
punctuated::Punctuated,
};
struct BenchmarkArgs {
setup: Option<Ident>,
teardown: Option<Ident>,
per_iteration: bool,
}
impl Parse for BenchmarkArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let mut setup = None;
let mut teardown = None;
let mut per_iteration = false;
if input.is_empty() {
return Ok(Self {
setup,
teardown,
per_iteration,
});
}
let args = Punctuated::<BenchmarkArg, Token![,]>::parse_terminated(input)?;
for arg in args {
match arg {
BenchmarkArg::Setup(ident) => {
if setup.is_some() {
return Err(syn::Error::new_spanned(ident, "duplicate setup argument"));
}
setup = Some(ident);
}
BenchmarkArg::Teardown(ident) => {
if teardown.is_some() {
return Err(syn::Error::new_spanned(
ident,
"duplicate teardown argument",
));
}
teardown = Some(ident);
}
BenchmarkArg::PerIteration => {
per_iteration = true;
}
}
}
if teardown.is_some() && setup.is_none() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"teardown requires setup to be specified",
));
}
if per_iteration && teardown.is_some() {
return Err(syn::Error::new(
proc_macro2::Span::call_site(),
"per_iteration mode is not compatible with teardown",
));
}
Ok(Self {
setup,
teardown,
per_iteration,
})
}
}
enum BenchmarkArg {
Setup(Ident),
Teardown(Ident),
PerIteration,
}
impl Parse for BenchmarkArg {
fn parse(input: ParseStream) -> syn::Result<Self> {
let name: Ident = input.parse()?;
match name.to_string().as_str() {
"setup" => {
input.parse::<Token![=]>()?;
let value: Ident = input.parse()?;
Ok(BenchmarkArg::Setup(value))
}
"teardown" => {
input.parse::<Token![=]>()?;
let value: Ident = input.parse()?;
Ok(BenchmarkArg::Teardown(value))
}
"per_iteration" => Ok(BenchmarkArg::PerIteration),
_ => Err(syn::Error::new_spanned(
name,
"expected 'setup', 'teardown', or 'per_iteration'",
)),
}
}
}
#[proc_macro_attribute]
pub fn benchmark(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as BenchmarkArgs);
let input_fn = parse_macro_input!(item as ItemFn);
let fn_name = &input_fn.sig.ident;
let fn_name_str = fn_name.to_string();
let vis = &input_fn.vis;
let sig = &input_fn.sig;
let block = &input_fn.block;
let attrs = &input_fn.attrs;
if input_fn.sig.asyncness.is_some() {
return syn::Error::new_spanned(
input_fn.sig.asyncness,
"#[benchmark] does not support async fn. Move async work into a synchronous runtime boundary so the benchmark measures execution instead of future creation.",
)
.to_compile_error()
.into();
}
if args.setup.is_some() {
if input_fn.sig.inputs.len() != 1 {
let param_count = input_fn.sig.inputs.len();
return syn::Error::new_spanned(
&input_fn.sig,
format!(
"#[benchmark(setup = ...)] functions must take exactly one parameter.\n\
Found {} parameter(s).\n\n\
Example:\n\
fn setup_data() -> MyData {{ ... }}\n\n\
#[benchmark(setup = setup_data)]\n\
fn {}(input: &MyData) {{\n\
// input is the result of setup_data()\n\
}}",
param_count, fn_name_str
),
)
.to_compile_error()
.into();
}
} else {
if !input_fn.sig.inputs.is_empty() {
let param_count = input_fn.sig.inputs.len();
let param_names: Vec<String> = input_fn
.sig
.inputs
.iter()
.map(|arg| match arg {
syn::FnArg::Receiver(_) => "self".to_string(),
syn::FnArg::Typed(pat) => quote!(#pat).to_string(),
})
.collect();
return syn::Error::new_spanned(
&input_fn.sig.inputs,
format!(
"#[benchmark] functions must take no parameters.\n\
Found {} parameter(s): {}\n\n\
If you need setup data, use the setup attribute:\n\n\
fn setup_data() -> MyData {{ ... }}\n\n\
#[benchmark(setup = setup_data)]\n\
fn {}(input: &MyData) {{\n\
// Your benchmark code using input\n\
}}",
param_count,
param_names.join(", "),
fn_name_str
),
)
.to_compile_error()
.into();
}
}
match &input_fn.sig.output {
ReturnType::Default => {} ReturnType::Type(_, return_type) => {
let type_str = quote!(#return_type).to_string();
if type_str.trim() != "()" {
return syn::Error::new_spanned(
return_type,
format!(
"#[benchmark] functions must return () (unit type).\n\
Found return type: {}\n\n\
Benchmark results should be consumed with std::hint::black_box() \
rather than returned:\n\n\
#[benchmark]\n\
fn {}() {{\n\
let result = compute_something();\n\
std::hint::black_box(result); // Prevents optimization\n\
}}",
type_str, fn_name_str
),
)
.to_compile_error()
.into();
}
}
}
let runner = generate_runner(fn_name, &args);
let expanded = quote! {
#(#attrs)*
#vis #sig {
#block
}
::inventory::submit! {
::mobench_sdk::registry::BenchFunction {
name: ::std::concat!(::std::module_path!(), "::", #fn_name_str),
runner: #runner,
}
}
};
TokenStream::from(expanded)
}
fn generate_runner(fn_name: &Ident, args: &BenchmarkArgs) -> proc_macro2::TokenStream {
match (&args.setup, &args.teardown, args.per_iteration) {
(None, None, _) => quote! {
|spec: ::mobench_sdk::timing::BenchSpec| -> ::std::result::Result<::mobench_sdk::timing::BenchReport, ::mobench_sdk::timing::TimingError> {
::mobench_sdk::timing::run_closure(spec, || {
#fn_name();
Ok(())
})
}
},
(Some(setup), None, false) => quote! {
|spec: ::mobench_sdk::timing::BenchSpec| -> ::std::result::Result<::mobench_sdk::timing::BenchReport, ::mobench_sdk::timing::TimingError> {
::mobench_sdk::timing::run_closure_with_setup(
spec,
|| #setup(),
|input| {
#fn_name(input);
Ok(())
},
)
}
},
(Some(setup), None, true) => quote! {
|spec: ::mobench_sdk::timing::BenchSpec| -> ::std::result::Result<::mobench_sdk::timing::BenchReport, ::mobench_sdk::timing::TimingError> {
::mobench_sdk::timing::run_closure_with_setup_per_iter(
spec,
|| #setup(),
|input| {
#fn_name(input);
Ok(())
},
)
}
},
(Some(setup), Some(teardown), false) => quote! {
|spec: ::mobench_sdk::timing::BenchSpec| -> ::std::result::Result<::mobench_sdk::timing::BenchReport, ::mobench_sdk::timing::TimingError> {
::mobench_sdk::timing::run_closure_with_setup_teardown(
spec,
|| #setup(),
|input| {
#fn_name(input);
Ok(())
},
|input| #teardown(input),
)
}
},
(None, Some(_), _) | (Some(_), Some(_), true) => {
quote! { compile_error!("invalid benchmark configuration") }
}
}
}