Skip to main content

dial9_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::{ItemFn, Path, Token, parse_macro_input};
6
7struct MainArgs {
8    config: Path,
9}
10
11const MISSING_CONFIG_HELP: &str = "missing required `config = <fn>` argument, \
12                           e.g. #[dial9_tokio_telemetry::main(config = my_config)]";
13
14const CONFIG_MUST_BE_ZERO_ARG_HELP: &str = "`config` must be a path to a zero-argument function, \
15                           e.g. #[dial9_tokio_telemetry::main(config = my_config)]";
16impl Parse for MainArgs {
17    fn parse(input: ParseStream) -> syn::Result<Self> {
18        if input.is_empty() {
19            return Err(input.error(MISSING_CONFIG_HELP));
20        }
21        let ident: syn::Ident = input.parse()?;
22        if ident != "config" {
23            return Err(syn::Error::new(ident.span(), MISSING_CONFIG_HELP));
24        }
25        input.parse::<Token![=]>()?;
26        let config: Path = input.parse()?;
27        if !input.is_empty() {
28            return Err(input.error(CONFIG_MUST_BE_ZERO_ARG_HELP));
29        }
30        Ok(MainArgs { config })
31    }
32}
33
34fn expand_main(args: MainArgs, input: ItemFn) -> Result<TokenStream2, syn::Error> {
35    if input.sig.asyncness.is_none() {
36        return Err(syn::Error::new_spanned(
37            input.sig.fn_token,
38            "the `async` keyword is missing from the function declaration",
39        ));
40    }
41
42    if !input.sig.inputs.is_empty() {
43        return Err(syn::Error::new_spanned(
44            &input.sig.inputs,
45            "#[dial9_tokio_telemetry::main] does not support function arguments",
46        ));
47    }
48
49    if !input.sig.generics.params.is_empty() {
50        return Err(syn::Error::new_spanned(
51            &input.sig.generics,
52            "#[dial9_tokio_telemetry::main] does not support generics",
53        ));
54    }
55
56    if input.sig.generics.where_clause.is_some() {
57        return Err(syn::Error::new_spanned(
58            &input.sig.generics.where_clause,
59            "#[dial9_tokio_telemetry::main] does not support where clauses",
60        ));
61    }
62
63    let config_fn = &args.config;
64    let attrs = &input.attrs;
65    let vis = &input.vis;
66    let name = &input.sig.ident;
67    let ret = &input.sig.output;
68    let body_stmts = &input.block.stmts;
69
70    Ok(quote! {
71        #(#attrs)*
72        #vis fn #name() #ret {
73            let (__tokio_runtime, __maybe_guard) = #config_fn()
74                .build()
75                .expect("failed to initialize runtime");
76            if let Some(__dial9_guard) = __maybe_guard {
77                let __dial9_handle = __dial9_guard.handle();
78                __tokio_runtime.block_on(async move {
79                    match __dial9_handle.spawn(async move { #(#body_stmts)* }).await {
80                        Ok(output) => output,
81                        Err(err) if err.is_panic() => {
82                            ::std::panic::resume_unwind(err.into_panic())
83                        }
84                        Err(_) => unreachable!("task cannot be cancelled inside block_on"),
85                    }
86                })
87            } else {
88                __tokio_runtime.block_on(async move { #(#body_stmts)* })
89            }
90        }
91    })
92}
93
94/// Instrument an async main function with dial9 telemetry.
95///
96/// This macro is a **replacement** for `#[tokio::main]`, not a complement —
97/// do not use both attributes on the same function. It builds the Tokio
98/// runtime internally and wraps the function body in a spawned task so that
99/// poll events are recorded by dial9. Without this, code running directly in
100/// `runtime.block_on(...)` is invisible to the telemetry hooks.
101///
102/// To spawn sub-tasks with wake-event tracking from anywhere inside the
103/// body, call `TelemetryHandle::current()` — the handle is installed on
104/// every runtime-owned thread by `on_thread_start`.
105///
106/// # Arguments
107///
108/// * `config` — path to a zero-argument function returning [`Dial9Config`].
109///   Build one with [`Dial9ConfigBuilder::new`] (telemetry enabled) or
110///   [`Dial9ConfigBuilder::disabled`] (plain tokio, no telemetry).
111///
112/// # Example
113///
114/// ```rust,ignore
115/// use dial9_tokio_telemetry::{main, config::{Dial9Config, Dial9ConfigBuilder}, telemetry::TelemetryHandle};
116///
117/// fn my_config() -> Dial9Config {
118///     Dial9ConfigBuilder::new("/tmp/trace.bin", 1024 * 1024, 16 * 1024 * 1024)
119///         .build()
120/// }
121///
122/// #[dial9_tokio_telemetry::main(config = my_config)]
123/// async fn main() {
124///     let handle = TelemetryHandle::current();
125///     handle
126///         .spawn(async { /* instrumented sub-task */ })
127///         .await
128///         .unwrap();
129/// }
130/// ```
131#[proc_macro_attribute]
132pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
133    let args = parse_macro_input!(attr as MainArgs);
134    let input = parse_macro_input!(item as ItemFn);
135
136    match expand_main(args, input) {
137        Ok(tokens) => tokens.into(),
138        Err(err) => err.to_compile_error().into(),
139    }
140}
141
142#[cfg(test)]
143mod tests {
144    use super::*;
145    use quote::quote;
146
147    fn expand(attr: TokenStream2, item: TokenStream2) -> String {
148        let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
149        let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
150        let expanded = expand_main(args, input).expect("expansion failed");
151        let file = syn::parse2(expanded).expect("failed to parse expansion");
152        prettyplease::unparse(&file)
153    }
154
155    #[test]
156    fn expand_basic() {
157        let output = expand(
158            quote! { config = my_config },
159            quote! {
160                async fn main() {
161                    do_work().await;
162                }
163            },
164        );
165        insta::assert_snapshot!(output);
166    }
167
168    #[test]
169    fn expand_with_return_type() {
170        let output = expand(
171            quote! { config = my_config },
172            quote! {
173                async fn main() -> Result<(), Box<dyn std::error::Error>> {
174                    do_work().await?;
175                    Ok(())
176                }
177            },
178        );
179        insta::assert_snapshot!(output);
180    }
181
182    #[test]
183    fn expand_with_attributes() {
184        let output = expand(
185            quote! { config = my_config },
186            quote! {
187                #[allow(unused)]
188                async fn main() {
189                    let _ = 42;
190                }
191            },
192        );
193        insta::assert_snapshot!(output);
194    }
195
196    fn expand_err(attr: TokenStream2, item: TokenStream2) -> String {
197        let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
198        let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
199        expand_main(args, input)
200            .expect_err("expected error")
201            .to_string()
202    }
203
204    #[test]
205    fn error_with_arguments() {
206        let msg = expand_err(
207            quote! { config = my_config },
208            quote! { async fn main(port: u16) {} },
209        );
210        assert!(
211            msg.contains("does not support function arguments"),
212            "unexpected error: {msg}"
213        );
214    }
215
216    #[test]
217    fn error_with_generics() {
218        let msg = expand_err(
219            quote! { config = my_config },
220            quote! { async fn main<T>() {} },
221        );
222        assert!(
223            msg.contains("does not support generics"),
224            "unexpected error: {msg}"
225        );
226    }
227
228    fn parse_args_err(attr: TokenStream2) -> String {
229        match syn::parse2::<MainArgs>(attr) {
230            Err(e) => e.to_string(),
231            Ok(_) => panic!("expected parse error"),
232        }
233    }
234
235    #[test]
236    fn error_empty_args() {
237        let msg = parse_args_err(quote! {});
238        assert!(msg.contains("config = <fn>"), "unexpected error: {msg}");
239    }
240
241    #[test]
242    fn error_wrong_arg_name() {
243        let msg = parse_args_err(quote! { foo = bar });
244        assert!(msg.contains("config = <fn>"), "unexpected error: {msg}");
245    }
246
247    #[test]
248    fn error_config_with_args() {
249        let msg = parse_args_err(quote! { config = my_config(arg) });
250        assert!(
251            msg.contains("zero-argument function"),
252            "unexpected error: {msg}"
253        );
254    }
255
256    #[test]
257    fn error_config_trailing_tokens() {
258        let msg = parse_args_err(quote! { config = my_config, extra = stuff });
259        assert!(
260            msg.contains("zero-argument function"),
261            "unexpected error: {msg}"
262        );
263    }
264
265    #[test]
266    fn error_not_async() {
267        let args: MainArgs =
268            syn::parse2(quote! { config = my_config }).expect("failed to parse args");
269        let input: ItemFn = syn::parse2(quote! {
270            fn main() {}
271        })
272        .expect("failed to parse fn");
273        let err = expand_main(args, input).expect_err("expected error for non-async fn");
274        let msg = err.to_string();
275        assert!(msg.contains("async"), "error should mention async: {msg}");
276    }
277}