Skip to main content

commonware_macros_impl/
lib.rs

1//! Proc-macro implementation for `commonware-macros`.
2//!
3//! This is an internal crate. Use [`commonware-macros`](https://docs.rs/commonware-macros)
4//! instead.
5
6use crate::nextest::configured_test_groups;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro_crate::{crate_name, FoundCrate};
10use quote::{format_ident, quote};
11use syn::{
12    braced,
13    parse::{Parse, ParseStream, Result},
14    parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
15};
16
17mod nextest;
18
19/// Stability level input that accepts either a literal integer (0-4) or a named constant
20/// (ALPHA, BETA, GAMMA, DELTA, EPSILON).
21struct StabilityLevel {
22    value: u8,
23}
24
25impl Parse for StabilityLevel {
26    fn parse(input: ParseStream<'_>) -> Result<Self> {
27        let lookahead = input.lookahead1();
28        if lookahead.peek(LitInt) {
29            let lit: LitInt = input.parse()?;
30            let value: u8 = lit
31                .base10_parse()
32                .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
33            if value > 4 {
34                return Err(Error::new(
35                    lit.span(),
36                    "stability level must be 0, 1, 2, 3, or 4",
37                ));
38            }
39            Ok(Self { value })
40        } else if lookahead.peek(Ident) {
41            let ident: Ident = input.parse()?;
42            let value = match ident.to_string().as_str() {
43                "ALPHA" => 0,
44                "BETA" => 1,
45                "GAMMA" => 2,
46                "DELTA" => 3,
47                "EPSILON" => 4,
48                _ => {
49                    return Err(Error::new(
50                        ident.span(),
51                        "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
52                    ));
53                }
54            };
55            Ok(Self { value })
56        } else {
57            Err(lookahead.error())
58        }
59    }
60}
61
62fn level_name(level: u8) -> &'static str {
63    match level {
64        0 => "ALPHA",
65        1 => "BETA",
66        2 => "GAMMA",
67        3 => "DELTA",
68        4 => "EPSILON",
69        _ => unreachable!(),
70    }
71}
72
73/// Generates cfg identifiers that should exclude an item at the given stability level.
74///
75/// The stability system works by excluding items when building at higher stability levels.
76/// For example, an item marked `#[stability(BETA)]` (level 1) should be excluded when
77/// building with `--cfg commonware_stability_GAMMA` (level 2) or higher.
78///
79/// This function returns identifiers for all levels above the given level, plus `RESERVED`.
80/// The generated `#[cfg(not(any(...)))]` attribute ensures the item is included only when
81/// none of the higher-level cfgs are set.
82///
83/// ```text
84/// Level 0 (ALPHA)   -> excludes at: BETA, GAMMA, DELTA, EPSILON, RESERVED
85/// Level 1 (BETA)    -> excludes at: GAMMA, DELTA, EPSILON, RESERVED
86/// Level 2 (GAMMA)   -> excludes at: DELTA, EPSILON, RESERVED
87/// Level 3 (DELTA)   -> excludes at: EPSILON, RESERVED
88/// Level 4 (EPSILON) -> excludes at: RESERVED
89/// ```
90///
91/// `RESERVED` is a special level used by `scripts/find_unstable_public.sh` to exclude ALL
92/// stability-marked items, leaving only unmarked public items visible in rustdoc output.
93fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
94    let mut names: Vec<_> = ((level + 1)..=4)
95        .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
96        .collect();
97
98    names.push(format_ident!("commonware_stability_RESERVED"));
99    names
100}
101
102#[proc_macro_attribute]
103pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
104    let level = parse_macro_input!(attr as StabilityLevel);
105    let exclude_names = exclusion_cfg_names(level.value);
106
107    let item2: proc_macro2::TokenStream = item.into();
108    let expanded = quote! {
109        #[cfg(not(any(#(#exclude_names),*)))]
110        #item2
111    };
112
113    TokenStream::from(expanded)
114}
115
116/// Input for the `stability_mod!` macro: `level, visibility mod name`
117struct StabilityModInput {
118    level: StabilityLevel,
119    visibility: Visibility,
120    name: Ident,
121}
122
123impl Parse for StabilityModInput {
124    fn parse(input: ParseStream<'_>) -> Result<Self> {
125        let level: StabilityLevel = input.parse()?;
126        input.parse::<Token![,]>()?;
127        let visibility: Visibility = input.parse()?;
128        input.parse::<Token![mod]>()?;
129        let name: Ident = input.parse()?;
130        Ok(Self {
131            level,
132            visibility,
133            name,
134        })
135    }
136}
137
138#[proc_macro]
139pub fn stability_mod(input: TokenStream) -> TokenStream {
140    let StabilityModInput {
141        level,
142        visibility,
143        name,
144    } = parse_macro_input!(input as StabilityModInput);
145
146    let exclude_names = exclusion_cfg_names(level.value);
147
148    let expanded = quote! {
149        #[cfg(not(any(#(#exclude_names),*)))]
150        #visibility mod #name;
151    };
152
153    TokenStream::from(expanded)
154}
155
156/// Input for the `stability_scope!` macro: `level [, cfg(predicate)] { items... }`
157struct StabilityScopeInput {
158    level: StabilityLevel,
159    predicate: Option<syn::Meta>,
160    items: Vec<syn::Item>,
161}
162
163impl Parse for StabilityScopeInput {
164    fn parse(input: ParseStream<'_>) -> Result<Self> {
165        let level: StabilityLevel = input.parse()?;
166
167        // Check for optional cfg predicate
168        let predicate = if input.peek(Token![,]) {
169            input.parse::<Token![,]>()?;
170
171            // Parse `cfg(...)` - expect the literal identifier "cfg" followed by parenthesized content
172            let cfg_ident: Ident = input.parse()?;
173            if cfg_ident != "cfg" {
174                return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
175            }
176            let cfg_content;
177            syn::parenthesized!(cfg_content in input);
178            Some(cfg_content.parse()?)
179        } else {
180            None
181        };
182
183        let content;
184        braced!(content in input);
185
186        let mut items = Vec::new();
187        while !content.is_empty() {
188            items.push(content.parse()?);
189        }
190
191        Ok(Self {
192            level,
193            predicate,
194            items,
195        })
196    }
197}
198
199#[proc_macro]
200pub fn stability_scope(input: TokenStream) -> TokenStream {
201    let StabilityScopeInput {
202        level,
203        predicate,
204        items,
205    } = parse_macro_input!(input as StabilityScopeInput);
206
207    let exclude_names = exclusion_cfg_names(level.value);
208
209    let cfg_attr = predicate.map_or_else(
210        || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
211        |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
212    );
213
214    let expanded_items: Vec<_> = items
215        .into_iter()
216        .map(|item| {
217            quote! {
218                #cfg_attr
219                #item
220            }
221        })
222        .collect();
223
224    let expanded = quote! {
225        #(#expanded_items)*
226    };
227
228    TokenStream::from(expanded)
229}
230
231#[proc_macro_attribute]
232pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
233    // Parse the input tokens into a syntax tree
234    let input = parse_macro_input!(item as ItemFn);
235
236    // Extract function components
237    let attrs = input.attrs;
238    let vis = input.vis;
239    let mut sig = input.sig;
240    let block = input.block;
241
242    // Remove 'async' from the function signature (#[test] only
243    // accepts sync functions)
244    sig.asyncness
245        .take()
246        .expect("test_async macro can only be used with async functions");
247
248    // Generate output tokens
249    let expanded = quote! {
250        #[test]
251        #(#attrs)*
252        #vis #sig {
253            futures::executor::block_on(async #block);
254        }
255    };
256    TokenStream::from(expanded)
257}
258
259#[proc_macro_attribute]
260pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
261    // Parse the input tokens into a syntax tree
262    let input = parse_macro_input!(item as ItemFn);
263
264    // Parse the attribute argument for log level
265    let log_level = if attr.is_empty() {
266        // Default log level is DEBUG
267        quote! { tracing::Level::DEBUG }
268    } else {
269        // Parse the attribute as a string literal
270        let level_str = parse_macro_input!(attr as LitStr);
271        let level_ident = level_str.value().to_uppercase();
272        match level_ident.as_str() {
273            "TRACE" => quote! { tracing::Level::TRACE },
274            "DEBUG" => quote! { tracing::Level::DEBUG },
275            "INFO" => quote! { tracing::Level::INFO },
276            "WARN" => quote! { tracing::Level::WARN },
277            "ERROR" => quote! { tracing::Level::ERROR },
278            _ => {
279                // Return a compile error for invalid log levels
280                return Error::new_spanned(
281                    level_str,
282                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
283                )
284                .to_compile_error()
285                .into();
286            }
287        }
288    };
289
290    // Extract function components
291    let attrs = input.attrs;
292    let vis = input.vis;
293    let sig = input.sig;
294    let block = input.block;
295
296    // Generate output tokens
297    let expanded = quote! {
298        #[test]
299        #(#attrs)*
300        #vis #sig {
301            // Create a subscriber and dispatcher with the specified log level
302            let subscriber = tracing_subscriber::fmt()
303                .with_test_writer()
304                .with_max_level(#log_level)
305                .with_line_number(true)
306                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
307                .finish();
308            let dispatcher = tracing::Dispatch::new(subscriber);
309
310            // Set the subscriber for the scope of the test
311            tracing::dispatcher::with_default(&dispatcher, || {
312                #block
313            });
314        }
315    };
316    TokenStream::from(expanded)
317}
318
319#[proc_macro_attribute]
320pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
321    if attr.is_empty() {
322        return Error::new(
323            Span::call_site(),
324            "test_group requires a string literal filter group name",
325        )
326        .to_compile_error()
327        .into();
328    }
329
330    let mut input = parse_macro_input!(item as ItemFn);
331    let group_literal = parse_macro_input!(attr as LitStr);
332
333    let group = match nextest::sanitize_group_literal(&group_literal) {
334        Ok(group) => group,
335        Err(err) => return err.to_compile_error().into(),
336    };
337    let groups = match configured_test_groups() {
338        Ok(groups) => groups,
339        Err(_) => {
340            // Don't fail the compilation if the file isn't found; just return the original input.
341            return TokenStream::from(quote!(#input));
342        }
343    };
344
345    if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
346        return err.to_compile_error().into();
347    }
348
349    let original_name = input.sig.ident.to_string();
350    let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
351
352    input.sig.ident = new_ident;
353
354    TokenStream::from(quote!(#input))
355}
356
357#[proc_macro_attribute]
358pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
359    let input = parse_macro_input!(item as ItemFn);
360
361    // Parse the attribute argument for log level
362    let log_level = if attr.is_empty() {
363        // Default log level is DEBUG
364        quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
365    } else {
366        // Parse the attribute as a string literal
367        let level_str = parse_macro_input!(attr as LitStr);
368        let level_ident = level_str.value().to_uppercase();
369        match level_ident.as_str() {
370            "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
371            "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
372            "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
373            "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
374            "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
375            _ => {
376                // Return a compile error for invalid log levels
377                return Error::new_spanned(
378                    level_str,
379                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
380                )
381                .to_compile_error()
382                .into();
383            }
384        }
385    };
386
387    let attrs = input.attrs;
388    let vis = input.vis;
389    let sig = input.sig;
390    let block = input.block;
391
392    // Create the signature of the inner function that takes the TraceStorage.
393    let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
394    let mut inner_sig = sig.clone();
395    inner_sig.ident = inner_ident.clone();
396
397    // Create the signature of the outer test function.
398    let mut outer_sig = sig;
399    outer_sig.inputs.clear();
400
401    // Detect the path of the `commonware-runtime` crate. If it has been renamed or
402    // this macro is being used within the `commonware-runtime` crate itself, adjust
403    // the path accordingly.
404    let rt_path = match crate_name("commonware-runtime") {
405        Ok(FoundCrate::Itself) => quote!(crate),
406        Ok(FoundCrate::Name(name)) => {
407            let ident = syn::Ident::new(&name, Span::call_site());
408            quote!(#ident)
409        }
410        Err(_) => quote!(::commonware_runtime), // fallback
411    };
412
413    let expanded = quote! {
414        // Inner test function runs the actual test logic, accepting the TraceStorage
415        // created by the harness.
416        #(#attrs)*
417        #vis #inner_sig #block
418
419        #[test]
420        #vis #outer_sig {
421            use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
422            use ::tracing::{Dispatch, dispatcher};
423            use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
424
425            let trace_store = TraceStorage::default();
426            let collecting_layer = CollectingLayer::new(trace_store.clone());
427
428            let fmt_layer = fmt::layer()
429                .with_test_writer()
430                .with_line_number(true)
431                .with_span_events(fmt::format::FmtSpan::CLOSE)
432                .with_filter(#log_level);
433
434            let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
435            let dispatcher = Dispatch::new(subscriber);
436            dispatcher::with_default(&dispatcher, || {
437                #inner_ident(trace_store);
438            });
439        }
440    };
441
442    TokenStream::from(expanded)
443}
444
445struct SelectInput {
446    branches: Vec<Branch>,
447}
448
449struct Branch {
450    pattern: Pat,
451    future: Expr,
452    body: Expr,
453}
454
455impl Parse for SelectInput {
456    fn parse(input: ParseStream<'_>) -> Result<Self> {
457        let mut branches = Vec::new();
458
459        while !input.is_empty() {
460            let pattern = Pat::parse_single(input)?;
461            input.parse::<Token![=]>()?;
462            let future: Expr = input.parse()?;
463            input.parse::<Token![=>]>()?;
464            let body: Expr = input.parse()?;
465
466            branches.push(Branch {
467                pattern,
468                future,
469                body,
470            });
471
472            if input.peek(Token![,]) {
473                input.parse::<Token![,]>()?;
474            } else {
475                break;
476            }
477        }
478
479        Ok(Self { branches })
480    }
481}
482
483#[proc_macro]
484pub fn select(input: TokenStream) -> TokenStream {
485    // Parse the input tokens
486    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
487
488    // Generate code from provided statements
489    let mut select_branches = Vec::new();
490    for Branch {
491        pattern,
492        future,
493        body,
494    } in branches.into_iter()
495    {
496        // Generate branch for `select!` macro
497        let branch_code = quote! {
498            #pattern = #future => #body,
499        };
500        select_branches.push(branch_code);
501    }
502
503    // Generate the final output code
504    quote! {
505        {
506            ::commonware_macros::__reexport::tokio::select! {
507                biased;
508                #(#select_branches)*
509            }
510        }
511    }
512    .into()
513}
514
515/// Input for [select_loop!].
516///
517/// Parses: `context, [on_start => expr,] on_stopped => expr, branches... [, on_end => expr]`
518struct SelectLoopInput {
519    context: Expr,
520    start_expr: Option<Expr>,
521    shutdown_expr: Expr,
522    branches: Vec<Branch>,
523    end_expr: Option<Expr>,
524}
525
526impl Parse for SelectLoopInput {
527    fn parse(input: ParseStream<'_>) -> Result<Self> {
528        // Parse context expression
529        let context: Expr = input.parse()?;
530        input.parse::<Token![,]>()?;
531
532        // Check for optional `on_start =>`
533        let start_expr = if input.peek(Ident) {
534            let ident: Ident = input.fork().parse()?;
535            if ident == "on_start" {
536                input.parse::<Ident>()?; // consume the ident
537                input.parse::<Token![=>]>()?;
538                let expr: Expr = input.parse()?;
539                input.parse::<Token![,]>()?;
540                Some(expr)
541            } else {
542                None
543            }
544        } else {
545            None
546        };
547
548        // Parse `on_stopped =>`
549        let on_stopped_ident: Ident = input.parse()?;
550        if on_stopped_ident != "on_stopped" {
551            return Err(Error::new(
552                on_stopped_ident.span(),
553                "expected `on_stopped` keyword",
554            ));
555        }
556        input.parse::<Token![=>]>()?;
557
558        // Parse shutdown expression
559        let shutdown_expr: Expr = input.parse()?;
560
561        // Parse comma after shutdown expression
562        input.parse::<Token![,]>()?;
563
564        // Parse branches directly (no surrounding braces)
565        // Stop when we see `on_end` or reach end of input
566        let mut branches = Vec::new();
567        while !input.is_empty() {
568            // Check if next token is `on_end`
569            if input.peek(Ident) {
570                let ident: Ident = input.fork().parse()?;
571                if ident == "on_end" {
572                    break;
573                }
574            }
575
576            let pattern = Pat::parse_single(input)?;
577            input.parse::<Token![=]>()?;
578            let future: Expr = input.parse()?;
579            input.parse::<Token![=>]>()?;
580            let body: Expr = input.parse()?;
581
582            branches.push(Branch {
583                pattern,
584                future,
585                body,
586            });
587
588            if input.peek(Token![,]) {
589                input.parse::<Token![,]>()?;
590            } else {
591                break;
592            }
593        }
594
595        // Check for optional `on_end =>`
596        let end_expr = if !input.is_empty() && input.peek(Ident) {
597            let ident: Ident = input.parse()?;
598            if ident == "on_end" {
599                input.parse::<Token![=>]>()?;
600                let expr: Expr = input.parse()?;
601                if input.peek(Token![,]) {
602                    input.parse::<Token![,]>()?;
603                }
604                Some(expr)
605            } else {
606                return Err(Error::new(ident.span(), "expected `on_end` keyword"));
607            }
608        } else {
609            None
610        };
611
612        Ok(Self {
613            context,
614            start_expr,
615            shutdown_expr,
616            branches,
617            end_expr,
618        })
619    }
620}
621
622#[proc_macro]
623pub fn select_loop(input: TokenStream) -> TokenStream {
624    let SelectLoopInput {
625        context,
626        start_expr,
627        shutdown_expr,
628        branches,
629        end_expr,
630    } = parse_macro_input!(input as SelectLoopInput);
631
632    // Convert branches to tokens for the inner select!
633    let branch_tokens: Vec<_> = branches
634        .iter()
635        .map(|b| {
636            let pattern = &b.pattern;
637            let future = &b.future;
638            let body = &b.body;
639            quote! { #pattern = #future => #body, }
640        })
641        .collect();
642
643    // Helper to convert an expression to tokens, inlining block contents
644    // to preserve variable scope
645    fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
646        match expr {
647            Expr::Block(block) => {
648                let stmts = &block.block.stmts;
649                quote! { #(#stmts)* }
650            }
651            other => quote! { #other; },
652        }
653    }
654
655    // Generate on_start and on_end tokens if present
656    let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
657    let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
658    let shutdown_tokens = expr_to_tokens(&shutdown_expr);
659
660    quote! {
661        {
662            let mut shutdown = #context.stopped();
663            loop {
664                #on_start_tokens
665
666                commonware_macros::select! {
667                    _ = &mut shutdown => {
668                        #shutdown_tokens
669
670                        // Break the loop after handling shutdown. Some implementations
671                        // may divert control flow themselves, so this may be unused.
672                        #[allow(unreachable_code)]
673                        break;
674                    },
675                    #(#branch_tokens)*
676                }
677
678                #on_end_tokens
679            }
680        }
681    }
682    .into()
683}