1#![doc(
7 html_logo_url = "https://commonware.xyz/imgs/rustdoc_logo.svg",
8 html_favicon_url = "https://commonware.xyz/favicon.ico"
9)]
10
11use crate::nextest::configured_test_groups;
12use proc_macro::TokenStream;
13use proc_macro2::Span;
14use proc_macro_crate::{crate_name, FoundCrate};
15use quote::{format_ident, quote};
16use syn::{
17 braced,
18 parse::{Parse, ParseStream, Result},
19 parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
20};
21
22mod nextest;
23
24struct StabilityLevel {
27 value: u8,
28}
29
30impl Parse for StabilityLevel {
31 fn parse(input: ParseStream<'_>) -> Result<Self> {
32 let lookahead = input.lookahead1();
33 if lookahead.peek(LitInt) {
34 let lit: LitInt = input.parse()?;
35 let value: u8 = lit
36 .base10_parse()
37 .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
38 if value > 4 {
39 return Err(Error::new(
40 lit.span(),
41 "stability level must be 0, 1, 2, 3, or 4",
42 ));
43 }
44 Ok(Self { value })
45 } else if lookahead.peek(Ident) {
46 let ident: Ident = input.parse()?;
47 let value = match ident.to_string().as_str() {
48 "ALPHA" => 0,
49 "BETA" => 1,
50 "GAMMA" => 2,
51 "DELTA" => 3,
52 "EPSILON" => 4,
53 _ => {
54 return Err(Error::new(
55 ident.span(),
56 "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
57 ));
58 }
59 };
60 Ok(Self { value })
61 } else {
62 Err(lookahead.error())
63 }
64 }
65}
66
67fn level_name(level: u8) -> &'static str {
68 match level {
69 0 => "ALPHA",
70 1 => "BETA",
71 2 => "GAMMA",
72 3 => "DELTA",
73 4 => "EPSILON",
74 _ => unreachable!(),
75 }
76}
77
78fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
99 let mut names: Vec<_> = ((level + 1)..=4)
100 .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
101 .collect();
102
103 names.push(format_ident!("commonware_stability_RESERVED"));
104 names
105}
106
107#[proc_macro_attribute]
108pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
109 let level = parse_macro_input!(attr as StabilityLevel);
110 let exclude_names = exclusion_cfg_names(level.value);
111
112 let item2: proc_macro2::TokenStream = item.into();
113 let expanded = quote! {
114 #[cfg(not(any(#(#exclude_names),*)))]
115 #item2
116 };
117
118 TokenStream::from(expanded)
119}
120
121struct StabilityModInput {
123 level: StabilityLevel,
124 visibility: Visibility,
125 name: Ident,
126}
127
128impl Parse for StabilityModInput {
129 fn parse(input: ParseStream<'_>) -> Result<Self> {
130 let level: StabilityLevel = input.parse()?;
131 input.parse::<Token![,]>()?;
132 let visibility: Visibility = input.parse()?;
133 input.parse::<Token![mod]>()?;
134 let name: Ident = input.parse()?;
135 Ok(Self {
136 level,
137 visibility,
138 name,
139 })
140 }
141}
142
143#[proc_macro]
144pub fn stability_mod(input: TokenStream) -> TokenStream {
145 let StabilityModInput {
146 level,
147 visibility,
148 name,
149 } = parse_macro_input!(input as StabilityModInput);
150
151 let exclude_names = exclusion_cfg_names(level.value);
152
153 let expanded = quote! {
154 #[cfg(not(any(#(#exclude_names),*)))]
155 #visibility mod #name;
156 };
157
158 TokenStream::from(expanded)
159}
160
161struct StabilityScopeInput {
163 level: StabilityLevel,
164 predicate: Option<syn::Meta>,
165 items: Vec<syn::Item>,
166}
167
168impl Parse for StabilityScopeInput {
169 fn parse(input: ParseStream<'_>) -> Result<Self> {
170 let level: StabilityLevel = input.parse()?;
171
172 let predicate = if input.peek(Token![,]) {
174 input.parse::<Token![,]>()?;
175
176 let cfg_ident: Ident = input.parse()?;
178 if cfg_ident != "cfg" {
179 return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
180 }
181 let cfg_content;
182 syn::parenthesized!(cfg_content in input);
183 Some(cfg_content.parse()?)
184 } else {
185 None
186 };
187
188 let content;
189 braced!(content in input);
190
191 let mut items = Vec::new();
192 while !content.is_empty() {
193 items.push(content.parse()?);
194 }
195
196 Ok(Self {
197 level,
198 predicate,
199 items,
200 })
201 }
202}
203
204#[proc_macro]
205pub fn stability_scope(input: TokenStream) -> TokenStream {
206 let StabilityScopeInput {
207 level,
208 predicate,
209 items,
210 } = parse_macro_input!(input as StabilityScopeInput);
211
212 let exclude_names = exclusion_cfg_names(level.value);
213
214 let cfg_attr = predicate.map_or_else(
215 || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
216 |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
217 );
218
219 let expanded_items: Vec<_> = items
220 .into_iter()
221 .map(|item| {
222 quote! {
223 #cfg_attr
224 #item
225 }
226 })
227 .collect();
228
229 let expanded = quote! {
230 #(#expanded_items)*
231 };
232
233 TokenStream::from(expanded)
234}
235
236#[proc_macro_attribute]
237pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
238 let input = parse_macro_input!(item as ItemFn);
240
241 let attrs = input.attrs;
243 let vis = input.vis;
244 let mut sig = input.sig;
245 let block = input.block;
246
247 sig.asyncness
250 .take()
251 .expect("test_async macro can only be used with async functions");
252
253 let expanded = quote! {
255 #[test]
256 #(#attrs)*
257 #vis #sig {
258 futures::executor::block_on(async #block);
259 }
260 };
261 TokenStream::from(expanded)
262}
263
264#[proc_macro_attribute]
265pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
266 let input = parse_macro_input!(item as ItemFn);
268
269 let log_level = if attr.is_empty() {
271 quote! { tracing::Level::DEBUG }
273 } else {
274 let level_str = parse_macro_input!(attr as LitStr);
276 let level_ident = level_str.value().to_uppercase();
277 match level_ident.as_str() {
278 "TRACE" => quote! { tracing::Level::TRACE },
279 "DEBUG" => quote! { tracing::Level::DEBUG },
280 "INFO" => quote! { tracing::Level::INFO },
281 "WARN" => quote! { tracing::Level::WARN },
282 "ERROR" => quote! { tracing::Level::ERROR },
283 _ => {
284 return Error::new_spanned(
286 level_str,
287 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
288 )
289 .to_compile_error()
290 .into();
291 }
292 }
293 };
294
295 let attrs = input.attrs;
297 let vis = input.vis;
298 let sig = input.sig;
299 let block = input.block;
300
301 let expanded = quote! {
303 #[test]
304 #(#attrs)*
305 #vis #sig {
306 let subscriber = tracing_subscriber::fmt()
308 .with_test_writer()
309 .with_max_level(#log_level)
310 .with_line_number(true)
311 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
312 .finish();
313 let dispatcher = tracing::Dispatch::new(subscriber);
314
315 tracing::dispatcher::with_default(&dispatcher, || {
317 #block
318 });
319 }
320 };
321 TokenStream::from(expanded)
322}
323
324#[proc_macro_attribute]
325pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
326 if attr.is_empty() {
327 return Error::new(
328 Span::call_site(),
329 "test_group requires a string literal filter group name",
330 )
331 .to_compile_error()
332 .into();
333 }
334
335 let mut input = parse_macro_input!(item as ItemFn);
336 let group_literal = parse_macro_input!(attr as LitStr);
337
338 let group = match nextest::sanitize_group_literal(&group_literal) {
339 Ok(group) => group,
340 Err(err) => return err.to_compile_error().into(),
341 };
342 let groups = match configured_test_groups() {
343 Ok(groups) => groups,
344 Err(_) => {
345 return TokenStream::from(quote!(#input));
347 }
348 };
349
350 if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
351 return err.to_compile_error().into();
352 }
353
354 let original_name = input.sig.ident.to_string();
355 let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
356
357 input.sig.ident = new_ident;
358
359 TokenStream::from(quote!(#input))
360}
361
362#[proc_macro_attribute]
363pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
364 let input = parse_macro_input!(item as ItemFn);
365
366 let log_level = if attr.is_empty() {
368 quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
370 } else {
371 let level_str = parse_macro_input!(attr as LitStr);
373 let level_ident = level_str.value().to_uppercase();
374 match level_ident.as_str() {
375 "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
376 "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
377 "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
378 "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
379 "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
380 _ => {
381 return Error::new_spanned(
383 level_str,
384 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
385 )
386 .to_compile_error()
387 .into();
388 }
389 }
390 };
391
392 let attrs = input.attrs;
393 let vis = input.vis;
394 let sig = input.sig;
395 let block = input.block;
396
397 let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
399 let mut inner_sig = sig.clone();
400 inner_sig.ident = inner_ident.clone();
401
402 let mut outer_sig = sig;
404 outer_sig.inputs.clear();
405
406 let rt_path = match crate_name("commonware-runtime") {
410 Ok(FoundCrate::Itself) => quote!(crate),
411 Ok(FoundCrate::Name(name)) => {
412 let ident = syn::Ident::new(&name, Span::call_site());
413 quote!(#ident)
414 }
415 Err(_) => quote!(::commonware_runtime), };
417
418 let expanded = quote! {
419 #(#attrs)*
422 #vis #inner_sig #block
423
424 #[test]
425 #vis #outer_sig {
426 use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
427 use ::tracing::{Dispatch, dispatcher};
428 use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
429
430 let trace_store = TraceStorage::default();
431 let collecting_layer = CollectingLayer::new(trace_store.clone());
432
433 let fmt_layer = fmt::layer()
434 .with_test_writer()
435 .with_line_number(true)
436 .with_span_events(fmt::format::FmtSpan::CLOSE)
437 .with_filter(#log_level);
438
439 let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
440 let dispatcher = Dispatch::new(subscriber);
441 dispatcher::with_default(&dispatcher, || {
442 #inner_ident(trace_store);
443 });
444 }
445 };
446
447 TokenStream::from(expanded)
448}
449
450struct SelectInput {
451 branches: Vec<Branch>,
452}
453
454struct Branch {
455 pattern: Pat,
456 future: Expr,
457 body: Expr,
458}
459
460struct SelectLoopBranch {
462 pattern: Pat,
463 future: Expr,
464 else_body: Option<Expr>,
465 body: Expr,
466}
467
468impl Parse for SelectInput {
469 fn parse(input: ParseStream<'_>) -> Result<Self> {
470 let mut branches = Vec::new();
471
472 while !input.is_empty() {
473 let pattern = Pat::parse_single(input)?;
474 input.parse::<Token![=]>()?;
475 let future: Expr = input.parse()?;
476 input.parse::<Token![=>]>()?;
477 let body: Expr = input.parse()?;
478
479 branches.push(Branch {
480 pattern,
481 future,
482 body,
483 });
484
485 if input.peek(Token![,]) {
486 input.parse::<Token![,]>()?;
487 } else {
488 break;
489 }
490 }
491
492 Ok(Self { branches })
493 }
494}
495
496#[proc_macro]
497pub fn select(input: TokenStream) -> TokenStream {
498 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
500
501 let mut select_branches = Vec::new();
503 for Branch {
504 pattern,
505 future,
506 body,
507 } in branches.into_iter()
508 {
509 let branch_code = quote! {
511 #pattern = #future => #body,
512 };
513 select_branches.push(branch_code);
514 }
515
516 quote! {
518 {
519 ::commonware_macros::__reexport::tokio::select! {
520 biased;
521 #(#select_branches)*
522 }
523 }
524 }
525 .into()
526}
527
528struct SelectLoopInput {
532 context: Expr,
533 start_expr: Option<Expr>,
534 shutdown_expr: Expr,
535 branches: Vec<SelectLoopBranch>,
536 end_expr: Option<Expr>,
537}
538
539impl Parse for SelectLoopInput {
540 fn parse(input: ParseStream<'_>) -> Result<Self> {
541 let context: Expr = input.parse()?;
543 input.parse::<Token![,]>()?;
544
545 let start_expr = if input.peek(Ident) {
547 let ident: Ident = input.fork().parse()?;
548 if ident == "on_start" {
549 input.parse::<Ident>()?; input.parse::<Token![=>]>()?;
551 let expr: Expr = input.parse()?;
552 input.parse::<Token![,]>()?;
553 Some(expr)
554 } else {
555 None
556 }
557 } else {
558 None
559 };
560
561 let on_stopped_ident: Ident = input.parse()?;
563 if on_stopped_ident != "on_stopped" {
564 return Err(Error::new(
565 on_stopped_ident.span(),
566 "expected `on_stopped` keyword",
567 ));
568 }
569 input.parse::<Token![=>]>()?;
570
571 let shutdown_expr: Expr = input.parse()?;
573
574 input.parse::<Token![,]>()?;
576
577 let mut branches = Vec::new();
580 while !input.is_empty() {
581 if input.peek(Ident) {
583 let ident: Ident = input.fork().parse()?;
584 if ident == "on_end" {
585 break;
586 }
587 }
588
589 let pattern = Pat::parse_single(input)?;
590 input.parse::<Token![=]>()?;
591 let future: Expr = input.parse()?;
592
593 let else_body = if input.peek(Token![else]) {
595 input.parse::<Token![else]>()?;
596 Some(input.parse::<Expr>()?)
597 } else {
598 None
599 };
600
601 input.parse::<Token![=>]>()?;
602 let body: Expr = input.parse()?;
603
604 branches.push(SelectLoopBranch {
605 pattern,
606 future,
607 else_body,
608 body,
609 });
610
611 if input.peek(Token![,]) {
612 input.parse::<Token![,]>()?;
613 } else {
614 break;
615 }
616 }
617
618 let end_expr = if !input.is_empty() && input.peek(Ident) {
620 let ident: Ident = input.parse()?;
621 if ident == "on_end" {
622 input.parse::<Token![=>]>()?;
623 let expr: Expr = input.parse()?;
624 if input.peek(Token![,]) {
625 input.parse::<Token![,]>()?;
626 }
627 Some(expr)
628 } else {
629 return Err(Error::new(ident.span(), "expected `on_end` keyword"));
630 }
631 } else {
632 None
633 };
634
635 Ok(Self {
636 context,
637 start_expr,
638 shutdown_expr,
639 branches,
640 end_expr,
641 })
642 }
643}
644
645#[proc_macro]
646pub fn select_loop(input: TokenStream) -> TokenStream {
647 let SelectLoopInput {
648 context,
649 start_expr,
650 shutdown_expr,
651 branches,
652 end_expr,
653 } = parse_macro_input!(input as SelectLoopInput);
654
655 fn is_irrefutable(pat: &Pat) -> bool {
656 match pat {
657 Pat::Wild(_) | Pat::Rest(_) => true,
658 Pat::Ident(i) => i.subpat.as_ref().is_none_or(|(_, p)| is_irrefutable(p)),
659 Pat::Type(t) => is_irrefutable(&t.pat),
660 Pat::Tuple(t) => t.elems.iter().all(is_irrefutable),
661 Pat::Reference(r) => is_irrefutable(&r.pat),
662 Pat::Paren(p) => is_irrefutable(&p.pat),
663 _ => false,
664 }
665 }
666
667 for b in &branches {
668 if b.else_body.is_none() && !is_irrefutable(&b.pattern) {
669 return Error::new_spanned(
670 &b.pattern,
671 "refutable patterns require an else clause: \
672 `Some(msg) = future else break => { ... }`",
673 )
674 .to_compile_error()
675 .into();
676 }
677 }
678
679 let branch_tokens: Vec<_> = branches
681 .iter()
682 .map(|b| {
683 let pattern = &b.pattern;
684 let future = &b.future;
685 let body = &b.body;
686
687 b.else_body.as_ref().map_or_else(
689 || quote! { #pattern = #future => #body, },
691 |else_expr| {
693 quote! {
694 __select_result = #future => {
695 let #pattern = __select_result else { #else_expr };
696 #body
697 },
698 }
699 },
700 )
701 })
702 .collect();
703
704 fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
707 match expr {
708 Expr::Block(block) => {
709 let stmts = &block.block.stmts;
710 quote! { #(#stmts)* }
711 }
712 other => quote! { #other; },
713 }
714 }
715
716 let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
718 let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
719 let shutdown_tokens = expr_to_tokens(&shutdown_expr);
720
721 quote! {
722 {
723 let mut shutdown = #context.stopped();
724 loop {
725 #on_start_tokens
726
727 commonware_macros::select! {
728 _ = &mut shutdown => {
729 #shutdown_tokens
730
731 #[allow(unreachable_code)]
734 break;
735 },
736 #(#branch_tokens)*
737 }
738
739 #on_end_tokens
740 }
741 }
742 }
743 .into()
744}