Skip to main content

async_reify_macros/
lib.rs

1//! Attribute proc macros for [`async-reify`](https://docs.rs/async-reify).
2//!
3//! Currently provides [`macro@trace_async`], an attribute that rewrites
4//! every `.await` point in an `async fn` body to record into a shared
5//! `async_reify::Trace` without you having to wrap each await in
6//! `LabeledFuture` by hand.
7//!
8//! You normally do not depend on this crate directly. Enable the `macros`
9//! feature on `async-reify` and the attribute is re-exported as
10//! [`async_reify::trace_async`](https://docs.rs/async-reify/latest/async_reify/attr.trace_async.html).
11//!
12//! # What the macro does
13//!
14//! `#[trace_async(trace = my_trace)]` on a function rewrites every
15//! `.await` inside the body so it is wrapped in a
16//! [`LabeledFuture`](https://docs.rs/async-reify/latest/async_reify/struct.LabeledFuture.html)
17//! that records into the trace handle named by the `trace = IDENT`
18//! argument. Labels are auto-generated as `"<expr> @ <file>:<line>"`, so
19//! every step in the resulting trace points back to the source line that
20//! produced it.
21//!
22//! See the [`async-reify` crate docs](https://docs.rs/async-reify) for
23//! the recording, inspection, and rendering pipeline this feeds into,
24//! and [`docs/phase4-async-reify.md`][phase4] for the design rationale.
25//!
26//! [phase4]: https://github.com/joshburgess/reify-reflect/blob/main/docs/phase4-async-reify.md
27
28#![deny(unsafe_code)]
29
30extern crate proc_macro;
31
32use proc_macro::TokenStream;
33use proc_macro2::TokenStream as TokenStream2;
34use quote::quote;
35use syn::parse::{Parse, ParseStream};
36use syn::visit_mut::VisitMut;
37use syn::{parse_macro_input, parse_quote, Expr, Ident, ItemFn, Token};
38
39/// Parsed `#[trace_async(trace = IDENT)]` arguments.
40struct TraceAsyncArgs {
41    trace_ident: Ident,
42}
43
44impl Parse for TraceAsyncArgs {
45    fn parse(input: ParseStream) -> syn::Result<Self> {
46        // Single arg: `trace = IDENT`
47        let key: Ident = input.parse()?;
48        if key != "trace" {
49            return Err(syn::Error::new(key.span(), "expected `trace = IDENT`"));
50        }
51        let _eq: Token![=] = input.parse()?;
52        let trace_ident: Ident = input.parse()?;
53        if !input.is_empty() {
54            return Err(input.error("unexpected tokens after `trace = IDENT`"));
55        }
56        Ok(TraceAsyncArgs { trace_ident })
57    }
58}
59
60/// Rewrite every `.await` in the visited body so it goes through a
61/// [`async_reify::LabeledFuture`] backed by the named trace handle.
62///
63/// The label is auto-generated as `"<expr> @ <file>:<line>"`.
64struct AwaitRewriter {
65    trace_ident: Ident,
66}
67
68impl VisitMut for AwaitRewriter {
69    fn visit_expr_mut(&mut self, expr: &mut Expr) {
70        // Recurse into children first so nested awaits are rewritten before
71        // we wrap their parent. This is safe because we replace `expr` with
72        // a block whose `.await` is already inside our wrapper expression
73        // and will not be re-visited.
74        syn::visit_mut::visit_expr_mut(self, expr);
75
76        if let Expr::Await(await_expr) = expr {
77            let inner = &*await_expr.base;
78            let label_str = inner_to_label(inner);
79            let trace = &self.trace_ident;
80            let replacement: Expr = parse_quote! {
81                {
82                    let __label = format!(
83                        "{} @ {}:{}",
84                        #label_str,
85                        file!(),
86                        line!(),
87                    );
88                    ::async_reify::LabeledFuture::new(#inner, &__label, #trace.clone()).await
89                }
90            };
91            *expr = replacement;
92        }
93    }
94
95    // Don't descend into nested closures or item definitions; their `.await`
96    // (if any, e.g. inside a nested `async fn`) is in a different async
97    // context and uses its own trace.
98    fn visit_expr_closure_mut(&mut self, _: &mut syn::ExprClosure) {}
99    fn visit_item_mut(&mut self, _: &mut syn::Item) {}
100}
101
102fn inner_to_label(expr: &Expr) -> String {
103    let s = quote!(#expr).to_string();
104    // Collapse runs of whitespace from token-stream pretty-printing.
105    let mut out = String::with_capacity(s.len());
106    let mut prev_space = false;
107    for ch in s.chars() {
108        if ch.is_whitespace() {
109            if !prev_space {
110                out.push(' ');
111                prev_space = true;
112            }
113        } else {
114            out.push(ch);
115            prev_space = false;
116        }
117    }
118    out.trim().to_string()
119}
120
121/// Rewrites an async function body so every `.await` records a labeled
122/// `async_reify::PollEvent` into the named shared trace.
123///
124/// The macro takes a single mandatory argument: `trace = IDENT`, where
125/// `IDENT` names a value of type
126/// `std::sync::Arc<std::sync::Mutex<async_reify::Trace>>` that is in
127/// scope inside the function. Typically `IDENT` is a function parameter.
128///
129/// Each `.await` inside the function body is replaced with an
130/// `async_reify::LabeledFuture` that records into the shared trace.
131/// The label is `"<expr> @ <file>:<line>"`. Awaits inside nested
132/// closures or nested item definitions are left alone (they belong
133/// to a different async scope).
134///
135/// # Examples
136///
137/// ```ignore
138/// use std::sync::{Arc, Mutex};
139/// use async_reify::Trace;
140/// use async_reify_macros::trace_async;
141///
142/// #[trace_async(trace = trace)]
143/// async fn workflow(trace: Arc<Mutex<Trace>>) -> i32 {
144///     fetch().await;
145///     compute().await;
146///     42
147/// }
148/// ```
149#[proc_macro_attribute]
150pub fn trace_async(attr: TokenStream, item: TokenStream) -> TokenStream {
151    let args = parse_macro_input!(attr as TraceAsyncArgs);
152    let mut func = parse_macro_input!(item as ItemFn);
153
154    if func.sig.asyncness.is_none() {
155        return syn::Error::new_spanned(&func.sig, "#[trace_async] requires an `async fn`")
156            .to_compile_error()
157            .into();
158    }
159
160    let mut rewriter = AwaitRewriter {
161        trace_ident: args.trace_ident,
162    };
163    rewriter.visit_block_mut(&mut func.block);
164
165    let tokens: TokenStream2 = quote! { #func };
166    tokens.into()
167}