Skip to main content

commonware_macros_impl/
lib.rs

1//! Proc-macro implementation for `commonware-macros`.
2//!
3//! This is an internal crate. Use [`commonware-macros`](https://docs.rs/commonware-macros)
4//! instead.
5
6#![doc(
7    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
8    html_favicon_url = "https://commonware.xyz/favicon.ico"
9)]
10
11use crate::nextest::configured_test_groups;
12use proc_macro::TokenStream;
13use proc_macro2::Span;
14use proc_macro_crate::{crate_name, FoundCrate};
15use quote::{format_ident, quote};
16use syn::{
17    braced,
18    parse::{Parse, ParseStream, Result},
19    parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
20};
21
22mod nextest;
23
24/// Stability level input that accepts either a literal integer (0-4) or a named constant
25/// (ALPHA, BETA, GAMMA, DELTA, EPSILON).
26struct StabilityLevel {
27    value: u8,
28}
29
30impl Parse for StabilityLevel {
31    fn parse(input: ParseStream<'_>) -> Result<Self> {
32        let lookahead = input.lookahead1();
33        if lookahead.peek(LitInt) {
34            let lit: LitInt = input.parse()?;
35            let value: u8 = lit
36                .base10_parse()
37                .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
38            if value > 4 {
39                return Err(Error::new(
40                    lit.span(),
41                    "stability level must be 0, 1, 2, 3, or 4",
42                ));
43            }
44            Ok(Self { value })
45        } else if lookahead.peek(Ident) {
46            let ident: Ident = input.parse()?;
47            let value = match ident.to_string().as_str() {
48                "ALPHA" => 0,
49                "BETA" => 1,
50                "GAMMA" => 2,
51                "DELTA" => 3,
52                "EPSILON" => 4,
53                _ => {
54                    return Err(Error::new(
55                        ident.span(),
56                        "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
57                    ));
58                }
59            };
60            Ok(Self { value })
61        } else {
62            Err(lookahead.error())
63        }
64    }
65}
66
67fn level_name(level: u8) -> &'static str {
68    match level {
69        0 => "ALPHA",
70        1 => "BETA",
71        2 => "GAMMA",
72        3 => "DELTA",
73        4 => "EPSILON",
74        _ => unreachable!(),
75    }
76}
77
78/// Generates cfg identifiers that should exclude an item at the given stability level.
79///
80/// The stability system works by excluding items when building at higher stability levels.
81/// For example, an item marked `#[stability(BETA)]` (level 1) should be excluded when
82/// building with `--cfg commonware_stability_GAMMA` (level 2) or higher.
83///
84/// This function returns identifiers for all levels above the given level, plus `RESERVED`.
85/// The generated `#[cfg(not(any(...)))]` attribute ensures the item is included only when
86/// none of the higher-level cfgs are set.
87///
88/// ```text
89/// Level 0 (ALPHA)   -> excludes at: BETA, GAMMA, DELTA, EPSILON, RESERVED
90/// Level 1 (BETA)    -> excludes at: GAMMA, DELTA, EPSILON, RESERVED
91/// Level 2 (GAMMA)   -> excludes at: DELTA, EPSILON, RESERVED
92/// Level 3 (DELTA)   -> excludes at: EPSILON, RESERVED
93/// Level 4 (EPSILON) -> excludes at: RESERVED
94/// ```
95///
96/// `RESERVED` is a special level used by `scripts/find_unstable_public.sh` to exclude ALL
97/// stability-marked items, leaving only unmarked public items visible in rustdoc output.
98fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
99    let mut names: Vec<_> = ((level + 1)..=4)
100        .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
101        .collect();
102
103    names.push(format_ident!("commonware_stability_RESERVED"));
104    names
105}
106
107#[proc_macro_attribute]
108pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
109    let level = parse_macro_input!(attr as StabilityLevel);
110    let exclude_names = exclusion_cfg_names(level.value);
111
112    let item2: proc_macro2::TokenStream = item.into();
113    let expanded = quote! {
114        #[cfg(not(any(#(#exclude_names),*)))]
115        #item2
116    };
117
118    TokenStream::from(expanded)
119}
120
121/// Input for the `stability_mod!` macro: `level, visibility mod name`
122struct StabilityModInput {
123    level: StabilityLevel,
124    visibility: Visibility,
125    name: Ident,
126}
127
128impl Parse for StabilityModInput {
129    fn parse(input: ParseStream<'_>) -> Result<Self> {
130        let level: StabilityLevel = input.parse()?;
131        input.parse::<Token![,]>()?;
132        let visibility: Visibility = input.parse()?;
133        input.parse::<Token![mod]>()?;
134        let name: Ident = input.parse()?;
135        Ok(Self {
136            level,
137            visibility,
138            name,
139        })
140    }
141}
142
143#[proc_macro]
144pub fn stability_mod(input: TokenStream) -> TokenStream {
145    let StabilityModInput {
146        level,
147        visibility,
148        name,
149    } = parse_macro_input!(input as StabilityModInput);
150
151    let exclude_names = exclusion_cfg_names(level.value);
152
153    let expanded = quote! {
154        #[cfg(not(any(#(#exclude_names),*)))]
155        #visibility mod #name;
156    };
157
158    TokenStream::from(expanded)
159}
160
161/// Input for the `stability_scope!` macro: `level [, cfg(predicate)] { items... }`
162struct StabilityScopeInput {
163    level: StabilityLevel,
164    predicate: Option<syn::Meta>,
165    items: Vec<syn::Item>,
166}
167
168impl Parse for StabilityScopeInput {
169    fn parse(input: ParseStream<'_>) -> Result<Self> {
170        let level: StabilityLevel = input.parse()?;
171
172        // Check for optional cfg predicate
173        let predicate = if input.peek(Token![,]) {
174            input.parse::<Token![,]>()?;
175
176            // Parse `cfg(...)` - expect the literal identifier "cfg" followed by parenthesized content
177            let cfg_ident: Ident = input.parse()?;
178            if cfg_ident != "cfg" {
179                return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
180            }
181            let cfg_content;
182            syn::parenthesized!(cfg_content in input);
183            Some(cfg_content.parse()?)
184        } else {
185            None
186        };
187
188        let content;
189        braced!(content in input);
190
191        let mut items = Vec::new();
192        while !content.is_empty() {
193            items.push(content.parse()?);
194        }
195
196        Ok(Self {
197            level,
198            predicate,
199            items,
200        })
201    }
202}
203
204#[proc_macro]
205pub fn stability_scope(input: TokenStream) -> TokenStream {
206    let StabilityScopeInput {
207        level,
208        predicate,
209        items,
210    } = parse_macro_input!(input as StabilityScopeInput);
211
212    let exclude_names = exclusion_cfg_names(level.value);
213
214    let cfg_attr = predicate.map_or_else(
215        || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
216        |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
217    );
218
219    let expanded_items: Vec<_> = items
220        .into_iter()
221        .map(|item| {
222            quote! {
223                #cfg_attr
224                #item
225            }
226        })
227        .collect();
228
229    let expanded = quote! {
230        #(#expanded_items)*
231    };
232
233    TokenStream::from(expanded)
234}
235
236#[proc_macro_attribute]
237pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
238    // Parse the input tokens into a syntax tree
239    let input = parse_macro_input!(item as ItemFn);
240
241    // Extract function components
242    let attrs = input.attrs;
243    let vis = input.vis;
244    let mut sig = input.sig;
245    let block = input.block;
246
247    // Remove 'async' from the function signature (#[test] only
248    // accepts sync functions)
249    sig.asyncness
250        .take()
251        .expect("test_async macro can only be used with async functions");
252
253    // Generate output tokens
254    let expanded = quote! {
255        #[test]
256        #(#attrs)*
257        #vis #sig {
258            futures::executor::block_on(async #block);
259        }
260    };
261    TokenStream::from(expanded)
262}
263
264#[proc_macro_attribute]
265pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
266    // Parse the input tokens into a syntax tree
267    let input = parse_macro_input!(item as ItemFn);
268
269    // Parse the attribute argument for default log level
270    let default_level = if attr.is_empty() {
271        "debug".to_string()
272    } else {
273        let level_str = parse_macro_input!(attr as LitStr);
274        let level_ident = level_str.value().to_lowercase();
275        match level_ident.as_str() {
276            "trace" | "debug" | "info" | "warn" | "error" => level_ident,
277            _ => {
278                return Error::new_spanned(
279                    level_str,
280                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
281                )
282                .to_compile_error()
283                .into();
284            }
285        }
286    };
287
288    // Extract function components
289    let attrs = input.attrs;
290    let vis = input.vis;
291    let sig = input.sig;
292    let block = input.block;
293
294    // Generate output tokens
295    let expanded = quote! {
296        #[test]
297        #(#attrs)*
298        #vis #sig {
299            use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
300
301            // Use RUST_LOG if set, otherwise fall back to the macro's default level
302            let filter = EnvFilter::try_from_default_env()
303                .unwrap_or_else(|_| EnvFilter::new(#default_level));
304            let subscriber = tracing_subscriber::Registry::default()
305                .with(
306                    tracing_subscriber::fmt::layer()
307                        .with_test_writer()
308                        .with_line_number(true)
309                        .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
310                )
311                .with(filter);
312            let dispatcher = tracing::Dispatch::new(subscriber);
313
314            // Set the subscriber for the scope of the test
315            tracing::dispatcher::with_default(&dispatcher, || {
316                #block
317            });
318        }
319    };
320    TokenStream::from(expanded)
321}
322
323#[proc_macro_attribute]
324pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
325    if attr.is_empty() {
326        return Error::new(
327            Span::call_site(),
328            "test_group requires a string literal filter group name",
329        )
330        .to_compile_error()
331        .into();
332    }
333
334    let mut input = parse_macro_input!(item as ItemFn);
335    let group_literal = parse_macro_input!(attr as LitStr);
336
337    let group = match nextest::sanitize_group_literal(&group_literal) {
338        Ok(group) => group,
339        Err(err) => return err.to_compile_error().into(),
340    };
341    let groups = match configured_test_groups() {
342        Ok(groups) => groups,
343        Err(_) => {
344            // Don't fail the compilation if the file isn't found; just return the original input.
345            return TokenStream::from(quote!(#input));
346        }
347    };
348
349    if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
350        return err.to_compile_error().into();
351    }
352
353    let original_name = input.sig.ident.to_string();
354    let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
355
356    input.sig.ident = new_ident;
357
358    TokenStream::from(quote!(#input))
359}
360
361#[proc_macro_attribute]
362pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
363    let input = parse_macro_input!(item as ItemFn);
364
365    // Parse the attribute argument for log level
366    let log_level = if attr.is_empty() {
367        // Default log level is DEBUG
368        quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
369    } else {
370        // Parse the attribute as a string literal
371        let level_str = parse_macro_input!(attr as LitStr);
372        let level_ident = level_str.value().to_uppercase();
373        match level_ident.as_str() {
374            "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
375            "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
376            "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
377            "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
378            "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
379            _ => {
380                // Return a compile error for invalid log levels
381                return Error::new_spanned(
382                    level_str,
383                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
384                )
385                .to_compile_error()
386                .into();
387            }
388        }
389    };
390
391    let attrs = input.attrs;
392    let vis = input.vis;
393    let sig = input.sig;
394    let block = input.block;
395
396    // Create the signature of the inner function that takes the TraceStorage.
397    let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
398    let mut inner_sig = sig.clone();
399    inner_sig.ident = inner_ident.clone();
400
401    // Create the signature of the outer test function.
402    let mut outer_sig = sig;
403    outer_sig.inputs.clear();
404
405    // Detect the path of the `commonware-runtime` crate. If it has been renamed or
406    // this macro is being used within the `commonware-runtime` crate itself, adjust
407    // the path accordingly.
408    let rt_path = match crate_name("commonware-runtime") {
409        Ok(FoundCrate::Itself) => quote!(crate),
410        Ok(FoundCrate::Name(name)) => {
411            let ident = syn::Ident::new(&name, Span::call_site());
412            quote!(#ident)
413        }
414        Err(_) => quote!(::commonware_runtime), // fallback
415    };
416
417    let expanded = quote! {
418        // Inner test function runs the actual test logic, accepting the TraceStorage
419        // created by the harness.
420        #(#attrs)*
421        #vis #inner_sig #block
422
423        #[test]
424        #vis #outer_sig {
425            use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
426            use ::tracing::{Dispatch, dispatcher};
427            use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
428
429            let trace_store = TraceStorage::default();
430            let collecting_layer = CollectingLayer::new(trace_store.clone());
431
432            let fmt_layer = fmt::layer()
433                .with_test_writer()
434                .with_line_number(true)
435                .with_span_events(fmt::format::FmtSpan::CLOSE)
436                .with_filter(#log_level);
437
438            let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
439            let dispatcher = Dispatch::new(subscriber);
440            dispatcher::with_default(&dispatcher, || {
441                #inner_ident(trace_store);
442            });
443        }
444    };
445
446    TokenStream::from(expanded)
447}
448
449struct SelectInput {
450    branches: Vec<Branch>,
451}
452
453struct Branch {
454    pattern: Pat,
455    future: Expr,
456    body: Expr,
457}
458
459/// Branch for [select_loop!] with optional `else` clause for `Some` patterns.
460struct SelectLoopBranch {
461    pattern: Pat,
462    future: Expr,
463    else_body: Option<Expr>,
464    body: Expr,
465}
466
467impl Parse for SelectInput {
468    fn parse(input: ParseStream<'_>) -> Result<Self> {
469        let mut branches = Vec::new();
470
471        while !input.is_empty() {
472            let pattern = Pat::parse_single(input)?;
473            input.parse::<Token![=]>()?;
474            let future: Expr = input.parse()?;
475            input.parse::<Token![=>]>()?;
476            let body: Expr = input.parse()?;
477
478            branches.push(Branch {
479                pattern,
480                future,
481                body,
482            });
483
484            if input.peek(Token![,]) {
485                input.parse::<Token![,]>()?;
486            } else {
487                break;
488            }
489        }
490
491        Ok(Self { branches })
492    }
493}
494
495#[proc_macro]
496pub fn select(input: TokenStream) -> TokenStream {
497    // Parse the input tokens
498    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
499
500    // Generate code from provided statements
501    let mut select_branches = Vec::new();
502    for Branch {
503        pattern,
504        future,
505        body,
506    } in branches.into_iter()
507    {
508        // Generate branch for `select!` macro
509        let branch_code = quote! {
510            #pattern = #future => #body,
511        };
512        select_branches.push(branch_code);
513    }
514
515    // Generate the final output code
516    quote! {
517        {
518            ::commonware_macros::__reexport::tokio::select! {
519                biased;
520                #(#select_branches)*
521            }
522        }
523    }
524    .into()
525}
526
527/// Input for [select_loop!].
528///
529/// Parses: `context, [on_start => expr,] on_stopped => expr, branches... [, on_end => expr]`
530struct SelectLoopInput {
531    context: Expr,
532    start_expr: Option<Expr>,
533    shutdown_expr: Expr,
534    branches: Vec<SelectLoopBranch>,
535    end_expr: Option<Expr>,
536}
537
538impl Parse for SelectLoopInput {
539    fn parse(input: ParseStream<'_>) -> Result<Self> {
540        // Parse context expression
541        let context: Expr = input.parse()?;
542        input.parse::<Token![,]>()?;
543
544        // Check for optional `on_start =>`
545        let start_expr = if input.peek(Ident) {
546            let ident: Ident = input.fork().parse()?;
547            if ident == "on_start" {
548                input.parse::<Ident>()?; // consume the ident
549                input.parse::<Token![=>]>()?;
550                let expr: Expr = input.parse()?;
551                input.parse::<Token![,]>()?;
552                Some(expr)
553            } else {
554                None
555            }
556        } else {
557            None
558        };
559
560        // Parse `on_stopped =>`
561        let on_stopped_ident: Ident = input.parse()?;
562        if on_stopped_ident != "on_stopped" {
563            return Err(Error::new(
564                on_stopped_ident.span(),
565                "expected `on_stopped` keyword",
566            ));
567        }
568        input.parse::<Token![=>]>()?;
569
570        // Parse shutdown expression
571        let shutdown_expr: Expr = input.parse()?;
572
573        // Parse comma after shutdown expression
574        input.parse::<Token![,]>()?;
575
576        // Parse branches directly (no surrounding braces)
577        // Stop when we see `on_end` or reach end of input
578        let mut branches = Vec::new();
579        while !input.is_empty() {
580            // Check if next token is `on_end`
581            if input.peek(Ident) {
582                let ident: Ident = input.fork().parse()?;
583                if ident == "on_end" {
584                    break;
585                }
586            }
587
588            let pattern = Pat::parse_single(input)?;
589            input.parse::<Token![=]>()?;
590            let future: Expr = input.parse()?;
591
592            // Parse optional else clause: `else expr`
593            let else_body = if input.peek(Token![else]) {
594                input.parse::<Token![else]>()?;
595                Some(input.parse::<Expr>()?)
596            } else {
597                None
598            };
599
600            input.parse::<Token![=>]>()?;
601            let body: Expr = input.parse()?;
602
603            branches.push(SelectLoopBranch {
604                pattern,
605                future,
606                else_body,
607                body,
608            });
609
610            if input.peek(Token![,]) {
611                input.parse::<Token![,]>()?;
612            } else {
613                break;
614            }
615        }
616
617        // Check for optional `on_end =>`
618        let end_expr = if !input.is_empty() && input.peek(Ident) {
619            let ident: Ident = input.parse()?;
620            if ident == "on_end" {
621                input.parse::<Token![=>]>()?;
622                let expr: Expr = input.parse()?;
623                if input.peek(Token![,]) {
624                    input.parse::<Token![,]>()?;
625                }
626                Some(expr)
627            } else {
628                return Err(Error::new(ident.span(), "expected `on_end` keyword"));
629            }
630        } else {
631            None
632        };
633
634        Ok(Self {
635            context,
636            start_expr,
637            shutdown_expr,
638            branches,
639            end_expr,
640        })
641    }
642}
643
644#[proc_macro]
645pub fn select_loop(input: TokenStream) -> TokenStream {
646    let SelectLoopInput {
647        context,
648        start_expr,
649        shutdown_expr,
650        branches,
651        end_expr,
652    } = parse_macro_input!(input as SelectLoopInput);
653
654    fn is_irrefutable(pat: &Pat) -> bool {
655        match pat {
656            Pat::Wild(_) | Pat::Rest(_) => true,
657            Pat::Ident(i) => i.subpat.as_ref().is_none_or(|(_, p)| is_irrefutable(p)),
658            Pat::Type(t) => is_irrefutable(&t.pat),
659            Pat::Tuple(t) => t.elems.iter().all(is_irrefutable),
660            Pat::Reference(r) => is_irrefutable(&r.pat),
661            Pat::Paren(p) => is_irrefutable(&p.pat),
662            _ => false,
663        }
664    }
665
666    for b in &branches {
667        if b.else_body.is_none() && !is_irrefutable(&b.pattern) {
668            return Error::new_spanned(
669                &b.pattern,
670                "refutable patterns require an else clause: \
671                 `Some(msg) = future else break => { ... }`",
672            )
673            .to_compile_error()
674            .into();
675        }
676    }
677
678    // Convert branches to tokens for the inner select!
679    let branch_tokens: Vec<_> = branches
680        .iter()
681        .map(|b| {
682            let pattern = &b.pattern;
683            let future = &b.future;
684            let body = &b.body;
685
686            // If else clause is present, use let-else to unwrap
687            b.else_body.as_ref().map_or_else(
688                // No else: normal pattern binding (already validated as irrefutable)
689                || quote! { #pattern = #future => #body, },
690                // With else: use let-else for refutable patterns
691                |else_expr| {
692                    quote! {
693                        __select_result = #future => {
694                            let #pattern = __select_result else { #else_expr };
695                            #body
696                        },
697                    }
698                },
699            )
700        })
701        .collect();
702
703    // Helper to convert an expression to tokens, inlining block contents
704    // to preserve variable scope
705    fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
706        match expr {
707            Expr::Block(block) => {
708                let stmts = &block.block.stmts;
709                quote! { #(#stmts)* }
710            }
711            other => quote! { #other; },
712        }
713    }
714
715    // Generate on_start and on_end tokens if present
716    let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
717    let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
718    let shutdown_tokens = expr_to_tokens(&shutdown_expr);
719
720    quote! {
721        {
722            let mut shutdown = #context.stopped();
723            loop {
724                #on_start_tokens
725
726                commonware_macros::select! {
727                    _ = &mut shutdown => {
728                        #shutdown_tokens
729
730                        // Break the loop after handling shutdown. Some implementations
731                        // may divert control flow themselves, so this may be unused.
732                        #[allow(unreachable_code)]
733                        break;
734                    },
735                    #(#branch_tokens)*
736                }
737
738                #on_end_tokens
739            }
740        }
741    }
742    .into()
743}