async_entry/
lib.rs

1use std::str::FromStr;
2
3use proc_macro::TokenStream;
4use proc_macro2::Span;
5use quote::quote;
6use quote::quote_spanned;
7use quote::ToTokens;
8use syn::__private::TokenStream2;
9use syn::parse::Parser;
10use syn::ItemFn;
11
12fn get_runtime_name() -> &'static str {
13    if cfg!(feature = "tokio") {
14        "tokio"
15    } else if cfg!(feature = "monoio") {
16        "monoio"
17    } else {
18        "tokio"
19    }
20}
21
22#[derive(Debug, Clone, Copy, PartialEq)]
23enum RuntimeFlavor {
24    CurrentThread,
25    Threaded,
26}
27
28impl RuntimeFlavor {
29    fn from_str(s: &str) -> Result<RuntimeFlavor, String> {
30        match s {
31            "current_thread" => Ok(RuntimeFlavor::CurrentThread),
32            "multi_thread" => Ok(RuntimeFlavor::Threaded),
33            _ => Err(format!(
34                "No such runtime flavor `{}`. The runtime flavors are `current_thread` and `multi_thread`.",
35                s
36            )),
37        }
38    }
39}
40
41#[derive(Debug)]
42struct FinalConfig {
43    flavor: RuntimeFlavor,
44    worker_threads: Option<usize>,
45    start_paused: Option<bool>,
46    init: Option<(String, Span)>,
47    tracing_span: Option<(String, Span)>,
48    tracing_lib: Option<(String, Span)>,
49}
50
51struct Configuration {
52    rt_multi_thread_available: bool,
53    default_flavor: RuntimeFlavor,
54    flavor: Option<RuntimeFlavor>,
55    worker_threads: Option<(usize, Span)>,
56    start_paused: Option<(bool, Span)>,
57    is_test: bool,
58    init: Option<(String, Span)>,
59    tracing_span: Option<(String, Span)>,
60
61    /// Import `tracing` and `tracing_future` in another crate or mod, e.g. `tracing_lib::tracing`, instead of using `tracing`.
62    tracing_lib: Option<(String, Span)>,
63}
64
65impl Configuration {
66    fn new(is_test: bool, rt_multi_thread: bool) -> Self {
67        Configuration {
68            rt_multi_thread_available: rt_multi_thread,
69            default_flavor: match is_test {
70                true => RuntimeFlavor::CurrentThread,
71                false => RuntimeFlavor::Threaded,
72            },
73            flavor: None,
74            worker_threads: None,
75            start_paused: None,
76            is_test,
77            init: None,
78            tracing_span: None,
79            tracing_lib: None,
80        }
81    }
82
83    fn set_flavor(&mut self, runtime: syn::Lit, span: Span) -> Result<(), syn::Error> {
84        if self.flavor.is_some() {
85            return Err(syn::Error::new(span, "`flavor` set multiple times."));
86        }
87
88        let runtime_str = parse_string(runtime, span, "flavor")?;
89        let runtime = RuntimeFlavor::from_str(&runtime_str).map_err(|err| syn::Error::new(span, err))?;
90        self.flavor = Some(runtime);
91        Ok(())
92    }
93
94    fn set_init(&mut self, init_fn: syn::Lit, span: Span) -> Result<(), syn::Error> {
95        if self.init.is_some() {
96            return Err(syn::Error::new(span, "`init` set multiple times."));
97        }
98
99        let init_expr = parse_string(init_fn, span, "init")?;
100        self.init = Some((init_expr, span));
101
102        Ok(())
103    }
104
105    fn set_tracing_span(&mut self, level: syn::Lit, span: Span) -> Result<(), syn::Error> {
106        if self.tracing_span.is_some() {
107            return Err(syn::Error::new(span, "`tracing_span` set multiple times."));
108        }
109
110        let tracing_span = parse_string(level, span, "tracing_span")?;
111        self.tracing_span = Some((tracing_span, span));
112
113        Ok(())
114    }
115
116    // TODO: test internal, test-async-entry
117    fn set_tracing_lib(&mut self, level: syn::Lit, span: Span) -> Result<(), syn::Error> {
118        if self.tracing_lib.is_some() {
119            return Err(syn::Error::new(span, "`tracing_lib` set multiple times."));
120        }
121
122        let tracing_lib = parse_string(level, span, "tracing_lib")?;
123        self.tracing_lib = Some((tracing_lib, span));
124
125        Ok(())
126    }
127
128    fn set_worker_threads(&mut self, worker_threads: syn::Lit, span: Span) -> Result<(), syn::Error> {
129        if self.worker_threads.is_some() {
130            return Err(syn::Error::new(span, "`worker_threads` set multiple times."));
131        }
132
133        let worker_threads = parse_int(worker_threads, span, "worker_threads")?;
134        if worker_threads == 0 {
135            self.flavor = Some(RuntimeFlavor::CurrentThread);
136            self.worker_threads = None;
137        } else {
138            self.flavor = Some(RuntimeFlavor::Threaded);
139            self.worker_threads = Some((worker_threads, span));
140        }
141
142        Ok(())
143    }
144
145    fn set_start_paused(&mut self, start_paused: syn::Lit, span: Span) -> Result<(), syn::Error> {
146        if self.start_paused.is_some() {
147            return Err(syn::Error::new(span, "`start_paused` set multiple times."));
148        }
149
150        let start_paused = parse_bool(start_paused, span, "start_paused")?;
151        self.start_paused = Some((start_paused, span));
152        Ok(())
153    }
154
155    fn macro_name(&self) -> &'static str {
156        if self.is_test {
157            match get_runtime_name() {
158                "tokio" => "tokio::test",
159                "monoio" => "monoio::test",
160                _ => unreachable!(),
161            }
162        } else {
163            match get_runtime_name() {
164                "tokio" => "tokio::main",
165                "monoio" => "monoio::main",
166                _ => unreachable!(),
167            }
168        }
169    }
170
171    fn build(&self) -> Result<FinalConfig, syn::Error> {
172        let flavor = self.flavor.unwrap_or(self.default_flavor);
173        use RuntimeFlavor::*;
174
175        let worker_threads = match (flavor, self.worker_threads) {
176            (CurrentThread, Some((_, worker_threads_span))) => {
177                let msg = format!(
178                    "The `worker_threads` option requires the `multi_thread` runtime flavor. Use `#[{}(flavor = \"multi_thread\")]`",
179                    self.macro_name(),
180                );
181                return Err(syn::Error::new(worker_threads_span, msg));
182            }
183            (CurrentThread, None) => None,
184            (Threaded, worker_threads) if self.rt_multi_thread_available => worker_threads.map(|(val, _span)| val),
185            (Threaded, _) => {
186                let msg = if self.flavor.is_none() {
187                    "The default runtime flavor is `multi_thread`, but the `rt-multi-thread` feature is disabled."
188                } else {
189                    "The runtime flavor `multi_thread` requires the `rt-multi-thread` feature."
190                };
191                return Err(syn::Error::new(Span::call_site(), msg));
192            }
193        };
194
195        let start_paused = match (flavor, self.start_paused) {
196            (Threaded, Some((_, start_paused_span))) => {
197                let msg = format!(
198                    "The `start_paused` option requires the `current_thread` runtime flavor. Use `#[{}(flavor = \"current_thread\")]`",
199                    self.macro_name(),
200                );
201                return Err(syn::Error::new(start_paused_span, msg));
202            }
203            (CurrentThread, Some((start_paused, _))) => Some(start_paused),
204            (_, None) => None,
205        };
206
207        Ok(FinalConfig {
208            flavor,
209            worker_threads,
210            start_paused,
211            init: self.init.clone(),
212            tracing_span: self.tracing_span.clone(),
213            tracing_lib: self.tracing_lib.clone(),
214        })
215    }
216}
217
218fn parse_int(int: syn::Lit, span: Span, field: &str) -> Result<usize, syn::Error> {
219    match int {
220        syn::Lit::Int(lit) => match lit.base10_parse::<usize>() {
221            Ok(value) => Ok(value),
222            Err(e) => Err(syn::Error::new(
223                span,
224                format!("Failed to parse value of `{}` as integer: {}", field, e),
225            )),
226        },
227        _ => Err(syn::Error::new(
228            span,
229            format!("Failed to parse value of `{}` as integer.", field),
230        )),
231    }
232}
233
234fn parse_string(int: syn::Lit, span: Span, field: &str) -> Result<String, syn::Error> {
235    match int {
236        syn::Lit::Str(s) => Ok(s.value()),
237        syn::Lit::Verbatim(s) => Ok(s.to_string()),
238        _ => Err(syn::Error::new(
239            span,
240            format!("Failed to parse value of `{}` as string.", field),
241        )),
242    }
243}
244
245fn parse_bool(bool: syn::Lit, span: Span, field: &str) -> Result<bool, syn::Error> {
246    match bool {
247        syn::Lit::Bool(b) => Ok(b.value),
248        _ => Err(syn::Error::new(
249            span,
250            format!("Failed to parse value of `{}` as bool.", field),
251        )),
252    }
253}
254
255fn build_config(args: AttributeArgs, rt_multi_thread: bool) -> Result<FinalConfig, syn::Error> {
256    let mut config = Configuration::new(true, rt_multi_thread);
257    let macro_name = config.macro_name();
258
259    for arg in args {
260        match arg {
261            syn::NestedMeta::Meta(syn::Meta::NameValue(namevalue)) => {
262                let ident = namevalue
263                    .path
264                    .get_ident()
265                    .ok_or_else(|| syn::Error::new_spanned(&namevalue, "Must have specified ident"))?
266                    .to_string()
267                    .to_lowercase();
268                match ident.as_str() {
269                    "worker_threads" => {
270                        config
271                            .set_worker_threads(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
272                    }
273                    "flavor" => {
274                        config.set_flavor(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
275                    }
276                    "start_paused" => {
277                        config.set_start_paused(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
278                    }
279                    "init" => {
280                        config.set_init(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
281                    }
282                    "tracing_span" => {
283                        config.set_tracing_span(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
284                    }
285                    "tracing_lib" => {
286                        config.set_tracing_lib(namevalue.lit.clone(), syn::spanned::Spanned::span(&namevalue.lit))?;
287                    }
288
289                    name => {
290                        let msg = format!(
291                            "Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `init`, `tracing_span`",
292                            name,
293                        );
294                        return Err(syn::Error::new_spanned(namevalue, msg));
295                    }
296                }
297            }
298            syn::NestedMeta::Meta(syn::Meta::Path(path)) => {
299                let name = path
300                    .get_ident()
301                    .ok_or_else(|| syn::Error::new_spanned(&path, "Must have specified ident"))?
302                    .to_string()
303                    .to_lowercase();
304                let msg = match name.as_str() {
305                    "threaded_scheduler" | "multi_thread" => {
306                        format!(
307                            "Set the runtime flavor with #[{}(flavor = \"multi_thread\")].",
308                            macro_name
309                        )
310                    }
311                    "basic_scheduler" | "current_thread" | "single_threaded" => {
312                        format!(
313                            "Set the runtime flavor with #[{}(flavor = \"current_thread\")].",
314                            macro_name
315                        )
316                    }
317                    "flavor" | "worker_threads" | "start_paused" => {
318                        format!("The `{}` attribute requires an argument.", name)
319                    }
320                    "init" => {
321                        format!(
322                            "The `{}` attribute requires an argument in string of the initializing statement to run.",
323                            name
324                        )
325                    }
326                    "tracing_span" => {
327                        format!(
328                            "The `{}` attribute requires an argument of level of the span, e.g. `debug` or `info`.",
329                            name
330                        )
331                    }
332                    "tracing_lib" => {
333                        format!(
334                            "The `{}` attribute requires an argument of level of the span, e.g. \"my_lib::\" or \"::\" or \"\".",
335                            name
336                        )
337                    }
338                    name => {
339                        format!("Unknown attribute {} is specified; expected one of: `flavor`, `worker_threads`, `start_paused`, `init`, `tracing_span`, `tracing_lib`", name)
340                    }
341                };
342                return Err(syn::Error::new_spanned(path, msg));
343            }
344            other => {
345                return Err(syn::Error::new_spanned(other, "Unknown attribute inside the macro"));
346            }
347        }
348    }
349
350    config.build()
351}
352
353type AttributeArgs = syn::punctuated::Punctuated<syn::NestedMeta, syn::Token![,]>;
354
355/// Marks async function to be executed by async runtime, suitable to test environment
356///
357/// It supports:
358/// - [tokio](https://tokio.rs/)
359/// - [monoio](https://github.com/bytedance/monoio)
360///
361/// By default it uses `tokio` runtime. Switch runtime with feature flags:
362/// - `tokio`: tokio runtime;
363/// - `monoio`: monoio runtime;
364///
365/// ## Usage for tokio runtime
366///
367/// ### Multi-thread runtime
368///
369/// ```no_run
370/// #[async_entry::test(flavor = "multi_thread", worker_threads = 1)]
371/// async fn my_test() {
372///     assert!(true);
373/// }
374/// ```
375///
376/// `worker_threads>0` implies `flavor="multi_thread"`.
377/// `worker_threads==0` implies `flavor="current_thread"`.
378///
379/// ### Using default
380///
381/// The default test runtime is single-threaded.
382///
383/// ```no_run
384/// #[async_entry::test]
385/// async fn my_test() {
386///     assert!(true);
387/// }
388/// ```
389///
390/// ### Configure the runtime to start with time paused
391///
392/// ```no_run
393/// #[async_entry::test(start_paused = true)]
394/// async fn my_test() {
395///     assert!(true);
396/// }
397/// ```
398///
399/// Note that `start_paused` requires the `test-util` feature to be enabled.
400///
401/// ### Add initialization statement
402///
403/// ```no_run
404/// #[async_entry::test(init = "init_log!()")]
405/// async fn my_test() {
406///     assert!(true);
407/// }
408/// // Will produce:
409/// //
410/// // fn my_test() {
411/// //
412/// //     let _g = init_log!();  // Add init statement
413/// //
414/// //     let body = async { assert!(true); };
415/// //     let rt = ...
416/// //     rt.block_on(body);
417/// // }
418/// ```
419///
420/// ### Add tracing span over the test fn
421///
422/// ```no_run
423/// #[async_entry::test(tracing_span = "info")]
424/// async fn my_test() {
425///     assert!(true);
426/// }
427/// // Will produce:
428/// //
429/// // fn my_test() {
430/// //     let body = async { assert!(true); };
431/// //
432/// //     use ::tracing::Instrument;                       // Add tracing span
433/// //     let body_span = ::tracing::info_span("my_test"); //
434/// //     let body = body.instrument(body_span);           //
435/// //
436/// //     let rt = ...
437/// //     rt.block_on(body);
438/// // }
439/// ```
440///
441/// ### Use other lib to import `tracing` and `tracing_future`
442///
443/// ```no_run
444/// #[async_entry::test(tracing_span = "info" ,tracing_lib="my_lib::")]
445/// async fn my_test() {
446///     assert!(true);
447/// }
448/// // Will produce:
449/// //
450/// // fn my_test() {
451/// //     let body = async { assert!(true); };
452/// //
453/// //     use my_lib::tracing::Instrument;                         // Add tracing span
454/// //     let body_span = my_lib::tracing::info_span("my_test");   //
455/// //     let body = body.instrument(body_span);                   //
456/// //
457/// //     let rt = ...
458/// //     rt.block_on(body);
459/// // }
460/// ```
461///
462/// ## Usage for monoio runtime
463///
464/// **When using `monoio` runtime with feature flag `monoio` enabled:
465/// `flavor`, `worker_threads` and `start_paused` are ignored**.
466///
467/// It is the same as using `tokio` runtime, except the runtime is `monoio`:
468///
469/// ```no_run
470/// #[async_entry::test()]
471/// async fn my_test() {
472///     assert!(true);
473/// }
474/// // Will produce:
475/// //
476/// // fn my_test() {
477/// //     // ...
478/// //
479/// //     let body = async { assert!(true); };
480/// //     let rt = monoio::RuntimeBuilder::<_>::new()...
481/// //     rt.block_on(body);
482/// // }
483/// ```
484///
485/// ### NOTE:
486///
487/// If you rename the async_entry crate in your dependencies this macro will not work.
488#[proc_macro_attribute]
489pub fn test(args: TokenStream, item: TokenStream) -> TokenStream {
490    let tokens = entry_test(args, item.clone());
491
492    match tokens {
493        Ok(x) => x,
494        Err(e) => token_stream_with_error(item, e),
495    }
496}
497
498/// Entry of async test fn.
499fn entry_test(args: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
500    let input = parse_item_fn(item)?;
501
502    // parse attribute arguments in the parentheses:
503    let parsed_args = AttributeArgs::parse_terminated.parse(args)?;
504    let config = build_config(parsed_args, true)?;
505
506    let item = build_test_fn(input, config)?;
507    Ok(item)
508}
509
510// Copied from tokio-macros:
511// `fn parse_knobs()` in `tokio/tokio-macros/src/entry.rs`.
512fn build_test_fn(mut item_fn: ItemFn, config: FinalConfig) -> Result<TokenStream, syn::Error> {
513    item_fn.sig.asyncness = None;
514
515    let fn_name = item_fn.sig.ident.to_string();
516
517    let (last_stmt_start_span, last_stmt_end_span) = {
518        let mut last_stmt = item_fn.block.stmts.last().map(ToTokens::into_token_stream).unwrap_or_default().into_iter();
519        // `Span` on stable Rust has a limitation that only points to the first
520        // token, not the whole tokens. We can work around this limitation by
521        // using the first/last span of the tokens like
522        // `syn::Error::new_spanned` does.
523        let start = last_stmt.next().map_or_else(Span::call_site, |t| t.span());
524        let end = last_stmt.last().map_or(start, |t| t.span());
525        (start, end)
526    };
527
528    let test_attr = quote! { #[::core::prelude::v1::test] };
529
530    let rt = build_runtime(last_stmt_start_span, &config)?;
531
532    let init = if let Some(init) = config.init {
533        let init_str = format!("let _g = {};", init.0);
534        let init_tokens = str_to_p2tokens(&init_str, init.1)?;
535
536        quote! { #init_tokens }
537    } else {
538        quote! {}
539    };
540
541    let body_tracing_span = if let Some(tspan) = config.tracing_span {
542        let tracing_lib = if let Some(l) = config.tracing_lib {
543            l.0.clone()
544        } else {
545            "".to_string()
546        };
547
548        let level = tspan.0;
549        let add_tracing_span = format!(
550            r#"
551            use {} tracing::Instrument;
552            let body_span = {} tracing::{}_span!("{}");
553            let body = body.instrument(body_span);
554        "#,
555            tracing_lib, tracing_lib, level, fn_name
556        );
557
558        let tracing_span = str_to_p2tokens(&add_tracing_span, tspan.1)?;
559        quote! { #tracing_span }
560    } else {
561        quote! {}
562    };
563
564    let old_body = &item_fn.block;
565    let old_brace = old_body.brace_token;
566    let (tail_return, tail_semicolon) = match old_body.stmts.last() {
567        Some(syn::Stmt::Semi(syn::Expr::Return(_), _)) => (quote! { return }, quote! { ; }),
568        Some(syn::Stmt::Semi(..)) | Some(syn::Stmt::Local(..)) | None => {
569            match &item_fn.sig.output {
570                syn::ReturnType::Type(_, ty) if matches!(&**ty, syn::Type::Tuple(ty) if ty.elems.is_empty()) => {
571                    (quote! {}, quote! { ; }) // unit
572                }
573                syn::ReturnType::Default => (quote! {}, quote! { ; }), // unit
574                syn::ReturnType::Type(..) => (quote! {}, quote! {}),   // ! or another
575            }
576        }
577        _ => (quote! {}, quote! {}),
578    };
579
580    // Assemble new body
581
582    let body = quote_spanned! {last_stmt_end_span=>
583        {
584            #init
585
586            let body = async #old_body;
587
588            #body_tracing_span
589
590            #[allow(unused_mut)]
591            let mut rt = #rt;
592
593            #[allow(clippy::expect_used)]
594            #tail_return rt.block_on(body) #tail_semicolon
595
596        }
597    };
598
599    item_fn.block = syn::parse2(body).expect("parsing failure:::");
600    item_fn.block.brace_token = old_brace;
601
602    let res = quote! {
603        #test_attr
604        #item_fn
605    };
606
607    let x: TokenStream = res.into_token_stream().into();
608    Ok(x)
609}
610
611/// Build a statement that builds a async runtime,
612/// e.g. `let rt = Builder::new_multi_thread().build().expect("");`
613fn build_runtime(span: Span, config: &FinalConfig) -> Result<TokenStream2, syn::Error> {
614    let rt_builder = {
615        match get_runtime_name() {
616            "tokio" => {
617                let mut rt_builder = quote! { tokio::runtime::Builder };
618
619                rt_builder = match config.flavor {
620                    RuntimeFlavor::CurrentThread => quote_spanned! {span=>
621                        #rt_builder::new_current_thread()
622                    },
623                    RuntimeFlavor::Threaded => quote_spanned! {span=>
624                        #rt_builder::new_multi_thread()
625                    },
626                };
627
628                if let Some(v) = config.worker_threads {
629                    rt_builder = quote! { #rt_builder.worker_threads(#v) };
630                }
631
632                if let Some(v) = config.start_paused {
633                    rt_builder = quote! { #rt_builder.start_paused(#v) };
634                }
635                rt_builder
636            }
637            "monoio" => {
638                let rt_builder = quote! { monoio::RuntimeBuilder::<monoio::FusionDriver>::new() };
639                rt_builder
640            }
641            _ => unreachable!(),
642        }
643    };
644
645    let rt: TokenStream2 = quote! {
646        #rt_builder
647            .enable_all()
648            .build()
649            .expect("Failed building the Runtime")
650    };
651
652    Ok(rt)
653}
654
655/// Parse TokenStream of some fn
656fn parse_item_fn(item: TokenStream) -> Result<ItemFn, syn::Error> {
657    let input = syn::parse::<ItemFn>(item.clone())?;
658
659    if input.sig.asyncness.is_none() {
660        let msg = "the `async` keyword is missing from the function declaration";
661        return Err(syn::Error::new_spanned(input.sig.fn_token, msg));
662    }
663
664    check_dup_test_attr(&input)?;
665
666    Ok(input)
667}
668
669/// Check if there is already a `#[test]` attribute for input function
670fn check_dup_test_attr(input: &ItemFn) -> Result<(), syn::Error> {
671    let mut attrs = input.attrs.iter();
672    let found = attrs.find(|a| a.path.is_ident("test"));
673    if let Some(attr) = found {
674        return Err(syn::Error::new_spanned(attr, "dup test"));
675    }
676
677    Ok(())
678}
679
680/// Parse rust source code in str and produce a TokenStream
681fn str_to_p2tokens(s: &str, span: Span) -> Result<proc_macro2::TokenStream, syn::Error> {
682    let toks = proc_macro2::TokenStream::from_str(s).map_err(|e| syn::Error::new(span, e))?;
683    Ok(toks)
684}
685
686fn token_stream_with_error(mut item: TokenStream, e: syn::Error) -> TokenStream {
687    item.extend(TokenStream::from(e.into_compile_error()));
688    item
689}