bma_benchmark_proc/
lib.rs

1//!  Procedure macros for <https://crates.io/crates/bma-benchmark>
2use proc_macro::{TokenStream, TokenTree};
3use quote::{quote, ToTokens};
4use std::panic::panic_any;
5
6const ERR_INVALID_OPTIONS: &str = "Invalid options";
7
8#[proc_macro_attribute]
9/// Wraps functions for a staged benchmark
10///
11/// Attribute options:
12///
13/// * **i** number of iterations, required
14/// * **name** custom stage name (the default is function name)
15/// * **check** check for the result, the function body MUST (not return but) END with a bool
16///
17/// If a function name starts with *test_* or *benchmark_*, the prefix is automatically stripped.
18///
19/// Example:
20///
21/// ```rust
22/// #[benchmark_stage(i=1_000)]
23/// fn test1() {
24///     // do something
25/// }
26/// ```
27///
28/// ```rust
29/// #[benchmark_stage(i=1_000,name=stage1)]
30/// fn test1() {
31///     // do something
32/// }
33/// ```
34///
35/// ```rust
36/// #[benchmark_stage(i=1_000,name=stage1,check)]
37/// fn test1() {
38///     File::create("/tmp/test123").is_ok()
39/// }
40/// ```
41///
42/// # Panics
43///
44/// Will panic on invalid options
45pub fn benchmark_stage(args: TokenStream, input: TokenStream) -> TokenStream {
46    let mut item: syn::Item = syn::parse(input).expect("Invalid input");
47    let mut args_iter = args.into_iter();
48    let mut opt_i: Option<u32> = None;
49    let mut opt_name: Option<String> = None;
50    let mut checked = false;
51    macro_rules! parse_opt {
52        ($c: block) => {{
53            let v = args_iter.next().expect(ERR_INVALID_OPTIONS);
54            if let TokenTree::Punct(c) = v {
55                if c.as_char() == '=' {
56                    $c
57                } else {
58                    panic_any(ERR_INVALID_OPTIONS);
59                }
60            } else {
61                panic_any(ERR_INVALID_OPTIONS);
62            }
63        }};
64    }
65    while let Some(v) = args_iter.next() {
66        if let TokenTree::Ident(i) = v {
67            let s = i.to_string();
68            match s.as_str() {
69                "i" => parse_opt!({
70                    if let TokenTree::Literal(v) =
71                        args_iter.next().expect("Option value not specified")
72                    {
73                        opt_i = Some(
74                            v.to_string()
75                                .replace('_', "")
76                                .parse()
77                                .expect("Invalid integer"),
78                        );
79                    } else {
80                        panic!("Invalid value for \"i\"");
81                    }
82                }),
83                "name" => parse_opt!({
84                    match args_iter.next().unwrap() {
85                        TokenTree::Literal(v) => opt_name = Some(v.to_string()),
86                        TokenTree::Ident(v) => opt_name = Some(v.to_string()),
87                        _ => panic!("Invalid value for \"name\""),
88                    }
89                }),
90                "check" => checked = true,
91                _ => panic!("Invalid parameter: {}", s),
92            }
93        }
94    }
95    let iterations = opt_i.expect("Iterations not specified");
96    let fn_item = match &mut item {
97        syn::Item::Fn(fn_item) => fn_item,
98        _ => panic!("expected fn"),
99    };
100    let mut name = opt_name.unwrap_or_else(|| {
101        let n = fn_item.sig.ident.to_string();
102        if n.starts_with("test_") {
103            n.strip_prefix("test_").unwrap().to_owned()
104        } else if n.starts_with("benchmark_") {
105            n.strip_prefix("benchmark_").unwrap().to_owned()
106        } else {
107            n
108        }
109    });
110    if name.starts_with('"') && name.ends_with('"') {
111        name = name[1..name.len() - 1].to_owned();
112    }
113    let fn_block = &fn_item.block;
114    if checked {
115        fn_item.block.stmts = vec![syn::parse(
116            quote!(bma_benchmark::staged_benchmark_check!(#name, #iterations, #fn_block);).into(),
117        )
118        .unwrap()];
119    } else {
120        fn_item.block.stmts = vec![syn::parse(
121            quote!(bma_benchmark::staged_benchmark!(#name, #iterations, #fn_block);).into(),
122        )
123        .unwrap()];
124    }
125    item.into_token_stream().into()
126}