instrumented_codegen/
lib.rs

1#![recursion_limit = "128"]
2
3//! # Instrumented
4//!
5//! `instrumented` provides an attribute macro that enables instrumentation of
6//! functions for use with Prometheus.
7//!
8//! This crate is largely based on the `log-derive` crate, and inspired by the
9//! `metered` crate. Some parts of the code were based directly on
10//! `log-derive`.
11//!
12//! For details, refer to the [instrumented] crate.
13extern crate proc_macro;
14extern crate syn;
15use darling::FromMeta;
16use proc_macro2::TokenStream;
17use quote::{quote, ToTokens};
18use syn::{
19    parse_macro_input, spanned::Spanned, token, AttributeArgs, Expr, ExprBlock, ExprClosure, Ident,
20    ItemFn, Meta, NestedMeta, Result, ReturnType, Type, TypePath,
21};
22
23struct FormattedAttributes {
24    ok_expr: TokenStream,
25    err_expr: TokenStream,
26    ctx: String,
27}
28
29impl FormattedAttributes {
30    pub fn parse_attributes(
31        attr: &[NestedMeta],
32        fmt_default: &str,
33        ctx_default: &str,
34    ) -> darling::Result<Self> {
35        Options::from_list(attr)
36            .map(|opts| Self::get_ok_err_streams(&opts, fmt_default, ctx_default))
37    }
38
39    fn get_ok_err_streams(att: &Options, fmt_default: &str, ctx_default: &str) -> Self {
40        let ok_log = att.ok_log();
41        let err_log = att.err_log();
42        let fmt = att.fmt().unwrap_or(fmt_default);
43        let ctx = att.ctx().unwrap_or(ctx_default).to_string();
44
45        let ok_expr = match ok_log {
46            Some(loglevel) => {
47                let log_token = get_logger_token(&loglevel);
48                quote! {log::log!(#log_token, #fmt, result);}
49            }
50            None => quote! {()},
51        };
52
53        let err_expr = match err_log {
54            Some(loglevel) => {
55                let log_token = get_logger_token(&loglevel);
56                quote! {log::log!(#log_token, #fmt, err);}
57            }
58            None => quote! {()},
59        };
60        FormattedAttributes {
61            ok_expr,
62            err_expr,
63            ctx,
64        }
65    }
66}
67
68#[derive(Default, FromMeta)]
69#[darling(default)]
70struct NamedOptions {
71    ok: Option<Ident>,
72    err: Option<Ident>,
73    fmt: Option<String>,
74    ctx: Option<String>,
75}
76
77struct Options {
78    /// The log level specified as the first word in the attribute.
79    leading_level: Option<Ident>,
80    named: NamedOptions,
81}
82
83impl Options {
84    pub fn ok_log(&self) -> Option<&Ident> {
85        self.named
86            .ok
87            .as_ref()
88            .or_else(|| self.leading_level.as_ref())
89    }
90
91    pub fn err_log(&self) -> Option<&Ident> {
92        self.named
93            .err
94            .as_ref()
95            .or_else(|| self.leading_level.as_ref())
96    }
97
98    pub fn fmt(&self) -> Option<&str> {
99        self.named.fmt.as_ref().map(String::as_str)
100    }
101
102    pub fn ctx(&self) -> Option<&str> {
103        self.named.ctx.as_ref().map(String::as_str)
104    }
105}
106
107impl FromMeta for Options {
108    fn from_list(items: &[NestedMeta]) -> darling::Result<Self> {
109        if items.is_empty() {
110            return Err(darling::Error::too_few_items(1));
111        }
112
113        let mut leading_level = None;
114
115        if let NestedMeta::Meta(first) = &items[0] {
116            if let Meta::Path(ident) = first {
117                leading_level = Some(ident.segments.first().unwrap().ident.clone());
118            }
119        }
120
121        let named = if leading_level.is_some() {
122            NamedOptions::from_list(&items[1..])?
123        } else {
124            NamedOptions::from_list(items)?
125        };
126
127        Ok(Options {
128            leading_level,
129            named,
130        })
131    }
132}
133
134/// Check if a return type is some form of `Result`. This assumes that all types named `Result`
135/// are in fact results, but is resilient to the possibility of `Result` types being referenced
136/// from specific modules.
137pub(crate) fn is_result_type(ty: &TypePath) -> bool {
138    if let Some(segment) = ty.path.segments.iter().last() {
139        segment.ident == "Result"
140    } else {
141        false
142    }
143}
144
145fn check_if_return_result(f: &ItemFn) -> bool {
146    if let ReturnType::Type(_, t) = &f.sig.output {
147        return match t.as_ref() {
148            Type::Path(path) => is_result_type(path),
149            _ => false,
150        };
151    }
152
153    false
154}
155
156fn get_logger_token(att: &Ident) -> TokenStream {
157    // Capitalize the first letter.
158    let attr_str = att.to_string().to_lowercase();
159    let mut attr_char = attr_str.chars();
160    let attr_str = attr_char.next().unwrap().to_uppercase().to_string() + attr_char.as_str();
161    let att_str = Ident::new(&attr_str, att.span());
162    quote!(log::Level::#att_str)
163}
164
165fn make_closure(original: &ItemFn) -> ExprClosure {
166    let body = Box::new(Expr::Block(ExprBlock {
167        attrs: Default::default(),
168        label: Default::default(),
169        block: *original.block.clone(),
170    }));
171
172    ExprClosure {
173        attrs: Default::default(),
174        asyncness: Default::default(),
175        movability: Default::default(),
176        capture: Some(token::Move {
177            span: original.span(),
178        }),
179        or1_token: Default::default(),
180        inputs: Default::default(),
181        or2_token: Default::default(),
182        output: ReturnType::Default,
183        body,
184    }
185}
186
187fn replace_function_headers(original: ItemFn, new: &mut ItemFn) {
188    let block = new.block.clone();
189    *new = original;
190    new.block = block;
191}
192
193#[allow(unused)]
194fn generate_function(
195    closure: &ExprClosure,
196    expressions: &FormattedAttributes,
197    result: bool,
198    function_name: String,
199    ctx: &str,
200) -> Result<ItemFn> {
201    let FormattedAttributes {
202        ok_expr,
203        err_expr,
204        ctx,
205    } = expressions;
206    let code = if result {
207        quote! {
208            fn temp() {
209                ::instrumented::inc_called_counter_for(#function_name, #ctx);
210                ::instrumented::inc_inflight_for(#function_name, #ctx);
211                let timer = ::instrumented::get_timer_for(#function_name, #ctx);
212                (#closure)()
213                    .map(|result| {
214                        #ok_expr;
215                        ::instrumented::dec_inflight_for(#function_name, #ctx);
216                        result
217                    })
218                    .map_err(|err| {
219                        #err_expr;
220                        ::instrumented::inc_error_counter_for(#function_name, #ctx, format!("{:?}", err));
221                        ::instrumented::dec_inflight_for(#function_name, #ctx);
222                        err
223                    })
224            }
225        }
226    } else {
227        quote! {
228            fn temp() {
229                ::instrumented::inc_called_counter_for(#function_name, #ctx);
230                ::instrumented::inc_inflight_for(#function_name, #ctx);
231                let timer = ::instrumented::get_timer_for(#function_name, #ctx);
232                let result = (#closure)();
233                #ok_expr;
234                ::instrumented::dec_inflight_for(#function_name, #ctx);
235                result
236            }
237        }
238    };
239
240    syn::parse2(code)
241}
242
243/// Instruments a function.
244///
245/// # Optional arguments
246/// * `ctx` - Specify a context label (defaults to `default`)
247/// * `fmt` - Provide a formatting string (defaults to `"() => {:?}`)
248///
249/// # Example
250/// ```rust
251/// extern crate instrumented;
252/// extern crate log;
253/// use instrumented::instrument;
254///
255/// // Logs at debug level with the `special` context using a custom log format.
256/// #[instrument(DEBUG, ctx = "special", fmt = "{}")]
257/// fn my_func() -> String {
258///     use std::{thread, time};
259///     let ten_millis = time::Duration::from_millis(10);
260///     thread::sleep(ten_millis);
261///     format!("slept for {:?} millis", ten_millis)
262/// }
263/// ```
264#[proc_macro_attribute]
265pub fn instrument(
266    attr: proc_macro::TokenStream,
267    item: proc_macro::TokenStream,
268) -> proc_macro::TokenStream {
269    let attr = parse_macro_input!(attr as AttributeArgs);
270    let original_fn: ItemFn = parse_macro_input!(item as ItemFn);
271    let fmt_default = original_fn.sig.ident.to_string() + "() => {:?}";
272    let ctx_default = "default";
273    let parsed_attributes =
274        match FormattedAttributes::parse_attributes(&attr, &fmt_default, &ctx_default) {
275            Ok(val) => val,
276            Err(err) => {
277                return err.write_errors().into();
278            }
279        };
280
281    let closure = make_closure(&original_fn);
282    let is_result = check_if_return_result(&original_fn);
283    let mut new_fn = generate_function(
284        &closure,
285        &parsed_attributes,
286        is_result,
287        original_fn.sig.ident.to_string(),
288        &parsed_attributes.ctx,
289    )
290    .expect("Failed Generating Function");
291    replace_function_headers(original_fn, &mut new_fn);
292    new_fn.into_token_stream().into()
293}
294
295#[cfg(test)]
296mod tests {
297    use syn::parse_quote;
298
299    use super::is_result_type;
300
301    #[test]
302    fn result_type() {
303        assert!(is_result_type(&parse_quote!(Result<T, E>)));
304        assert!(is_result_type(&parse_quote!(std::result::Result<T, E>)));
305        assert!(is_result_type(&parse_quote!(fmt::Result)));
306    }
307}