1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
extern crate proc_macro;

use proc_macro::TokenStream;
use quote::quote;
use syn::*;

#[proc_macro_attribute]
pub fn timed(_attrs: TokenStream, item: TokenStream) -> TokenStream {
    if let Ok(mut fun) = parse::<ItemFn>(item.clone()) {
        let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
        fun.block.stmts = new_stmts;
        return quote!(#fun).into();
    }

    if let Ok(mut fun) = parse::<TraitItemMethod>(item.clone()) {
        if let Some(block) = fun.default.as_mut() {
            let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut block.stmts);
            block.stmts = new_stmts;
            return quote!(#fun).into();
        }
    }

    if let Ok(mut fun) = parse::<ImplItemMethod>(item) {
        let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
        fun.block.stmts = new_stmts;
        return quote!(#fun).into();
    }

    panic!("`funtime::timed` only works on functions")
}

fn rewrite_stmts(name: String, stmts: &mut Vec<Stmt>) -> Vec<Stmt> {
    
    fn truncate_stmt(stmt: &Stmt, len: usize) -> String {
        let short =
            format!("{}", quote::ToTokens::to_token_stream(stmt)).chars().collect::<Vec<_>>();

        let short = if short.len() > len {
            let mut short = short[..(len - 3)].into_iter().collect::<String>();
            short.push_str("...");
            short
        } else {
            short.into_iter().collect::<String>()
        };

        short
    }

    let setup: Block = parse_quote! {{
        struct FuntimeTimer {
            start: std::time::Instant,
            name: &'static str,
            buffer: String,
            prev_mark: Option<std::time::Duration>,
        }


        impl Drop for FuntimeTimer {
            fn drop(&mut self) {
                use std::fmt::Write;
                writeln!(&mut self.buffer, "funtime end: `{}` took {:?}", self.name, self.start.elapsed()).unwrap();
                print!("{}", &self.buffer);
            }
        }

        impl FuntimeTimer {
            fn new(name: &'static str) -> Self {
                use std::fmt::Write;
                let mut buffer = String::new();
                writeln!(&mut buffer, "funtime start: `{}`", name).unwrap();
                FuntimeTimer {
                    start: std::time::Instant::now(),
                    name,
                    buffer,
                    prev_mark: None,
                }
            }

            fn mark_elapsed(&mut self, short: &str) {
                use std::fmt::Write;
                let mut elapsed = self.start.elapsed();
                if let Some(prev) = self.prev_mark.replace(elapsed) {
                    elapsed = elapsed - prev;
                }
                writeln!(&mut self.buffer, "  took {:?}: `{}`", elapsed, short).unwrap();
            }
        }

        let mut funtime_timer = FuntimeTimer::new(#name);

    }};

    let mut new_stmts = setup.stmts;

    let last = stmts.pop();

    for stmt in stmts.drain(..) {
        let short = truncate_stmt(&stmt, 40);

        let next_stmt = parse_quote!(funtime_timer.mark_elapsed(#short););

        new_stmts.push(stmt);
        new_stmts.push(next_stmt);
    }

    if let Some(stmt) = last {
        let short = truncate_stmt(&stmt, 40);
        let new_stmt = parse_quote! {
            let funtime_return_val = {
                #stmt
            };
        };

        let next_stmt = parse_quote!(funtime_timer.mark_elapsed(#short););
        let return_stmt = parse_quote!(return funtime_return_val;);

        new_stmts.push(new_stmt);
        new_stmts.push(next_stmt);
        new_stmts.push(return_stmt);
    }

    new_stmts
}