metrics_fn_codegen/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::{Span, TokenStream as TokenStream2};
3use quote::{quote_spanned, ToTokens};
4use syn::{parse_macro_input, AttributeArgs};
5
6mod call_type;
7mod function;
8mod return_type_classification;
9
10use function::*;
11
12use crate::return_type_classification::ReturnTypeClassification;
13
14#[proc_macro_attribute]
15pub fn dummy(_attr: TokenStream, item: TokenStream) -> TokenStream {
16	item
17}
18
19#[proc_macro_attribute]
20pub fn measure(attrs: TokenStream, item: TokenStream) -> TokenStream {
21	let span = proc_macro2::Span::call_site();
22
23	let attrs = parse_macro_input!(attrs as AttributeArgs);
24	if attrs.len() > 0 {
25		return syn::Error::new(span, "#[measure] does not take arguments.")
26			.to_compile_error()
27			.into();
28	}
29
30	let original_fn = parse_macro_input!(item as Function);
31	let wrapped_fn =
32		original_fn.rename(format!("{}__{}", original_fn.function.sig.ident.clone().to_string(), "wrapped").as_str());
33
34	let wrapped_attrs_tokens = original_fn.attributes_tokens();
35	let wrapped_call_tokens = wrapped_fn.call(span);
36	let wrapped_call_fn_name = original_fn.function.sig.ident.clone().to_string();
37	let wrapped_sig_tokens = wrapped_fn.function.sig.into_token_stream();
38	let wrapped_body_tokens = original_fn.function.block.clone().into_token_stream();
39	let wrapper_sig_tokens = original_fn.signature_full();
40	let record_call_tokens = build_record_call(span, original_fn, wrapped_call_fn_name);
41
42	let output = quote_spanned! { span =>
43		#wrapped_attrs_tokens
44		#wrapper_sig_tokens {
45
46			let start__ = std::time::Instant::now();
47			let output__ = #wrapped_call_tokens;
48			let end__ = std::time::Instant::now();
49
50			let module_name = module_path!();
51			#record_call_tokens
52
53			return output__;
54		}
55
56		#[allow(non_snake_case)]
57		#wrapped_sig_tokens
58		#wrapped_body_tokens
59	};
60
61	output.into()
62}
63
64fn build_record_call(span: Span, original_fn: Function, wrapped_call_fn_name: String) -> TokenStream2 {
65	match original_fn.return_type() {
66		ReturnTypeClassification::Result => {
67			quote_spanned! { span =>
68				let result__: core::result::Result<(), ()> = if output__.is_ok() {
69					Ok(())
70				} else {
71					Err(())
72				};
73				metrics_fn::record(module_name, #wrapped_call_fn_name, result__, end__.duration_since(start__).as_secs_f64());
74			}
75		},
76		_ => {
77			quote_spanned! { span =>
78				metrics_fn::record(module_name, #wrapped_call_fn_name, Ok(()), end__.duration_since(start__).as_secs_f64());
79			}
80		},
81	}
82}