metrics_fn_codegen/
lib.rs1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{quote_spanned, ToTokens};
4use syn::{parse_macro_input, AttributeArgs};
5
6mod call_type;
7mod function;
8mod return_type_classification;
9
10use function::*;
11
12use crate::return_type_classification::ReturnTypeClassification;
13
14#[proc_macro_attribute]
15pub fn dummy(_attr: TokenStream, item: TokenStream) -> TokenStream {
16 item
17}
18
19#[proc_macro_attribute]
20pub fn measure(attrs: TokenStream, item: TokenStream) -> TokenStream {
21 let span = proc_macro2::Span::call_site();
22
23 let attrs = parse_macro_input!(attrs as AttributeArgs);
24 if attrs.len() > 0 {
25 return syn::Error::new(span, "#[measure] does not take arguments.")
26 .to_compile_error()
27 .into();
28 }
29
30 let original_fn = parse_macro_input!(item as Function);
31 let wrapped_fn =
32 original_fn.rename(format!("{}__{}", original_fn.function.sig.ident.clone().to_string(), "wrapped").as_str());
33
34 let wrapped_attrs_tokens = original_fn.attributes_tokens();
35 let wrapped_call_tokens = wrapped_fn.call(span);
36 let wrapped_call_fn_name = original_fn.function.sig.ident.clone().to_string();
37 let wrapped_sig_tokens = wrapped_fn.function.sig.into_token_stream();
38 let wrapped_body_tokens = original_fn.function.block.clone().into_token_stream();
39 let wrapper_sig_tokens = original_fn.signature_full();
40 let record_call_tokens = build_record_call(span, original_fn, wrapped_call_fn_name);
41
42 let output = quote_spanned! { span =>
43 #wrapped_attrs_tokens
44 #wrapper_sig_tokens {
45
46 let start__ = std::time::Instant::now();
47 let output__ = #wrapped_call_tokens;
48 let end__ = std::time::Instant::now();
49
50 let module_name = module_path!();
51 #record_call_tokens
52
53 return output__;
54 }
55
56 #[allow(non_snake_case)]
57 #wrapped_sig_tokens
58 #wrapped_body_tokens
59 };
60
61 output.into()
62}
63
64fn build_record_call(span: Span, original_fn: Function, wrapped_call_fn_name: String) -> TokenStream2 {
65 match original_fn.return_type() {
66 ReturnTypeClassification::Result => {
67 quote_spanned! { span =>
68 let result__: core::result::Result<(), ()> = if output__.is_ok() {
69 Ok(())
70 } else {
71 Err(())
72 };
73 metrics_fn::record(module_name, #wrapped_call_fn_name, result__, end__.duration_since(start__).as_secs_f64());
74 }
75 },
76 _ => {
77 quote_spanned! { span =>
78 metrics_fn::record(module_name, #wrapped_call_fn_name, Ok(()), end__.duration_since(start__).as_secs_f64());
79 }
80 },
81 }
82}