1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::{ExprClosure, ItemFn, Path, Token, parse_macro_input};
6
7enum ConfigSource {
8 Path(Path),
9 Closure(ExprClosure),
10}
11
12struct MainArgs {
13 config: ConfigSource,
14}
15
16const MISSING_CONFIG_HELP: &str = "missing required `config` argument, e.g.\n \
17 #[dial9_tokio_telemetry::main(config = my_config_fn)]\n\
18 or with an inline closure:\n \
19 #[dial9_tokio_telemetry::main(config = || Dial9Config::builder().base_path(...).max_file_size(...).max_total_size(...).build().unwrap())]";
20
21const CONFIG_MUST_BE_ZERO_ARG_HELP: &str = "`config` must be a zero-argument function path or a zero-argument closure, e.g.\n \
22 #[dial9_tokio_telemetry::main(config = my_config_fn)]\n\
23 or with an inline closure:\n \
24 #[dial9_tokio_telemetry::main(config = || Dial9Config::builder().base_path(...).max_file_size(...).max_total_size(...).build().unwrap())]";
25impl Parse for MainArgs {
26 fn parse(input: ParseStream) -> syn::Result<Self> {
27 if input.is_empty() {
28 return Err(input.error(MISSING_CONFIG_HELP));
29 }
30 let ident: syn::Ident = input.parse()?;
31 if ident != "config" {
32 return Err(syn::Error::new(ident.span(), MISSING_CONFIG_HELP));
33 }
34 input.parse::<Token![=]>()?;
35
36 let config = if input.peek(Token![|]) || input.peek(Token![move]) {
37 let closure: ExprClosure = input.parse()?;
38 if !closure.inputs.is_empty() {
39 return Err(syn::Error::new_spanned(
40 &closure.inputs,
41 CONFIG_MUST_BE_ZERO_ARG_HELP,
42 ));
43 }
44 ConfigSource::Closure(closure)
45 } else {
46 ConfigSource::Path(input.parse()?)
47 };
48
49 if !input.is_empty() {
50 return Err(input.error(CONFIG_MUST_BE_ZERO_ARG_HELP));
51 }
52 Ok(MainArgs { config })
53 }
54}
55
56fn expand_main(args: MainArgs, input: ItemFn) -> Result<TokenStream2, syn::Error> {
57 if input.sig.asyncness.is_none() {
58 return Err(syn::Error::new_spanned(
59 input.sig.fn_token,
60 "the `async` keyword is missing from the function declaration",
61 ));
62 }
63
64 if !input.sig.inputs.is_empty() {
65 return Err(syn::Error::new_spanned(
66 &input.sig.inputs,
67 "#[dial9_tokio_telemetry::main] does not support function arguments",
68 ));
69 }
70
71 if !input.sig.generics.params.is_empty() {
72 return Err(syn::Error::new_spanned(
73 &input.sig.generics,
74 "#[dial9_tokio_telemetry::main] does not support generics",
75 ));
76 }
77
78 if input.sig.generics.where_clause.is_some() {
79 return Err(syn::Error::new_spanned(
80 &input.sig.generics.where_clause,
81 "#[dial9_tokio_telemetry::main] does not support where clauses",
82 ));
83 }
84
85 let config_call = match &args.config {
86 ConfigSource::Path(p) => quote! { #p() },
87 ConfigSource::Closure(c) => quote! { (#c)() },
88 };
89 let attrs = &input.attrs;
90 let vis = &input.vis;
91 let name = &input.sig.ident;
92 let ret = &input.sig.output;
93 let body_stmts = &input.block.stmts;
94
95 Ok(quote! {
96 #(#attrs)*
97 #vis fn #name() #ret {
98 let __dial9_rt = ::dial9_tokio_telemetry::TracedRuntime::new(#config_call);
99 __dial9_rt.block_on(async move { #(#body_stmts)* })
100 }
101 })
102}
103
104#[proc_macro_attribute]
209pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
210 let args = parse_macro_input!(attr as MainArgs);
211 let input = parse_macro_input!(item as ItemFn);
212
213 match expand_main(args, input) {
214 Ok(tokens) => tokens.into(),
215 Err(err) => err.to_compile_error().into(),
216 }
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222 use quote::quote;
223
224 fn expand(attr: TokenStream2, item: TokenStream2) -> String {
225 let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
226 let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
227 let expanded = expand_main(args, input).expect("expansion failed");
228 let file = syn::parse2(expanded).expect("failed to parse expansion");
229 prettyplease::unparse(&file)
230 }
231
232 #[test]
233 fn expand_basic() {
234 let output = expand(
235 quote! { config = my_config },
236 quote! {
237 async fn main() {
238 do_work().await;
239 }
240 },
241 );
242 insta::assert_snapshot!(output);
243 }
244
245 #[test]
246 fn expand_with_return_type() {
247 let output = expand(
248 quote! { config = my_config },
249 quote! {
250 async fn main() -> Result<(), Box<dyn std::error::Error>> {
251 do_work().await?;
252 Ok(())
253 }
254 },
255 );
256 insta::assert_snapshot!(output);
257 }
258
259 #[test]
260 fn expand_with_attributes() {
261 let output = expand(
262 quote! { config = my_config },
263 quote! {
264 #[allow(unused)]
265 async fn main() {
266 let _ = 42;
267 }
268 },
269 );
270 insta::assert_snapshot!(output);
271 }
272
273 fn expand_err(attr: TokenStream2, item: TokenStream2) -> String {
274 let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
275 let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
276 expand_main(args, input)
277 .expect_err("expected error")
278 .to_string()
279 }
280
281 #[test]
282 fn error_with_arguments() {
283 let msg = expand_err(
284 quote! { config = my_config },
285 quote! { async fn main(port: u16) {} },
286 );
287 assert!(
288 msg.contains("does not support function arguments"),
289 "unexpected error: {msg}"
290 );
291 }
292
293 #[test]
294 fn error_with_generics() {
295 let msg = expand_err(
296 quote! { config = my_config },
297 quote! { async fn main<T>() {} },
298 );
299 assert!(
300 msg.contains("does not support generics"),
301 "unexpected error: {msg}"
302 );
303 }
304
305 fn parse_args_err(attr: TokenStream2) -> String {
306 match syn::parse2::<MainArgs>(attr) {
307 Err(e) => e.to_string(),
308 Ok(_) => panic!("expected parse error"),
309 }
310 }
311
312 #[test]
313 fn error_empty_args() {
314 let msg = parse_args_err(quote! {});
315 assert!(
316 msg.contains("missing required `config`"),
317 "unexpected error: {msg}"
318 );
319 }
320
321 #[test]
322 fn error_wrong_arg_name() {
323 let msg = parse_args_err(quote! { foo = bar });
324 assert!(
325 msg.contains("missing required `config`"),
326 "unexpected error: {msg}"
327 );
328 }
329
330 #[test]
331 fn error_config_with_args() {
332 let msg = parse_args_err(quote! { config = my_config(arg) });
333 assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
334 }
335
336 #[test]
337 fn error_config_trailing_tokens() {
338 let msg = parse_args_err(quote! { config = my_config, extra = stuff });
339 assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
340 }
341
342 #[test]
343 fn expand_with_inline_closure() {
344 let output = expand(
345 quote! { config = || my_config() },
346 quote! {
347 async fn main() {
348 do_work().await;
349 }
350 },
351 );
352 insta::assert_snapshot!(output);
353 }
354
355 #[test]
356 fn expand_with_move_closure() {
357 let output = expand(
358 quote! { config = move || my_config() },
359 quote! {
360 async fn main() {
361 do_work().await;
362 }
363 },
364 );
365 insta::assert_snapshot!(output);
366 }
367
368 #[test]
369 fn error_closure_with_args() {
370 let msg = parse_args_err(quote! { config = |x| my_config() });
371 assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
372 }
373
374 #[test]
375 fn error_not_async() {
376 let args: MainArgs =
377 syn::parse2(quote! { config = my_config }).expect("failed to parse args");
378 let input: ItemFn = syn::parse2(quote! {
379 fn main() {}
380 })
381 .expect("failed to parse fn");
382 let err = expand_main(args, input).expect_err("expected error for non-async fn");
383 let msg = err.to_string();
384 assert!(msg.contains("async"), "error should mention async: {msg}");
385 }
386}