1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
//! This crate implements the macro for `performance_mark` and should not be used directly.

use proc_macro2::{Ident, TokenStream, TokenTree};
use quote::{ToTokens, TokenStreamExt};
use syn::{
    parse2, parse_quote,
    spanned::Spanned,
    visit_mut::{visit_stmt_mut, VisitMut},
    Expr, Item, Stmt,
};

#[doc(hidden)]
pub fn performance_mark(attr: TokenStream, item: TokenStream) -> Result<TokenStream, syn::Error> {
    let mut item = match parse2::<Item>(item.clone()).unwrap() {
        Item::Fn(function) => function,
        _ => {
            return Err(syn::Error::new(
                item.into_token_stream().span(),
                "Unexpected token",
            ))
        }
    };

    let mut asyncness = false;
    let mut log_function = None;

    for token in attr {
        match token {
            TokenTree::Ident(ref ident) => {
                if ident.to_string() == "async" {
                    asyncness = true;
                } else if log_function.is_none() {
                    log_function = Some(ArbitraryFunction(ident.clone()))
                } else {
                    return Err(syn::Error::new(token.span(), "Unexpected token"));
                }
            }
            _ => return Err(syn::Error::new(token.span(), "Unexpected token")),
        }
    }

    let start_stmt: Stmt = parse_quote!(let start = std::time::Instant::now(););

    item.block.stmts.insert(0, start_stmt);

    let function_name = item.sig.ident.to_string();
    let mut end_stmts: Vec<Stmt> = parse_quote! {
        let ctx = performance_mark::LogContext {
            function: #function_name.to_string(),
            duration: std::time::Instant::now().duration_since(start),
        };
    };

    if let Some(log_function) = log_function {
        if asyncness {
            end_stmts.push(parse_quote! {
                #log_function(ctx).await;
            })
        } else {
            end_stmts.push(parse_quote! {
                #log_function(ctx);
            });
        }
    } else {
        end_stmts.push(parse_quote! {
            println!("(performance_mark) {} took {:?}", ctx.function, ctx.duration);
        })
    }

    let mut visitor = InsertBeforeReturnVisitor {
        end_stmts: &end_stmts,
        asyncness,
    };
    visitor.visit_item_fn_mut(&mut item);
    item.block.stmts.extend(end_stmts);

    Ok(item.into_token_stream())
}

struct InsertBeforeReturnVisitor<'a> {
    end_stmts: &'a Vec<Stmt>,
    asyncness: bool,
}

impl<'a> InsertBeforeReturnVisitor<'a> {
    fn construct_expr(&self, return_stmt: &Stmt) -> Expr {
        let stmts = VecStmt(self.end_stmts);

        if self.asyncness {
            Expr::Await(parse_quote! {
                async {
                    #stmts
                    #return_stmt
                }.await
            })
        } else {
            Expr::Block(parse_quote! {
                {
                    #stmts,
                    #return_stmt
                }
            })
        }
    }
}

impl<'a> VisitMut for InsertBeforeReturnVisitor<'a> {
    fn visit_stmt_mut(&mut self, stmt: &mut Stmt) {
        let original_stmt = stmt.clone();

        match stmt {
            Stmt::Expr(Expr::Return(return_expr), _) => {
                return_expr
                    .expr
                    .replace(Box::new(self.construct_expr(&original_stmt)));
            }
            Stmt::Expr(ref mut return_expr, None) => {
                match return_expr {
                    Expr::ForLoop(_) | Expr::If(_) | Expr::Loop(_) | Expr::While(_) => {
                        return visit_stmt_mut(self, stmt);
                    }
                    _ => {}
                }

                *return_expr = self.construct_expr(&original_stmt);
            }
            _ => {}
        }
    }

    fn visit_expr_closure_mut(&mut self, _: &mut syn::ExprClosure) {
        // NO-OP, do not visit the inside of closures
    }
}

struct VecStmt<'a>(&'a Vec<Stmt>);

impl<'a> ToTokens for VecStmt<'a> {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        for stmt in self.0.iter() {
            stmt.to_tokens(tokens);
        }
    }
}

struct ArbitraryFunction(Ident);

impl ToTokens for ArbitraryFunction {
    fn to_tokens(&self, tokens: &mut TokenStream) {
        tokens.append(self.0.clone());
    }
}