commonware_macros/
lib.rs

1//! Augment the development of primitives with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11    parse::{Parse, ParseStream, Result},
12    parse_macro_input,
13    spanned::Spanned,
14    Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
15};
16
17/// Run a test function asynchronously.
18///
19/// This macro is powered by the [futures](https://docs.rs/futures) crate
20/// and is not bound to a particular executor or context.
21///
22/// # Example
23/// ```rust
24/// use commonware_macros::test_async;
25///
26/// #[test_async]
27/// async fn test_async_fn() {
28///    assert_eq!(2 + 2, 4);
29/// }
30/// ```
31#[proc_macro_attribute]
32pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
33    // Parse the input tokens into a syntax tree
34    let input = parse_macro_input!(item as ItemFn);
35
36    // Extract function components
37    let attrs = input.attrs;
38    let vis = input.vis;
39    let mut sig = input.sig;
40    let block = input.block;
41
42    // Remove 'async' from the function signature (#[test] only
43    // accepts sync functions)
44    sig.asyncness
45        .take()
46        .expect("test_async macro can only be used with async functions");
47
48    // Generate output tokens
49    let expanded = quote! {
50        #[test]
51        #(#attrs)*
52        #vis #sig {
53            futures::executor::block_on(async #block);
54        }
55    };
56    TokenStream::from(expanded)
57}
58
59/// Capture logs (based on the provided log level) from a test run using
60/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
61///
62/// This macro defaults to a log level of `DEBUG` if no level is provided.
63///
64/// This macro is powered by the [tracing](https://docs.rs/tracing) and
65/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
66///
67/// # Example
68/// ```rust
69/// use commonware_macros::test_traced;
70/// use tracing::{debug, info};
71///
72/// #[test_traced("INFO")]
73/// fn test_info_level() {
74///     info!("This is an info log");
75///     debug!("This is a debug log (won't be shown)");
76///     assert_eq!(2 + 2, 4);
77/// }
78/// ```
79#[proc_macro_attribute]
80pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
81    // Parse the input tokens into a syntax tree
82    let input = parse_macro_input!(item as ItemFn);
83
84    // Parse the attribute argument for log level
85    let log_level = if attr.is_empty() {
86        // Default log level is DEBUG
87        quote! { tracing::Level::DEBUG }
88    } else {
89        // Parse the attribute as a string literal
90        let level_str = parse_macro_input!(attr as LitStr);
91        let level_ident = level_str.value().to_uppercase();
92        match level_ident.as_str() {
93            "TRACE" => quote! { tracing::Level::TRACE },
94            "DEBUG" => quote! { tracing::Level::DEBUG },
95            "INFO" => quote! { tracing::Level::INFO },
96            "WARN" => quote! { tracing::Level::WARN },
97            "ERROR" => quote! { tracing::Level::ERROR },
98            _ => {
99                // Return a compile error for invalid log levels
100                return Error::new_spanned(
101                    level_str,
102                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
103                )
104                .to_compile_error()
105                .into();
106            }
107        }
108    };
109
110    // Extract function components
111    let attrs = input.attrs;
112    let vis = input.vis;
113    let sig = input.sig;
114    let block = input.block;
115
116    // Generate output tokens
117    let expanded = quote! {
118        #[test]
119        #(#attrs)*
120        #vis #sig {
121            // Create a subscriber and dispatcher with the specified log level
122            let subscriber = tracing_subscriber::fmt()
123                .with_test_writer()
124                .with_max_level(#log_level)
125                .with_line_number(true)
126                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
127                .finish();
128            let dispatcher = tracing::Dispatch::new(subscriber);
129
130            // Set the subscriber for the scope of the test
131            tracing::dispatcher::with_default(&dispatcher, || {
132                #block
133            });
134        }
135    };
136    TokenStream::from(expanded)
137}
138
139struct SelectInput {
140    branches: Vec<Branch>,
141}
142
143struct Branch {
144    pattern: Pat,
145    future: Expr,
146    block: Block,
147}
148
149impl Parse for SelectInput {
150    fn parse(input: ParseStream) -> Result<Self> {
151        let mut branches = Vec::new();
152
153        while !input.is_empty() {
154            let pattern: Pat = input.parse()?;
155            input.parse::<Token![=]>()?;
156            let future: Expr = input.parse()?;
157            input.parse::<Token![=>]>()?;
158            let block: Block = input.parse()?;
159
160            branches.push(Branch {
161                pattern,
162                future,
163                block,
164            });
165
166            if input.peek(Token![,]) {
167                input.parse::<Token![,]>()?;
168            } else {
169                break;
170            }
171        }
172
173        Ok(SelectInput { branches })
174    }
175}
176
177/// Select the first future that completes (biased by order).
178///
179/// This macro is powered by the [futures](https://docs.rs/futures) crate
180/// and is not bound to a particular executor or context.
181///
182/// # Fusing
183///
184/// This macro handles both the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
185/// and [pinning](https://docs.rs/futures/latest/futures/macro.pin_mut.html) of (fused) futures in
186/// a `select`-specific scope.
187///
188/// # Example
189///
190/// ```rust
191/// use std::time::Duration;
192/// use commonware_macros::select;
193/// use futures::executor::block_on;
194/// use futures_timer::Delay;
195///
196/// async fn task() -> usize {
197///     42
198/// }
199//
200/// block_on(async move {
201///     select! {
202///         _ = Delay::new(Duration::from_secs(1)) => {
203///             println!("timeout fired");
204///         },
205///         v = task() => {
206///             println!("task completed with value: {}", v);
207///         },
208///     };
209/// });
210/// ```
211#[proc_macro]
212pub fn select(input: TokenStream) -> TokenStream {
213    // Parse the input tokens
214    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
215
216    // Generate code from provided statements
217    let mut stmts = Vec::new();
218    let mut select_branches = Vec::new();
219    for (
220        index,
221        Branch {
222            pattern,
223            future,
224            block,
225        },
226    ) in branches.into_iter().enumerate()
227    {
228        // Generate a unique identifier for each future
229        let future_ident = Ident::new(&format!("__select_future_{index}"), pattern.span());
230
231        // Fuse and pin each future
232        let stmt = quote! {
233            let #future_ident = (#future).fuse();
234            futures::pin_mut!(#future_ident);
235        };
236        stmts.push(stmt);
237
238        // Generate branch for `select_biased!` macro
239        let branch_code = quote! {
240            #pattern = #future_ident => #block,
241        };
242        select_branches.push(branch_code);
243    }
244
245    // Generate the final output code
246    quote! {
247        {
248            use futures::FutureExt as _;
249            #(#stmts)*
250
251            futures::select_biased! {
252                #(#select_branches)*
253            }
254        }
255    }
256    .into()
257}