1use crate::nextest::configured_test_groups;
7use proc_macro::TokenStream;
8use proc_macro2::Span;
9use proc_macro_crate::{crate_name, FoundCrate};
10use quote::{format_ident, quote};
11use syn::{
12 braced,
13 parse::{Parse, ParseStream, Result},
14 parse_macro_input, Error, Expr, Ident, ItemFn, LitInt, LitStr, Pat, Token, Visibility,
15};
16
17mod nextest;
18
19struct StabilityLevel {
22 value: u8,
23}
24
25impl Parse for StabilityLevel {
26 fn parse(input: ParseStream<'_>) -> Result<Self> {
27 let lookahead = input.lookahead1();
28 if lookahead.peek(LitInt) {
29 let lit: LitInt = input.parse()?;
30 let value: u8 = lit
31 .base10_parse()
32 .map_err(|_| Error::new(lit.span(), "stability level must be 0, 1, 2, 3, or 4"))?;
33 if value > 4 {
34 return Err(Error::new(
35 lit.span(),
36 "stability level must be 0, 1, 2, 3, or 4",
37 ));
38 }
39 Ok(Self { value })
40 } else if lookahead.peek(Ident) {
41 let ident: Ident = input.parse()?;
42 let value = match ident.to_string().as_str() {
43 "ALPHA" => 0,
44 "BETA" => 1,
45 "GAMMA" => 2,
46 "DELTA" => 3,
47 "EPSILON" => 4,
48 _ => {
49 return Err(Error::new(
50 ident.span(),
51 "expected stability level: ALPHA, BETA, GAMMA, DELTA, EPSILON, or 0-4",
52 ));
53 }
54 };
55 Ok(Self { value })
56 } else {
57 Err(lookahead.error())
58 }
59 }
60}
61
62fn level_name(level: u8) -> &'static str {
63 match level {
64 0 => "ALPHA",
65 1 => "BETA",
66 2 => "GAMMA",
67 3 => "DELTA",
68 4 => "EPSILON",
69 _ => unreachable!(),
70 }
71}
72
73fn exclusion_cfg_names(level: u8) -> Vec<proc_macro2::Ident> {
94 let mut names: Vec<_> = ((level + 1)..=4)
95 .map(|l| format_ident!("commonware_stability_{}", level_name(l)))
96 .collect();
97
98 names.push(format_ident!("commonware_stability_RESERVED"));
99 names
100}
101
102#[proc_macro_attribute]
103pub fn stability(attr: TokenStream, item: TokenStream) -> TokenStream {
104 let level = parse_macro_input!(attr as StabilityLevel);
105 let exclude_names = exclusion_cfg_names(level.value);
106
107 let item2: proc_macro2::TokenStream = item.into();
108 let expanded = quote! {
109 #[cfg(not(any(#(#exclude_names),*)))]
110 #item2
111 };
112
113 TokenStream::from(expanded)
114}
115
116struct StabilityModInput {
118 level: StabilityLevel,
119 visibility: Visibility,
120 name: Ident,
121}
122
123impl Parse for StabilityModInput {
124 fn parse(input: ParseStream<'_>) -> Result<Self> {
125 let level: StabilityLevel = input.parse()?;
126 input.parse::<Token![,]>()?;
127 let visibility: Visibility = input.parse()?;
128 input.parse::<Token![mod]>()?;
129 let name: Ident = input.parse()?;
130 Ok(Self {
131 level,
132 visibility,
133 name,
134 })
135 }
136}
137
138#[proc_macro]
139pub fn stability_mod(input: TokenStream) -> TokenStream {
140 let StabilityModInput {
141 level,
142 visibility,
143 name,
144 } = parse_macro_input!(input as StabilityModInput);
145
146 let exclude_names = exclusion_cfg_names(level.value);
147
148 let expanded = quote! {
149 #[cfg(not(any(#(#exclude_names),*)))]
150 #visibility mod #name;
151 };
152
153 TokenStream::from(expanded)
154}
155
156struct StabilityScopeInput {
158 level: StabilityLevel,
159 predicate: Option<syn::Meta>,
160 items: Vec<syn::Item>,
161}
162
163impl Parse for StabilityScopeInput {
164 fn parse(input: ParseStream<'_>) -> Result<Self> {
165 let level: StabilityLevel = input.parse()?;
166
167 let predicate = if input.peek(Token![,]) {
169 input.parse::<Token![,]>()?;
170
171 let cfg_ident: Ident = input.parse()?;
173 if cfg_ident != "cfg" {
174 return Err(Error::new(cfg_ident.span(), "expected `cfg`"));
175 }
176 let cfg_content;
177 syn::parenthesized!(cfg_content in input);
178 Some(cfg_content.parse()?)
179 } else {
180 None
181 };
182
183 let content;
184 braced!(content in input);
185
186 let mut items = Vec::new();
187 while !content.is_empty() {
188 items.push(content.parse()?);
189 }
190
191 Ok(Self {
192 level,
193 predicate,
194 items,
195 })
196 }
197}
198
199#[proc_macro]
200pub fn stability_scope(input: TokenStream) -> TokenStream {
201 let StabilityScopeInput {
202 level,
203 predicate,
204 items,
205 } = parse_macro_input!(input as StabilityScopeInput);
206
207 let exclude_names = exclusion_cfg_names(level.value);
208
209 let cfg_attr = predicate.map_or_else(
210 || quote! { #[cfg(not(any(#(#exclude_names),*)))] },
211 |pred| quote! { #[cfg(all(#pred, not(any(#(#exclude_names),*))))] },
212 );
213
214 let expanded_items: Vec<_> = items
215 .into_iter()
216 .map(|item| {
217 quote! {
218 #cfg_attr
219 #item
220 }
221 })
222 .collect();
223
224 let expanded = quote! {
225 #(#expanded_items)*
226 };
227
228 TokenStream::from(expanded)
229}
230
231#[proc_macro_attribute]
232pub fn test_async(_: TokenStream, item: TokenStream) -> TokenStream {
233 let input = parse_macro_input!(item as ItemFn);
235
236 let attrs = input.attrs;
238 let vis = input.vis;
239 let mut sig = input.sig;
240 let block = input.block;
241
242 sig.asyncness
245 .take()
246 .expect("test_async macro can only be used with async functions");
247
248 let expanded = quote! {
250 #[test]
251 #(#attrs)*
252 #vis #sig {
253 futures::executor::block_on(async #block);
254 }
255 };
256 TokenStream::from(expanded)
257}
258
259#[proc_macro_attribute]
260pub fn test_traced(attr: TokenStream, item: TokenStream) -> TokenStream {
261 let input = parse_macro_input!(item as ItemFn);
263
264 let log_level = if attr.is_empty() {
266 quote! { tracing::Level::DEBUG }
268 } else {
269 let level_str = parse_macro_input!(attr as LitStr);
271 let level_ident = level_str.value().to_uppercase();
272 match level_ident.as_str() {
273 "TRACE" => quote! { tracing::Level::TRACE },
274 "DEBUG" => quote! { tracing::Level::DEBUG },
275 "INFO" => quote! { tracing::Level::INFO },
276 "WARN" => quote! { tracing::Level::WARN },
277 "ERROR" => quote! { tracing::Level::ERROR },
278 _ => {
279 return Error::new_spanned(
281 level_str,
282 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
283 )
284 .to_compile_error()
285 .into();
286 }
287 }
288 };
289
290 let attrs = input.attrs;
292 let vis = input.vis;
293 let sig = input.sig;
294 let block = input.block;
295
296 let expanded = quote! {
298 #[test]
299 #(#attrs)*
300 #vis #sig {
301 let subscriber = tracing_subscriber::fmt()
303 .with_test_writer()
304 .with_max_level(#log_level)
305 .with_line_number(true)
306 .with_span_events(tracing_subscriber::fmt::format::FmtSpan::CLOSE)
307 .finish();
308 let dispatcher = tracing::Dispatch::new(subscriber);
309
310 tracing::dispatcher::with_default(&dispatcher, || {
312 #block
313 });
314 }
315 };
316 TokenStream::from(expanded)
317}
318
319#[proc_macro_attribute]
320pub fn test_group(attr: TokenStream, item: TokenStream) -> TokenStream {
321 if attr.is_empty() {
322 return Error::new(
323 Span::call_site(),
324 "test_group requires a string literal filter group name",
325 )
326 .to_compile_error()
327 .into();
328 }
329
330 let mut input = parse_macro_input!(item as ItemFn);
331 let group_literal = parse_macro_input!(attr as LitStr);
332
333 let group = match nextest::sanitize_group_literal(&group_literal) {
334 Ok(group) => group,
335 Err(err) => return err.to_compile_error().into(),
336 };
337 let groups = match configured_test_groups() {
338 Ok(groups) => groups,
339 Err(_) => {
340 return TokenStream::from(quote!(#input));
342 }
343 };
344
345 if let Err(err) = nextest::ensure_group_known(groups, &group, group_literal.span()) {
346 return err.to_compile_error().into();
347 }
348
349 let original_name = input.sig.ident.to_string();
350 let new_ident = Ident::new(&format!("{original_name}_{group}_"), input.sig.ident.span());
351
352 input.sig.ident = new_ident;
353
354 TokenStream::from(quote!(#input))
355}
356
357#[proc_macro_attribute]
358pub fn test_collect_traces(attr: TokenStream, item: TokenStream) -> TokenStream {
359 let input = parse_macro_input!(item as ItemFn);
360
361 let log_level = if attr.is_empty() {
363 quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG }
365 } else {
366 let level_str = parse_macro_input!(attr as LitStr);
368 let level_ident = level_str.value().to_uppercase();
369 match level_ident.as_str() {
370 "TRACE" => quote! { ::tracing_subscriber::filter::LevelFilter::TRACE },
371 "DEBUG" => quote! { ::tracing_subscriber::filter::LevelFilter::DEBUG },
372 "INFO" => quote! { ::tracing_subscriber::filter::LevelFilter::INFO },
373 "WARN" => quote! { ::tracing_subscriber::filter::LevelFilter::WARN },
374 "ERROR" => quote! { ::tracing_subscriber::filter::LevelFilter::ERROR },
375 _ => {
376 return Error::new_spanned(
378 level_str,
379 "Invalid log level. Expected one of: TRACE, DEBUG, INFO, WARN, ERROR.",
380 )
381 .to_compile_error()
382 .into();
383 }
384 }
385 };
386
387 let attrs = input.attrs;
388 let vis = input.vis;
389 let sig = input.sig;
390 let block = input.block;
391
392 let inner_ident = format_ident!("__{}_inner_traced", sig.ident);
394 let mut inner_sig = sig.clone();
395 inner_sig.ident = inner_ident.clone();
396
397 let mut outer_sig = sig;
399 outer_sig.inputs.clear();
400
401 let rt_path = match crate_name("commonware-runtime") {
405 Ok(FoundCrate::Itself) => quote!(crate),
406 Ok(FoundCrate::Name(name)) => {
407 let ident = syn::Ident::new(&name, Span::call_site());
408 quote!(#ident)
409 }
410 Err(_) => quote!(::commonware_runtime), };
412
413 let expanded = quote! {
414 #(#attrs)*
417 #vis #inner_sig #block
418
419 #[test]
420 #vis #outer_sig {
421 use ::tracing_subscriber::{Layer, fmt, Registry, layer::SubscriberExt, util::SubscriberInitExt};
422 use ::tracing::{Dispatch, dispatcher};
423 use #rt_path::telemetry::traces::collector::{CollectingLayer, TraceStorage};
424
425 let trace_store = TraceStorage::default();
426 let collecting_layer = CollectingLayer::new(trace_store.clone());
427
428 let fmt_layer = fmt::layer()
429 .with_test_writer()
430 .with_line_number(true)
431 .with_span_events(fmt::format::FmtSpan::CLOSE)
432 .with_filter(#log_level);
433
434 let subscriber = Registry::default().with(collecting_layer).with(fmt_layer);
435 let dispatcher = Dispatch::new(subscriber);
436 dispatcher::with_default(&dispatcher, || {
437 #inner_ident(trace_store);
438 });
439 }
440 };
441
442 TokenStream::from(expanded)
443}
444
445struct SelectInput {
446 branches: Vec<Branch>,
447}
448
449struct Branch {
450 pattern: Pat,
451 future: Expr,
452 body: Expr,
453}
454
455impl Parse for SelectInput {
456 fn parse(input: ParseStream<'_>) -> Result<Self> {
457 let mut branches = Vec::new();
458
459 while !input.is_empty() {
460 let pattern = Pat::parse_single(input)?;
461 input.parse::<Token![=]>()?;
462 let future: Expr = input.parse()?;
463 input.parse::<Token![=>]>()?;
464 let body: Expr = input.parse()?;
465
466 branches.push(Branch {
467 pattern,
468 future,
469 body,
470 });
471
472 if input.peek(Token![,]) {
473 input.parse::<Token![,]>()?;
474 } else {
475 break;
476 }
477 }
478
479 Ok(Self { branches })
480 }
481}
482
483#[proc_macro]
484pub fn select(input: TokenStream) -> TokenStream {
485 let SelectInput { branches } = parse_macro_input!(input as SelectInput);
487
488 let mut select_branches = Vec::new();
490 for Branch {
491 pattern,
492 future,
493 body,
494 } in branches.into_iter()
495 {
496 let branch_code = quote! {
498 #pattern = #future => #body,
499 };
500 select_branches.push(branch_code);
501 }
502
503 quote! {
505 {
506 ::commonware_macros::__reexport::tokio::select! {
507 biased;
508 #(#select_branches)*
509 }
510 }
511 }
512 .into()
513}
514
515struct SelectLoopInput {
519 context: Expr,
520 start_expr: Option<Expr>,
521 shutdown_expr: Expr,
522 branches: Vec<Branch>,
523 end_expr: Option<Expr>,
524}
525
526impl Parse for SelectLoopInput {
527 fn parse(input: ParseStream<'_>) -> Result<Self> {
528 let context: Expr = input.parse()?;
530 input.parse::<Token![,]>()?;
531
532 let start_expr = if input.peek(Ident) {
534 let ident: Ident = input.fork().parse()?;
535 if ident == "on_start" {
536 input.parse::<Ident>()?; input.parse::<Token![=>]>()?;
538 let expr: Expr = input.parse()?;
539 input.parse::<Token![,]>()?;
540 Some(expr)
541 } else {
542 None
543 }
544 } else {
545 None
546 };
547
548 let on_stopped_ident: Ident = input.parse()?;
550 if on_stopped_ident != "on_stopped" {
551 return Err(Error::new(
552 on_stopped_ident.span(),
553 "expected `on_stopped` keyword",
554 ));
555 }
556 input.parse::<Token![=>]>()?;
557
558 let shutdown_expr: Expr = input.parse()?;
560
561 input.parse::<Token![,]>()?;
563
564 let mut branches = Vec::new();
567 while !input.is_empty() {
568 if input.peek(Ident) {
570 let ident: Ident = input.fork().parse()?;
571 if ident == "on_end" {
572 break;
573 }
574 }
575
576 let pattern = Pat::parse_single(input)?;
577 input.parse::<Token![=]>()?;
578 let future: Expr = input.parse()?;
579 input.parse::<Token![=>]>()?;
580 let body: Expr = input.parse()?;
581
582 branches.push(Branch {
583 pattern,
584 future,
585 body,
586 });
587
588 if input.peek(Token![,]) {
589 input.parse::<Token![,]>()?;
590 } else {
591 break;
592 }
593 }
594
595 let end_expr = if !input.is_empty() && input.peek(Ident) {
597 let ident: Ident = input.parse()?;
598 if ident == "on_end" {
599 input.parse::<Token![=>]>()?;
600 let expr: Expr = input.parse()?;
601 if input.peek(Token![,]) {
602 input.parse::<Token![,]>()?;
603 }
604 Some(expr)
605 } else {
606 return Err(Error::new(ident.span(), "expected `on_end` keyword"));
607 }
608 } else {
609 None
610 };
611
612 Ok(Self {
613 context,
614 start_expr,
615 shutdown_expr,
616 branches,
617 end_expr,
618 })
619 }
620}
621
622#[proc_macro]
623pub fn select_loop(input: TokenStream) -> TokenStream {
624 let SelectLoopInput {
625 context,
626 start_expr,
627 shutdown_expr,
628 branches,
629 end_expr,
630 } = parse_macro_input!(input as SelectLoopInput);
631
632 let branch_tokens: Vec<_> = branches
634 .iter()
635 .map(|b| {
636 let pattern = &b.pattern;
637 let future = &b.future;
638 let body = &b.body;
639 quote! { #pattern = #future => #body, }
640 })
641 .collect();
642
643 fn expr_to_tokens(expr: &Expr) -> proc_macro2::TokenStream {
646 match expr {
647 Expr::Block(block) => {
648 let stmts = &block.block.stmts;
649 quote! { #(#stmts)* }
650 }
651 other => quote! { #other; },
652 }
653 }
654
655 let on_start_tokens = start_expr.as_ref().map(expr_to_tokens);
657 let on_end_tokens = end_expr.as_ref().map(expr_to_tokens);
658 let shutdown_tokens = expr_to_tokens(&shutdown_expr);
659
660 quote! {
661 {
662 let mut shutdown = #context.stopped();
663 loop {
664 #on_start_tokens
665
666 commonware_macros::select! {
667 _ = &mut shutdown => {
668 #shutdown_tokens
669
670 #[allow(unreachable_code)]
673 break;
674 },
675 #(#branch_tokens)*
676 }
677
678 #on_end_tokens
679 }
680 }
681 }
682 .into()
683}