#![deny(unsafe_code)]
extern crate proc_macro;
use proc_macro::TokenStream;
use proc_macro2::TokenStream as TokenStream2;
use quote::quote;
use syn::parse::{Parse, ParseStream};
use syn::visit_mut::VisitMut;
use syn::{parse_macro_input, parse_quote, Expr, Ident, ItemFn, Token};
struct TraceAsyncArgs {
trace_ident: Ident,
}
impl Parse for TraceAsyncArgs {
fn parse(input: ParseStream) -> syn::Result<Self> {
let key: Ident = input.parse()?;
if key != "trace" {
return Err(syn::Error::new(key.span(), "expected `trace = IDENT`"));
}
let _eq: Token![=] = input.parse()?;
let trace_ident: Ident = input.parse()?;
if !input.is_empty() {
return Err(input.error("unexpected tokens after `trace = IDENT`"));
}
Ok(TraceAsyncArgs { trace_ident })
}
}
struct AwaitRewriter {
trace_ident: Ident,
}
impl VisitMut for AwaitRewriter {
fn visit_expr_mut(&mut self, expr: &mut Expr) {
syn::visit_mut::visit_expr_mut(self, expr);
if let Expr::Await(await_expr) = expr {
let inner = &*await_expr.base;
let label_str = inner_to_label(inner);
let trace = &self.trace_ident;
let replacement: Expr = parse_quote! {
{
let __label = format!(
"{} @ {}:{}",
#label_str,
file!(),
line!(),
);
::async_reify::LabeledFuture::new(#inner, &__label, #trace.clone()).await
}
};
*expr = replacement;
}
}
fn visit_expr_closure_mut(&mut self, _: &mut syn::ExprClosure) {}
fn visit_item_mut(&mut self, _: &mut syn::Item) {}
}
fn inner_to_label(expr: &Expr) -> String {
let s = quote!(#expr).to_string();
let mut out = String::with_capacity(s.len());
let mut prev_space = false;
for ch in s.chars() {
if ch.is_whitespace() {
if !prev_space {
out.push(' ');
prev_space = true;
}
} else {
out.push(ch);
prev_space = false;
}
}
out.trim().to_string()
}
#[proc_macro_attribute]
pub fn trace_async(attr: TokenStream, item: TokenStream) -> TokenStream {
let args = parse_macro_input!(attr as TraceAsyncArgs);
let mut func = parse_macro_input!(item as ItemFn);
if func.sig.asyncness.is_none() {
return syn::Error::new_spanned(&func.sig, "#[trace_async] requires an `async fn`")
.to_compile_error()
.into();
}
let mut rewriter = AwaitRewriter {
trace_ident: args.trace_ident,
};
rewriter.visit_block_mut(&mut func.block);
let tokens: TokenStream2 = quote! { #func };
tokens.into()
}