use proc_macro2::TokenStream;
use quote::{quote, quote_spanned, ToTokens};
use syn::visit_mut::VisitMut;
use syn::{
punctuated::Punctuated, spanned::Spanned, Block, Expr, ExprAsync, ExprCall, FnArg, Ident, Item,
ItemFn, Pat, PatIdent, Path, ReturnType, Signature, Stmt, Token, Type, TypePath,
};
use crate::MaybeItemFnRef;
pub(crate) fn gen_function<'a, B: ToTokens + 'a>(
input: MaybeItemFnRef<'a, B>,
instrumented_function_name: &str,
self_type: Option<&TypePath>,
) -> proc_macro2::TokenStream {
let MaybeItemFnRef {
attrs,
vis,
sig,
block,
} = input;
let Signature {
output,
inputs: params,
unsafety,
asyncness,
constness,
abi,
ident,
generics:
syn::Generics {
params: gen_params,
where_clause,
..
},
..
} = sig;
let (return_type, return_span) = if let ReturnType::Type(_, return_type) = &output {
(erase_impl_trait(return_type), return_type.span())
} else {
(syn::parse_quote! { () }, ident.span())
};
let fake_return_edge = quote_spanned! {return_span=>
#[allow(unreachable_code, clippy::diverging_sub_expression, clippy::let_unit_value, clippy::empty_loop)]
if false {
let __backtrace_attr_fake_return: #return_type = loop {};
return __backtrace_attr_fake_return;
}
};
let block = quote! {
{
#fake_return_edge
#block
}
};
let body = gen_block(
&block,
params,
asyncness.is_some(),
instrumented_function_name,
self_type,
);
quote!(
#(#attrs) *
#vis #constness #unsafety #asyncness #abi fn #ident<#gen_params>(#params) #output
#where_clause
{
#body
}
)
}
fn gen_block<B: ToTokens>(
block: &B,
_params: &Punctuated<FnArg, Token![,]>,
async_context: bool,
_instrumented_function_name: &str,
_self_type: Option<&TypePath>,
) -> proc_macro2::TokenStream {
if async_context {
quote!(async_backtrace::frame!(async move { #block }).await)
} else {
quote_spanned!(block.span() => #block)
}
}
enum AsyncKind<'a> {
Function(&'a ItemFn),
Async {
async_expr: &'a ExprAsync,
pinned_box: bool,
},
}
pub(crate) struct AsyncInfo<'block> {
source_stmt: &'block Stmt,
kind: AsyncKind<'block>,
self_type: Option<TypePath>,
input: &'block ItemFn,
}
impl<'block> AsyncInfo<'block> {
pub(crate) fn from_fn(input: &'block ItemFn) -> Option<Self> {
if input.sig.asyncness.is_some() {
return None;
}
let block = &input.block;
let inside_funs = block.stmts.iter().filter_map(|stmt| {
if let Stmt::Item(Item::Fn(fun)) = &stmt {
if fun.sig.asyncness.is_some() {
return Some((stmt, fun));
}
}
None
});
let (last_expr_stmt, last_expr) = block.stmts.iter().rev().find_map(|stmt| {
if let Stmt::Expr(expr) = stmt {
Some((stmt, expr))
} else {
None
}
})?;
if let Expr::Async(async_expr) = last_expr {
return Some(AsyncInfo {
source_stmt: last_expr_stmt,
kind: AsyncKind::Async {
async_expr,
pinned_box: false,
},
self_type: None,
input,
});
}
let (outside_func, outside_args) = match last_expr {
Expr::Call(ExprCall { func, args, .. }) => (func, args),
_ => return None,
};
let path = match outside_func.as_ref() {
Expr::Path(path) => &path.path,
_ => return None,
};
if !path_to_string(path).ends_with("Box::pin") {
return None;
}
if outside_args.is_empty() {
return None;
}
if let Expr::Async(async_expr) = &outside_args[0] {
return Some(AsyncInfo {
source_stmt: last_expr_stmt,
kind: AsyncKind::Async {
async_expr,
pinned_box: true,
},
self_type: None,
input,
});
}
let func = match &outside_args[0] {
Expr::Call(ExprCall { func, .. }) => func,
_ => return None,
};
let func_name = match **func {
Expr::Path(ref func_path) => path_to_string(&func_path.path),
_ => return None,
};
let (stmt_func_declaration, func) = inside_funs
.into_iter()
.find(|(_, fun)| fun.sig.ident == func_name)?;
let mut self_type = None;
for arg in &func.sig.inputs {
if let FnArg::Typed(ty) = arg {
if let Pat::Ident(PatIdent { ref ident, .. }) = *ty.pat {
if ident == "_self" {
let mut ty = *ty.ty.clone();
if let Type::Reference(syn::TypeReference { elem, .. }) = ty {
ty = *elem;
}
if let Type::Path(tp) = ty {
self_type = Some(tp);
break;
}
}
}
}
}
Some(AsyncInfo {
source_stmt: stmt_func_declaration,
kind: AsyncKind::Function(func),
self_type,
input,
})
}
pub(crate) fn gen_async(self, instrumented_function_name: &str) -> proc_macro::TokenStream {
let mut out_stmts: Vec<TokenStream> = self
.input
.block
.stmts
.iter()
.map(|stmt| stmt.to_token_stream())
.collect();
if let Some((iter, _stmt)) = self
.input
.block
.stmts
.iter()
.enumerate()
.find(|(_iter, stmt)| *stmt == self.source_stmt)
{
out_stmts[iter] = match self.kind {
AsyncKind::Function(fun) => gen_function(
fun.into(),
instrumented_function_name,
self.self_type.as_ref(),
),
AsyncKind::Async {
async_expr,
pinned_box,
} => {
let instrumented_block = gen_block(
&async_expr.block,
&self.input.sig.inputs,
true,
instrumented_function_name,
None,
);
let async_attrs = &async_expr.attrs;
if pinned_box {
quote! {
Box::pin(#(#async_attrs) * async move { #instrumented_block })
}
} else {
quote! {
#(#async_attrs) * async move { #instrumented_block }
}
}
}
};
}
let vis = &self.input.vis;
let sig = &self.input.sig;
let attrs = &self.input.attrs;
quote!(
#(#attrs) *
#vis #sig {
#(#out_stmts) *
}
)
.into()
}
}
fn path_to_string(path: &Path) -> String {
use std::fmt::Write;
let mut res = String::with_capacity(path.segments.len() * 5);
for i in 0..path.segments.len() {
write!(&mut res, "{}", path.segments[i].ident)
.expect("writing to a String should never fail");
if i < path.segments.len() - 1 {
res.push_str("::");
}
}
res
}
struct IdentAndTypesRenamer<'a> {
types: Vec<(&'a str, TypePath)>,
idents: Vec<(Ident, Ident)>,
}
impl<'a> VisitMut for IdentAndTypesRenamer<'a> {
#[allow(clippy::cmp_owned)]
fn visit_ident_mut(&mut self, id: &mut Ident) {
for (old_ident, new_ident) in &self.idents {
if id.to_string() == old_ident.to_string() {
*id = new_ident.clone();
}
}
}
fn visit_type_mut(&mut self, ty: &mut Type) {
for (type_name, new_type) in &self.types {
if let Type::Path(TypePath { path, .. }) = ty {
if path_to_string(path) == *type_name {
*ty = Type::Path(new_type.clone());
}
}
}
}
}
struct AsyncTraitBlockReplacer<'a> {
block: &'a Block,
patched_block: Block,
}
impl<'a> VisitMut for AsyncTraitBlockReplacer<'a> {
fn visit_block_mut(&mut self, i: &mut Block) {
if i == self.block {
*i = self.patched_block.clone();
}
}
}
struct ImplTraitEraser;
impl VisitMut for ImplTraitEraser {
fn visit_type_mut(&mut self, t: &mut Type) {
if let Type::ImplTrait(..) = t {
*t = syn::TypeInfer {
underscore_token: Token),
}
.into();
} else {
syn::visit_mut::visit_type_mut(self, t);
}
}
}
fn erase_impl_trait(ty: &Type) -> Type {
let mut ty = ty.clone();
ImplTraitEraser.visit_type_mut(&mut ty);
ty
}