Skip to main content

dial9_macro/
lib.rs

1use proc_macro::TokenStream;
2use proc_macro2::TokenStream as TokenStream2;
3use quote::quote;
4use syn::parse::{Parse, ParseStream};
5use syn::{ExprClosure, ItemFn, Path, Token, parse_macro_input};
6
7enum ConfigSource {
8    Path(Path),
9    Closure(ExprClosure),
10}
11
12struct MainArgs {
13    config: ConfigSource,
14}
15
16const MISSING_CONFIG_HELP: &str = "missing required `config` argument, e.g.\n  \
17                           #[dial9_tokio_telemetry::main(config = my_config_fn)]\n\
18                           or with an inline closure:\n  \
19                           #[dial9_tokio_telemetry::main(config = || Dial9Config::builder().base_path(...).max_file_size(...).max_total_size(...).build().unwrap())]";
20
21const CONFIG_MUST_BE_ZERO_ARG_HELP: &str = "`config` must be a zero-argument function path or a zero-argument closure, e.g.\n  \
22                           #[dial9_tokio_telemetry::main(config = my_config_fn)]\n\
23                           or with an inline closure:\n  \
24                           #[dial9_tokio_telemetry::main(config = || Dial9Config::builder().base_path(...).max_file_size(...).max_total_size(...).build().unwrap())]";
25impl Parse for MainArgs {
26    fn parse(input: ParseStream) -> syn::Result<Self> {
27        if input.is_empty() {
28            return Err(input.error(MISSING_CONFIG_HELP));
29        }
30        let ident: syn::Ident = input.parse()?;
31        if ident != "config" {
32            return Err(syn::Error::new(ident.span(), MISSING_CONFIG_HELP));
33        }
34        input.parse::<Token![=]>()?;
35
36        let config = if input.peek(Token![|]) || input.peek(Token![move]) {
37            let closure: ExprClosure = input.parse()?;
38            if !closure.inputs.is_empty() {
39                return Err(syn::Error::new_spanned(
40                    &closure.inputs,
41                    CONFIG_MUST_BE_ZERO_ARG_HELP,
42                ));
43            }
44            ConfigSource::Closure(closure)
45        } else {
46            ConfigSource::Path(input.parse()?)
47        };
48
49        if !input.is_empty() {
50            return Err(input.error(CONFIG_MUST_BE_ZERO_ARG_HELP));
51        }
52        Ok(MainArgs { config })
53    }
54}
55
56fn expand_main(args: MainArgs, input: ItemFn) -> Result<TokenStream2, syn::Error> {
57    if input.sig.asyncness.is_none() {
58        return Err(syn::Error::new_spanned(
59            input.sig.fn_token,
60            "the `async` keyword is missing from the function declaration",
61        ));
62    }
63
64    if !input.sig.inputs.is_empty() {
65        return Err(syn::Error::new_spanned(
66            &input.sig.inputs,
67            "#[dial9_tokio_telemetry::main] does not support function arguments",
68        ));
69    }
70
71    if !input.sig.generics.params.is_empty() {
72        return Err(syn::Error::new_spanned(
73            &input.sig.generics,
74            "#[dial9_tokio_telemetry::main] does not support generics",
75        ));
76    }
77
78    if input.sig.generics.where_clause.is_some() {
79        return Err(syn::Error::new_spanned(
80            &input.sig.generics.where_clause,
81            "#[dial9_tokio_telemetry::main] does not support where clauses",
82        ));
83    }
84
85    let config_call = match &args.config {
86        ConfigSource::Path(p) => quote! { #p() },
87        ConfigSource::Closure(c) => quote! { (#c)() },
88    };
89    let attrs = &input.attrs;
90    let vis = &input.vis;
91    let name = &input.sig.ident;
92    let ret = &input.sig.output;
93    let body_stmts = &input.block.stmts;
94
95    Ok(quote! {
96        #(#attrs)*
97        #vis fn #name() #ret {
98            let __dial9_rt = ::dial9_tokio_telemetry::TracedRuntime::new(#config_call);
99            __dial9_rt.block_on(async move { #(#body_stmts)* })
100        }
101    })
102}
103
104/// Instrument an async main function with dial9 telemetry.
105///
106/// This macro is a **replacement** for `#[tokio::main]`, not a complement —
107/// do not use both attributes on the same function. It builds the Tokio
108/// runtime internally and wraps the function body in a spawned task so that
109/// poll events are recorded by dial9. Without this, code running directly in
110/// `runtime.block_on(...)` is invisible to the telemetry hooks.
111///
112/// To spawn sub-tasks with wake-event tracking from anywhere inside the
113/// body, call `TelemetryHandle::current()` — the handle is installed on
114/// every runtime-owned thread by `on_thread_start`.
115///
116/// # Arguments
117///
118/// * `config` — a zero-argument function path or a zero-argument closure
119///   returning any value convertible into a `TracedRuntime`. In
120///   practice that means one of:
121///     - `Dial9Config` from `Dial9Config::builder().build()` (strict):
122///       any builder validation or writer-I/O failure surfaces from
123///       `.build()` as a `Dial9ConfigBuilderError`; runtime construction
124///       under the macro panics on tokio-builder or telemetry-core I/O.
125///     - `Dial9Config` from `Dial9Config::builder().build_or_disabled()`
126///       (lenient): the same `Dial9Config` type, but validation and
127///       writer-I/O failures are logged at `error!` and downgraded to a
128///       disabled config that still preserves your `with_tokio`
129///       configurators.
130///     - The deprecated positional `dial9_tokio_telemetry::config::Dial9Config`,
131///       kept compatible via a bridge impl.
132///
133///   Use `.enabled(false)` on the builder to run without telemetry
134///   while keeping your `with_tokio` configurators.
135///
136/// # Examples
137///
138/// Using a named function:
139///
140/// ```rust,ignore
141/// use dial9_tokio_telemetry::{main, Dial9Config, telemetry::TelemetryHandle};
142///
143/// fn my_config() -> Dial9Config {
144///     Dial9Config::builder()
145///         .base_path("/tmp/trace.bin")
146///         .max_file_size(1024 * 1024)
147///         .max_total_size(16 * 1024 * 1024)
148///         .build()
149///         .expect("config build failed")
150/// }
151///
152/// #[dial9_tokio_telemetry::main(config = my_config)]
153/// async fn main() {
154///     let handle = TelemetryHandle::current();
155///     handle
156///         .spawn(async { /* instrumented sub-task */ })
157///         .await
158///         .unwrap();
159/// }
160/// ```
161///
162/// Using an inline closure:
163///
164/// ```rust,ignore
165/// #[dial9_tokio_telemetry::main(config = || {
166///     Dial9Config::builder()
167///         .base_path("/tmp/trace.bin")
168///         .max_file_size(1024 * 1024)
169///         .max_total_size(16 * 1024 * 1024)
170///         .build()
171///         .expect("config build failed")
172/// })]
173/// async fn main() {
174///     /* ... */
175/// }
176/// ```
177///
178/// Lenient (telemetry is best-effort; falls back to a plain tokio
179/// runtime if writer setup fails):
180///
181/// ```rust,ignore
182/// #[dial9_tokio_telemetry::main(config = || {
183///     Dial9Config::builder()
184///         .base_path("/tmp/trace.bin")
185///         .max_file_size(1024 * 1024)
186///         .max_total_size(16 * 1024 * 1024)
187///         .build_or_disabled()
188/// })]
189/// async fn main() {
190///     /* ... */
191/// }
192/// ```
193///
194/// Disabled (no telemetry, plain tokio runtime — useful for toggling
195/// dial9 off via a feature flag or env var without removing the macro):
196///
197/// ```rust,ignore
198/// #[dial9_tokio_telemetry::main(config = || {
199///     Dial9Config::builder()
200///         .enabled(false)
201///         .build()
202///         .expect("config build failed")
203/// })]
204/// async fn main() {
205///     /* ... */
206/// }
207/// ```
208#[proc_macro_attribute]
209pub fn main(attr: TokenStream, item: TokenStream) -> TokenStream {
210    let args = parse_macro_input!(attr as MainArgs);
211    let input = parse_macro_input!(item as ItemFn);
212
213    match expand_main(args, input) {
214        Ok(tokens) => tokens.into(),
215        Err(err) => err.to_compile_error().into(),
216    }
217}
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use quote::quote;
223
224    fn expand(attr: TokenStream2, item: TokenStream2) -> String {
225        let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
226        let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
227        let expanded = expand_main(args, input).expect("expansion failed");
228        let file = syn::parse2(expanded).expect("failed to parse expansion");
229        prettyplease::unparse(&file)
230    }
231
232    #[test]
233    fn expand_basic() {
234        let output = expand(
235            quote! { config = my_config },
236            quote! {
237                async fn main() {
238                    do_work().await;
239                }
240            },
241        );
242        insta::assert_snapshot!(output);
243    }
244
245    #[test]
246    fn expand_with_return_type() {
247        let output = expand(
248            quote! { config = my_config },
249            quote! {
250                async fn main() -> Result<(), Box<dyn std::error::Error>> {
251                    do_work().await?;
252                    Ok(())
253                }
254            },
255        );
256        insta::assert_snapshot!(output);
257    }
258
259    #[test]
260    fn expand_with_attributes() {
261        let output = expand(
262            quote! { config = my_config },
263            quote! {
264                #[allow(unused)]
265                async fn main() {
266                    let _ = 42;
267                }
268            },
269        );
270        insta::assert_snapshot!(output);
271    }
272
273    fn expand_err(attr: TokenStream2, item: TokenStream2) -> String {
274        let args: MainArgs = syn::parse2(attr).expect("failed to parse args");
275        let input: ItemFn = syn::parse2(item).expect("failed to parse fn");
276        expand_main(args, input)
277            .expect_err("expected error")
278            .to_string()
279    }
280
281    #[test]
282    fn error_with_arguments() {
283        let msg = expand_err(
284            quote! { config = my_config },
285            quote! { async fn main(port: u16) {} },
286        );
287        assert!(
288            msg.contains("does not support function arguments"),
289            "unexpected error: {msg}"
290        );
291    }
292
293    #[test]
294    fn error_with_generics() {
295        let msg = expand_err(
296            quote! { config = my_config },
297            quote! { async fn main<T>() {} },
298        );
299        assert!(
300            msg.contains("does not support generics"),
301            "unexpected error: {msg}"
302        );
303    }
304
305    fn parse_args_err(attr: TokenStream2) -> String {
306        match syn::parse2::<MainArgs>(attr) {
307            Err(e) => e.to_string(),
308            Ok(_) => panic!("expected parse error"),
309        }
310    }
311
312    #[test]
313    fn error_empty_args() {
314        let msg = parse_args_err(quote! {});
315        assert!(
316            msg.contains("missing required `config`"),
317            "unexpected error: {msg}"
318        );
319    }
320
321    #[test]
322    fn error_wrong_arg_name() {
323        let msg = parse_args_err(quote! { foo = bar });
324        assert!(
325            msg.contains("missing required `config`"),
326            "unexpected error: {msg}"
327        );
328    }
329
330    #[test]
331    fn error_config_with_args() {
332        let msg = parse_args_err(quote! { config = my_config(arg) });
333        assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
334    }
335
336    #[test]
337    fn error_config_trailing_tokens() {
338        let msg = parse_args_err(quote! { config = my_config, extra = stuff });
339        assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
340    }
341
342    #[test]
343    fn expand_with_inline_closure() {
344        let output = expand(
345            quote! { config = || my_config() },
346            quote! {
347                async fn main() {
348                    do_work().await;
349                }
350            },
351        );
352        insta::assert_snapshot!(output);
353    }
354
355    #[test]
356    fn expand_with_move_closure() {
357        let output = expand(
358            quote! { config = move || my_config() },
359            quote! {
360                async fn main() {
361                    do_work().await;
362                }
363            },
364        );
365        insta::assert_snapshot!(output);
366    }
367
368    #[test]
369    fn error_closure_with_args() {
370        let msg = parse_args_err(quote! { config = |x| my_config() });
371        assert!(msg.contains("zero-argument"), "unexpected error: {msg}");
372    }
373
374    #[test]
375    fn error_not_async() {
376        let args: MainArgs =
377            syn::parse2(quote! { config = my_config }).expect("failed to parse args");
378        let input: ItemFn = syn::parse2(quote! {
379            fn main() {}
380        })
381        .expect("failed to parse fn");
382        let err = expand_main(args, input).expect_err("expected error for non-async fn");
383        let msg = err.to_string();
384        assert!(msg.contains("async"), "error should mention async: {msg}");
385    }
386}