commonware_macros/
lib.rs

1//! Augment the development of primitives with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use crate::nextest::configured_test_groups;
9use proc_macro::TokenStream;
10use proc_macro2::Span;
11use proc_macro_crate::{crate_name, FoundCrate};
12use quote::{format_ident, quote, ToTokens};
13use syn::{
14    parse::{Parse, ParseStream, Result},
15    parse_macro_input, Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
16};
17
18mod nextest;
19
20/// Run a test function asynchronously.
21///
22/// This macro is powered by the [futures](https://docs.rs/futures) crate
23/// and is not bound to a particular executor or context.
24///
25/// # Example
26/// ```rust
27/// use commonware_macros::test_async;
28///
29/// #[test_async]
30/// async fn test_async_fn() {
31///    assert_eq!(2 + 2, 4);
32/// }
33/// ```
34#[proc_macro_attribute]
35pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
36    // Parse the input tokens into a syntax tree
37    let input = parse_macro_input!(item as ItemFn);
38
39    // Extract function components
40    let attrs = input.attrs;
41    let vis = input.vis;
42    let mut sig = input.sig;
43    let block = input.block;
44
45    // Remove 'async' from the function signature (#[test] only
46    // accepts sync functions)
47    sig.asyncness
48        .take()
49        .expect("test_async macro can only be used with async functions");
50
51    // Generate output tokens
52    let expanded = quote! {
53        #[test]
54        #(#attrs)*
55        #vis #sig {
56            futures::executor::block_on(async #block);
57        }
58    };
59    TokenStream::from(expanded)
60}
61
62/// Capture logs (based on the provided log level) from a test run using
63/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
64///
65/// This macro defaults to a log level of `DEBUG` if no level is provided.
66///
67/// This macro is powered by the [tracing](https://docs.rs/tracing) and
68/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
69///
70/// # Example
71/// ```rust
72/// use commonware_macros::test_traced;
73/// use tracing::{debug, info};
74///
75/// #[test_traced("INFO")]
76/// fn test_info_level() {
77///     info!("This is an info log");
78///     debug!("This is a debug log (won't be shown)");
79///     assert_eq!(2 + 2, 4);
80/// }
81/// ```
82#[proc_macro_attribute]
83pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
84    // Parse the input tokens into a syntax tree
85    let input = parse_macro_input!(item as ItemFn);
86
87    // Parse the attribute argument for log level
88    let log_level = if attr.is_empty() {
89        // Default log level is DEBUG
90        quote! { tracing::Level::DEBUG }
91    } else {
92        // Parse the attribute as a string literal
93        let level_str = parse_macro_input!(attr as LitStr);
94        let level_ident = level_str.value().to_uppercase();
95        match level_ident.as_str() {
96            "TRACE" => quote! { tracing::Level::TRACE },
97            "DEBUG" => quote! { tracing::Level::DEBUG },
98            "INFO" => quote! { tracing::Level::INFO },
99            "WARN" => quote! { tracing::Level::WARN },
100            "ERROR" => quote! { tracing::Level::ERROR },
101            _ => {
102                // Return a compile error for invalid log levels
103                return Error::new_spanned(
104                    level_str,
105                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
106                )
107                .to_compile_error()
108                .into();
109            }
110        }
111    };
112
113    // Extract function components
114    let attrs = input.attrs;
115    let vis = input.vis;
116    let sig = input.sig;
117    let block = input.block;
118
119    // Generate output tokens
120    let expanded = quote! {
121        #[test]
122        #(#attrs)*
123        #vis #sig {
124            // Create a subscriber and dispatcher with the specified log level
125            let subscriber = tracing_subscriber::fmt()
126                .with_test_writer()
127                .with_max_level(#log_level)
128                .with_line_number(true)
129                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
130                .finish();
131            let dispatcher = tracing::Dispatch::new(subscriber);
132
133            // Set the subscriber for the scope of the test
134            tracing::dispatcher::with_default(&dispatcher, || {
135                #block
136            });
137        }
138    };
139    TokenStream::from(expanded)
140}
141
142/// Prefix a test name with a nextest filter group.
143///
144/// This renames `test_some_behavior` into `test_some_behavior_<group>_`, making
145/// it easy to filter tests by group postfixes in nextest.
146#[proc_macro_attribute]
147pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
148    if attr.is_empty() {
149        return Error::new(
150            Span::call_site(),
151            "test_group requires a string literal filter group name",
152        )
153        .to_compile_error()
154        .into();
155    }
156
157    let mut input = parse_macro_input!(item as ItemFn);
158    let group_literal = parse_macro_input!(attr as LitStr);
159
160    let group = match nextest::sanitize_group_literal(&group_literal) {
161        Ok(group) => group,
162        Err(err) => return err.to_compile_error().into(),
163    };
164    let groups = match configured_test_groups() {
165        Ok(groups) => groups,
166        Err(_) => {
167            // Don't fail the compilation if the file isn't found; just return the original input.
168            return TokenStream::from(quote!(#input));
169        }
170    };
171
172    if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
173        return err.to_compile_error().into();
174    }
175
176    let original_name = input.sig.ident.to_string();
177    let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
178
179    input.sig.ident = new_ident;
180
181    TokenStream::from(quote!(#input))
182}
183
184/// Capture logs from a test run into an in-memory store.
185///
186/// This macro defaults to a log level of `DEBUG` on the [mod@tracing_subscriber::fmt] layer if no level is provided.
187///
188/// This macro is powered by the [tracing](https://docs.rs/tracing),
189/// [tracing-subscriber](https://docs.rs/tracing-subscriber), and
190/// [commonware-runtime](https://docs.rs/commonware-runtime) crates.
191///
192/// # Note
193///
194/// This macro requires the resolution of the `commonware-runtime`, `tracing`, and `tracing_subscriber` crates.
195///
196/// # Example
197/// ```rust,ignore
198/// use commonware_macros::test_collect_traces;
199/// use commonware_runtime::telemetry::traces::collector::TraceStorage;
200/// use tracing::{debug, info};
201///
202/// #[test_collect_traces("INFO")]
203/// fn test_info_level(traces: TraceStorage) {
204///     // Filter applies to console output (FmtLayer)
205///     info!("This is an info log");
206///     debug!("This is a debug log (won't be shown in console output)");
207///
208///     // All traces are collected, regardless of level, by the CollectingLayer.
209///     assert_eq!(traces.get_all().len(), 2);
210/// }
211/// ```
212#[proc_macro_attribute]
213pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
214    let input = parse_macro_input!(item as ItemFn);
215
216    // Parse the attribute argument for log level
217    let log_level = if attr.is_empty() {
218        // Default log level is DEBUG
219        quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
220    } else {
221        // Parse the attribute as a string literal
222        let level_str = parse_macro_input!(attr as LitStr);
223        let level_ident = level_str.value().to_uppercase();
224        match level_ident.as_str() {
225            "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
226            "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
227            "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
228            "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
229            "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
230            _ => {
231                // Return a compile error for invalid log levels
232                return Error::new_spanned(
233                    level_str,
234                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
235                )
236                .to_compile_error()
237                .into();
238            }
239        }
240    };
241
242    let attrs = input.attrs;
243    let vis = input.vis;
244    let sig = input.sig;
245    let block = input.block;
246
247    // Create the signature of the inner function that takes the TraceStorage.
248    let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
249    let mut inner_sig = sig.clone();
250    inner_sig.ident = inner_ident.clone();
251
252    // Create the signature of the outer test function.
253    let mut outer_sig = sig;
254    outer_sig.inputs.clear();
255
256    // Detect the path of the `commonware-runtime` crate. If it has been renamed or
257    // this macro is being used within the `commonware-runtime` crate itself, adjust
258    // the path accordingly.
259    let rt_path = match crate_name("commonware-runtime") {
260        Ok(FoundCrate::Itself) => quote!(crate),
261        Ok(FoundCrate::Name(name)) => {
262            let ident = syn::Ident::new(&name, Span::call_site());
263            quote!(#ident)
264        }
265        Err(_) => quote!(::commonware_runtime), // fallback
266    };
267
268    let expanded = quote! {
269        // Inner test function runs the actual test logic, accepting the TraceStorage
270        // created by the harness.
271        #(#attrs)*
272        #vis #inner_sig #block
273
274        #[test]
275        #vis #outer_sig {
276            use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
277            use ::tracing::{Dispatch, dispatcher};
278            use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
279
280            let trace_store = TraceStorage::default();
281            let collecting_layer = CollectingLayer::new(trace_store.clone());
282
283            let fmt_layer = fmt::layer()
284                .with_test_writer()
285                .with_line_number(true)
286                .with_span_events(fmt::format::FmtSpan::CLOSE)
287                .with_filter(#log_level);
288
289            let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
290            let dispatcher = Dispatch::new(subscriber);
291            dispatcher::with_default(&dispatcher, || {
292                #inner_ident(trace_store);
293            });
294        }
295    };
296
297    TokenStream::from(expanded)
298}
299
300struct SelectInput {
301    branches: Vec<Branch>,
302}
303
304struct Branch {
305    pattern: Pat,
306    future: Expr,
307    block: Block,
308}
309
310impl Parse for SelectInput {
311    fn parse(input: ParseStream<'_>) -> Result<Self> {
312        let mut branches = Vec::new();
313
314        while !input.is_empty() {
315            let pattern = Pat::parse_single(input)?;
316            input.parse::<Token![=]>()?;
317            let future: Expr = input.parse()?;
318            input.parse::<Token![=>]>()?;
319            let block: Block = input.parse()?;
320
321            branches.push(Branch {
322                pattern,
323                future,
324                block,
325            });
326
327            if input.peek(Token![,]) {
328                input.parse::<Token![,]>()?;
329            } else {
330                break;
331            }
332        }
333
334        Ok(Self { branches })
335    }
336}
337
338impl ToTokens for SelectInput {
339    fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
340        for branch in &self.branches {
341            let pattern = &branch.pattern;
342            let future = &branch.future;
343            let block = &branch.block;
344
345            tokens.extend(quote! {
346                #pattern = #future => #block,
347            });
348        }
349    }
350}
351
352/// Select the first future that completes (biased by order).
353///
354/// This macro is powered by the [futures](https://docs.rs/futures) crate
355/// and is not bound to a particular executor or context.
356///
357/// # Fusing
358///
359/// This macro handles the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
360/// futures in a `select`-specific scope.
361///
362/// # Example
363///
364/// ```rust
365/// use std::time::Duration;
366/// use commonware_macros::select;
367/// use futures::executor::block_on;
368/// use futures_timer::Delay;
369///
370/// async fn task() -> usize {
371///     42
372/// }
373//
374/// block_on(async move {
375///     select! {
376///         _ = Delay::new(Duration::from_secs(1)) => {
377///             println!("timeout fired");
378///         },
379///         v = task() => {
380///             println!("task completed with value: {}", v);
381///         },
382///     };
383/// });
384/// ```
385#[proc_macro]
386pub fn select(input: TokenStream) -> TokenStream {
387    // Parse the input tokens
388    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
389
390    // Generate code from provided statements
391    let mut select_branches = Vec::new();
392    for Branch {
393        pattern,
394        future,
395        block,
396    } in branches.into_iter()
397    {
398        // Generate branch for `select_biased!` macro
399        let branch_code = quote! {
400            #pattern = (#future).fuse() => #block,
401        };
402        select_branches.push(branch_code);
403    }
404
405    // Generate the final output code
406    quote! {
407        {
408            use futures::FutureExt as _;
409
410            futures::select_biased! {
411                #(#select_branches)*
412            }
413        }
414    }
415    .into()
416}
417
418/// Input for [select_loop!].
419///
420/// Parses: `context, on_stopped => { block }, { branches... }`
421struct SelectLoopInput {
422    context: Expr,
423    shutdown_block: Block,
424    branches: Vec<Branch>,
425}
426
427impl Parse for SelectLoopInput {
428    fn parse(input: ParseStream<'_>) -> Result<Self> {
429        // Parse context expression
430        let context: Expr = input.parse()?;
431        input.parse::<Token![,]>()?;
432
433        // Parse `on_stopped =>`
434        let on_stopped_ident: Ident = input.parse()?;
435        if on_stopped_ident != "on_stopped" {
436            return Err(Error::new(
437                on_stopped_ident.span(),
438                "expected `on_stopped` keyword",
439            ));
440        }
441        input.parse::<Token![=>]>()?;
442
443        // Parse shutdown block
444        let shutdown_block: Block = input.parse()?;
445
446        // Parse comma after shutdown block
447        input.parse::<Token![,]>()?;
448
449        // Parse branches directly (no surrounding braces)
450        let mut branches = Vec::new();
451        while !input.is_empty() {
452            let pattern = Pat::parse_single(input)?;
453            input.parse::<Token![=]>()?;
454            let future: Expr = input.parse()?;
455            input.parse::<Token![=>]>()?;
456            let block: Block = input.parse()?;
457
458            branches.push(Branch {
459                pattern,
460                future,
461                block,
462            });
463
464            if input.peek(Token![,]) {
465                input.parse::<Token![,]>()?;
466            } else {
467                break;
468            }
469        }
470
471        Ok(Self {
472            context,
473            shutdown_block,
474            branches,
475        })
476    }
477}
478
479/// Convenience macro to continuously [select!] over a set of futures in biased order,
480/// with a required shutdown handler.
481///
482/// This macro automatically creates a shutdown future from the provided context and requires a
483/// shutdown handler block. The shutdown future is created outside the loop, allowing it to
484/// persist across iterations until shutdown is signaled. The shutdown branch is always checked
485/// first (biased).
486///
487/// After the shutdown block is executed, the loop breaks by default. If different control flow
488/// is desired (such as returning from the enclosing function), it must be handled explicitly.
489///
490/// # Syntax
491///
492/// ```rust,ignore
493/// select_loop! {
494///     context,
495///     on_stopped => { cleanup },
496///     pattern = future => block,
497///     // ...
498/// }
499/// ```
500///
501/// The `shutdown` variable (the future from `context.stopped()`) is accessible in the
502/// shutdown block, allowing explicit cleanup such as `drop(shutdown)` before breaking or returning.
503///
504/// # Example
505///
506/// ```rust,ignore
507/// use commonware_macros::select_loop;
508///
509/// async fn run(context: impl commonware_runtime::Spawner) {
510///     select_loop! {
511///         context,
512///         on_stopped => {
513///             println!("shutting down");
514///             drop(shutdown);
515///         },
516///         msg = receiver.recv() => {
517///             println!("received: {:?}", msg);
518///         },
519///     }
520/// }
521/// ```
522#[proc_macro]
523pub fn select_loop(input: TokenStream) -> TokenStream {
524    let SelectLoopInput {
525        context,
526        shutdown_block,
527        branches,
528    } = parse_macro_input!(input as SelectLoopInput);
529
530    // Convert branches to tokens for the inner select!
531    let branch_tokens: Vec<_> = branches
532        .iter()
533        .map(|b| {
534            let pattern = &b.pattern;
535            let future = &b.future;
536            let block = &b.block;
537            quote! { #pattern = #future => #block, }
538        })
539        .collect();
540
541    quote! {
542        {
543            let mut shutdown = #context.stopped();
544            loop {
545                commonware_macros::select! {
546                    _ = &mut shutdown => {
547                        #shutdown_block
548
549                        // Break the loop after handling shutdown. Some implementations
550                        // may divert control flow themselves, so this may be unused.
551                        #[allow(unreachable_code)]
552                        break;
553                    },
554                    #(#branch_tokens)*
555                }
556            }
557        }
558    }
559    .into()
560}