commonware_macros/lib.rs
1//! Augment the development of primitives with procedural macros.
2
3use proc_macro::TokenStream;
4use quote::quote;
5use syn::{
6 parse::{Parse, ParseStream, Result},
7 parse_macro_input,
8 spanned::Spanned,
9 Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
10};
11
12/// Run a test function asynchronously.
13///
14/// This macro is powered by the [futures](https://docs.rs/futures) crate
15/// and is not bound to a particular executor or context.
16///
17/// # Example
18/// ```rust
19/// use commonware_macros::test_async;
20///
21/// #[test_async]
22/// async fn test_async_fn() {
23/// assert_eq!(2 + 2, 4);
24/// }
25/// ```
26#[proc_macro_attribute]
27pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
28 // Parse the input tokens into a syntax tree
29 let input = parse_macro_input!(item as ItemFn);
30
31 // Extract function components
32 let attrs = input.attrs;
33 let vis = input.vis;
34 let mut sig = input.sig;
35 let block = input.block;
36
37 // Remove 'async' from the function signature (#[test] only
38 // accepts sync functions)
39 sig.asyncness
40 .take()
41 .expect("test_async macro can only be used with async functions");
42
43 // Generate output tokens
44 let expanded = quote! {
45 #[test]
46 #(#attrs)*
47 #vis #sig {
48 futures::executor::block_on(async #block);
49 }
50 };
51 TokenStream::from(expanded)
52}
53
54/// Capture logs (based on the provided log level) from a test run using
55/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
56///
57/// This macro defaults to a log level of `DEBUG` if no level is provided.
58///
59/// This macro is powered by the [tracing](https://docs.rs/tracing) and
60/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
61///
62/// # Example
63/// ```rust
64/// use commonware_macros::test_traced;
65/// use tracing::{debug, info};
66///
67/// #[test_traced("INFO")]
68/// fn test_info_level() {
69/// info!("This is an info log");
70/// debug!("This is a debug log (won't be shown)");
71/// assert_eq!(2 + 2, 4);
72/// }
73/// ```
74#[proc_macro_attribute]
75pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
76 // Parse the input tokens into a syntax tree
77 let input = parse_macro_input!(item as ItemFn);
78
79 // Parse the attribute argument for log level
80 let log_level = if attr.is_empty() {
81 // Default log level is DEBUG
82 quote! { tracing::Level::DEBUG }
83 } else {
84 // Parse the attribute as a string literal
85 let level_str = parse_macro_input!(attr as LitStr);
86 let level_ident = level_str.value().to_uppercase();
87 match level_ident.as_str() {
88 "TRACE" => quote! { tracing::Level::TRACE },
89 "DEBUG" => quote! { tracing::Level::DEBUG },
90 "INFO" => quote! { tracing::Level::INFO },
91 "WARN" => quote! { tracing::Level::WARN },
92 "ERROR" => quote! { tracing::Level::ERROR },
93 _ => {
94 // Return a compile error for invalid log levels
95 return Error::new_spanned(
96 level_str,
97 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
98 )
99 .to_compile_error()
100 .into();
101 }
102 }
103 };
104
105 // Extract function components
106 let attrs = input.attrs;
107 let vis = input.vis;
108 let sig = input.sig;
109 let block = input.block;
110
111 // Generate output tokens
112 let expanded = quote! {
113 #[test]
114 #(#attrs)*
115 #vis #sig {
116 // Create a subscriber and dispatcher with the specified log level
117 let subscriber = tracing_subscriber::fmt()
118 .with_test_writer()
119 .with_max_level(#log_level)
120 .with_line_number(true)
121 .finish();
122 let dispatcher = tracing::Dispatch::new(subscriber);
123
124 // Set the subscriber for the scope of the test
125 tracing::dispatcher::with_default(&dispatcher, || {
126 #block
127 });
128 }
129 };
130 TokenStream::from(expanded)
131}
132
133struct SelectInput {
134 branches: Vec<Branch>,
135}
136
137struct Branch {
138 pattern: Pat,
139 future: Expr,
140 block: Block,
141}
142
143impl Parse for SelectInput {
144 fn parse(input: ParseStream) -> Result<Self> {
145 let mut branches = Vec::new();
146
147 while !input.is_empty() {
148 let pattern: Pat = input.parse()?;
149 input.parse::<Token![=]>()?;
150 let future: Expr = input.parse()?;
151 input.parse::<Token![=>]>()?;
152 let block: Block = input.parse()?;
153
154 branches.push(Branch {
155 pattern,
156 future,
157 block,
158 });
159
160 if input.peek(Token![,]) {
161 input.parse::<Token![,]>()?;
162 } else {
163 break;
164 }
165 }
166
167 Ok(SelectInput { branches })
168 }
169}
170
171/// Select the first future that completes (biased by order).
172///
173/// This macro is powered by the [futures](https://docs.rs/futures) crate
174/// and is not bound to a particular executor or context.
175///
176/// # Fusing
177///
178/// This macro handles both the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
179/// and [pinning](https://docs.rs/futures/latest/futures/macro.pin_mut.html) of (fused) futures in
180/// a `select`-specific scope.
181///
182/// # Example
183///
184/// ```rust
185/// use std::time::Duration;
186/// use commonware_macros::select;
187/// use futures::executor::block_on;
188/// use futures_timer::Delay;
189///
190/// async fn task() -> usize {
191/// 42
192/// }
193//
194/// block_on(async move {
195/// select! {
196/// _ = Delay::new(Duration::from_secs(1)) => {
197/// println!("timeout fired");
198/// },
199/// v = task() => {
200/// println!("task completed with value: {}", v);
201/// },
202/// };
203/// });
204/// ```
205#[proc_macro]
206pub fn select(input: TokenStream) -> TokenStream {
207 // Parse the input tokens
208 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
209
210 // Generate code from provided statements
211 let mut stmts = Vec::new();
212 let mut select_branches = Vec::new();
213 for (
214 index,
215 Branch {
216 pattern,
217 future,
218 block,
219 },
220 ) in branches.into_iter().enumerate()
221 {
222 // Generate a unique identifier for each future
223 let future_ident = Ident::new(&format!("__select_future_{}", index), pattern.span());
224
225 // Fuse and pin each future
226 let stmt = quote! {
227 let #future_ident = (#future).fuse();
228 futures::pin_mut!(#future_ident);
229 };
230 stmts.push(stmt);
231
232 // Generate branch for `select_biased!` macro
233 let branch_code = quote! {
234 #pattern = #future_ident => #block,
235 };
236 select_branches.push(branch_code);
237 }
238
239 // Generate the final output code
240 quote! {
241 {
242 use futures::FutureExt as _;
243 #(#stmts)*
244
245 futures::select_biased! {
246 #(#select_branches)*
247 }
248 }
249 }
250 .into()
251}