1#![doc(
7 html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
8 html_favicon_url = "https://commonware.xyz/favicon.ico"
9)]
10
11use crate::nextest::configured_test_groups;
12use proc_macro::TokenStream;
13use proc_macro2::Span;
14use proc_macro_crate::{crate_name, FoundCrate};
15use quote::{format_ident, quote};
16use syn::{
17 braced,
18 parse::{Parse, ParseStream, Result},
19 parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
20};
21
22mod nextest;
23
24struct StabilityLevel {
27 value: u8,
28}
29
30impl Parse for StabilityLevel {
31 fn parse(input: ParseStream<'_>) -> Result<Self> {
32 let lookahead = input.lookahead1();
33 if lookahead.peek(LitInt) {
34 let lit: LitInt = input.parse()?;
35 let value: u8 = lit
36 .base10_parse()
37 .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
38 if value > 4 {
39 return Err(Error::new(
40 lit.span(),
41 "stability level must be 0, 1, 2, 3, or 4",
42 ));
43 }
44 Ok(Self { value })
45 } else if lookahead.peek(Ident) {
46 let ident: Ident = input.parse()?;
47 let value = match ident.to_string().as_str() {
48 "ALPHA" => 0,
49 "BETA" => 1,
50 "GAMMA" => 2,
51 "DELTA" => 3,
52 "EPSILON" => 4,
53 _ => {
54 return Err(Error::new(
55 ident.span(),
56 "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
57 ));
58 }
59 };
60 Ok(Self { value })
61 } else {
62 Err(lookahead.error())
63 }
64 }
65}
66
67fn level_name(level: u8) -> &'static str {
68 match level {
69 0 => "ALPHA",
70 1 => "BETA",
71 2 => "GAMMA",
72 3 => "DELTA",
73 4 => "EPSILON",
74 _ => unreachable!(),
75 }
76}
77
78fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
99 let mut names: Vec<_> = ((level + 1)..=4)
100 .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
101 .collect();
102
103 names.push(format_ident!("commonware_stability_RESERVED"));
104 names
105}
106
107#[proc_macro_attribute]
108pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
109 let level = parse_macro_input!(attr as StabilityLevel);
110 let exclude_names = exclusion_cfg_names(level.value);
111
112 let item2: proc_macro2::TokenStream = item.into();
113 let expanded = quote! {
114 #[cfg(not(any(#(#exclude_names),*)))]
115 #item2
116 };
117
118 TokenStream::from(expanded)
119}
120
121struct StabilityModInput {
123 level: StabilityLevel,
124 visibility: Visibility,
125 name: Ident,
126}
127
128impl Parse for StabilityModInput {
129 fn parse(input: ParseStream<'_>) -> Result<Self> {
130 let level: StabilityLevel = input.parse()?;
131 input.parse::<Token![,]>()?;
132 let visibility: Visibility = input.parse()?;
133 input.parse::<Token![mod]>()?;
134 let name: Ident = input.parse()?;
135 Ok(Self {
136 level,
137 visibility,
138 name,
139 })
140 }
141}
142
143#[proc_macro]
144pub fn stability_mod(input: TokenStream) -> TokenStream {
145 let StabilityModInput {
146 level,
147 visibility,
148 name,
149 } = parse_macro_input!(input as StabilityModInput);
150
151 let exclude_names = exclusion_cfg_names(level.value);
152
153 let expanded = quote! {
154 #[cfg(not(any(#(#exclude_names),*)))]
155 #visibility mod #name;
156 };
157
158 TokenStream::from(expanded)
159}
160
161struct StabilityScopeInput {
163 level: StabilityLevel,
164 predicate: Option<syn::Meta>,
165 items: Vec<syn::Item>,
166}
167
168impl Parse for StabilityScopeInput {
169 fn parse(input: ParseStream<'_>) -> Result<Self> {
170 let level: StabilityLevel = input.parse()?;
171
172 let predicate = if input.peek(Token![,]) {
174 input.parse::<Token![,]>()?;
175
176 let cfg_ident: Ident = input.parse()?;
178 if cfg_ident != "cfg" {
179 return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
180 }
181 let cfg_content;
182 syn::parenthesized!(cfg_content in input);
183 Some(cfg_content.parse()?)
184 } else {
185 None
186 };
187
188 let content;
189 braced!(content in input);
190
191 let mut items = Vec::new();
192 while !content.is_empty() {
193 items.push(content.parse()?);
194 }
195
196 Ok(Self {
197 level,
198 predicate,
199 items,
200 })
201 }
202}
203
204#[proc_macro]
205pub fn stability_scope(input: TokenStream) -> TokenStream {
206 let StabilityScopeInput {
207 level,
208 predicate,
209 items,
210 } = parse_macro_input!(input as StabilityScopeInput);
211
212 let exclude_names = exclusion_cfg_names(level.value);
213
214 let cfg_attr = predicate.map_or_else(
215 || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
216 |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
217 );
218
219 let expanded_items: Vec<_> = items
220 .into_iter()
221 .map(|item| {
222 quote! {
223 #cfg_attr
224 #item
225 }
226 })
227 .collect();
228
229 let expanded = quote! {
230 #(#expanded_items)*
231 };
232
233 TokenStream::from(expanded)
234}
235
236#[proc_macro_attribute]
237pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
238 let input = parse_macro_input!(item as ItemFn);
240
241 let attrs = input.attrs;
243 let vis = input.vis;
244 let mut sig = input.sig;
245 let block = input.block;
246
247 sig.asyncness
250 .take()
251 .expect("test_async macro can only be used with async functions");
252
253 let expanded = quote! {
255 #[test]
256 #(#attrs)*
257 #vis #sig {
258 futures::executor::block_on(async #block);
259 }
260 };
261 TokenStream::from(expanded)
262}
263
264#[proc_macro_attribute]
265pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
266 let input = parse_macro_input!(item as ItemFn);
268
269 let default_level = if attr.is_empty() {
271 "debug".to_string()
272 } else {
273 let level_str = parse_macro_input!(attr as LitStr);
274 let level_ident = level_str.value().to_lowercase();
275 match level_ident.as_str() {
276 "trace" | "debug" | "info" | "warn" | "error" => level_ident,
277 _ => {
278 return Error::new_spanned(
279 level_str,
280 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
281 )
282 .to_compile_error()
283 .into();
284 }
285 }
286 };
287
288 let attrs = input.attrs;
290 let vis = input.vis;
291 let sig = input.sig;
292 let block = input.block;
293
294 let expanded = quote! {
296 #[test]
297 #(#attrs)*
298 #vis #sig {
299 use tracing_subscriber::{EnvFilter, layer::SubscriberExt, util::SubscriberInitExt};
300
301 let filter = EnvFilter::try_from_default_env()
303 .unwrap_or_else(|_| EnvFilter::new(#default_level));
304 let subscriber = tracing_subscriber::Registry::default()
305 .with(
306 tracing_subscriber::fmt::layer()
307 .with_test_writer()
308 .with_line_number(true)
309 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
310 )
311 .with(filter);
312 let dispatcher = tracing::Dispatch::new(subscriber);
313
314 tracing::dispatcher::with_default(&dispatcher, || {
316 #block
317 });
318 }
319 };
320 TokenStream::from(expanded)
321}
322
323#[proc_macro_attribute]
324pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
325 if attr.is_empty() {
326 return Error::new(
327 Span::call_site(),
328 "test_group requires a string literal filter group name",
329 )
330 .to_compile_error()
331 .into();
332 }
333
334 let mut input = parse_macro_input!(item as ItemFn);
335 let group_literal = parse_macro_input!(attr as LitStr);
336
337 let group = match nextest::sanitize_group_literal(&group_literal) {
338 Ok(group) => group,
339 Err(err) => return err.to_compile_error().into(),
340 };
341 let groups = match configured_test_groups() {
342 Ok(groups) => groups,
343 Err(_) => {
344 return TokenStream::from(quote!(#input));
346 }
347 };
348
349 if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
350 return err.to_compile_error().into();
351 }
352
353 let original_name = input.sig.ident.to_string();
354 let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
355
356 input.sig.ident = new_ident;
357
358 TokenStream::from(quote!(#input))
359}
360
361#[proc_macro_attribute]
362pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
363 let input = parse_macro_input!(item as ItemFn);
364
365 let log_level = if attr.is_empty() {
367 quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
369 } else {
370 let level_str = parse_macro_input!(attr as LitStr);
372 let level_ident = level_str.value().to_uppercase();
373 match level_ident.as_str() {
374 "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
375 "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
376 "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
377 "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
378 "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
379 _ => {
380 return Error::new_spanned(
382 level_str,
383 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
384 )
385 .to_compile_error()
386 .into();
387 }
388 }
389 };
390
391 let attrs = input.attrs;
392 let vis = input.vis;
393 let sig = input.sig;
394 let block = input.block;
395
396 let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
398 let mut inner_sig = sig.clone();
399 inner_sig.ident = inner_ident.clone();
400
401 let mut outer_sig = sig;
403 outer_sig.inputs.clear();
404
405 let rt_path = match crate_name("commonware-runtime") {
409 Ok(FoundCrate::Itself) => quote!(crate),
410 Ok(FoundCrate::Name(name)) => {
411 let ident = syn::Ident::new(&name, Span::call_site());
412 quote!(#ident)
413 }
414 Err(_) => quote!(::commonware_runtime), };
416
417 let expanded = quote! {
418 #(#attrs)*
421 #vis #inner_sig #block
422
423 #[test]
424 #vis #outer_sig {
425 use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
426 use ::tracing::{Dispatch, dispatcher};
427 use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
428
429 let trace_store = TraceStorage::default();
430 let collecting_layer = CollectingLayer::new(trace_store.clone());
431
432 let fmt_layer = fmt::layer()
433 .with_test_writer()
434 .with_line_number(true)
435 .with_span_events(fmt::format::FmtSpan::CLOSE)
436 .with_filter(#log_level);
437
438 let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
439 let dispatcher = Dispatch::new(subscriber);
440 dispatcher::with_default(&dispatcher, || {
441 #inner_ident(trace_store);
442 });
443 }
444 };
445
446 TokenStream::from(expanded)
447}
448
449struct SelectInput {
450 branches: Vec<Branch>,
451}
452
453struct Branch {
454 pattern: Pat,
455 future: Expr,
456 body: Expr,
457}
458
459struct SelectLoopBranch {
461 pattern: Pat,
462 future: Expr,
463 else_body: Option<Expr>,
464 body: Expr,
465}
466
467impl Parse for SelectInput {
468 fn parse(input: ParseStream<'_>) -> Result<Self> {
469 let mut branches = Vec::new();
470
471 while !input.is_empty() {
472 let pattern = Pat::parse_single(input)?;
473 input.parse::<Token![=]>()?;
474 let future: Expr = input.parse()?;
475 input.parse::<Token![=>]>()?;
476 let body: Expr = input.parse()?;
477
478 branches.push(Branch {
479 pattern,
480 future,
481 body,
482 });
483
484 if input.peek(Token![,]) {
485 input.parse::<Token![,]>()?;
486 } else {
487 break;
488 }
489 }
490
491 Ok(Self { branches })
492 }
493}
494
495#[proc_macro]
496pub fn select(input: TokenStream) -> TokenStream {
497 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
499
500 let mut select_branches = Vec::new();
502 for Branch {
503 pattern,
504 future,
505 body,
506 } in branches.into_iter()
507 {
508 let branch_code = quote! {
510 #pattern = #future => #body,
511 };
512 select_branches.push(branch_code);
513 }
514
515 quote! {
517 {
518 ::commonware_macros::__reexport::tokio::select! {
519 biased;
520 #(#select_branches)*
521 }
522 }
523 }
524 .into()
525}
526
527struct SelectLoopInput {
531 context: Expr,
532 start_expr: Option<Expr>,
533 shutdown_expr: Expr,
534 branches: Vec<SelectLoopBranch>,
535 end_expr: Option<Expr>,
536}
537
538impl Parse for SelectLoopInput {
539 fn parse(input: ParseStream<'_>) -> Result<Self> {
540 let context: Expr = input.parse()?;
542 input.parse::<Token![,]>()?;
543
544 let start_expr = if input.peek(Ident) {
546 let ident: Ident = input.fork().parse()?;
547 if ident == "on_start" {
548 input.parse::<Ident>()?; input.parse::<Token![=>]>()?;
550 let expr: Expr = input.parse()?;
551 input.parse::<Token![,]>()?;
552 Some(expr)
553 } else {
554 None
555 }
556 } else {
557 None
558 };
559
560 let on_stopped_ident: Ident = input.parse()?;
562 if on_stopped_ident != "on_stopped" {
563 return Err(Error::new(
564 on_stopped_ident.span(),
565 "expected `on_stopped` keyword",
566 ));
567 }
568 input.parse::<Token![=>]>()?;
569
570 let shutdown_expr: Expr = input.parse()?;
572
573 input.parse::<Token![,]>()?;
575
576 let mut branches = Vec::new();
579 while !input.is_empty() {
580 if input.peek(Ident) {
582 let ident: Ident = input.fork().parse()?;
583 if ident == "on_end" {
584 break;
585 }
586 }
587
588 let pattern = Pat::parse_single(input)?;
589 input.parse::<Token![=]>()?;
590 let future: Expr = input.parse()?;
591
592 let else_body = if input.peek(Token![else]) {
594 input.parse::<Token![else]>()?;
595 Some(input.parse::<Expr>()?)
596 } else {
597 None
598 };
599
600 input.parse::<Token![=>]>()?;
601 let body: Expr = input.parse()?;
602
603 branches.push(SelectLoopBranch {
604 pattern,
605 future,
606 else_body,
607 body,
608 });
609
610 if input.peek(Token![,]) {
611 input.parse::<Token![,]>()?;
612 } else {
613 break;
614 }
615 }
616
617 let end_expr = if !input.is_empty() && input.peek(Ident) {
619 let ident: Ident = input.parse()?;
620 if ident == "on_end" {
621 input.parse::<Token![=>]>()?;
622 let expr: Expr = input.parse()?;
623 if input.peek(Token![,]) {
624 input.parse::<Token![,]>()?;
625 }
626 Some(expr)
627 } else {
628 return Err(Error::new(ident.span(), "expected `on_end` keyword"));
629 }
630 } else {
631 None
632 };
633
634 Ok(Self {
635 context,
636 start_expr,
637 shutdown_expr,
638 branches,
639 end_expr,
640 })
641 }
642}
643
644#[proc_macro]
645pub fn select_loop(input: TokenStream) -> TokenStream {
646 let SelectLoopInput {
647 context,
648 start_expr,
649 shutdown_expr,
650 branches,
651 end_expr,
652 } = parse_macro_input!(input as SelectLoopInput);
653
654 fn is_irrefutable(pat: &Pat) -> bool {
655 match pat {
656 Pat::Wild(_) | Pat::Rest(_) => true,
657 Pat::Ident(i) => i.subpat.as_ref().is_none_or(|(_, p)| is_irrefutable(p)),
658 Pat::Type(t) => is_irrefutable(&t.pat),
659 Pat::Tuple(t) => t.elems.iter().all(is_irrefutable),
660 Pat::Reference(r) => is_irrefutable(&r.pat),
661 Pat::Paren(p) => is_irrefutable(&p.pat),
662 _ => false,
663 }
664 }
665
666 for b in &branches {
667 if b.else_body.is_none() && !is_irrefutable(&b.pattern) {
668 return Error::new_spanned(
669 &b.pattern,
670 "refutable patterns require an else clause: \
671 `Some(msg) = future else break => { ... }`",
672 )
673 .to_compile_error()
674 .into();
675 }
676 }
677
678 let branch_tokens: Vec<_> = branches
680 .iter()
681 .map(|b| {
682 let pattern = &b.pattern;
683 let future = &b.future;
684 let body = &b.body;
685
686 b.else_body.as_ref().map_or_else(
688 || quote! { #pattern = #future => #body, },
690 |else_expr| {
692 quote! {
693 __select_result = #future => {
694 let #pattern = __select_result else { #else_expr };
695 #body
696 },
697 }
698 },
699 )
700 })
701 .collect();
702
703 fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
706 match expr {
707 Expr::Block(block) => {
708 let stmts = &block.block.stmts;
709 quote! { #(#stmts)* }
710 }
711 other => quote! { #other; },
712 }
713 }
714
715 let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
717 let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
718 let shutdown_tokens = expr_to_tokens(&shutdown_expr);
719
720 quote! {
721 {
722 let mut shutdown = #context.stopped();
723 loop {
724 #on_start_tokens
725
726 commonware_macros::select! {
727 _ = &mut shutdown => {
728 #shutdown_tokens
729
730 #[allow(unreachable_code)]
733 break;
734 },
735 #(#branch_tokens)*
736 }
737
738 #on_end_tokens
739 }
740 }
741 }
742 .into()
743}