1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use quote::quote_spanned;
7use quote::ToTokens;
8use syn::__private::TokenStream2;
9use syn::parse::Parser;
10use syn::ItemFn;
11
12fn get_runtime_name() -> &'static str {
13 if cfg!(feature = "tokio") {
14 "tokio"
15 } else if cfg!(feature = "monoio") {
16 "monoio"
17 } else {
18 "tokio"
19 }
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23enum RuntimeFlavor {
24 CurrentThread,
25 Threaded,
26}
27
28impl RuntimeFlavor {
29 fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
30 match s {
31 "current_thread" => Ok(RuntimeFlavor::CurrentThread),
32 "multi_thread" => Ok(RuntimeFlavor::Threaded),
33 _ => Err(format!(
34 "No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.",
35 s
36 )),
37 }
38 }
39}
40
41#[derive(Debug)]
42struct FinalConfig {
43 flavor: RuntimeFlavor,
44 worker_threads: Option<usize>,
45 start_paused: Option<bool>,
46 init: Option<(String, Span)>,
47 tracing_span: Option<(String, Span)>,
48 tracing_lib: Option<(String, Span)>,
49}
50
51struct Configuration {
52 rt_multi_thread_available: bool,
53 default_flavor: RuntimeFlavor,
54 flavor: Option<RuntimeFlavor>,
55 worker_threads: Option<(usize, Span)>,
56 start_paused: Option<(bool, Span)>,
57 is_test: bool,
58 init: Option<(String, Span)>,
59 tracing_span: Option<(String, Span)>,
60
61 tracing_lib: Option<(String, Span)>,
63}
64
65impl Configuration {
66 fn new(is_test: bool, rt_multi_thread: bool) -> Self {
67 Configuration {
68 rt_multi_thread_available: rt_multi_thread,
69 default_flavor: match is_test {
70 true => RuntimeFlavor::CurrentThread,
71 false => RuntimeFlavor::Threaded,
72 },
73 flavor: None,
74 worker_threads: None,
75 start_paused: None,
76 is_test,
77 init: None,
78 tracing_span: None,
79 tracing_lib: None,
80 }
81 }
82
83 fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
84 if self.flavor.is_some() {
85 return Err(syn::Error::new(span, "`flavor` set multiple times."));
86 }
87
88 let runtime_str = parse_string(runtime, span, "flavor")?;
89 let runtime = RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
90 self.flavor = Some(runtime);
91 Ok(())
92 }
93
94 fn set_init(&mut self, init_fn: syn::Lit, span: Span) -> Result<(), syn::Error> {
95 if self.init.is_some() {
96 return Err(syn::Error::new(span, "`init` set multiple times."));
97 }
98
99 let init_expr = parse_string(init_fn, span, "init")?;
100 self.init = Some((init_expr, span));
101
102 Ok(())
103 }
104
105 fn set_tracing_span(&mut self, level: syn::Lit, span: Span) -> Result<(), syn::Error> {
106 if self.tracing_span.is_some() {
107 return Err(syn::Error::new(span, "`tracing_span` set multiple times."));
108 }
109
110 let tracing_span = parse_string(level, span, "tracing_span")?;
111 self.tracing_span = Some((tracing_span, span));
112
113 Ok(())
114 }
115
116 fn set_tracing_lib(&mut self, level: syn::Lit, span: Span) -> Result<(), syn::Error> {
118 if self.tracing_lib.is_some() {
119 return Err(syn::Error::new(span, "`tracing_lib` set multiple times."));
120 }
121
122 let tracing_lib = parse_string(level, span, "tracing_lib")?;
123 self.tracing_lib = Some((tracing_lib, span));
124
125 Ok(())
126 }
127
128 fn set_worker_threads(&mut self, worker_threads: syn::Lit, span: Span) -> Result<(), syn::Error> {
129 if self.worker_threads.is_some() {
130 return Err(syn::Error::new(span, "`worker_threads` set multiple times."));
131 }
132
133 let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
134 if worker_threads == 0 {
135 self.flavor = Some(RuntimeFlavor::CurrentThread);
136 self.worker_threads = None;
137 } else {
138 self.flavor = Some(RuntimeFlavor::Threaded);
139 self.worker_threads = Some((worker_threads, span));
140 }
141
142 Ok(())
143 }
144
145 fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
146 if self.start_paused.is_some() {
147 return Err(syn::Error::new(span, "`start_paused` set multiple times."));
148 }
149
150 let start_paused = parse_bool(start_paused, span, "start_paused")?;
151 self.start_paused = Some((start_paused, span));
152 Ok(())
153 }
154
155 fn macro_name(&self) -> &'static str {
156 if self.is_test {
157 match get_runtime_name() {
158 "tokio" => "tokio::test",
159 "monoio" => "monoio::test",
160 _ => unreachable!(),
161 }
162 } else {
163 match get_runtime_name() {
164 "tokio" => "tokio::main",
165 "monoio" => "monoio::main",
166 _ => unreachable!(),
167 }
168 }
169 }
170
171 fn build(&self) -> Result<FinalConfig, syn::Error> {
172 let flavor = self.flavor.unwrap_or(self.default_flavor);
173 use RuntimeFlavor::*;
174
175 let worker_threads = match (flavor, self.worker_threads) {
176 (CurrentThread, Some((_, worker_threads_span))) => {
177 let msg = format!(
178 "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
179 self.macro_name(),
180 );
181 return Err(syn::Error::new(worker_threads_span, msg));
182 }
183 (CurrentThread, None) => None,
184 (Threaded, worker_threads) if self.rt_multi_thread_available => worker_threads.map(|(val, _span)| val),
185 (Threaded, _) => {
186 let msg = if self.flavor.is_none() {
187 "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
188 } else {
189 "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
190 };
191 return Err(syn::Error::new(Span::call_site(), msg));
192 }
193 };
194
195 let start_paused = match (flavor, self.start_paused) {
196 (Threaded, Some((_, start_paused_span))) => {
197 let msg = format!(
198 "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
199 self.macro_name(),
200 );
201 return Err(syn::Error::new(start_paused_span, msg));
202 }
203 (CurrentThread, Some((start_paused, _))) => Some(start_paused),
204 (_, None) => None,
205 };
206
207 Ok(FinalConfig {
208 flavor,
209 worker_threads,
210 start_paused,
211 init: self.init.clone(),
212 tracing_span: self.tracing_span.clone(),
213 tracing_lib: self.tracing_lib.clone(),
214 })
215 }
216}
217
218fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
219 match int {
220 syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
221 Ok(value) => Ok(value),
222 Err(e) => Err(syn::Error::new(
223 span,
224 format!("Failed to parse value of `{}` as integer: {}", field, e),
225 )),
226 },
227 _ => Err(syn::Error::new(
228 span,
229 format!("Failed to parse value of `{}` as integer.", field),
230 )),
231 }
232}
233
234fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
235 match int {
236 syn::Lit::Str(s) => Ok(s.value()),
237 syn::Lit::Verbatim(s) => Ok(s.to_string()),
238 _ => Err(syn::Error::new(
239 span,
240 format!("Failed to parse value of `{}` as string.", field),
241 )),
242 }
243}
244
245fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
246 match bool {
247 syn::Lit::Bool(b) => Ok(b.value),
248 _ => Err(syn::Error::new(
249 span,
250 format!("Failed to parse value of `{}` as bool.", field),
251 )),
252 }
253}
254
255fn build_config(args: AttributeArgs, rt_multi_thread: bool) -> Result<FinalConfig, syn::Error> {
256 let mut config = Configuration::new(true, rt_multi_thread);
257 let macro_name = config.macro_name();
258
259 for arg in args {
260 match arg {
261 syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
262 let ident = namevalue
263 .path
264 .get_ident()
265 .ok_or_else(|| syn::Error::new_spanned(&namevalue, "Must have specified ident"))?
266 .to_string()
267 .to_lowercase();
268 match ident.as_str() {
269 "worker_threads" => {
270 config
271 .set_worker_threads(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
272 }
273 "flavor" => {
274 config.set_flavor(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
275 }
276 "start_paused" => {
277 config.set_start_paused(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
278 }
279 "init" => {
280 config.set_init(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
281 }
282 "tracing_span" => {
283 config.set_tracing_span(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
284 }
285 "tracing_lib" => {
286 config.set_tracing_lib(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
287 }
288
289 name => {
290 let msg = format!(
291 "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `init`, `tracing_span`",
292 name,
293 );
294 return Err(syn::Error::new_spanned(namevalue, msg));
295 }
296 }
297 }
298 syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
299 let name = path
300 .get_ident()
301 .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
302 .to_string()
303 .to_lowercase();
304 let msg = match name.as_str() {
305 "threaded_scheduler" | "multi_thread" => {
306 format!(
307 "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
308 macro_name
309 )
310 }
311 "basic_scheduler" | "current_thread" | "single_threaded" => {
312 format!(
313 "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
314 macro_name
315 )
316 }
317 "flavor" | "worker_threads" | "start_paused" => {
318 format!("The `{}` attribute requires an argument.", name)
319 }
320 "init" => {
321 format!(
322 "The `{}` attribute requires an argument in string of the initializing statement to run.",
323 name
324 )
325 }
326 "tracing_span" => {
327 format!(
328 "The `{}` attribute requires an argument of level of the span, e.g. `debug` or `info`.",
329 name
330 )
331 }
332 "tracing_lib" => {
333 format!(
334 "The `{}` attribute requires an argument of level of the span, e.g. \"my_lib::\" or \"::\" or \"\".",
335 name
336 )
337 }
338 name => {
339 format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `init`, `tracing_span`, `tracing_lib`", name)
340 }
341 };
342 return Err(syn::Error::new_spanned(path, msg));
343 }
344 other => {
345 return Err(syn::Error::new_spanned(other, "Unknown attribute inside the macro"));
346 }
347 }
348 }
349
350 config.build()
351}
352
353type AttributeArgs = syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>;
354
355#[proc_macro_attribute]
489pub fn test(args: TokenStream, item: TokenStream) -> TokenStream {
490 let tokens = entry_test(args, item.clone());
491
492 match tokens {
493 Ok(x) => x,
494 Err(e) => token_stream_with_error(item, e),
495 }
496}
497
498fn entry_test(args: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
500 let input = parse_item_fn(item)?;
501
502 let parsed_args = AttributeArgs::parse_terminated.parse(args)?;
504 let config = build_config(parsed_args, true)?;
505
506 let item = build_test_fn(input, config)?;
507 Ok(item)
508}
509
510fn build_test_fn(mut item_fn: ItemFn, config: FinalConfig) -> Result<TokenStream, syn::Error> {
513 item_fn.sig.asyncness = None;
514
515 let fn_name = item_fn.sig.ident.to_string();
516
517 let (last_stmt_start_span, last_stmt_end_span) = {
518 let mut last_stmt = item_fn.block.stmts.last().map(ToTokens::into_token_stream).unwrap_or_default().into_iter();
519 let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
524 let end = last_stmt.last().map_or(start, |t| t.span());
525 (start, end)
526 };
527
528 let test_attr = quote! { #[::core::prelude::v1::test] };
529
530 let rt = build_runtime(last_stmt_start_span, &config)?;
531
532 let init = if let Some(init) = config.init {
533 let init_str = format!("let _g = {};", init.0);
534 let init_tokens = str_to_p2tokens(&init_str, init.1)?;
535
536 quote! { #init_tokens }
537 } else {
538 quote! {}
539 };
540
541 let body_tracing_span = if let Some(tspan) = config.tracing_span {
542 let tracing_lib = if let Some(l) = config.tracing_lib {
543 l.0.clone()
544 } else {
545 "".to_string()
546 };
547
548 let level = tspan.0;
549 let add_tracing_span = format!(
550 r#"
551 use {} tracing::Instrument;
552 let body_span = {} tracing::{}_span!("{}");
553 let body = body.instrument(body_span);
554 "#,
555 tracing_lib, tracing_lib, level, fn_name
556 );
557
558 let tracing_span = str_to_p2tokens(&add_tracing_span, tspan.1)?;
559 quote! { #tracing_span }
560 } else {
561 quote! {}
562 };
563
564 let old_body = &item_fn.block;
565 let old_brace = old_body.brace_token;
566 let (tail_return, tail_semicolon) = match old_body.stmts.last() {
567 Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }),
568 Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => {
569 match &item_fn.sig.output {
570 syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => {
571 (quote! {}, quote! { ; }) }
573 syn::ReturnType::Default => (quote! {}, quote! { ; }), syn::ReturnType::Type(..) => (quote! {}, quote! {}), }
576 }
577 _ => (quote! {}, quote! {}),
578 };
579
580 let body = quote_spanned! {last_stmt_end_span=>
583 {
584 #init
585
586 let body = async #old_body;
587
588 #body_tracing_span
589
590 #[allow(unused_mut)]
591 let mut rt = #rt;
592
593 #[allow(clippy::expect_used)]
594 #tail_return rt.block_on(body) #tail_semicolon
595
596 }
597 };
598
599 item_fn.block = syn::parse2(body).expect("parsing failure:::");
600 item_fn.block.brace_token = old_brace;
601
602 let res = quote! {
603 #test_attr
604 #item_fn
605 };
606
607 let x: TokenStream = res.into_token_stream().into();
608 Ok(x)
609}
610
611fn build_runtime(span: Span, config: &FinalConfig) -> Result<TokenStream2, syn::Error> {
614 let rt_builder = {
615 match get_runtime_name() {
616 "tokio" => {
617 let mut rt_builder = quote! { tokio::runtime::Builder };
618
619 rt_builder = match config.flavor {
620 RuntimeFlavor::CurrentThread => quote_spanned! {span=>
621 #rt_builder::new_current_thread()
622 },
623 RuntimeFlavor::Threaded => quote_spanned! {span=>
624 #rt_builder::new_multi_thread()
625 },
626 };
627
628 if let Some(v) = config.worker_threads {
629 rt_builder = quote! { #rt_builder.worker_threads(#v) };
630 }
631
632 if let Some(v) = config.start_paused {
633 rt_builder = quote! { #rt_builder.start_paused(#v) };
634 }
635 rt_builder
636 }
637 "monoio" => {
638 let rt_builder = quote! { monoio::RuntimeBuilder::<monoio::FusionDriver>::new() };
639 rt_builder
640 }
641 _ => unreachable!(),
642 }
643 };
644
645 let rt: TokenStream2 = quote! {
646 #rt_builder
647 .enable_all()
648 .build()
649 .expect("Failed building the Runtime")
650 };
651
652 Ok(rt)
653}
654
655fn parse_item_fn(item: TokenStream) -> Result<ItemFn, syn::Error> {
657 let input = syn::parse::<ItemFn>(item.clone())?;
658
659 if input.sig.asyncness.is_none() {
660 let msg = "the `async` keyword is missing from the function declaration";
661 return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
662 }
663
664 check_dup_test_attr(&input)?;
665
666 Ok(input)
667}
668
669fn check_dup_test_attr(input: &ItemFn) -> Result<(), syn::Error> {
671 let mut attrs = input.attrs.iter();
672 let found = attrs.find(|a| a.path.is_ident("test"));
673 if let Some(attr) = found {
674 return Err(syn::Error::new_spanned(attr, "dup test"));
675 }
676
677 Ok(())
678}
679
680fn str_to_p2tokens(s: &str, span: Span) -> Result<proc_macro2::TokenStream, syn::Error> {
682 let toks = proc_macro2::TokenStream::from_str(s).map_err(|e| syn::Error::new(span, e))?;
683 Ok(toks)
684}
685
686fn token_stream_with_error(mut item: TokenStream, e: syn::Error) -> TokenStream {
687 item.extend(TokenStream::from(e.into_compile_error()));
688 item
689}