hotpath_macros/
lib.rs

1use proc_macro::TokenStream;
2use quote::quote;
3use syn::parse::Parser;
4use syn::{parse_macro_input, ItemFn, LitInt, LitStr};
5
6#[derive(Clone, Copy)]
7enum Format {
8    Table,
9    Json,
10    JsonPretty,
11}
12
13impl Format {
14    fn to_tokens(self) -> proc_macro2::TokenStream {
15        match self {
16            Format::Table => quote!(hotpath::Format::Table),
17            Format::Json => quote!(hotpath::Format::Json),
18            Format::JsonPretty => quote!(hotpath::Format::JsonPretty),
19        }
20    }
21}
22
23#[proc_macro_attribute]
24pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
25    let input = parse_macro_input!(item as ItemFn);
26    let vis = &input.vis;
27    let sig = &input.sig;
28    let block = &input.block;
29
30    // Defaults
31    let mut percentiles: Vec<u8> = vec![95];
32    let mut format = Format::Table;
33
34    // Parse named args like: percentiles=[..], format=".."
35    if !attr.is_empty() {
36        let parser = syn::meta::parser(|meta| {
37            if meta.path.is_ident("percentiles") {
38                meta.input.parse::<syn::Token![=]>()?;
39                let content;
40                syn::bracketed!(content in meta.input);
41                let mut vals = Vec::new();
42                while !content.is_empty() {
43                    let li: LitInt = content.parse()?;
44                    let v: u8 = li.base10_parse()?;
45                    if !(0..=100).contains(&v) {
46                        return Err(
47                            meta.error(format!("Invalid percentile {} (must be 0..=100)", v))
48                        );
49                    }
50                    vals.push(v);
51                    if !content.is_empty() {
52                        content.parse::<syn::Token![,]>()?;
53                    }
54                }
55                if vals.is_empty() {
56                    return Err(meta.error("At least one percentile must be specified"));
57                }
58                percentiles = vals;
59                return Ok(());
60            }
61
62            if meta.path.is_ident("format") {
63                meta.input.parse::<syn::Token![=]>()?;
64                let lit: LitStr = meta.input.parse()?;
65                format =
66                    match lit.value().as_str() {
67                        "table" => Format::Table,
68                        "json" => Format::Json,
69                        "json-pretty" => Format::JsonPretty,
70                        other => return Err(meta.error(format!(
71                            "Unknown format {:?}. Expected one of: \"table\", \"json\", \"json-pretty\"",
72                            other
73                        ))),
74                    };
75                return Ok(());
76            }
77
78            Err(meta.error("Unknown parameter. Supported: percentiles=[..], format=\"..\""))
79        });
80
81        if let Err(e) = parser.parse2(proc_macro2::TokenStream::from(attr)) {
82            return e.to_compile_error().into();
83        }
84    }
85
86    let percentiles_array = quote! { &[#(#percentiles),*] };
87    let format_token = format.to_tokens();
88
89    let output = quote! {
90        #vis #sig {
91            let _hotpath = {
92                fn __caller_fn() {}
93                let caller_name = std::any::type_name_of_val(&__caller_fn)
94                    .strip_suffix("::__caller_fn")
95                    .unwrap_or(std::any::type_name_of_val(&__caller_fn))
96                    .replace("::{{closure}}", "");
97                hotpath::init(caller_name.to_string(), #percentiles_array, #format_token)
98            };
99
100            #block
101        }
102    };
103
104    output.into()
105}
106
107#[proc_macro_attribute]
108pub fn measure(_attr: TokenStream, item: TokenStream) -> TokenStream {
109    let input = parse_macro_input!(item as ItemFn);
110    let vis = &input.vis;
111    let sig = &input.sig;
112    let block = &input.block;
113
114    let name = sig.ident.to_string();
115    let asyncness = sig.asyncness.is_some();
116
117    let output = if asyncness {
118        quote! {
119            #vis #sig {
120                async {
121                    hotpath::cfg_if! {
122                        if #[cfg(feature = "hotpath-off")] {
123                            // No-op when hotpath-off is enabled
124                        } else if #[cfg(any(
125                            feature = "hotpath-alloc-bytes-total",
126                            feature = "hotpath-alloc-bytes-max",
127                            feature = "hotpath-alloc-count-total",
128                            feature = "hotpath-alloc-count-max"
129                        ))] {
130                            use hotpath::{Handle, RuntimeFlavor};
131                            let runtime_flavor = Handle::try_current().ok().map(|h| h.runtime_flavor());
132
133                            let _guard = match runtime_flavor {
134                                Some(RuntimeFlavor::CurrentThread) => {
135                                    hotpath::AllocGuardType::AllocGuard(hotpath::AllocGuard::new(concat!(module_path!(), "::", #name)))
136                                }
137                                _ => {
138                                    hotpath::AllocGuardType::NoopAsyncAllocGuard(hotpath::NoopAsyncAllocGuard::new(concat!(module_path!(), "::", #name)))
139                                }
140                            };
141                        } else {
142                            let _guard = hotpath::TimeGuard::new(concat!(module_path!(), "::", #name));
143                        }
144                    }
145
146                    #block
147                }.await
148            }
149        }
150    } else {
151        quote! {
152            #vis #sig {
153                hotpath::cfg_if! {
154                    if #[cfg(feature = "hotpath-off")] {
155                        // No-op when hotpath-off is enabled
156                    } else if #[cfg(any(
157                        feature = "hotpath-alloc-bytes-total",
158                        feature = "hotpath-alloc-bytes-max",
159                        feature = "hotpath-alloc-count-total",
160                        feature = "hotpath-alloc-count-max"
161                    ))] {
162                        let _guard = hotpath::AllocGuard::new(concat!(module_path!(), "::", #name));
163                    } else {
164                        let _guard = hotpath::TimeGuard::new(concat!(module_path!(), "::", #name));
165                    }
166                }
167
168                #block
169            }
170        }
171    };
172
173    output.into()
174}