1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::{ItemFn, Path, Token, parse_macro_input};
6
7struct MainArgs {
8 config: Path,
9}
10
11const MISSING_CONFIG_HELP: &str = "missing required `config = <fn>` argument, \
12 e.g. #[dial9_tokio_telemetry::main(config = my_config)]";
13
14const CONFIG_MUST_BE_ZERO_ARG_HELP: &str = "`config` must be a path to a zero-argument function, \
15 e.g. #[dial9_tokio_telemetry::main(config = my_config)]";
16impl Parse for MainArgs {
17 fn parse(input: ParseStream) -> syn::Result<Self> {
18 if input.is_empty() {
19 return Err(input.error(MISSING_CONFIG_HELP));
20 }
21 let ident: syn::Ident = input.parse()?;
22 if ident != "config" {
23 return Err(syn::Error::new(ident.span(), MISSING_CONFIG_HELP));
24 }
25 input.parse::<Token![=]>()?;
26 let config: Path = input.parse()?;
27 if !input.is_empty() {
28 return Err(input.error(CONFIG_MUST_BE_ZERO_ARG_HELP));
29 }
30 Ok(MainArgs { config })
31 }
32}
33
34fn expand_main(args: MainArgs, input: ItemFn) -> Result<TokenStream2, syn::Error> {
35 if input.sig.asyncness.is_none() {
36 return Err(syn::Error::new_spanned(
37 input.sig.fn_token,
38 "the `async` keyword is missing from the function declaration",
39 ));
40 }
41
42 if !input.sig.inputs.is_empty() {
43 return Err(syn::Error::new_spanned(
44 &input.sig.inputs,
45 "#[dial9_tokio_telemetry::main] does not support function arguments",
46 ));
47 }
48
49 if !input.sig.generics.params.is_empty() {
50 return Err(syn::Error::new_spanned(
51 &input.sig.generics,
52 "#[dial9_tokio_telemetry::main] does not support generics",
53 ));
54 }
55
56 if input.sig.generics.where_clause.is_some() {
57 return Err(syn::Error::new_spanned(
58 &input.sig.generics.where_clause,
59 "#[dial9_tokio_telemetry::main] does not support where clauses",
60 ));
61 }
62
63 let config_fn = &args.config;
64 let attrs = &input.attrs;
65 let vis = &input.vis;
66 let name = &input.sig.ident;
67 let ret = &input.sig.output;
68 let body_stmts = &input.block.stmts;
69
70 Ok(quote! {
71 #(#attrs)*
72 #vis fn #name() #ret {
73 let (__tokio_runtime, __maybe_guard) = #config_fn()
74 .build()
75 .expect("failed to initialize runtime");
76 if let Some(__dial9_guard) = __maybe_guard {
77 let __dial9_handle = __dial9_guard.handle();
78 __tokio_runtime.block_on(async move {
79 match __dial9_handle.spawn(async move { #(#body_stmts)* }).await {
80 Ok(output) => output,
81 Err(err) if err.is_panic() => {
82 ::std::panic::resume_unwind(err.into_panic())
83 }
84 Err(_) => unreachable!("task cannot be cancelled inside block_on"),
85 }
86 })
87 } else {
88 __tokio_runtime.block_on(async move { #(#body_stmts)* })
89 }
90 }
91 })
92}
93
94#[proc_macro_attribute]
132pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
133 let args = parse_macro_input!(attr as MainArgs);
134 let input = parse_macro_input!(item as ItemFn);
135
136 match expand_main(args, input) {
137 Ok(tokens) => tokens.into(),
138 Err(err) => err.to_compile_error().into(),
139 }
140}
141
142#[cfg(test)]
143mod tests {
144 use super::*;
145 use quote::quote;
146
147 fn expand(attr: TokenStream2, item: TokenStream2) -> String {
148 let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
149 let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
150 let expanded = expand_main(args, input).expect("expansion failed");
151 let file = syn::parse2(expanded).expect("failed to parse expansion");
152 prettyplease::unparse(&file)
153 }
154
155 #[test]
156 fn expand_basic() {
157 let output = expand(
158 quote! { config = my_config },
159 quote! {
160 async fn main() {
161 do_work().await;
162 }
163 },
164 );
165 insta::assert_snapshot!(output);
166 }
167
168 #[test]
169 fn expand_with_return_type() {
170 let output = expand(
171 quote! { config = my_config },
172 quote! {
173 async fn main() -> Result<(), Box<dyn std::error::Error>> {
174 do_work().await?;
175 Ok(())
176 }
177 },
178 );
179 insta::assert_snapshot!(output);
180 }
181
182 #[test]
183 fn expand_with_attributes() {
184 let output = expand(
185 quote! { config = my_config },
186 quote! {
187 #[allow(unused)]
188 async fn main() {
189 let _ = 42;
190 }
191 },
192 );
193 insta::assert_snapshot!(output);
194 }
195
196 fn expand_err(attr: TokenStream2, item: TokenStream2) -> String {
197 let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
198 let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
199 expand_main(args, input)
200 .expect_err("expected error")
201 .to_string()
202 }
203
204 #[test]
205 fn error_with_arguments() {
206 let msg = expand_err(
207 quote! { config = my_config },
208 quote! { async fn main(port: u16) {} },
209 );
210 assert!(
211 msg.contains("does not support function arguments"),
212 "unexpected error: {msg}"
213 );
214 }
215
216 #[test]
217 fn error_with_generics() {
218 let msg = expand_err(
219 quote! { config = my_config },
220 quote! { async fn main<T>() {} },
221 );
222 assert!(
223 msg.contains("does not support generics"),
224 "unexpected error: {msg}"
225 );
226 }
227
228 fn parse_args_err(attr: TokenStream2) -> String {
229 match syn::parse2::<MainArgs>(attr) {
230 Err(e) => e.to_string(),
231 Ok(_) => panic!("expected parse error"),
232 }
233 }
234
235 #[test]
236 fn error_empty_args() {
237 let msg = parse_args_err(quote! {});
238 assert!(msg.contains("config = <fn>"), "unexpected error: {msg}");
239 }
240
241 #[test]
242 fn error_wrong_arg_name() {
243 let msg = parse_args_err(quote! { foo = bar });
244 assert!(msg.contains("config = <fn>"), "unexpected error: {msg}");
245 }
246
247 #[test]
248 fn error_config_with_args() {
249 let msg = parse_args_err(quote! { config = my_config(arg) });
250 assert!(
251 msg.contains("zero-argument function"),
252 "unexpected error: {msg}"
253 );
254 }
255
256 #[test]
257 fn error_config_trailing_tokens() {
258 let msg = parse_args_err(quote! { config = my_config, extra = stuff });
259 assert!(
260 msg.contains("zero-argument function"),
261 "unexpected error: {msg}"
262 );
263 }
264
265 #[test]
266 fn error_not_async() {
267 let args: MainArgs =
268 syn::parse2(quote! { config = my_config }).expect("failed to parse args");
269 let input: ItemFn = syn::parse2(quote! {
270 fn main() {}
271 })
272 .expect("failed to parse fn");
273 let err = expand_main(args, input).expect_err("expected error for non-async fn");
274 let msg = err.to_string();
275 assert!(msg.contains("async"), "error should mention async: {msg}");
276 }
277}