commonware_macros/
lib.rs

1//! Augment the development of primitives with procedural macros.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6    parse::{Parse, ParseStream, Result},
7    parse_macro_input,
8    spanned::Spanned,
9    Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
10};
11
12/// Run a test function asynchronously.
13///
14/// This macro is powered by the [futures](https://docs.rs/futures) crate
15/// and is not bound to a particular executor or context.
16///
17/// # Example
18/// ```rust
19/// use commonware_macros::test_async;
20///
21/// #[test_async]
22/// async fn test_async_fn() {
23///    assert_eq!(2 + 2, 4);
24/// }
25/// ```
26#[proc_macro_attribute]
27pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
28    // Parse the input tokens into a syntax tree
29    let input = parse_macro_input!(item as ItemFn);
30
31    // Extract function components
32    let attrs = input.attrs;
33    let vis = input.vis;
34    let mut sig = input.sig;
35    let block = input.block;
36
37    // Remove 'async' from the function signature (#[test] only
38    // accepts sync functions)
39    sig.asyncness
40        .take()
41        .expect("test_async macro can only be used with async functions");
42
43    // Generate output tokens
44    let expanded = quote! {
45        #[test]
46        #(#attrs)*
47        #vis #sig {
48            futures::executor::block_on(async #block);
49        }
50    };
51    TokenStream::from(expanded)
52}
53
54/// Capture logs (based on the provided log level) from a test run using
55/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
56///
57/// This macro defaults to a log level of `DEBUG` if no level is provided.
58///
59/// This macro is powered by the [tracing](https://docs.rs/tracing) and
60/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
61///
62/// # Example
63/// ```rust
64/// use commonware_macros::test_traced;
65/// use tracing::{debug, info};
66///
67/// #[test_traced("INFO")]
68/// fn test_info_level() {
69///     info!("This is an info log");
70///     debug!("This is a debug log (won't be shown)");
71///     assert_eq!(2 + 2, 4);
72/// }
73/// ```
74#[proc_macro_attribute]
75pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
76    // Parse the input tokens into a syntax tree
77    let input = parse_macro_input!(item as ItemFn);
78
79    // Parse the attribute argument for log level
80    let log_level = if attr.is_empty() {
81        // Default log level is DEBUG
82        quote! { tracing::Level::DEBUG }
83    } else {
84        // Parse the attribute as a string literal
85        let level_str = parse_macro_input!(attr as LitStr);
86        let level_ident = level_str.value().to_uppercase();
87        match level_ident.as_str() {
88            "TRACE" => quote! { tracing::Level::TRACE },
89            "DEBUG" => quote! { tracing::Level::DEBUG },
90            "INFO" => quote! { tracing::Level::INFO },
91            "WARN" => quote! { tracing::Level::WARN },
92            "ERROR" => quote! { tracing::Level::ERROR },
93            _ => {
94                // Return a compile error for invalid log levels
95                return Error::new_spanned(
96                    level_str,
97                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
98                )
99                .to_compile_error()
100                .into();
101            }
102        }
103    };
104
105    // Extract function components
106    let attrs = input.attrs;
107    let vis = input.vis;
108    let sig = input.sig;
109    let block = input.block;
110
111    // Generate output tokens
112    let expanded = quote! {
113        #[test]
114        #(#attrs)*
115        #vis #sig {
116            // Create a subscriber and dispatcher with the specified log level
117            let subscriber = tracing_subscriber::fmt()
118                .with_test_writer()
119                .with_max_level(#log_level)
120                .with_line_number(true)
121                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
122                .finish();
123            let dispatcher = tracing::Dispatch::new(subscriber);
124
125            // Set the subscriber for the scope of the test
126            tracing::dispatcher::with_default(&dispatcher, || {
127                #block
128            });
129        }
130    };
131    TokenStream::from(expanded)
132}
133
134struct SelectInput {
135    branches: Vec<Branch>,
136}
137
138struct Branch {
139    pattern: Pat,
140    future: Expr,
141    block: Block,
142}
143
144impl Parse for SelectInput {
145    fn parse(input: ParseStream) -> Result<Self> {
146        let mut branches = Vec::new();
147
148        while !input.is_empty() {
149            let pattern: Pat = input.parse()?;
150            input.parse::<Token![=]>()?;
151            let future: Expr = input.parse()?;
152            input.parse::<Token![=>]>()?;
153            let block: Block = input.parse()?;
154
155            branches.push(Branch {
156                pattern,
157                future,
158                block,
159            });
160
161            if input.peek(Token![,]) {
162                input.parse::<Token![,]>()?;
163            } else {
164                break;
165            }
166        }
167
168        Ok(SelectInput { branches })
169    }
170}
171
172/// Select the first future that completes (biased by order).
173///
174/// This macro is powered by the [futures](https://docs.rs/futures) crate
175/// and is not bound to a particular executor or context.
176///
177/// # Fusing
178///
179/// This macro handles both the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
180/// and [pinning](https://docs.rs/futures/latest/futures/macro.pin_mut.html) of (fused) futures in
181/// a `select`-specific scope.
182///
183/// # Example
184///
185/// ```rust
186/// use std::time::Duration;
187/// use commonware_macros::select;
188/// use futures::executor::block_on;
189/// use futures_timer::Delay;
190///
191/// async fn task() -> usize {
192///     42
193/// }
194//
195/// block_on(async move {
196///     select! {
197///         _ = Delay::new(Duration::from_secs(1)) => {
198///             println!("timeout fired");
199///         },
200///         v = task() => {
201///             println!("task completed with value: {}", v);
202///         },
203///     };
204/// });
205/// ```
206#[proc_macro]
207pub fn select(input: TokenStream) -> TokenStream {
208    // Parse the input tokens
209    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
210
211    // Generate code from provided statements
212    let mut stmts = Vec::new();
213    let mut select_branches = Vec::new();
214    for (
215        index,
216        Branch {
217            pattern,
218            future,
219            block,
220        },
221    ) in branches.into_iter().enumerate()
222    {
223        // Generate a unique identifier for each future
224        let future_ident = Ident::new(&format!("__select_future_{index}"), pattern.span());
225
226        // Fuse and pin each future
227        let stmt = quote! {
228            let #future_ident = (#future).fuse();
229            futures::pin_mut!(#future_ident);
230        };
231        stmts.push(stmt);
232
233        // Generate branch for `select_biased!` macro
234        let branch_code = quote! {
235            #pattern = #future_ident => #block,
236        };
237        select_branches.push(branch_code);
238    }
239
240    // Generate the final output code
241    quote! {
242        {
243            use futures::FutureExt as _;
244            #(#stmts)*
245
246            futures::select_biased! {
247                #(#select_branches)*
248            }
249        }
250    }
251    .into()
252}