use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Mutex, OnceLock};
use std::time::Instant;
use lazy_static::lazy_static;
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
parse,
parse2,
parse_macro_input,
parse_str,
Error,
Expr,
ExprLit,
Field,
Fields,
Ident,
Item,
ItemConst,
ItemEnum,
ItemExternCrate,
ItemFn,
ItemMod,
ItemStatic,
ItemStruct,
ItemTrait,
ItemTraitAlias,
ItemType,
ItemUnion,
ItemUse,
LitBool,
LitStr,
Meta,
Token,
TraitItem,
Visibility,
};
#[proc_macro_attribute]
pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream {
let version = parse_macro_input!(attr as syn::LitStr);
let item_trait = parse_macro_input!(input as ItemTrait);
let trait_name = &item_trait.ident;
let visibility = &item_trait.vis;
let versioned_methods = item_trait
.items
.iter()
.map(|item| {
if let TraitItem::Fn(method) = item {
let new_method = syn::TraitItemFn {
attrs: method
.attrs
.iter()
.filter(|attr| !matches!(attr.meta, Meta::NameValue(_)))
.map(|attr| {
let mut new_attr = attr.clone();
if attr.path().is_ident("method") {
let _ = attr.parse_nested_meta(|meta| {
if meta.path.is_ident("name") {
let value = meta.value()?;
let method_name: LitStr = value.parse()?;
let new_meta_str = format!(
"method(name = \"{}_{}\")",
version.value(),
method_name.value()
);
new_attr.meta = syn::parse_str::<Meta>(&new_meta_str)?;
}
Ok(())
});
}
new_attr
})
.collect::<Vec<_>>(),
sig: method.sig.clone(),
default: method.default.clone(),
semi_token: method.semi_token,
};
new_method.into()
} else {
item.clone()
}
})
.collect::<Vec<TraitItem>>();
let versioned_trait = syn::ItemTrait {
attrs: vec![syn::parse_quote!(#[rpc(server, client, namespace = "starknet")])],
vis: visibility.clone(),
unsafety: None,
auto_token: None,
ident: syn::Ident::new(&format!("{}{}", trait_name, version.value()), trait_name.span()),
colon_token: None,
supertraits: item_trait.supertraits.clone(),
brace_token: item_trait.brace_token,
items: versioned_methods,
restriction: item_trait.restriction.clone(),
generics: item_trait.generics.clone(),
trait_token: item_trait.trait_token,
};
versioned_trait.to_token_stream().into()
}
#[proc_macro_attribute]
pub fn latency_histogram(attr: TokenStream, input: TokenStream) -> TokenStream {
let (metric_name, control_with_config, input_fn) = parse_latency_histogram_attributes::<ExprLit>(
attr,
input,
"Expecting a string literal for metric name",
);
let metric_recording_logic = quote! {
::metrics::histogram!(#metric_name).record(exec_time);
};
let collect_metric_flag = quote! {
papyrus_common::metrics::COLLECT_PROFILING_METRICS
};
create_modified_function(
control_with_config,
input_fn,
metric_recording_logic,
collect_metric_flag,
)
}
#[proc_macro_attribute]
pub fn sequencer_latency_histogram(attr: TokenStream, input: TokenStream) -> TokenStream {
let (metric_name, control_with_config, input_fn) = parse_latency_histogram_attributes::<Ident>(
attr,
input,
"Expecting an identifier for metric name",
);
let metric_recording_logic = quote! {
#metric_name.record(exec_time);
};
let collect_metric_flag = quote! {
apollo_metrics::metrics::COLLECT_SEQUENCER_PROFILING_METRICS
};
create_modified_function(
control_with_config,
input_fn,
metric_recording_logic,
collect_metric_flag,
)
}
fn parse_latency_histogram_attributes<T: Parse>(
attr: TokenStream,
input: TokenStream,
err_msg: &str,
) -> (T, LitBool, ItemFn) {
let binding = attr.to_string();
let parts: Vec<&str> = binding.split(',').collect();
let metric_name_string = parts
.first()
.expect("attribute should include metric name and control with config boolean")
.trim()
.to_string();
let control_with_config_string = parts
.get(1)
.expect("attribute should include metric name and control with config boolean")
.trim()
.to_string();
let control_with_config = parse_str::<LitBool>(&control_with_config_string)
.expect("Expecting a boolean value for control with config");
let metric_name = parse_str::<T>(&metric_name_string).expect(err_msg);
let input_fn = parse::<ItemFn>(input).expect("Failed to parse input as ItemFn");
(metric_name, control_with_config, input_fn)
}
fn create_modified_function(
control_with_config: LitBool,
input_fn: ItemFn,
metric_recording_logic: impl ToTokens,
collect_metric_flag: impl ToTokens,
) -> TokenStream {
let origin_block = &input_fn.block;
let expanded_block = quote! {
{
let mut start_function_time = None;
if !#control_with_config || (#control_with_config && *(#collect_metric_flag.get().unwrap_or(&false))) {
start_function_time = Some(std::time::Instant::now());
}
let return_value = #origin_block;
if let Some(start_time) = start_function_time {
let exec_time = start_time.elapsed().as_secs_f64();
#metric_recording_logic
}
return_value
}
};
let modified_function = ItemFn {
block: parse2(expanded_block).expect("Parse tokens in latency_histogram attribute."),
..input_fn
};
modified_function.to_token_stream().into()
}
fn get_uniq_identifier_for_call_site(identifier_prefix: &str) -> Ident {
let span = proc_macro::Span::call_site();
let span_str = format!("{span:?}");
let mut hasher = DefaultHasher::new();
span_str.hash(&mut hasher);
let hash_id = format!("{:x}", hasher.finish()); let ident_str = format!("__{identifier_prefix}_{hash_id}");
Ident::new(&ident_str, proc_macro2::Span::call_site())
}
struct LogEveryNMacroInput {
log_macro: syn::Path,
n: Expr,
args: Punctuated<Expr, Token![,]>,
}
impl Parse for LogEveryNMacroInput {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let log_macro: syn::Path = input.parse()?;
input.parse::<Token![,]>()?;
let n: Expr = input.parse()?;
input.parse::<Token![,]>()?;
let args: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;
Ok(LogEveryNMacroInput { log_macro, n, args })
}
}
#[proc_macro]
pub fn log_every_n(input: TokenStream) -> TokenStream {
let LogEveryNMacroInput { log_macro, n, args, .. } =
parse_macro_input!(input as LogEveryNMacroInput);
let span = proc_macro::Span::call_site();
let span_str = format!("{span:?}");
let mut hasher = DefaultHasher::new();
span_str.hash(&mut hasher);
let hash_id = format!("{:x}", hasher.finish()); let ident_str = format!("__TRACING_COUNT_{hash_id}");
let ident = Ident::new(&ident_str, proc_macro2::Span::call_site());
let args = args.into_iter().collect::<Vec<_>>();
let expanded = quote! {
{
static #ident: ::std::sync::OnceLock<::std::sync::atomic::AtomicUsize> = ::std::sync::OnceLock::new();
let counter = #ident.get_or_init(|| ::std::sync::atomic::AtomicUsize::new(0));
let current_count = counter.fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);
if current_count.is_multiple_of(#n) {
#log_macro!(#(#args),*);
}
}
};
TokenStream::from(expanded)
}
struct LogEveryNSecMacroInput {
log_macro: syn::Path,
n: Expr,
args: Punctuated<Expr, Token![,]>,
}
impl Parse for LogEveryNSecMacroInput {
fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
let log_macro: syn::Path = input.parse()?;
input.parse::<Token![,]>()?;
let n: Expr = input.parse()?;
input.parse::<Token![,]>()?;
let args: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;
Ok(LogEveryNSecMacroInput { log_macro, n, args })
}
}
lazy_static! {
static ref LOG_EVERY_N_MS_CLOCK_START: Instant = Instant::now();
}
#[proc_macro]
pub fn log_every_n_ms(input: TokenStream) -> TokenStream {
let LogEveryNSecMacroInput { log_macro, n, args, .. } =
parse_macro_input!(input as LogEveryNSecMacroInput);
let ident_last_log_time = get_uniq_identifier_for_call_site("TRACING_LAST_LOG_TIME");
let ident_start_time = get_uniq_identifier_for_call_site("TRACING_START_TIME");
let args = args.into_iter().collect::<Vec<_>>();
let expanded = quote! {
{
static #ident_start_time: ::std::sync::OnceLock<::std::time::Instant> = ::std::sync::OnceLock::new();
static #ident_last_log_time: ::std::sync::OnceLock<::std::sync::atomic::AtomicU64> = ::std::sync::OnceLock::new();
let last_log_u64 = #ident_last_log_time.get_or_init(|| ::std::sync::atomic::AtomicU64::new(0));
match last_log_u64.fetch_update(
::std::sync::atomic::Ordering::Relaxed,
::std::sync::atomic::Ordering::Relaxed,
|curr_val : u64| {
let now_with_zero : u64 = #ident_start_time.get_or_init(|| ::std::time::Instant::now())
.elapsed().as_millis().try_into()
.expect("Timestamp in millis is larger than u64::MAX");
let now : u64 = now_with_zero + 1;
if curr_val == 0 {
return Some(now);
}
if curr_val + (#n) <= now {
return Some(now);
}
None
}
) {
Ok(old_now) => {
#log_macro!(#(#args),*);
}
Err(_) => {
}
};
}
};
TokenStream::from(expanded)
}
static NEXT: AtomicU16 = AtomicU16::new(0);
static MAP: OnceLock<Mutex<HashMap<String, u16>>> = OnceLock::new();
fn alloc_for(key: String) -> u16 {
let map = MAP.get_or_init(|| Mutex::new(HashMap::new()));
let mut map = map.lock().unwrap();
if let Some(&id) = map.get(&key) {
return id;
}
let id = NEXT.fetch_add(1, Ordering::Relaxed);
if id == u16::MAX {
panic!("unique_u16 exhausted: > 65536 unique callsites in this crate");
}
map.insert(key, id);
id
}
#[proc_macro]
pub fn unique_u16(_input: TokenStream) -> TokenStream {
let span = proc_macro::Span::call_site();
let file = span.file();
let line = span.line(); let col = span.column();
let key = format!("{file}:{line}:{col}");
let id = alloc_for(key);
let lit = proc_macro::Literal::u16_suffixed(id);
TokenStream::from(proc_macro::TokenTree::Literal(lit))
}
#[proc_macro_attribute]
pub fn make_visibility(attrs: TokenStream, input: TokenStream) -> TokenStream {
let visibility: Visibility = parse_macro_input!(attrs);
let mut input: Item = parse_macro_input!(input);
match input {
Item::Const(ItemConst { ref mut vis, .. })
| Item::Enum(ItemEnum { ref mut vis, .. })
| Item::ExternCrate(ItemExternCrate { ref mut vis, .. })
| Item::Fn(ItemFn { ref mut vis, .. })
| Item::Mod(ItemMod { ref mut vis, .. })
| Item::Static(ItemStatic { ref mut vis, .. })
| Item::Struct(ItemStruct { ref mut vis, .. })
| Item::Trait(ItemTrait { ref mut vis, .. })
| Item::TraitAlias(ItemTraitAlias { ref mut vis, .. })
| Item::Type(ItemType { ref mut vis, .. })
| Item::Union(ItemUnion { ref mut vis, .. })
| Item::Use(ItemUse { ref mut vis, .. }) => *vis = visibility,
_ => {
return Error::new_spanned(&input, "Cannot override the `#[visibility]` of this item")
.to_compile_error()
.into();
}
}
input.into_token_stream().into()
}
#[proc_macro_attribute]
pub fn upgrade_fields_visibility(attrs: TokenStream, input: TokenStream) -> TokenStream {
let target_visibility: Visibility = parse_macro_input!(attrs);
let mut input: Item = parse_macro_input!(input);
let Item::Struct(ItemStruct { ref mut fields, .. }) = &mut input else {
return Error::new_spanned(
&input,
"`upgrade_fields_visibility` can only be applied to structs",
)
.to_compile_error()
.into();
};
match fields {
Fields::Named(named_fields) => {
named_fields
.named
.iter_mut()
.for_each(|field| upgrade_field_visibility(field, &target_visibility));
}
Fields::Unnamed(unnamed_fields) => {
unnamed_fields
.unnamed
.iter_mut()
.for_each(|field| upgrade_field_visibility(field, &target_visibility));
}
Fields::Unit => {
return Error::new_spanned(
&input,
"`upgrade_fields_visibility` can only be applied to structs with fields",
)
.to_compile_error()
.into();
}
}
input.into_token_stream().into()
}
fn upgrade_field_visibility(field: &mut Field, target_visibility: &Visibility) {
if visibility_level(&field.vis) < visibility_level(target_visibility) {
field.vis = target_visibility.clone();
}
}
fn visibility_level(vis: &Visibility) -> u8 {
match vis {
Visibility::Inherited => 0,
Visibility::Restricted(_) => 1, Visibility::Public(_) => 2,
}
}