Skip to main content

commonware_macros_impl/
lib.rs

1//! Proc-macro implementation for `commonware-macros`.
2//!
3//! This is an internal crate. Use [`commonware-macros`](https://docs.rs/commonware-macros)
4//! instead.
5
6#![doc(
7    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
8    html_favicon_url = "https://commonware.xyz/favicon.ico"
9)]
10
11use crate::nextest::configured_test_groups;
12use proc_macro::TokenStream;
13use proc_macro2::Span;
14use proc_macro_crate::{crate_name, FoundCrate};
15use quote::{format_ident, quote};
16use syn::{
17    braced,
18    parse::{Parse, ParseStream, Result},
19    parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
20};
21
22mod nextest;
23
24/// Stability level input that accepts either a literal integer (0-4) or a named constant
25/// (ALPHA, BETA, GAMMA, DELTA, EPSILON).
26struct StabilityLevel {
27    value: u8,
28}
29
30impl Parse for StabilityLevel {
31    fn parse(input: ParseStream<'_>) -> Result<Self> {
32        let lookahead = input.lookahead1();
33        if lookahead.peek(LitInt) {
34            let lit: LitInt = input.parse()?;
35            let value: u8 = lit
36                .base10_parse()
37                .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
38            if value > 4 {
39                return Err(Error::new(
40                    lit.span(),
41                    "stability level must be 0, 1, 2, 3, or 4",
42                ));
43            }
44            Ok(Self { value })
45        } else if lookahead.peek(Ident) {
46            let ident: Ident = input.parse()?;
47            let value = match ident.to_string().as_str() {
48                "ALPHA" => 0,
49                "BETA" => 1,
50                "GAMMA" => 2,
51                "DELTA" => 3,
52                "EPSILON" => 4,
53                _ => {
54                    return Err(Error::new(
55                        ident.span(),
56                        "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
57                    ));
58                }
59            };
60            Ok(Self { value })
61        } else {
62            Err(lookahead.error())
63        }
64    }
65}
66
67fn level_name(level: u8) -> &'static str {
68    match level {
69        0 => "ALPHA",
70        1 => "BETA",
71        2 => "GAMMA",
72        3 => "DELTA",
73        4 => "EPSILON",
74        _ => unreachable!(),
75    }
76}
77
78/// Generates cfg identifiers that should exclude an item at the given stability level.
79///
80/// The stability system works by excluding items when building at higher stability levels.
81/// For example, an item marked `#[stability(BETA)]` (level 1) should be excluded when
82/// building with `--cfg commonware_stability_GAMMA` (level 2) or higher.
83///
84/// This function returns identifiers for all levels above the given level, plus `RESERVED`.
85/// The generated `#[cfg(not(any(...)))]` attribute ensures the item is included only when
86/// none of the higher-level cfgs are set.
87///
88/// ```text
89/// Level 0 (ALPHA)   -> excludes at: BETA, GAMMA, DELTA, EPSILON, RESERVED
90/// Level 1 (BETA)    -> excludes at: GAMMA, DELTA, EPSILON, RESERVED
91/// Level 2 (GAMMA)   -> excludes at: DELTA, EPSILON, RESERVED
92/// Level 3 (DELTA)   -> excludes at: EPSILON, RESERVED
93/// Level 4 (EPSILON) -> excludes at: RESERVED
94/// ```
95///
96/// `RESERVED` is a special level used by `scripts/find_unstable_public.sh` to exclude ALL
97/// stability-marked items, leaving only unmarked public items visible in rustdoc output.
98fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
99    let mut names: Vec<_> = ((level + 1)..=4)
100        .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
101        .collect();
102
103    names.push(format_ident!("commonware_stability_RESERVED"));
104    names
105}
106
107#[proc_macro_attribute]
108pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
109    let level = parse_macro_input!(attr as StabilityLevel);
110    let exclude_names = exclusion_cfg_names(level.value);
111
112    let item2: proc_macro2::TokenStream = item.into();
113    let expanded = quote! {
114        #[cfg(not(any(#(#exclude_names),*)))]
115        #item2
116    };
117
118    TokenStream::from(expanded)
119}
120
121/// Input for the `stability_mod!` macro: `level, visibility mod name`
122struct StabilityModInput {
123    level: StabilityLevel,
124    visibility: Visibility,
125    name: Ident,
126}
127
128impl Parse for StabilityModInput {
129    fn parse(input: ParseStream<'_>) -> Result<Self> {
130        let level: StabilityLevel = input.parse()?;
131        input.parse::<Token![,]>()?;
132        let visibility: Visibility = input.parse()?;
133        input.parse::<Token![mod]>()?;
134        let name: Ident = input.parse()?;
135        Ok(Self {
136            level,
137            visibility,
138            name,
139        })
140    }
141}
142
143#[proc_macro]
144pub fn stability_mod(input: TokenStream) -> TokenStream {
145    let StabilityModInput {
146        level,
147        visibility,
148        name,
149    } = parse_macro_input!(input as StabilityModInput);
150
151    let exclude_names = exclusion_cfg_names(level.value);
152
153    let expanded = quote! {
154        #[cfg(not(any(#(#exclude_names),*)))]
155        #visibility mod #name;
156    };
157
158    TokenStream::from(expanded)
159}
160
161/// Input for the `stability_scope!` macro: `level [, cfg(predicate)] { items... }`
162struct StabilityScopeInput {
163    level: StabilityLevel,
164    predicate: Option<syn::Meta>,
165    items: Vec<syn::Item>,
166}
167
168impl Parse for StabilityScopeInput {
169    fn parse(input: ParseStream<'_>) -> Result<Self> {
170        let level: StabilityLevel = input.parse()?;
171
172        // Check for optional cfg predicate
173        let predicate = if input.peek(Token![,]) {
174            input.parse::<Token![,]>()?;
175
176            // Parse `cfg(...)` - expect the literal identifier "cfg" followed by parenthesized content
177            let cfg_ident: Ident = input.parse()?;
178            if cfg_ident != "cfg" {
179                return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
180            }
181            let cfg_content;
182            syn::parenthesized!(cfg_content in input);
183            Some(cfg_content.parse()?)
184        } else {
185            None
186        };
187
188        let content;
189        braced!(content in input);
190
191        let mut items = Vec::new();
192        while !content.is_empty() {
193            items.push(content.parse()?);
194        }
195
196        Ok(Self {
197            level,
198            predicate,
199            items,
200        })
201    }
202}
203
204#[proc_macro]
205pub fn stability_scope(input: TokenStream) -> TokenStream {
206    let StabilityScopeInput {
207        level,
208        predicate,
209        items,
210    } = parse_macro_input!(input as StabilityScopeInput);
211
212    let exclude_names = exclusion_cfg_names(level.value);
213
214    let cfg_attr = predicate.map_or_else(
215        || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
216        |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
217    );
218
219    let expanded_items: Vec<_> = items
220        .into_iter()
221        .map(|item| {
222            quote! {
223                #cfg_attr
224                #item
225            }
226        })
227        .collect();
228
229    let expanded = quote! {
230        #(#expanded_items)*
231    };
232
233    TokenStream::from(expanded)
234}
235
236#[proc_macro_attribute]
237pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
238    // Parse the input tokens into a syntax tree
239    let input = parse_macro_input!(item as ItemFn);
240
241    // Extract function components
242    let attrs = input.attrs;
243    let vis = input.vis;
244    let mut sig = input.sig;
245    let block = input.block;
246
247    // Remove 'async' from the function signature (#[test] only
248    // accepts sync functions)
249    sig.asyncness
250        .take()
251        .expect("test_async macro can only be used with async functions");
252
253    // Generate output tokens
254    let expanded = quote! {
255        #[test]
256        #(#attrs)*
257        #vis #sig {
258            futures::executor::block_on(async #block);
259        }
260    };
261    TokenStream::from(expanded)
262}
263
264#[proc_macro_attribute]
265pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
266    // Parse the input tokens into a syntax tree
267    let input = parse_macro_input!(item as ItemFn);
268
269    // Parse the attribute argument for log level
270    let log_level = if attr.is_empty() {
271        // Default log level is DEBUG
272        quote! { tracing::Level::DEBUG }
273    } else {
274        // Parse the attribute as a string literal
275        let level_str = parse_macro_input!(attr as LitStr);
276        let level_ident = level_str.value().to_uppercase();
277        match level_ident.as_str() {
278            "TRACE" => quote! { tracing::Level::TRACE },
279            "DEBUG" => quote! { tracing::Level::DEBUG },
280            "INFO" => quote! { tracing::Level::INFO },
281            "WARN" => quote! { tracing::Level::WARN },
282            "ERROR" => quote! { tracing::Level::ERROR },
283            _ => {
284                // Return a compile error for invalid log levels
285                return Error::new_spanned(
286                    level_str,
287                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
288                )
289                .to_compile_error()
290                .into();
291            }
292        }
293    };
294
295    // Extract function components
296    let attrs = input.attrs;
297    let vis = input.vis;
298    let sig = input.sig;
299    let block = input.block;
300
301    // Generate output tokens
302    let expanded = quote! {
303        #[test]
304        #(#attrs)*
305        #vis #sig {
306            // Create a subscriber and dispatcher with the specified log level
307            let subscriber = tracing_subscriber::fmt()
308                .with_test_writer()
309                .with_max_level(#log_level)
310                .with_line_number(true)
311                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
312                .finish();
313            let dispatcher = tracing::Dispatch::new(subscriber);
314
315            // Set the subscriber for the scope of the test
316            tracing::dispatcher::with_default(&dispatcher, || {
317                #block
318            });
319        }
320    };
321    TokenStream::from(expanded)
322}
323
324#[proc_macro_attribute]
325pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
326    if attr.is_empty() {
327        return Error::new(
328            Span::call_site(),
329            "test_group requires a string literal filter group name",
330        )
331        .to_compile_error()
332        .into();
333    }
334
335    let mut input = parse_macro_input!(item as ItemFn);
336    let group_literal = parse_macro_input!(attr as LitStr);
337
338    let group = match nextest::sanitize_group_literal(&group_literal) {
339        Ok(group) => group,
340        Err(err) => return err.to_compile_error().into(),
341    };
342    let groups = match configured_test_groups() {
343        Ok(groups) => groups,
344        Err(_) => {
345            // Don't fail the compilation if the file isn't found; just return the original input.
346            return TokenStream::from(quote!(#input));
347        }
348    };
349
350    if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
351        return err.to_compile_error().into();
352    }
353
354    let original_name = input.sig.ident.to_string();
355    let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
356
357    input.sig.ident = new_ident;
358
359    TokenStream::from(quote!(#input))
360}
361
362#[proc_macro_attribute]
363pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
364    let input = parse_macro_input!(item as ItemFn);
365
366    // Parse the attribute argument for log level
367    let log_level = if attr.is_empty() {
368        // Default log level is DEBUG
369        quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
370    } else {
371        // Parse the attribute as a string literal
372        let level_str = parse_macro_input!(attr as LitStr);
373        let level_ident = level_str.value().to_uppercase();
374        match level_ident.as_str() {
375            "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
376            "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
377            "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
378            "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
379            "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
380            _ => {
381                // Return a compile error for invalid log levels
382                return Error::new_spanned(
383                    level_str,
384                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
385                )
386                .to_compile_error()
387                .into();
388            }
389        }
390    };
391
392    let attrs = input.attrs;
393    let vis = input.vis;
394    let sig = input.sig;
395    let block = input.block;
396
397    // Create the signature of the inner function that takes the TraceStorage.
398    let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
399    let mut inner_sig = sig.clone();
400    inner_sig.ident = inner_ident.clone();
401
402    // Create the signature of the outer test function.
403    let mut outer_sig = sig;
404    outer_sig.inputs.clear();
405
406    // Detect the path of the `commonware-runtime` crate. If it has been renamed or
407    // this macro is being used within the `commonware-runtime` crate itself, adjust
408    // the path accordingly.
409    let rt_path = match crate_name("commonware-runtime") {
410        Ok(FoundCrate::Itself) => quote!(crate),
411        Ok(FoundCrate::Name(name)) => {
412            let ident = syn::Ident::new(&name, Span::call_site());
413            quote!(#ident)
414        }
415        Err(_) => quote!(::commonware_runtime), // fallback
416    };
417
418    let expanded = quote! {
419        // Inner test function runs the actual test logic, accepting the TraceStorage
420        // created by the harness.
421        #(#attrs)*
422        #vis #inner_sig #block
423
424        #[test]
425        #vis #outer_sig {
426            use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
427            use ::tracing::{Dispatch, dispatcher};
428            use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
429
430            let trace_store = TraceStorage::default();
431            let collecting_layer = CollectingLayer::new(trace_store.clone());
432
433            let fmt_layer = fmt::layer()
434                .with_test_writer()
435                .with_line_number(true)
436                .with_span_events(fmt::format::FmtSpan::CLOSE)
437                .with_filter(#log_level);
438
439            let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
440            let dispatcher = Dispatch::new(subscriber);
441            dispatcher::with_default(&dispatcher, || {
442                #inner_ident(trace_store);
443            });
444        }
445    };
446
447    TokenStream::from(expanded)
448}
449
450struct SelectInput {
451    branches: Vec<Branch>,
452}
453
454struct Branch {
455    pattern: Pat,
456    future: Expr,
457    body: Expr,
458}
459
460/// Branch for [select_loop!] with optional `else` clause for `Some` patterns.
461struct SelectLoopBranch {
462    pattern: Pat,
463    future: Expr,
464    else_body: Option<Expr>,
465    body: Expr,
466}
467
468impl Parse for SelectInput {
469    fn parse(input: ParseStream<'_>) -> Result<Self> {
470        let mut branches = Vec::new();
471
472        while !input.is_empty() {
473            let pattern = Pat::parse_single(input)?;
474            input.parse::<Token![=]>()?;
475            let future: Expr = input.parse()?;
476            input.parse::<Token![=>]>()?;
477            let body: Expr = input.parse()?;
478
479            branches.push(Branch {
480                pattern,
481                future,
482                body,
483            });
484
485            if input.peek(Token![,]) {
486                input.parse::<Token![,]>()?;
487            } else {
488                break;
489            }
490        }
491
492        Ok(Self { branches })
493    }
494}
495
496#[proc_macro]
497pub fn select(input: TokenStream) -> TokenStream {
498    // Parse the input tokens
499    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
500
501    // Generate code from provided statements
502    let mut select_branches = Vec::new();
503    for Branch {
504        pattern,
505        future,
506        body,
507    } in branches.into_iter()
508    {
509        // Generate branch for `select!` macro
510        let branch_code = quote! {
511            #pattern = #future => #body,
512        };
513        select_branches.push(branch_code);
514    }
515
516    // Generate the final output code
517    quote! {
518        {
519            ::commonware_macros::__reexport::tokio::select! {
520                biased;
521                #(#select_branches)*
522            }
523        }
524    }
525    .into()
526}
527
528/// Input for [select_loop!].
529///
530/// Parses: `context, [on_start => expr,] on_stopped => expr, branches... [, on_end => expr]`
531struct SelectLoopInput {
532    context: Expr,
533    start_expr: Option<Expr>,
534    shutdown_expr: Expr,
535    branches: Vec<SelectLoopBranch>,
536    end_expr: Option<Expr>,
537}
538
539impl Parse for SelectLoopInput {
540    fn parse(input: ParseStream<'_>) -> Result<Self> {
541        // Parse context expression
542        let context: Expr = input.parse()?;
543        input.parse::<Token![,]>()?;
544
545        // Check for optional `on_start =>`
546        let start_expr = if input.peek(Ident) {
547            let ident: Ident = input.fork().parse()?;
548            if ident == "on_start" {
549                input.parse::<Ident>()?; // consume the ident
550                input.parse::<Token![=>]>()?;
551                let expr: Expr = input.parse()?;
552                input.parse::<Token![,]>()?;
553                Some(expr)
554            } else {
555                None
556            }
557        } else {
558            None
559        };
560
561        // Parse `on_stopped =>`
562        let on_stopped_ident: Ident = input.parse()?;
563        if on_stopped_ident != "on_stopped" {
564            return Err(Error::new(
565                on_stopped_ident.span(),
566                "expected `on_stopped` keyword",
567            ));
568        }
569        input.parse::<Token![=>]>()?;
570
571        // Parse shutdown expression
572        let shutdown_expr: Expr = input.parse()?;
573
574        // Parse comma after shutdown expression
575        input.parse::<Token![,]>()?;
576
577        // Parse branches directly (no surrounding braces)
578        // Stop when we see `on_end` or reach end of input
579        let mut branches = Vec::new();
580        while !input.is_empty() {
581            // Check if next token is `on_end`
582            if input.peek(Ident) {
583                let ident: Ident = input.fork().parse()?;
584                if ident == "on_end" {
585                    break;
586                }
587            }
588
589            let pattern = Pat::parse_single(input)?;
590            input.parse::<Token![=]>()?;
591            let future: Expr = input.parse()?;
592
593            // Parse optional else clause: `else expr`
594            let else_body = if input.peek(Token![else]) {
595                input.parse::<Token![else]>()?;
596                Some(input.parse::<Expr>()?)
597            } else {
598                None
599            };
600
601            input.parse::<Token![=>]>()?;
602            let body: Expr = input.parse()?;
603
604            branches.push(SelectLoopBranch {
605                pattern,
606                future,
607                else_body,
608                body,
609            });
610
611            if input.peek(Token![,]) {
612                input.parse::<Token![,]>()?;
613            } else {
614                break;
615            }
616        }
617
618        // Check for optional `on_end =>`
619        let end_expr = if !input.is_empty() && input.peek(Ident) {
620            let ident: Ident = input.parse()?;
621            if ident == "on_end" {
622                input.parse::<Token![=>]>()?;
623                let expr: Expr = input.parse()?;
624                if input.peek(Token![,]) {
625                    input.parse::<Token![,]>()?;
626                }
627                Some(expr)
628            } else {
629                return Err(Error::new(ident.span(), "expected `on_end` keyword"));
630            }
631        } else {
632            None
633        };
634
635        Ok(Self {
636            context,
637            start_expr,
638            shutdown_expr,
639            branches,
640            end_expr,
641        })
642    }
643}
644
645#[proc_macro]
646pub fn select_loop(input: TokenStream) -> TokenStream {
647    let SelectLoopInput {
648        context,
649        start_expr,
650        shutdown_expr,
651        branches,
652        end_expr,
653    } = parse_macro_input!(input as SelectLoopInput);
654
655    fn is_irrefutable(pat: &Pat) -> bool {
656        match pat {
657            Pat::Wild(_) | Pat::Rest(_) => true,
658            Pat::Ident(i) => i.subpat.as_ref().is_none_or(|(_, p)| is_irrefutable(p)),
659            Pat::Type(t) => is_irrefutable(&t.pat),
660            Pat::Tuple(t) => t.elems.iter().all(is_irrefutable),
661            Pat::Reference(r) => is_irrefutable(&r.pat),
662            Pat::Paren(p) => is_irrefutable(&p.pat),
663            _ => false,
664        }
665    }
666
667    for b in &branches {
668        if b.else_body.is_none() && !is_irrefutable(&b.pattern) {
669            return Error::new_spanned(
670                &b.pattern,
671                "refutable patterns require an else clause: \
672                 `Some(msg) = future else break => { ... }`",
673            )
674            .to_compile_error()
675            .into();
676        }
677    }
678
679    // Convert branches to tokens for the inner select!
680    let branch_tokens: Vec<_> = branches
681        .iter()
682        .map(|b| {
683            let pattern = &b.pattern;
684            let future = &b.future;
685            let body = &b.body;
686
687            // If else clause is present, use let-else to unwrap
688            b.else_body.as_ref().map_or_else(
689                // No else: normal pattern binding (already validated as irrefutable)
690                || quote! { #pattern = #future => #body, },
691                // With else: use let-else for refutable patterns
692                |else_expr| {
693                    quote! {
694                        __select_result = #future => {
695                            let #pattern = __select_result else { #else_expr };
696                            #body
697                        },
698                    }
699                },
700            )
701        })
702        .collect();
703
704    // Helper to convert an expression to tokens, inlining block contents
705    // to preserve variable scope
706    fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
707        match expr {
708            Expr::Block(block) => {
709                let stmts = &block.block.stmts;
710                quote! { #(#stmts)* }
711            }
712            other => quote! { #other; },
713        }
714    }
715
716    // Generate on_start and on_end tokens if present
717    let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
718    let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
719    let shutdown_tokens = expr_to_tokens(&shutdown_expr);
720
721    quote! {
722        {
723            let mut shutdown = #context.stopped();
724            loop {
725                #on_start_tokens
726
727                commonware_macros::select! {
728                    _ = &mut shutdown => {
729                        #shutdown_tokens
730
731                        // Break the loop after handling shutdown. Some implementations
732                        // may divert control flow themselves, so this may be unused.
733                        #[allow(unreachable_code)]
734                        break;
735                    },
736                    #(#branch_tokens)*
737                }
738
739                #on_end_tokens
740            }
741        }
742    }
743    .into()
744}