commonware_macros/
lib.rs

1//! Augment the development of primitives with procedural macros.
2
3#![doc(
4    html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5    html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use proc_macro2::Span;
10use proc_macro_crate::{crate_name, FoundCrate};
11use quote::{format_ident, quote};
12use syn::{
13    parse::{Parse, ParseStream, Result},
14    parse_macro_input,
15    spanned::Spanned,
16    Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
17};
18
19/// Run a test function asynchronously.
20///
21/// This macro is powered by the [futures](https://docs.rs/futures) crate
22/// and is not bound to a particular executor or context.
23///
24/// # Example
25/// ```rust
26/// use commonware_macros::test_async;
27///
28/// #[test_async]
29/// async fn test_async_fn() {
30///    assert_eq!(2 + 2, 4);
31/// }
32/// ```
33#[proc_macro_attribute]
34pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
35    // Parse the input tokens into a syntax tree
36    let input = parse_macro_input!(item as ItemFn);
37
38    // Extract function components
39    let attrs = input.attrs;
40    let vis = input.vis;
41    let mut sig = input.sig;
42    let block = input.block;
43
44    // Remove 'async' from the function signature (#[test] only
45    // accepts sync functions)
46    sig.asyncness
47        .take()
48        .expect("test_async macro can only be used with async functions");
49
50    // Generate output tokens
51    let expanded = quote! {
52        #[test]
53        #(#attrs)*
54        #vis #sig {
55            futures::executor::block_on(async #block);
56        }
57    };
58    TokenStream::from(expanded)
59}
60
61/// Capture logs (based on the provided log level) from a test run using
62/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
63///
64/// This macro defaults to a log level of `DEBUG` if no level is provided.
65///
66/// This macro is powered by the [tracing](https://docs.rs/tracing) and
67/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
68///
69/// # Example
70/// ```rust
71/// use commonware_macros::test_traced;
72/// use tracing::{debug, info};
73///
74/// #[test_traced("INFO")]
75/// fn test_info_level() {
76///     info!("This is an info log");
77///     debug!("This is a debug log (won't be shown)");
78///     assert_eq!(2 + 2, 4);
79/// }
80/// ```
81#[proc_macro_attribute]
82pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
83    // Parse the input tokens into a syntax tree
84    let input = parse_macro_input!(item as ItemFn);
85
86    // Parse the attribute argument for log level
87    let log_level = if attr.is_empty() {
88        // Default log level is DEBUG
89        quote! { tracing::Level::DEBUG }
90    } else {
91        // Parse the attribute as a string literal
92        let level_str = parse_macro_input!(attr as LitStr);
93        let level_ident = level_str.value().to_uppercase();
94        match level_ident.as_str() {
95            "TRACE" => quote! { tracing::Level::TRACE },
96            "DEBUG" => quote! { tracing::Level::DEBUG },
97            "INFO" => quote! { tracing::Level::INFO },
98            "WARN" => quote! { tracing::Level::WARN },
99            "ERROR" => quote! { tracing::Level::ERROR },
100            _ => {
101                // Return a compile error for invalid log levels
102                return Error::new_spanned(
103                    level_str,
104                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
105                )
106                .to_compile_error()
107                .into();
108            }
109        }
110    };
111
112    // Extract function components
113    let attrs = input.attrs;
114    let vis = input.vis;
115    let sig = input.sig;
116    let block = input.block;
117
118    // Generate output tokens
119    let expanded = quote! {
120        #[test]
121        #(#attrs)*
122        #vis #sig {
123            // Create a subscriber and dispatcher with the specified log level
124            let subscriber = tracing_subscriber::fmt()
125                .with_test_writer()
126                .with_max_level(#log_level)
127                .with_line_number(true)
128                .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
129                .finish();
130            let dispatcher = tracing::Dispatch::new(subscriber);
131
132            // Set the subscriber for the scope of the test
133            tracing::dispatcher::with_default(&dispatcher, || {
134                #block
135            });
136        }
137    };
138    TokenStream::from(expanded)
139}
140
141/// Capture logs from a test run into an in-memory store.
142///
143/// This macro defaults to a log level of `DEBUG` on the [mod@tracing_subscriber::fmt] layer if no level is provided.
144///
145/// This macro is powered by the [tracing](https://docs.rs/tracing),
146/// [tracing-subscriber](https://docs.rs/tracing-subscriber), and
147/// [commonware-runtime](https://docs.rs/commonware-runtime) crates.
148///
149/// # Note
150///
151/// This macro requires the resolution of the `commonware-runtime`, `tracing`, and `tracing_subscriber` crates.
152///
153/// # Example
154/// ```rust,ignore
155/// use commonware_macros::test_collect_traces;
156/// use commonware_runtime::telemetry::traces::collector::TraceStorage;
157/// use tracing::{debug, info};
158///
159/// #[test_collect_traces("INFO")]
160/// fn test_info_level(traces: TraceStorage) {
161///     // Filter applies to console output (FmtLayer)
162///     info!("This is an info log");
163///     debug!("This is a debug log (won't be shown in console output)");
164///
165///     // All traces are collected, regardless of level, by the CollectingLayer.
166///     assert_eq!(traces.get_all().len(), 2);
167/// }
168/// ```
169#[proc_macro_attribute]
170pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
171    let input = parse_macro_input!(item as ItemFn);
172
173    // Parse the attribute argument for log level
174    let log_level = if attr.is_empty() {
175        // Default log level is DEBUG
176        quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
177    } else {
178        // Parse the attribute as a string literal
179        let level_str = parse_macro_input!(attr as LitStr);
180        let level_ident = level_str.value().to_uppercase();
181        match level_ident.as_str() {
182            "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
183            "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
184            "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
185            "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
186            "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
187            _ => {
188                // Return a compile error for invalid log levels
189                return Error::new_spanned(
190                    level_str,
191                    "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
192                )
193                .to_compile_error()
194                .into();
195            }
196        }
197    };
198
199    let attrs = input.attrs;
200    let vis = input.vis;
201    let sig = input.sig;
202    let block = input.block;
203
204    // Create the signature of the inner function that takes the TraceStorage.
205    let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
206    let mut inner_sig = sig.clone();
207    inner_sig.ident = inner_ident.clone();
208
209    // Create the signature of the outer test function.
210    let mut outer_sig = sig;
211    outer_sig.inputs.clear();
212
213    // Detect the path of the `commonware-runtime` crate. If it has been renamed or
214    // this macro is being used within the `commonware-runtime` crate itself, adjust
215    // the path accordingly.
216    let rt_path = match crate_name("commonware-runtime") {
217        Ok(FoundCrate::Itself) => quote!(crate),
218        Ok(FoundCrate::Name(name)) => {
219            let ident = syn::Ident::new(&name, Span::call_site());
220            quote!(#ident)
221        }
222        Err(_) => quote!(::commonware_runtime), // fallback
223    };
224
225    let expanded = quote! {
226        // Inner test function runs the actual test logic, accepting the TraceStorage
227        // created by the harness.
228        #(#attrs)*
229        #vis #inner_sig #block
230
231        #[test]
232        #vis #outer_sig {
233            use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
234            use ::tracing::{Dispatch, dispatcher};
235            use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
236
237            let trace_store = TraceStorage::default();
238            let collecting_layer = CollectingLayer::new(trace_store.clone());
239
240            let fmt_layer = fmt::layer()
241                .with_test_writer()
242                .with_line_number(true)
243                .with_span_events(fmt::format::FmtSpan::CLOSE)
244                .with_filter(#log_level);
245
246            let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
247            let dispatcher = Dispatch::new(subscriber);
248            dispatcher::with_default(&dispatcher, || {
249                #inner_ident(trace_store);
250            });
251        }
252    };
253
254    TokenStream::from(expanded)
255}
256
257struct SelectInput {
258    branches: Vec<Branch>,
259}
260
261struct Branch {
262    pattern: Pat,
263    future: Expr,
264    block: Block,
265}
266
267impl Parse for SelectInput {
268    fn parse(input: ParseStream) -> Result<Self> {
269        let mut branches = Vec::new();
270
271        while !input.is_empty() {
272            let pattern: Pat = input.parse()?;
273            input.parse::<Token![=]>()?;
274            let future: Expr = input.parse()?;
275            input.parse::<Token![=>]>()?;
276            let block: Block = input.parse()?;
277
278            branches.push(Branch {
279                pattern,
280                future,
281                block,
282            });
283
284            if input.peek(Token![,]) {
285                input.parse::<Token![,]>()?;
286            } else {
287                break;
288            }
289        }
290
291        Ok(SelectInput { branches })
292    }
293}
294
295/// Select the first future that completes (biased by order).
296///
297/// This macro is powered by the [futures](https://docs.rs/futures) crate
298/// and is not bound to a particular executor or context.
299///
300/// # Fusing
301///
302/// This macro handles both the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
303/// and [pinning](https://docs.rs/futures/latest/futures/macro.pin_mut.html) of (fused) futures in
304/// a `select`-specific scope.
305///
306/// # Example
307///
308/// ```rust
309/// use std::time::Duration;
310/// use commonware_macros::select;
311/// use futures::executor::block_on;
312/// use futures_timer::Delay;
313///
314/// async fn task() -> usize {
315///     42
316/// }
317//
318/// block_on(async move {
319///     select! {
320///         _ = Delay::new(Duration::from_secs(1)) => {
321///             println!("timeout fired");
322///         },
323///         v = task() => {
324///             println!("task completed with value: {}", v);
325///         },
326///     };
327/// });
328/// ```
329#[proc_macro]
330pub fn select(input: TokenStream) -> TokenStream {
331    // Parse the input tokens
332    let SelectInput { branches } = parse_macro_input!(input as SelectInput);
333
334    // Generate code from provided statements
335    let mut stmts = Vec::new();
336    let mut select_branches = Vec::new();
337    for (
338        index,
339        Branch {
340            pattern,
341            future,
342            block,
343        },
344    ) in branches.into_iter().enumerate()
345    {
346        // Generate a unique identifier for each future
347        let future_ident = Ident::new(&format!("__select_future_{index}"), pattern.span());
348
349        // Fuse and pin each future
350        let stmt = quote! {
351            let #future_ident = (#future).fuse();
352            futures::pin_mut!(#future_ident);
353        };
354        stmts.push(stmt);
355
356        // Generate branch for `select_biased!` macro
357        let branch_code = quote! {
358            #pattern = #future_ident => #block,
359        };
360        select_branches.push(branch_code);
361    }
362
363    // Generate the final output code
364    quote! {
365        {
366            use futures::FutureExt as _;
367            #(#stmts)*
368
369            futures::select_biased! {
370                #(#select_branches)*
371            }
372        }
373    }
374    .into()
375}