commonware_macros/lib.rs
1//! Augment the development of primitives with procedural macros.
2
3#![doc(
4 html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
5 html_favicon_url = "https://commonware.xyz/favicon.ico"
6)]
7
8use crate::nextest::configured_test_groups;
9use proc_macro::TokenStream;
10use proc_macro2::Span;
11use proc_macro_crate::{crate_name, FoundCrate};
12use quote::{format_ident, quote, ToTokens};
13use syn::{
14 parse::{Parse, ParseStream, Result},
15 parse_macro_input, Block, Error, Expr, Ident, ItemFn, LitStr, Pat, Token,
16};
17
18mod nextest;
19
20/// Run a test function asynchronously.
21///
22/// This macro is powered by the [futures](https://docs.rs/futures) crate
23/// and is not bound to a particular executor or context.
24///
25/// # Example
26/// ```rust
27/// use commonware_macros::test_async;
28///
29/// #[test_async]
30/// async fn test_async_fn() {
31/// assert_eq!(2 + 2, 4);
32/// }
33/// ```
34#[proc_macro_attribute]
35pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
36 // Parse the input tokens into a syntax tree
37 let input = parse_macro_input!(item as ItemFn);
38
39 // Extract function components
40 let attrs = input.attrs;
41 let vis = input.vis;
42 let mut sig = input.sig;
43 let block = input.block;
44
45 // Remove 'async' from the function signature (#[test] only
46 // accepts sync functions)
47 sig.asyncness
48 .take()
49 .expect("test_async macro can only be used with async functions");
50
51 // Generate output tokens
52 let expanded = quote! {
53 #[test]
54 #(#attrs)*
55 #vis #sig {
56 futures::executor::block_on(async #block);
57 }
58 };
59 TokenStream::from(expanded)
60}
61
62/// Capture logs (based on the provided log level) from a test run using
63/// [libtest's output capture functionality](https://doc.rust-lang.org/book/ch11-02-running-tests.html#showing-function-output).
64///
65/// This macro defaults to a log level of `DEBUG` if no level is provided.
66///
67/// This macro is powered by the [tracing](https://docs.rs/tracing) and
68/// [tracing-subscriber](https://docs.rs/tracing-subscriber) crates.
69///
70/// # Example
71/// ```rust
72/// use commonware_macros::test_traced;
73/// use tracing::{debug, info};
74///
75/// #[test_traced("INFO")]
76/// fn test_info_level() {
77/// info!("This is an info log");
78/// debug!("This is a debug log (won't be shown)");
79/// assert_eq!(2 + 2, 4);
80/// }
81/// ```
82#[proc_macro_attribute]
83pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
84 // Parse the input tokens into a syntax tree
85 let input = parse_macro_input!(item as ItemFn);
86
87 // Parse the attribute argument for log level
88 let log_level = if attr.is_empty() {
89 // Default log level is DEBUG
90 quote! { tracing::Level::DEBUG }
91 } else {
92 // Parse the attribute as a string literal
93 let level_str = parse_macro_input!(attr as LitStr);
94 let level_ident = level_str.value().to_uppercase();
95 match level_ident.as_str() {
96 "TRACE" => quote! { tracing::Level::TRACE },
97 "DEBUG" => quote! { tracing::Level::DEBUG },
98 "INFO" => quote! { tracing::Level::INFO },
99 "WARN" => quote! { tracing::Level::WARN },
100 "ERROR" => quote! { tracing::Level::ERROR },
101 _ => {
102 // Return a compile error for invalid log levels
103 return Error::new_spanned(
104 level_str,
105 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
106 )
107 .to_compile_error()
108 .into();
109 }
110 }
111 };
112
113 // Extract function components
114 let attrs = input.attrs;
115 let vis = input.vis;
116 let sig = input.sig;
117 let block = input.block;
118
119 // Generate output tokens
120 let expanded = quote! {
121 #[test]
122 #(#attrs)*
123 #vis #sig {
124 // Create a subscriber and dispatcher with the specified log level
125 let subscriber = tracing_subscriber::fmt()
126 .with_test_writer()
127 .with_max_level(#log_level)
128 .with_line_number(true)
129 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
130 .finish();
131 let dispatcher = tracing::Dispatch::new(subscriber);
132
133 // Set the subscriber for the scope of the test
134 tracing::dispatcher::with_default(&dispatcher, || {
135 #block
136 });
137 }
138 };
139 TokenStream::from(expanded)
140}
141
142/// Prefix a test name with a nextest filter group.
143///
144/// This renames `test_some_behavior` into `test_some_behavior_<group>_`, making
145/// it easy to filter tests by group postfixes in nextest.
146#[proc_macro_attribute]
147pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
148 if attr.is_empty() {
149 return Error::new(
150 Span::call_site(),
151 "test_group requires a string literal filter group name",
152 )
153 .to_compile_error()
154 .into();
155 }
156
157 let mut input = parse_macro_input!(item as ItemFn);
158 let group_literal = parse_macro_input!(attr as LitStr);
159
160 let group = match nextest::sanitize_group_literal(&group_literal) {
161 Ok(group) => group,
162 Err(err) => return err.to_compile_error().into(),
163 };
164 let groups = match configured_test_groups() {
165 Ok(groups) => groups,
166 Err(_) => {
167 // Don't fail the compilation if the file isn't found; just return the original input.
168 return TokenStream::from(quote!(#input));
169 }
170 };
171
172 if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
173 return err.to_compile_error().into();
174 }
175
176 let original_name = input.sig.ident.to_string();
177 let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
178
179 input.sig.ident = new_ident;
180
181 TokenStream::from(quote!(#input))
182}
183
184/// Capture logs from a test run into an in-memory store.
185///
186/// This macro defaults to a log level of `DEBUG` on the [mod@tracing_subscriber::fmt] layer if no level is provided.
187///
188/// This macro is powered by the [tracing](https://docs.rs/tracing),
189/// [tracing-subscriber](https://docs.rs/tracing-subscriber), and
190/// [commonware-runtime](https://docs.rs/commonware-runtime) crates.
191///
192/// # Note
193///
194/// This macro requires the resolution of the `commonware-runtime`, `tracing`, and `tracing_subscriber` crates.
195///
196/// # Example
197/// ```rust,ignore
198/// use commonware_macros::test_collect_traces;
199/// use commonware_runtime::telemetry::traces::collector::TraceStorage;
200/// use tracing::{debug, info};
201///
202/// #[test_collect_traces("INFO")]
203/// fn test_info_level(traces: TraceStorage) {
204/// // Filter applies to console output (FmtLayer)
205/// info!("This is an info log");
206/// debug!("This is a debug log (won't be shown in console output)");
207///
208/// // All traces are collected, regardless of level, by the CollectingLayer.
209/// assert_eq!(traces.get_all().len(), 2);
210/// }
211/// ```
212#[proc_macro_attribute]
213pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
214 let input = parse_macro_input!(item as ItemFn);
215
216 // Parse the attribute argument for log level
217 let log_level = if attr.is_empty() {
218 // Default log level is DEBUG
219 quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
220 } else {
221 // Parse the attribute as a string literal
222 let level_str = parse_macro_input!(attr as LitStr);
223 let level_ident = level_str.value().to_uppercase();
224 match level_ident.as_str() {
225 "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
226 "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
227 "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
228 "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
229 "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
230 _ => {
231 // Return a compile error for invalid log levels
232 return Error::new_spanned(
233 level_str,
234 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
235 )
236 .to_compile_error()
237 .into();
238 }
239 }
240 };
241
242 let attrs = input.attrs;
243 let vis = input.vis;
244 let sig = input.sig;
245 let block = input.block;
246
247 // Create the signature of the inner function that takes the TraceStorage.
248 let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
249 let mut inner_sig = sig.clone();
250 inner_sig.ident = inner_ident.clone();
251
252 // Create the signature of the outer test function.
253 let mut outer_sig = sig;
254 outer_sig.inputs.clear();
255
256 // Detect the path of the `commonware-runtime` crate. If it has been renamed or
257 // this macro is being used within the `commonware-runtime` crate itself, adjust
258 // the path accordingly.
259 let rt_path = match crate_name("commonware-runtime") {
260 Ok(FoundCrate::Itself) => quote!(crate),
261 Ok(FoundCrate::Name(name)) => {
262 let ident = syn::Ident::new(&name, Span::call_site());
263 quote!(#ident)
264 }
265 Err(_) => quote!(::commonware_runtime), // fallback
266 };
267
268 let expanded = quote! {
269 // Inner test function runs the actual test logic, accepting the TraceStorage
270 // created by the harness.
271 #(#attrs)*
272 #vis #inner_sig #block
273
274 #[test]
275 #vis #outer_sig {
276 use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
277 use ::tracing::{Dispatch, dispatcher};
278 use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
279
280 let trace_store = TraceStorage::default();
281 let collecting_layer = CollectingLayer::new(trace_store.clone());
282
283 let fmt_layer = fmt::layer()
284 .with_test_writer()
285 .with_line_number(true)
286 .with_span_events(fmt::format::FmtSpan::CLOSE)
287 .with_filter(#log_level);
288
289 let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
290 let dispatcher = Dispatch::new(subscriber);
291 dispatcher::with_default(&dispatcher, || {
292 #inner_ident(trace_store);
293 });
294 }
295 };
296
297 TokenStream::from(expanded)
298}
299
300struct SelectInput {
301 branches: Vec<Branch>,
302}
303
304struct Branch {
305 pattern: Pat,
306 future: Expr,
307 block: Block,
308}
309
310impl Parse for SelectInput {
311 fn parse(input: ParseStream<'_>) -> Result<Self> {
312 let mut branches = Vec::new();
313
314 while !input.is_empty() {
315 let pattern = Pat::parse_single(input)?;
316 input.parse::<Token![=]>()?;
317 let future: Expr = input.parse()?;
318 input.parse::<Token![=>]>()?;
319 let block: Block = input.parse()?;
320
321 branches.push(Branch {
322 pattern,
323 future,
324 block,
325 });
326
327 if input.peek(Token![,]) {
328 input.parse::<Token![,]>()?;
329 } else {
330 break;
331 }
332 }
333
334 Ok(Self { branches })
335 }
336}
337
338impl ToTokens for SelectInput {
339 fn to_tokens(&self, tokens: &mut proc_macro2::TokenStream) {
340 for branch in &self.branches {
341 let pattern = &branch.pattern;
342 let future = &branch.future;
343 let block = &branch.block;
344
345 tokens.extend(quote! {
346 #pattern = #future => #block,
347 });
348 }
349 }
350}
351
352/// Select the first future that completes (biased by order).
353///
354/// This macro is powered by the [futures](https://docs.rs/futures) crate
355/// and is not bound to a particular executor or context.
356///
357/// # Fusing
358///
359/// This macro handles the [fusing](https://docs.rs/futures/latest/futures/future/trait.FutureExt.html#method.fuse)
360/// futures in a `select`-specific scope.
361///
362/// # Example
363///
364/// ```rust
365/// use std::time::Duration;
366/// use commonware_macros::select;
367/// use futures::executor::block_on;
368/// use futures_timer::Delay;
369///
370/// async fn task() -> usize {
371/// 42
372/// }
373//
374/// block_on(async move {
375/// select! {
376/// _ = Delay::new(Duration::from_secs(1)) => {
377/// println!("timeout fired");
378/// },
379/// v = task() => {
380/// println!("task completed with value: {}", v);
381/// },
382/// };
383/// });
384/// ```
385#[proc_macro]
386pub fn select(input: TokenStream) -> TokenStream {
387 // Parse the input tokens
388 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
389
390 // Generate code from provided statements
391 let mut select_branches = Vec::new();
392 for Branch {
393 pattern,
394 future,
395 block,
396 } in branches.into_iter()
397 {
398 // Generate branch for `select_biased!` macro
399 let branch_code = quote! {
400 #pattern = (#future).fuse() => #block,
401 };
402 select_branches.push(branch_code);
403 }
404
405 // Generate the final output code
406 quote! {
407 {
408 use futures::FutureExt as _;
409
410 futures::select_biased! {
411 #(#select_branches)*
412 }
413 }
414 }
415 .into()
416}
417
418/// Input for [select_loop!].
419///
420/// Parses: `context, on_stopped => { block }, { branches... }`
421struct SelectLoopInput {
422 context: Expr,
423 shutdown_block: Block,
424 branches: Vec<Branch>,
425}
426
427impl Parse for SelectLoopInput {
428 fn parse(input: ParseStream<'_>) -> Result<Self> {
429 // Parse context expression
430 let context: Expr = input.parse()?;
431 input.parse::<Token![,]>()?;
432
433 // Parse `on_stopped =>`
434 let on_stopped_ident: Ident = input.parse()?;
435 if on_stopped_ident != "on_stopped" {
436 return Err(Error::new(
437 on_stopped_ident.span(),
438 "expected `on_stopped` keyword",
439 ));
440 }
441 input.parse::<Token![=>]>()?;
442
443 // Parse shutdown block
444 let shutdown_block: Block = input.parse()?;
445
446 // Parse comma after shutdown block
447 input.parse::<Token![,]>()?;
448
449 // Parse branches directly (no surrounding braces)
450 let mut branches = Vec::new();
451 while !input.is_empty() {
452 let pattern = Pat::parse_single(input)?;
453 input.parse::<Token![=]>()?;
454 let future: Expr = input.parse()?;
455 input.parse::<Token![=>]>()?;
456 let block: Block = input.parse()?;
457
458 branches.push(Branch {
459 pattern,
460 future,
461 block,
462 });
463
464 if input.peek(Token![,]) {
465 input.parse::<Token![,]>()?;
466 } else {
467 break;
468 }
469 }
470
471 Ok(Self {
472 context,
473 shutdown_block,
474 branches,
475 })
476 }
477}
478
479/// Convenience macro to continuously [select!] over a set of futures in biased order,
480/// with a required shutdown handler.
481///
482/// This macro automatically creates a shutdown future from the provided context and requires a
483/// shutdown handler block. The shutdown future is created outside the loop, allowing it to
484/// persist across iterations until shutdown is signaled. The shutdown branch is always checked
485/// first (biased).
486///
487/// After the shutdown block is executed, the loop breaks by default. If different control flow
488/// is desired (such as returning from the enclosing function), it must be handled explicitly.
489///
490/// # Syntax
491///
492/// ```rust,ignore
493/// select_loop! {
494/// context,
495/// on_stopped => { cleanup },
496/// pattern = future => block,
497/// // ...
498/// }
499/// ```
500///
501/// The `shutdown` variable (the future from `context.stopped()`) is accessible in the
502/// shutdown block, allowing explicit cleanup such as `drop(shutdown)` before breaking or returning.
503///
504/// # Example
505///
506/// ```rust,ignore
507/// use commonware_macros::select_loop;
508///
509/// async fn run(context: impl commonware_runtime::Spawner) {
510/// select_loop! {
511/// context,
512/// on_stopped => {
513/// println!("shutting down");
514/// drop(shutdown);
515/// },
516/// msg = receiver.recv() => {
517/// println!("received: {:?}", msg);
518/// },
519/// }
520/// }
521/// ```
522#[proc_macro]
523pub fn select_loop(input: TokenStream) -> TokenStream {
524 let SelectLoopInput {
525 context,
526 shutdown_block,
527 branches,
528 } = parse_macro_input!(input as SelectLoopInput);
529
530 // Convert branches to tokens for the inner select!
531 let branch_tokens: Vec<_> = branches
532 .iter()
533 .map(|b| {
534 let pattern = &b.pattern;
535 let future = &b.future;
536 let block = &b.block;
537 quote! { #pattern = #future => #block, }
538 })
539 .collect();
540
541 quote! {
542 {
543 let mut shutdown = #context.stopped();
544 loop {
545 commonware_macros::select! {
546 _ = &mut shutdown => {
547 #shutdown_block
548
549 // Break the loop after handling shutdown. Some implementations
550 // may divert control flow themselves, so this may be unused.
551 #[allow(unreachable_code)]
552 break;
553 },
554 #(#branch_tokens)*
555 }
556 }
557 }
558 }
559 .into()
560}