use proc_macro::TokenStream;
use quote::quote;
use syn::parse::Parser;
use syn::{parse_macro_input, ImplItem, Item, ItemFn, LitInt, LitStr};
#[derive(Clone, Copy)]
pub(crate) enum Format {
Table,
Json,
JsonPretty,
None,
}
impl Format {
pub(crate) fn to_tokens(self) -> proc_macro2::TokenStream {
match self {
Format::Table => quote!(hotpath::Format::Table),
Format::Json => quote!(hotpath::Format::Json),
Format::JsonPretty => quote!(hotpath::Format::JsonPretty),
Format::None => quote!(hotpath::Format::None),
}
}
}
pub fn main_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let mut percentiles: Vec<u8> = vec![95];
let mut format = Format::Table;
let mut functions_limit: usize = 15;
let mut output_path: Option<String> = None;
let mut report_sections: Option<String> = None;
if !attr.is_empty() {
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("percentiles") {
meta.input.parse::<syn::Token![=]>()?;
let content;
syn::bracketed!(content in meta.input);
let mut vals = Vec::new();
while !content.is_empty() {
let li: LitInt = content.parse()?;
let v: u8 = li.base10_parse()?;
if !(0..=100).contains(&v) {
return Err(
meta.error(format!("Invalid percentile {} (must be 0..=100)", v))
);
}
vals.push(v);
if !content.is_empty() {
content.parse::<syn::Token![,]>()?;
}
}
if vals.is_empty() {
return Err(meta.error("At least one percentile must be specified"));
}
percentiles = vals;
return Ok(());
}
if meta.path.is_ident("format") {
meta.input.parse::<syn::Token![=]>()?;
let lit: LitStr = meta.input.parse()?;
format =
match lit.value().as_str() {
"table" => Format::Table,
"json" => Format::Json,
"json-pretty" => Format::JsonPretty,
"none" => Format::None,
other => return Err(meta.error(format!(
"Unknown format {:?}. Expected one of: \"table\", \"json\", \"json-pretty\", \"none\"",
other
))),
};
return Ok(());
}
if meta.path.is_ident("limit") {
meta.input.parse::<syn::Token![=]>()?;
let li: LitInt = meta.input.parse()?;
functions_limit = li.base10_parse()?;
return Ok(());
}
if meta.path.is_ident("output_path") {
meta.input.parse::<syn::Token![=]>()?;
let lit: LitStr = meta.input.parse()?;
output_path = Some(lit.value());
return Ok(());
}
if meta.path.is_ident("report") {
meta.input.parse::<syn::Token![=]>()?;
let lit: LitStr = meta.input.parse()?;
report_sections = Some(lit.value());
return Ok(());
}
Err(meta.error(
"Unknown parameter. Supported: percentiles=[..], format=\"..\", limit=N, output_path=\"..\", report=\"..\"",
))
});
if let Err(e) = parser.parse2(proc_macro2::TokenStream::from(attr)) {
return e.to_compile_error().into();
}
}
let percentiles_array = quote! { &[#(#percentiles),*] };
let format_token = format.to_tokens();
let asyncness = sig.asyncness.is_some();
let fn_name = &sig.ident;
let output_path_call = match &output_path {
Some(path) => quote! { .output_path(#path) },
None => quote! {},
};
let sections_call = if let Some(ref report_str) = report_sections {
let section_tokens: Vec<proc_macro2::TokenStream> = report_str
.split(',')
.filter_map(|s| {
let s = s.trim();
match s {
"functions-timing" => Some(quote! { hotpath::Section::FunctionsTiming }),
"functions-alloc" => Some(quote! { hotpath::Section::FunctionsAlloc }),
"channels" => Some(quote! { hotpath::Section::Channels }),
"streams" => Some(quote! { hotpath::Section::Streams }),
"futures" => Some(quote! { hotpath::Section::Futures }),
"threads" => Some(quote! { hotpath::Section::Threads }),
"all" => None, _ => None,
}
})
.collect();
if report_str.split(',').any(|s| s.trim() == "all") {
quote! { .with_sections(hotpath::Section::all()) }
} else if !section_tokens.is_empty() {
quote! { .with_sections(vec![#(#section_tokens),*]) }
} else {
quote! {}
}
} else {
quote! {}
};
let caller_name_init = quote! {
let caller_name: &'static str =
concat!(module_path!(), "::", stringify!(#fn_name));
};
let builder_chain = quote! {
hotpath::HotpathGuardBuilder::new(caller_name)
.percentiles(#percentiles_array)
.with_functions_limit(#functions_limit)
.format(#format_token)
#output_path_call
#sections_call
};
let guard_init = quote! {
let _hotpath: Option<hotpath::HotpathGuard> = {
#caller_name_init
let builder = #builder_chain;
match std::env::var("HOTPATH_SHUTDOWN_MS").ok().and_then(|v| v.parse::<u64>().ok()) {
Some(ms) => {
builder.build_with_shutdown(std::time::Duration::from_millis(ms));
None
}
None => Some(builder.build()),
}
};
};
let body = quote! {
#guard_init
#block
};
let wrapped_body = if asyncness {
quote! { async { #body }.await }
} else {
body
};
let output = quote! {
#vis #sig {
#wrapped_body
}
};
output.into()
}
pub fn measure_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
let fn_ident = &sig.ident;
let is_async_fn = sig.asyncness.is_some();
let mut enable_result_logging = false;
let mut enable_future_tracking = false;
if !attr.is_empty() {
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("log") {
meta.input.parse::<syn::Token![=]>()?;
let lit: syn::LitBool = meta.input.parse()?;
enable_result_logging = lit.value();
return Ok(());
}
if meta.path.is_ident("future") {
meta.input.parse::<syn::Token![=]>()?;
let lit: syn::LitBool = meta.input.parse()?;
enable_future_tracking = lit.value();
return Ok(());
}
Err(meta.error("Unknown parameter. Supported: log = true, future = true"))
});
if let Err(e) = parser.parse2(proc_macro2::TokenStream::from(attr)) {
return e.to_compile_error().into();
}
}
if enable_future_tracking && !is_async_fn {
return syn::Error::new_spanned(
sig.fn_token,
"future = true can only be used on async functions",
)
.to_compile_error()
.into();
}
let measurement_loc = quote! { concat!(module_path!(), "::", stringify!(#fn_ident)) };
let wrapped_body = if !is_async_fn {
if enable_result_logging {
quote! {
hotpath::functions::measure_sync_log(#measurement_loc, || #block)
}
} else {
quote! {
let _guard = hotpath::functions::build_measurement_guard_sync(#measurement_loc, false);
#block
}
}
} else if enable_future_tracking {
if enable_result_logging {
quote! {
hotpath::functions::measure_async_future_log(#measurement_loc, async #block).await
}
} else {
quote! {
hotpath::functions::measure_async_future(#measurement_loc, async #block).await
}
}
} else if enable_result_logging {
quote! {
hotpath::functions::measure_async_log(#measurement_loc, async #block).await
}
} else {
quote! {
hotpath::functions::measure_async(#measurement_loc, async #block).await
}
};
let output = quote! {
#(#attrs)*
#vis #sig {
#wrapped_body
}
};
output.into()
}
pub fn future_fn_impl(attr: TokenStream, item: TokenStream) -> TokenStream {
let input = parse_macro_input!(item as ItemFn);
let attrs = &input.attrs;
let vis = &input.vis;
let sig = &input.sig;
let block = &input.block;
if sig.asyncness.is_none() {
return syn::Error::new_spanned(
sig.fn_token,
"The #[future_fn] attribute can only be applied to async functions",
)
.to_compile_error()
.into();
}
let mut log_result = false;
if !attr.is_empty() {
let parser = syn::meta::parser(|meta| {
if meta.path.is_ident("log") {
meta.input.parse::<syn::Token![=]>()?;
let lit: syn::LitBool = meta.input.parse()?;
log_result = lit.value();
return Ok(());
}
Err(meta.error("Unknown parameter. Supported: log = true"))
});
if let Err(e) = parser.parse2(proc_macro2::TokenStream::from(attr)) {
return e.to_compile_error().into();
}
}
let fn_name = &sig.ident;
let wrapped_body = if log_result {
quote! {
{
const FUTURE_LOC: &'static str = concat!(module_path!(), "::", stringify!(#fn_name));
hotpath::futures::init_futures_state();
hotpath::InstrumentFutureLog::instrument_future_log(
async #block,
FUTURE_LOC,
None
).await
}
}
} else {
quote! {
{
const FUTURE_LOC: &'static str = concat!(module_path!(), "::", stringify!(#fn_name));
hotpath::futures::init_futures_state();
hotpath::InstrumentFuture::instrument_future(
async #block,
FUTURE_LOC,
None
).await
}
}
};
let output = quote! {
#(#attrs)*
#vis #sig {
#wrapped_body
}
};
output.into()
}
pub fn skip_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
item
}
pub fn measure_all_impl(_attr: TokenStream, item: TokenStream) -> TokenStream {
let parsed_item = parse_macro_input!(item as Item);
match parsed_item {
Item::Mod(mut module) => {
if let Some((_brace, items)) = &mut module.content {
for it in items.iter_mut() {
if let Item::Fn(func) = it {
if !has_hotpath_skip_or_measure(&func.attrs) {
let func_tokens = TokenStream::from(quote!(#func));
let transformed = measure_impl(TokenStream::new(), func_tokens);
*func = syn::parse_macro_input!(transformed as ItemFn);
}
}
}
}
TokenStream::from(quote!(#module))
}
Item::Impl(mut impl_block) => {
for item in impl_block.items.iter_mut() {
if let ImplItem::Fn(method) = item {
if !has_hotpath_skip_or_measure(&method.attrs) {
let func_tokens = TokenStream::from(quote!(#method));
let transformed = measure_impl(TokenStream::new(), func_tokens);
*method = syn::parse_macro_input!(transformed as syn::ImplItemFn);
}
}
}
TokenStream::from(quote!(#impl_block))
}
_ => panic!("measure_all can only be applied to modules or impl blocks"),
}
}
fn has_hotpath_skip_or_measure(attrs: &[syn::Attribute]) -> bool {
attrs.iter().any(|attr| {
let path = attr.path();
if path.segments.len() == 2
&& path.segments[0].ident == "hotpath"
&& path.segments[1].ident == "skip"
{
return true;
}
if path.segments.len() == 2
&& path.segments[0].ident == "hotpath"
&& path.segments[1].ident == "measure"
{
return true;
}
if path.is_ident("cfg_attr") {
let attr_str = quote!(#attr).to_string();
if attr_str.contains("hotpath")
&& (attr_str.contains("skip") || attr_str.contains("measure"))
{
return true;
}
}
false
})
}