canbench_rs_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::{quote, ToTokens};
3use syn::{parse_macro_input, AttributeArgs, ItemFn, NestedMeta, ReturnType};
4
5/// A macro for declaring a benchmark where only some part of the function is
6/// benchmarked.
7#[proc_macro_attribute]
8pub fn bench(arg_tokens: TokenStream, item: TokenStream) -> TokenStream {
9    // Parse the input as a function
10    let input = parse_macro_input!(item as ItemFn);
11
12    // Parse the attribute arguments
13    let args = parse_macro_input!(arg_tokens as AttributeArgs);
14
15    // Extract function name, inputs, and output
16    let func_name = &input.sig.ident;
17    let inputs = &input.sig.inputs;
18    let output = &input.sig.output;
19
20    // Check that there are no function arguments
21    if !inputs.is_empty() {
22        return syn::Error::new_spanned(inputs, "Benchmark should not take any arguments")
23            .to_compile_error()
24            .into();
25    }
26
27    // Prefix the benchmark name with "__canbench__".
28    // This is to inform that the `canbench` binary that this query is a benchmark
29    // that it should run.
30    let renamed_func_name =
31        syn::Ident::new(&format!("__canbench__{}", func_name), func_name.span());
32    let tracing_func_name = syn::Ident::new(&format!("__tracing__{}", func_name), func_name.span());
33
34    // Validate the argument and generate code accordingly
35    let expanded = match args.as_slice() {
36        [NestedMeta::Meta(meta)] if meta.path().is_ident("raw") => {
37            // If the argument is "raw", validate that the function returns BenchResult
38            if let ReturnType::Type(_, ty) = output {
39                if ty.to_token_stream().to_string() != quote!(BenchResult).to_string()
40                    && ty.to_token_stream().to_string()
41                        != quote!(canbench_rs::BenchResult).to_string()
42                {
43                    // If the return type is not BenchResult, generate a compile-time error
44                    return syn::Error::new_spanned(ty, "Raw benchmark should return BenchResult.")
45                        .to_compile_error()
46                        .into();
47                }
48            } else {
49                // If there is no return type, generate a compile-time error
50                return syn::Error::new_spanned(output, "Raw benchmark should return BenchResult.")
51                    .to_compile_error()
52                    .into();
53            }
54
55            quote! {
56                #input
57
58                #[ic_cdk::query]
59                #[allow(non_snake_case)]
60                fn #renamed_func_name() -> canbench_rs::BenchResult {
61                    #func_name()
62                }
63
64                #[ic_cdk::query]
65                #[allow(non_snake_case)]
66                fn #tracing_func_name(bench_instructions: u64) -> Result<Vec<(i32, i64)>, String> {
67                    #func_name();
68                    canbench_rs::get_traces(bench_instructions)
69                }
70            }
71        }
72        [] => {
73            // If there is no argument, validate that the function returns nothing
74            if let ReturnType::Type(_, ty) = &input.sig.output {
75                // If the return type is not empty, generate a compile-time error
76                return syn::Error::new_spanned(ty, "Benchmark should not return any values.")
77                    .to_compile_error()
78                    .into();
79            }
80
81            quote! {
82                #input
83
84                #[ic_cdk::query]
85                #[allow(non_snake_case)]
86                fn #renamed_func_name() -> canbench_rs::BenchResult {
87                    canbench_rs::bench_fn(|| {
88                        #func_name();
89                    })
90                }
91
92                #[ic_cdk::query]
93                #[allow(non_snake_case)]
94                fn #tracing_func_name(bench_instructions: u64) -> Result<Vec<(i32, i64)>, String> {
95                    canbench_rs::bench_fn(|| {
96                        #func_name();
97                    });
98                    canbench_rs::get_traces(bench_instructions)
99                }
100            }
101        }
102        _ => {
103            // If there is any other argument, generate a compile-time error
104            let args_tokens = args
105                .iter()
106                .map(|arg| quote!(#arg).to_token_stream())
107                .collect::<proc_macro2::TokenStream>();
108
109            return syn::Error::new_spanned(
110                args_tokens,
111                "Invalid argument. Use 'raw' or no argument.",
112            )
113            .to_compile_error()
114            .into();
115        }
116    };
117
118    TokenStream::from(expanded)
119}