aleo_std_timed/
lib.rs

1// Copyright (C) 2019-2021 Aleo Systems Inc.
2// This file is part of the aleo-std library.
3
4// The aleo-std library is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, either version 3 of the License, or
7// (at your option) any later version.
8
9// The aleo-std library is distributed in the hope that it will be useful,
10// but WITHOUT ANY WARRANTY; without even the implied warranty of
11// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12// GNU General Public License for more details.
13
14// You should have received a copy of the GNU General Public License
15// along with the aleo-std library. If not, see <https://www.gnu.org/licenses/>.
16
17// With credits to kardeiz/funtime.
18
19extern crate proc_macro;
20
21use proc_macro::TokenStream;
22use quote::quote;
23use syn::*;
24
25#[proc_macro_attribute]
26pub fn timed(_attrs: TokenStream, item: TokenStream) -> TokenStream {
27    if let Ok(mut fun) = parse::<ItemFn>(item.clone()) {
28        let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
29        fun.block.stmts = new_stmts;
30        return quote!(#fun).into();
31    }
32
33    if let Ok(mut fun) = parse::<TraitItemMethod>(item.clone()) {
34        if let Some(block) = fun.default.as_mut() {
35            let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut block.stmts);
36            block.stmts = new_stmts;
37            return quote!(#fun).into();
38        }
39    }
40
41    if let Ok(mut fun) = parse::<ImplItemMethod>(item) {
42        let new_stmts = rewrite_stmts(fun.sig.ident.to_string(), &mut fun.block.stmts);
43        fun.block.stmts = new_stmts;
44        return quote!(#fun).into();
45    }
46
47    panic!("`timed` only works on functions")
48}
49
50#[cfg(feature = "timed")]
51fn rewrite_stmts(name: String, stmts: &mut Vec<Stmt>) -> Vec<Stmt> {
52    /// Truncates the given statement to the specified number of characters.
53    fn truncate(stmt: &Stmt, len: usize) -> String {
54        // Convert the statement to a string.
55        let string = quote::ToTokens::to_token_stream(stmt).to_string().replace("\n", " ");
56        // If the statement is too long, truncate it.
57        match string.chars().count() > len {
58            // Truncate the statement and append "..." to the end.
59            true => string.chars().take(len).chain("...".chars()).collect::<String>(),
60            // Otherwise, return the statement as-is.
61            false => string,
62        }
63    }
64
65    let setup: Block = parse_quote! {{
66        struct Timed {
67            start: std::time::Instant,
68            name: &'static str,
69            buffer: String,
70            prev_mark: Option<std::time::Duration>,
71        }
72
73        impl Timed {
74            fn new(name: &'static str) -> Self {
75                use std::fmt::Write;
76                let mut buffer = String::new();
77                writeln!(&mut buffer, "Start: `{}`", name).unwrap();
78                Timed {
79                    start: std::time::Instant::now(),
80                    name,
81                    buffer,
82                    prev_mark: None,
83                }
84            }
85
86            fn mark_elapsed(&mut self, short: &str) {
87                use std::fmt::Write;
88
89                let mut elapsed = self.start.elapsed();
90                if let Some(prev) = self.prev_mark.replace(elapsed) {
91                    elapsed -= prev;
92                }
93
94                let elapsed = {
95                    let secs = elapsed.as_secs();
96                    let millis = elapsed.subsec_millis();
97                    let micros = elapsed.subsec_micros() % 1000;
98                    let nanos = elapsed.subsec_nanos() % 1000;
99                    if secs != 0 {
100                        format!("{}.{:0>3}s", secs, millis)
101                    } else if millis > 0 {
102                        format!("{}.{:0>3}ms", millis, micros)
103                    } else if micros > 0 {
104                        format!("{}.{:0>3}µs", micros, nanos)
105                    } else {
106                        format!("{}ns", elapsed.subsec_nanos())
107                    }
108                };
109
110                writeln!(&mut self.buffer, "    {:<55} {:->25}", short, elapsed).unwrap();
111            }
112        }
113
114        impl Drop for Timed {
115            fn drop(&mut self) {
116                use std::fmt::Write;
117                writeln!(&mut self.buffer, "End: `{}` took {:?}", self.name, self.start.elapsed()).unwrap();
118                print!("{}", &self.buffer);
119            }
120        }
121
122        let mut timed = Timed::new(#name);
123
124    }};
125
126    const LENGTH: usize = 45;
127
128    let mut new_stmts = setup.stmts;
129
130    let last = stmts.pop();
131
132    for (index, stmt) in stmts.drain(..).enumerate() {
133        let short = truncate(&stmt, LENGTH);
134        let short = format!("L{index}: {short}");
135
136        let next_stmt = parse_quote!(timed.mark_elapsed(#short););
137
138        new_stmts.push(stmt);
139        new_stmts.push(next_stmt);
140    }
141
142    if let Some(stmt) = last {
143        let short = truncate(&stmt, LENGTH);
144
145        let new_stmt = parse_quote! {
146            let return_stmt = { #stmt };
147        };
148        let next_stmt = parse_quote!(timed.mark_elapsed(#short););
149        let return_stmt = parse_quote!(return return_stmt;);
150
151        new_stmts.push(new_stmt);
152        new_stmts.push(next_stmt);
153        new_stmts.push(return_stmt);
154    }
155
156    new_stmts
157}
158
159#[cfg(not(feature = "timed"))]
160fn rewrite_stmts(_name: String, stmts: &mut [Stmt]) -> Vec<Stmt> {
161    stmts.to_vec()
162}