apollo_proc_macros 0.18.0-dev.2

Procedural macros for the Papyrus node
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
use std::collections::HashMap;
use std::hash::{DefaultHasher, Hash, Hasher};
use std::sync::atomic::{AtomicU16, Ordering};
use std::sync::{Mutex, OnceLock};
use std::time::Instant;

use lazy_static::lazy_static;
use proc_macro::TokenStream;
use quote::{quote, ToTokens};
use syn::parse::{Parse, ParseStream};
use syn::punctuated::Punctuated;
use syn::{
    parse,
    parse2,
    parse_macro_input,
    parse_str,
    Error,
    Expr,
    ExprLit,
    Field,
    Fields,
    Ident,
    Item,
    ItemConst,
    ItemEnum,
    ItemExternCrate,
    ItemFn,
    ItemMod,
    ItemStatic,
    ItemStruct,
    ItemTrait,
    ItemTraitAlias,
    ItemType,
    ItemUnion,
    ItemUse,
    LitBool,
    LitStr,
    Meta,
    Token,
    TraitItem,
    Visibility,
};

/// This macro is a wrapper around the "rpc" macro supplied by the jsonrpsee library that generates
/// a server and client traits from a given trait definition. The wrapper gets a version id and
/// prepend the version id to the trait name and to every method name (note method name refers to
/// the name the API has for the function not the actual function name). We need this in order to be
/// able to merge multiple versions of jsonrpc APIs into one server and not have a clash in method
/// resolution.
///
/// # Example:
///
/// Given this code:
/// ```rust,ignore
/// #[versioned_rpc("V0_6_0")]
/// pub trait JsonRpc {
///     #[method(name = "blockNumber")]
///     fn block_number(&self) -> Result<BlockNumber, Error>;
/// }
/// ```
///
/// The macro will generate this code:
/// ```rust,ignore
/// #[rpc(server, client, namespace = "starknet")]
/// pub trait JsonRpcV0_6_0 {
///     #[method(name = "V0_6_0_blockNumber")]
///     fn block_number(&self) -> Result<BlockNumber, Error>;
/// }
/// ```
#[proc_macro_attribute]
pub fn versioned_rpc(attr: TokenStream, input: TokenStream) -> TokenStream {
    let version = parse_macro_input!(attr as syn::LitStr);
    let item_trait = parse_macro_input!(input as ItemTrait);

    let trait_name = &item_trait.ident;
    let visibility = &item_trait.vis;

    // generate the new method signatures with the version prefix
    let versioned_methods = item_trait
        .items
        .iter()
        .map(|item| {
            if let TraitItem::Fn(method) = item {
                let new_method = syn::TraitItemFn {
                    attrs: method
                        .attrs
                        .iter()
                        .filter(|attr| !matches!(attr.meta, Meta::NameValue(_)))
                        .map(|attr| {
                            let mut new_attr = attr.clone();
                            if attr.path().is_ident("method") {
                                let _ = attr.parse_nested_meta(|meta| {
                                    if meta.path.is_ident("name") {
                                        let value = meta.value()?;
                                        let method_name: LitStr = value.parse()?;
                                        let new_meta_str = format!(
                                            "method(name = \"{}_{}\")",
                                            version.value(),
                                            method_name.value()
                                        );
                                        new_attr.meta = syn::parse_str::<Meta>(&new_meta_str)?;
                                    }
                                    Ok(())
                                });
                            }
                            new_attr
                        })
                        .collect::<Vec<_>>(),
                    sig: method.sig.clone(),
                    default: method.default.clone(),
                    semi_token: method.semi_token,
                };
                new_method.into()
            } else {
                item.clone()
            }
        })
        .collect::<Vec<TraitItem>>();

    // generate the versioned trait with the new method signatures
    let versioned_trait = syn::ItemTrait {
        attrs: vec![syn::parse_quote!(#[rpc(server, client, namespace = "starknet")])],
        vis: visibility.clone(),
        unsafety: None,
        auto_token: None,
        ident: syn::Ident::new(&format!("{}{}", trait_name, version.value()), trait_name.span()),
        colon_token: None,
        supertraits: item_trait.supertraits.clone(),
        brace_token: item_trait.brace_token,
        items: versioned_methods,
        restriction: item_trait.restriction.clone(),
        generics: item_trait.generics.clone(),
        trait_token: item_trait.trait_token,
    };

    versioned_trait.to_token_stream().into()
}

/// This macro will emit a histogram metric with the given name and the latency of the function.
/// In addition, also a debug log with the metric name and the execution time will be emitted.
/// The macro also receives a boolean for whether it will be emitted only when
/// profiling is activated or at all times.
///
/// # Example
/// Given this code:
///
/// ```rust,ignore
/// #[latency_histogram("metric_name", false)]
/// fn foo() {
///     // Some code ...
/// }
/// ```
/// Every call to foo will update the histogram metric with the name “metric_name” with the time it
/// took to execute foo. In addition, a debug log with the following format will be emitted:
/// “<metric_name>: <execution_time>”
/// The metric will be emitted regardless of the value of the profiling configuration,
/// since the config value is false.
#[proc_macro_attribute]
pub fn latency_histogram(attr: TokenStream, input: TokenStream) -> TokenStream {
    let (metric_name, control_with_config, input_fn) = parse_latency_histogram_attributes::<ExprLit>(
        attr,
        input,
        "Expecting a string literal for metric name",
    );

    // TODO(DanB): consider naming the input value instead of providing a bool
    // TODO(DanB): consider adding support for metrics levels (e.g. debug, info, warn, error)
    // instead of boolean

    let metric_recording_logic = quote! {
        ::metrics::histogram!(#metric_name).record(exec_time);
    };

    let collect_metric_flag = quote! {
        papyrus_common::metrics::COLLECT_PROFILING_METRICS
    };

    create_modified_function(
        control_with_config,
        input_fn,
        metric_recording_logic,
        collect_metric_flag,
    )
}

/// This macro will emit a histogram metric with the given name and the latency of the function.
/// In addition, also a debug log with the metric name and the execution time will be emitted.
/// The macro also receives a boolean for whether it will be emitted only when
/// profiling is activated or at all times.
///
/// # Example
/// Given this code:
///
/// ```rust,ignore
/// use apollo_metrics::metrics::{MetricHistogram, MetricScope};
///
/// const FOO_HISTOGRAM_METRIC: MetricHistogram = MetricHistogram::new(
///     MetricScope::Infra,
///     "foo_histogram_metric",
///     "foo function latency histogram metrics",
/// );
///
/// #[sequencer_latency_histogram(FOO_HISTOGRAM_METRIC, false)]
/// fn foo() {
///     // Some code ...
/// }
/// ```
/// Every call to foo will update the histogram metric FOO_HISTOGRAM_METRIC with the time it
/// took to execute foo. In addition, a debug log with the following format will be emitted:
/// “<metric_name>: <execution_time>”
/// The metric will be emitted regardless of the value of the profiling configuration,
/// since the config value is false.
#[proc_macro_attribute]
pub fn sequencer_latency_histogram(attr: TokenStream, input: TokenStream) -> TokenStream {
    let (metric_name, control_with_config, input_fn) = parse_latency_histogram_attributes::<Ident>(
        attr,
        input,
        "Expecting an identifier for metric name",
    );

    let metric_recording_logic = quote! {
        #metric_name.record(exec_time);
    };

    let collect_metric_flag = quote! {
        apollo_metrics::metrics::COLLECT_SEQUENCER_PROFILING_METRICS
    };

    create_modified_function(
        control_with_config,
        input_fn,
        metric_recording_logic,
        collect_metric_flag,
    )
}

/// Helper function to parse the attributes and input for the latency histogram macros.
fn parse_latency_histogram_attributes<T: Parse>(
    attr: TokenStream,
    input: TokenStream,
    err_msg: &str,
) -> (T, LitBool, ItemFn) {
    let binding = attr.to_string();
    let parts: Vec<&str> = binding.split(',').collect();
    let metric_name_string = parts
        .first()
        .expect("attribute should include metric name and control with config boolean")
        .trim()
        .to_string();
    let control_with_config_string = parts
        .get(1)
        .expect("attribute should include metric name and control with config boolean")
        .trim()
        .to_string();

    let control_with_config = parse_str::<LitBool>(&control_with_config_string)
        .expect("Expecting a boolean value for control with config");
    let metric_name = parse_str::<T>(&metric_name_string).expect(err_msg);

    let input_fn = parse::<ItemFn>(input).expect("Failed to parse input as ItemFn");

    (metric_name, control_with_config, input_fn)
}

/// Helper function to create the expanded block and modified function.
fn create_modified_function(
    control_with_config: LitBool,
    input_fn: ItemFn,
    metric_recording_logic: impl ToTokens,
    collect_metric_flag: impl ToTokens,
) -> TokenStream {
    // Create a new block with the metric update.
    let origin_block = &input_fn.block;
    let expanded_block = quote! {
        {
            let mut start_function_time = None;
            if !#control_with_config || (#control_with_config && *(#collect_metric_flag.get().unwrap_or(&false))) {
                start_function_time = Some(std::time::Instant::now());
            }
            let return_value = #origin_block;
            if let Some(start_time) = start_function_time {
                let exec_time = start_time.elapsed().as_secs_f64();
                #metric_recording_logic
            }
            return_value
        }
    };

    // Create a new function with the modified block.
    let modified_function = ItemFn {
        block: parse2(expanded_block).expect("Parse tokens in latency_histogram attribute."),
        ..input_fn
    };

    modified_function.to_token_stream().into()
}

fn get_uniq_identifier_for_call_site(identifier_prefix: &str) -> Ident {
    // Use call site span for uniqueness
    let span = proc_macro::Span::call_site();
    let span_str = format!("{span:?}");

    let mut hasher = DefaultHasher::new();
    span_str.hash(&mut hasher);

    let hash_id = format!("{:x}", hasher.finish()); // short identifier
    let ident_str = format!("__{identifier_prefix}_{hash_id}");
    Ident::new(&ident_str, proc_macro2::Span::call_site())
}

struct LogEveryNMacroInput {
    log_macro: syn::Path,
    n: Expr,
    args: Punctuated<Expr, Token![,]>,
}

impl Parse for LogEveryNMacroInput {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let log_macro: syn::Path = input.parse()?;
        input.parse::<Token![,]>()?;
        let n: Expr = input.parse()?;
        input.parse::<Token![,]>()?;
        let args: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;

        Ok(LogEveryNMacroInput { log_macro, n, args })
    }
}

/// An internal helper macro for logging a message every `n` calls to the macro.
/// Do not use this directly. Instead use the `info_every_n!`, `debug_every_n!`, etc. macros.
#[proc_macro]
pub fn log_every_n(input: TokenStream) -> TokenStream {
    let LogEveryNMacroInput { log_macro, n, args, .. } =
        parse_macro_input!(input as LogEveryNMacroInput);

    // Use call site span for uniqueness
    let span = proc_macro::Span::call_site();
    let span_str = format!("{span:?}");

    let mut hasher = DefaultHasher::new();
    span_str.hash(&mut hasher);

    let hash_id = format!("{:x}", hasher.finish()); // short identifier
    let ident_str = format!("__TRACING_COUNT_{hash_id}");
    let ident = Ident::new(&ident_str, proc_macro2::Span::call_site());

    let args = args.into_iter().collect::<Vec<_>>();

    let expanded = quote! {
        {
            static #ident: ::std::sync::OnceLock<::std::sync::atomic::AtomicUsize> = ::std::sync::OnceLock::new();
            let counter = #ident.get_or_init(|| ::std::sync::atomic::AtomicUsize::new(0));
            let current_count = counter.fetch_add(1, ::std::sync::atomic::Ordering::Relaxed);

            if current_count.is_multiple_of(#n) {
                #log_macro!(#(#args),*);
            }
        }
    };

    TokenStream::from(expanded)
}

struct LogEveryNSecMacroInput {
    log_macro: syn::Path,
    n: Expr,
    args: Punctuated<Expr, Token![,]>,
}

impl Parse for LogEveryNSecMacroInput {
    fn parse(input: ParseStream<'_>) -> syn::Result<Self> {
        let log_macro: syn::Path = input.parse()?;
        input.parse::<Token![,]>()?;
        let n: Expr = input.parse()?;
        input.parse::<Token![,]>()?;
        let args: Punctuated<Expr, Token![,]> = Punctuated::parse_terminated(input)?;

        Ok(LogEveryNSecMacroInput { log_macro, n, args })
    }
}

lazy_static! {
    static ref LOG_EVERY_N_MS_CLOCK_START: Instant = Instant::now();
}

/// An internal helper macro for logging a message at most once every `n` milliseconds.
/// Do not use this directly. Instead use the `info_every_n_ms!`, `debug_every_n_ms!`, etc.
/// macros.
#[proc_macro]
pub fn log_every_n_ms(input: TokenStream) -> TokenStream {
    let LogEveryNSecMacroInput { log_macro, n, args, .. } =
        parse_macro_input!(input as LogEveryNSecMacroInput);

    let ident_last_log_time = get_uniq_identifier_for_call_site("TRACING_LAST_LOG_TIME");
    let ident_start_time = get_uniq_identifier_for_call_site("TRACING_START_TIME");

    let args = args.into_iter().collect::<Vec<_>>();

    let expanded = quote! {
        {
            // We use this to measure the passage of time. We don't use the system time since
            // it can go backwards (e.g. when the system clock is updated).
            static #ident_start_time: ::std::sync::OnceLock<::std::time::Instant> = ::std::sync::OnceLock::new();

            static #ident_last_log_time: ::std::sync::OnceLock<::std::sync::atomic::AtomicU64> = ::std::sync::OnceLock::new();
            let last_log_u64 = #ident_last_log_time.get_or_init(|| ::std::sync::atomic::AtomicU64::new(0));

            match last_log_u64.fetch_update(
                ::std::sync::atomic::Ordering::Relaxed,
                ::std::sync::atomic::Ordering::Relaxed,
                |curr_val : u64| {
                    // We use millis and not secs to avoid having any rounding issues (e.g. 1.9
                    // seconds).
                    let now_with_zero : u64 = #ident_start_time.get_or_init(|| ::std::time::Instant::now())
                        .elapsed().as_millis().try_into()
                        .expect("Timestamp in millis is larger than u64::MAX");
                    // We add +1 to avoid having a value of 0 which can be confused with the first
                    // call.
                    let now : u64 = now_with_zero + 1;

                    if curr_val == 0 {
                        // First call, update the time to start counting from now.
                        return Some(now);
                    }
                    if curr_val + (#n) <= now {
                        // We should log. Next log should be after n seconds from now.
                        return Some(now);
                    }
                    None
                }
            ) {
                Ok(old_now) => {
                    // We updated the last log time, meaning we should log.
                    #log_macro!(#(#args),*);
                }
                Err(_) => {
                    // We should not log.
                }
            };
        }
    };

    TokenStream::from(expanded)
}

static NEXT: AtomicU16 = AtomicU16::new(0);
static MAP: OnceLock<Mutex<HashMap<String, u16>>> = OnceLock::new();

fn alloc_for(key: String) -> u16 {
    let map = MAP.get_or_init(|| Mutex::new(HashMap::new()));
    let mut map = map.lock().unwrap();

    if let Some(&id) = map.get(&key) {
        return id;
    }

    let id = NEXT.fetch_add(1, Ordering::Relaxed);
    if id == u16::MAX {
        panic!("unique_u16 exhausted: > 65536 unique callsites in this crate");
    }

    map.insert(key, id);
    id
}

#[proc_macro]
pub fn unique_u16(_input: TokenStream) -> TokenStream {
    // NOTE: Use proc_macro::Span (stable APIs on 1.92)
    let span = proc_macro::Span::call_site();

    // Prefer file() for a stable-ish key (local_file may be None under remapping / RA)
    let file = span.file();
    let line = span.line(); // 1-indexed
    let col = span.column(); // 1-indexed

    let key = format!("{file}:{line}:{col}");
    let id = alloc_for(key);

    // Emit a single u16-suffixed literal so it works in const context
    let lit = proc_macro::Literal::u16_suffixed(id);
    TokenStream::from(proc_macro::TokenTree::Literal(lit))
}

#[proc_macro_attribute]
pub fn make_visibility(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let visibility: Visibility = parse_macro_input!(attrs);
    let mut input: Item = parse_macro_input!(input);

    match input {
        Item::Const(ItemConst { ref mut vis, .. })
        | Item::Enum(ItemEnum { ref mut vis, .. })
        | Item::ExternCrate(ItemExternCrate { ref mut vis, .. })
        | Item::Fn(ItemFn { ref mut vis, .. })
        | Item::Mod(ItemMod { ref mut vis, .. })
        | Item::Static(ItemStatic { ref mut vis, .. })
        | Item::Struct(ItemStruct { ref mut vis, .. })
        | Item::Trait(ItemTrait { ref mut vis, .. })
        | Item::TraitAlias(ItemTraitAlias { ref mut vis, .. })
        | Item::Type(ItemType { ref mut vis, .. })
        | Item::Union(ItemUnion { ref mut vis, .. })
        | Item::Use(ItemUse { ref mut vis, .. }) => *vis = visibility,
        _ => {
            return Error::new_spanned(&input, "Cannot override the `#[visibility]` of this item")
                .to_compile_error()
                .into();
        }
    }

    input.into_token_stream().into()
}

/// Upgrades the visibility of all fields of a struct to at least the specified visibility.
/// Visibility can only be upgraded (private -> pub(crate) -> pub), never downgraded.
///
/// # Example
///
/// ```rust,ignore
/// use apollo_proc_macros::upgrade_fields_visibility;
///
/// #[upgrade_fields_visibility(pub(crate))]
/// pub struct MyStruct {
///     field1: i32,           // private -> pub(crate)
///     pub(crate) field2: String,  // pub(crate) -> pub(crate) (unchanged)
///     pub field3: bool,      // pub -> pub (unchanged, not downgraded)
/// }
///
/// // After macro expansion, the struct becomes:
/// // pub struct MyStruct {
/// //     pub(crate) field1: i32,
/// //     pub(crate) field2: String,
/// //     pub field3: bool,  // remains pub, not downgraded to pub(crate)
/// // }
/// ```
#[proc_macro_attribute]
pub fn upgrade_fields_visibility(attrs: TokenStream, input: TokenStream) -> TokenStream {
    let target_visibility: Visibility = parse_macro_input!(attrs);
    let mut input: Item = parse_macro_input!(input);

    let Item::Struct(ItemStruct { ref mut fields, .. }) = &mut input else {
        return Error::new_spanned(
            &input,
            "`upgrade_fields_visibility` can only be applied to structs",
        )
        .to_compile_error()
        .into();
    };

    // Upgrade visibility for all fields.
    match fields {
        Fields::Named(named_fields) => {
            named_fields
                .named
                .iter_mut()
                .for_each(|field| upgrade_field_visibility(field, &target_visibility));
        }
        Fields::Unnamed(unnamed_fields) => {
            unnamed_fields
                .unnamed
                .iter_mut()
                .for_each(|field| upgrade_field_visibility(field, &target_visibility));
        }
        Fields::Unit => {
            return Error::new_spanned(
                &input,
                "`upgrade_fields_visibility` can only be applied to structs with fields",
            )
            .to_compile_error()
            .into();
        }
    }

    input.into_token_stream().into()
}

/// Upgrades the visibility of a field to the target visibility if the current visibility is lower.
fn upgrade_field_visibility(field: &mut Field, target_visibility: &Visibility) {
    if visibility_level(&field.vis) < visibility_level(target_visibility) {
        field.vis = target_visibility.clone();
    }
}

/// Returns the visibility level as a number for comparison.
/// Higher number = more visible.
/// private (Inherited) = 0, pub(crate) (Restricted) = 1, pub (Public) = 2.
fn visibility_level(vis: &Visibility) -> u8 {
    match vis {
        Visibility::Inherited => 0,
        Visibility::Restricted(_) => 1, // pub(path) is similar to pub(crate) in scope.
        Visibility::Public(_) => 2,
    }
}