commonware_macros/lib.rs
1//! Augment the development of primitives with procedural macros.
2
3#![doc(
4 html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5 html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use proc_macro::TokenStream;
9use quote::quote;
10use syn::{
11 parse::{Parse, ParseStream, Result},
12 parse_macro_input,
13 spanned::Spanned,
14 Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
15};
16
17/// Run a test function asynchronously.
18///
19/// This macro is powered by the [futures](https://docs.rs/futures) crate
20/// and is not bound to a particular executor or context.
21///
22/// # Example
23/// ```rust
24/// use commonware_macros::test_async;
25///
26/// #[test_async]
27/// async fn test_async_fn() {
28/// assert_eq!(2 + 2, 4);
29/// }
30/// ```
31#[proc_macro_attribute]
32pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
33 // Parse the input tokens into a syntax tree
34 let input = parse_macro_input!(item as ItemFn);
35
36 // Extract function components
37 let attrs = input.attrs;
38 let vis = input.vis;
39 let mut sig = input.sig;
40 let block = input.block;
41
42 // Remove 'async' from the function signature (#[test] only
43 // accepts sync functions)
44 sig.asyncness
45 .take()
46 .expect("test_async macro can only be used with async functions");
47
48 // Generate output tokens
49 let expanded = quote! {
50 #[test]
51 #(#attrs)*
52 #vis #sig {
53 futures::executor::block_on(async #block);
54 }
55 };
56 TokenStream::from(expanded)
57}
58
59/// Capture logs (based on the provided log level) from a test run using
60/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
61///
62/// This macro defaults to a log level of `DEBUG` if no level is provided.
63///
64/// This macro is powered by the [tracing](https://docs.rs/tracing) and
65/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
66///
67/// # Example
68/// ```rust
69/// use commonware_macros::test_traced;
70/// use tracing::{debug, info};
71///
72/// #[test_traced("INFO")]
73/// fn test_info_level() {
74/// info!("This is an info log");
75/// debug!("This is a debug log (won't be shown)");
76/// assert_eq!(2 + 2, 4);
77/// }
78/// ```
79#[proc_macro_attribute]
80pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
81 // Parse the input tokens into a syntax tree
82 let input = parse_macro_input!(item as ItemFn);
83
84 // Parse the attribute argument for log level
85 let log_level = if attr.is_empty() {
86 // Default log level is DEBUG
87 quote! { tracing::Level::DEBUG }
88 } else {
89 // Parse the attribute as a string literal
90 let level_str = parse_macro_input!(attr as LitStr);
91 let level_ident = level_str.value().to_uppercase();
92 match level_ident.as_str() {
93 "TRACE" => quote! { tracing::Level::TRACE },
94 "DEBUG" => quote! { tracing::Level::DEBUG },
95 "INFO" => quote! { tracing::Level::INFO },
96 "WARN" => quote! { tracing::Level::WARN },
97 "ERROR" => quote! { tracing::Level::ERROR },
98 _ => {
99 // Return a compile error for invalid log levels
100 return Error::new_spanned(
101 level_str,
102 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
103 )
104 .to_compile_error()
105 .into();
106 }
107 }
108 };
109
110 // Extract function components
111 let attrs = input.attrs;
112 let vis = input.vis;
113 let sig = input.sig;
114 let block = input.block;
115
116 // Generate output tokens
117 let expanded = quote! {
118 #[test]
119 #(#attrs)*
120 #vis #sig {
121 // Create a subscriber and dispatcher with the specified log level
122 let subscriber = tracing_subscriber::fmt()
123 .with_test_writer()
124 .with_max_level(#log_level)
125 .with_line_number(true)
126 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
127 .finish();
128 let dispatcher = tracing::Dispatch::new(subscriber);
129
130 // Set the subscriber for the scope of the test
131 tracing::dispatcher::with_default(&dispatcher, || {
132 #block
133 });
134 }
135 };
136 TokenStream::from(expanded)
137}
138
139struct SelectInput {
140 branches: Vec<Branch>,
141}
142
143struct Branch {
144 pattern: Pat,
145 future: Expr,
146 block: Block,
147}
148
149impl Parse for SelectInput {
150 fn parse(input: ParseStream) -> Result<Self> {
151 let mut branches = Vec::new();
152
153 while !input.is_empty() {
154 let pattern: Pat = input.parse()?;
155 input.parse::<Token![=]>()?;
156 let future: Expr = input.parse()?;
157 input.parse::<Token![=>]>()?;
158 let block: Block = input.parse()?;
159
160 branches.push(Branch {
161 pattern,
162 future,
163 block,
164 });
165
166 if input.peek(Token![,]) {
167 input.parse::<Token![,]>()?;
168 } else {
169 break;
170 }
171 }
172
173 Ok(SelectInput { branches })
174 }
175}
176
177/// Select the first future that completes (biased by order).
178///
179/// This macro is powered by the [futures](https://docs.rs/futures) crate
180/// and is not bound to a particular executor or context.
181///
182/// # Fusing
183///
184/// This macro handles both the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
185/// and [pinning](https://docs.rs/futures/latest/futures/macro.pin_mut.html) of (fused) futures in
186/// a `select`-specific scope.
187///
188/// # Example
189///
190/// ```rust
191/// use std::time::Duration;
192/// use commonware_macros::select;
193/// use futures::executor::block_on;
194/// use futures_timer::Delay;
195///
196/// async fn task() -> usize {
197/// 42
198/// }
199//
200/// block_on(async move {
201/// select! {
202/// _ = Delay::new(Duration::from_secs(1)) => {
203/// println!("timeout fired");
204/// },
205/// v = task() => {
206/// println!("task completed with value: {}", v);
207/// },
208/// };
209/// });
210/// ```
211#[proc_macro]
212pub fn select(input: TokenStream) -> TokenStream {
213 // Parse the input tokens
214 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
215
216 // Generate code from provided statements
217 let mut stmts = Vec::new();
218 let mut select_branches = Vec::new();
219 for (
220 index,
221 Branch {
222 pattern,
223 future,
224 block,
225 },
226 ) in branches.into_iter().enumerate()
227 {
228 // Generate a unique identifier for each future
229 let future_ident = Ident::new(&format!("__select_future_{index}"), pattern.span());
230
231 // Fuse and pin each future
232 let stmt = quote! {
233 let #future_ident = (#future).fuse();
234 futures::pin_mut!(#future_ident);
235 };
236 stmts.push(stmt);
237
238 // Generate branch for `select_biased!` macro
239 let branch_code = quote! {
240 #pattern = #future_ident => #block,
241 };
242 select_branches.push(branch_code);
243 }
244
245 // Generate the final output code
246 quote! {
247 {
248 use futures::FutureExt as _;
249 #(#stmts)*
250
251 futures::select_biased! {
252 #(#select_branches)*
253 }
254 }
255 }
256 .into()
257}