autometrics_macros/
lib.rs

1use crate::parse::{AutometricsArgs, Item};
2use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use regex::Regex;
6use std::env;
7use std::str::FromStr;
8use syn::{
9    parse_macro_input, GenericArgument, ImplItem, ItemFn, ItemImpl, PathArguments, Result,
10    ReturnType, Type,
11};
12
13mod parse;
14mod result_labels;
15
16const ADD_BUILD_INFO_LABELS: &str =
17    "* on (instance, job) group_left(version, commit) last_over_time(build_info[1s])";
18
19const DEFAULT_PROMETHEUS_URL: &str = "http://localhost:9090";
20
21#[proc_macro_attribute]
22pub fn autometrics(
23    args: proc_macro::TokenStream,
24    item: proc_macro::TokenStream,
25) -> proc_macro::TokenStream {
26    let args = parse_macro_input!(args as AutometricsArgs);
27
28    let async_trait = check_async_trait(&item);
29    let item = parse_macro_input!(item as Item);
30
31    let result = match item {
32        Item::Function(item) => instrument_function(&args, item, args.struct_name.as_deref()),
33        Item::Impl(item) => instrument_impl_block(&args, item, &async_trait),
34    };
35
36    let output = match result {
37        Ok(output) => output,
38        Err(err) => err.into_compile_error(),
39    };
40
41    output.into()
42}
43
44/// returns the `async_trait` attributes that have to be re-added after our instrumentation magic has been added
45fn check_async_trait(input: &proc_macro::TokenStream) -> String {
46    let regex = Regex::new(r#"#\[[^\]]*async_trait\]"#)
47        .expect("The regex is hardcoded and thus guaranteed to be successfully parseable");
48
49    let original = input.to_string();
50    let attributes: Vec<_> = regex.find_iter(&original).map(|m| m.as_str()).collect();
51
52    attributes.join("\n")
53}
54
55#[proc_macro_derive(ResultLabels, attributes(label))]
56pub fn result_labels(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
57    let input = parse_macro_input!(input as syn::DeriveInput);
58    result_labels::expand(input)
59        .unwrap_or_else(syn::Error::into_compile_error)
60        .into()
61}
62
63/// Add autometrics instrumentation to a single function
64fn instrument_function(
65    args: &AutometricsArgs,
66    item: ItemFn,
67    struct_name: Option<&str>,
68) -> Result<TokenStream> {
69    let sig = item.sig;
70    let block = item.block;
71    let vis = item.vis;
72    let attrs = item.attrs;
73
74    // Methods are identified as Struct::method
75    let function_name = match struct_name {
76        Some(struct_name) => format!("{}::{}", struct_name, sig.ident),
77        None => sig.ident.to_string(),
78    };
79
80    // The PROMETHEUS_URL can be configured by passing the environment variable during build time
81    let prometheus_url =
82        env::var("PROMETHEUS_URL").unwrap_or_else(|_| DEFAULT_PROMETHEUS_URL.to_string());
83
84    // Build the documentation we'll add to the function's RustDocs, unless it is disabled by the environment variable
85    let metrics_docs = if env::var("AUTOMETRICS_DISABLE_DOCS").is_ok() {
86        String::new()
87    } else {
88        create_metrics_docs(&prometheus_url, &function_name, args.track_concurrency)
89    };
90
91    // Type annotation to allow type inference to work on return expressions (such as `.collect()`), as
92    // well as prevent compiler type-inference from selecting the wrong branch in the `spez` macro later.
93    //
94    // Type inference can make the compiler select one of the early cases of `autometrics::result_labels!`
95    // even if the types `T` or `E` do not implement the `GetLabels` trait. That leads to a compilation error
96    // looking like this:
97    // ```
98    // error[E0277]: the trait bound `ApiError: GetLabels` is not satisfied
99    //  --> examples/full-api/src/routes.rs:48:1
100    //   |
101    //48 | #[autometrics(objective = API_SLO)]
102    //   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ the trait `GetLabels` is not implemented for `ApiError`
103    //   |
104    //   = help: the trait `create_user::{closure#0}::Match2` is implemented for `&&&&create_user::{closure#0}::Match<&Result<T, E>>`
105    //note: required for `&&&&create_user::{closure#0}::Match<&Result<Json<User>, ApiError>>` to implement `create_user::{closure#0}::Match2`
106    //  --> examples/full-api/src/routes.rs:48:1
107    //   |
108    //48 | #[autometrics(objective = API_SLO)]
109    //   | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
110    //   = note: this error originates in the macro `$crate::__private::spez` which comes from the expansion of the attribute macro `autometrics` (in Nightly builds, run with -Z macro-backtrace for more info)
111    // ```
112    //
113    // specifying the return type makes the compiler select the (correct) fallback case of `ApiError` not being a
114    // `GetLabels` implementor.
115    let return_type = match sig.output {
116        ReturnType::Default => quote! { : () },
117        ReturnType::Type(_, ref t) => match t.as_ref() {
118            Type::ImplTrait(_) => quote! {},
119            Type::Path(path) => {
120                let mut ts = vec![];
121                let mut first = true;
122
123                for segment in &path.path.segments {
124                    let ident = &segment.ident;
125                    let args = &segment.arguments;
126
127                    // special handling in case the type is angle bracket with a `impl` trait
128                    // in such a case, we would run into the following error
129                    //
130                    // ```
131                    // error[E0562]: `impl Trait` only allowed in function and inherent method return types, not in variable bindings
132                    //   --> src/main.rs:11:28
133                    //    |
134                    // 11 | async fn hello() -> Result<impl ToString, std::io::Error> {
135                    //    |                            ^^^^^^^^^^^^^
136                    // ```
137                    //
138                    // this whole block just re-creates the angle bracketed `<impl ToString, std::io::Error>`
139                    // manually but the trait `impl` replaced with an infer `_`, which fixes this issue
140                    let suffix = match args {
141                        PathArguments::AngleBracketed(brackets) => {
142                            let mut ts = vec![];
143
144                            for args in &brackets.args {
145                                ts.push(match args {
146                                    GenericArgument::Type(Type::ImplTrait(_)) => {
147                                        quote! { _ }
148                                    }
149                                    generic_arg => quote! { #generic_arg },
150                                });
151                            }
152
153                            quote! { ::<#(#ts),*> }
154                        }
155                        _ => quote! {},
156                    };
157
158                    // primitive way to check whenever this is the first iteration or not
159                    // as on the first iteration, we don't want to prepend `::`,
160                    // as types may be local and/or imported and then couldn't be found
161                    if !first {
162                        ts.push(quote! { :: });
163                    } else {
164                        first = false;
165                    }
166
167                    ts.push(quote! { #ident });
168                    ts.push(quote! { #suffix });
169                }
170
171                quote! { : #(#ts)* }
172            }
173            _ => quote! { : #t },
174        },
175    };
176
177    // Track the name and module of the current function as a task-local variable
178    // so that any functions it calls know which function they were called by
179    let caller_info = quote! {
180        use autometrics::__private::{CALLER, CallerInfo};
181        let caller = CallerInfo {
182            caller_function: #function_name,
183            caller_module: module_path!(),
184        };
185    };
186
187    // Wrap the body of the original function, using a slightly different approach based on whether the function is async
188    let call_function = if sig.asyncness.is_some() {
189        quote! {
190            {
191                #caller_info
192                CALLER.scope(caller, async move {
193                    #block
194                }).await
195            }
196        }
197    } else {
198        quote! {
199            {
200                #caller_info
201                CALLER.sync_scope(caller, move || {
202                    #block
203                })
204            }
205        }
206    };
207
208    let objective = if let Some(objective) = &args.objective {
209        quote! { Some(#objective) }
210    } else {
211        quote! { None }
212    };
213
214    let counter_labels = if args.ok_if.is_some() || args.error_if.is_some() {
215        // Apply the predicate to determine whether to consider the result as "ok" or "error"
216        let result_label = if let Some(ok_if) = &args.ok_if {
217            quote! { if #ok_if (&result) { "ok" } else { "error" } }
218        } else if let Some(error_if) = &args.error_if {
219            quote! { if #error_if (&result) { "error" } else { "ok" } }
220        } else {
221            unreachable!()
222        };
223        quote! {
224            {
225                use autometrics::__private::{CALLER, CounterLabels, GetStaticStrFromIntoStaticStr, GetStaticStr};
226                let result_label = #result_label;
227                // If the return type implements Into<&'static str>, attach that as a label
228                let value_type = (&result).__autometrics_static_str();
229                let caller = CALLER.get();
230                CounterLabels::new(
231                    #function_name,
232                    module_path!(),
233                    caller.caller_function,
234                    caller.caller_module,
235                    Some((result_label, value_type)),
236                    #objective,
237                )
238            }
239        }
240    } else {
241        quote! {
242            {
243                use autometrics::__private::{CALLER, CounterLabels, GetLabels};
244                let result_labels = autometrics::get_result_labels_for_value!(&result);
245                let caller = CALLER.get();
246                CounterLabels::new(
247                    #function_name,
248                    module_path!(),
249                    caller.caller_function,
250                    caller.caller_module,
251                    result_labels,
252                    #objective,
253                )
254            }
255        }
256    };
257
258    let gauge_labels = if args.track_concurrency {
259        quote! { {
260            use autometrics::__private::GaugeLabels;
261            Some(&GaugeLabels::new(
262                #function_name,
263                module_path!(),
264            )) }
265        }
266    } else {
267        quote! { None }
268    };
269
270    // This is a little nuts.
271    // In debug mode, we're using the `linkme` crate to collect all the function descriptions into a static slice.
272    // We're then using that to start all the function counters at zero, even before the function is called.
273    let collect_function_descriptions = if cfg!(debug_assertions) {
274        quote! {
275            {
276                use autometrics::__private::{linkme::distributed_slice, FUNCTION_DESCRIPTIONS, FunctionDescription};
277                #[distributed_slice(FUNCTION_DESCRIPTIONS)]
278                // Point the distributed_slice macro to the linkme crate re-exported from autometrics
279                #[linkme(crate = autometrics::__private::linkme)]
280                static FUNCTION_DESCRIPTION: FunctionDescription = FunctionDescription {
281                    name: #function_name,
282                    module: module_path!(),
283                    objective: #objective,
284                };
285            }
286        }
287    } else {
288        quote! {}
289    };
290
291    Ok(quote! {
292        #(#attrs)*
293
294        // Append the metrics documentation to the end of the function's documentation
295        #[doc=#metrics_docs]
296
297        #vis #sig {
298            #collect_function_descriptions
299
300            let __autometrics_tracker = {
301                use autometrics::__private::{AutometricsTracker, BuildInfoLabels, TrackMetrics};
302                AutometricsTracker::set_build_info(&BuildInfoLabels::new(
303                    option_env!("AUTOMETRICS_VERSION").or(option_env!("CARGO_PKG_VERSION")).unwrap_or_default(),
304                    option_env!("AUTOMETRICS_COMMIT").or(option_env!("VERGEN_GIT_SHA")).unwrap_or_default(),
305                    option_env!("AUTOMETRICS_BRANCH").or(option_env!("VERGEN_GIT_BRANCH")).unwrap_or_default(),
306                ));
307                AutometricsTracker::start(#gauge_labels)
308            };
309
310            let result #return_type = #call_function;
311
312            {
313                use autometrics::__private::{HistogramLabels, TrackMetrics};
314                let counter_labels = #counter_labels;
315                let histogram_labels = HistogramLabels::new(
316                    #function_name,
317                     module_path!(),
318                     #objective,
319                );
320                __autometrics_tracker.finish(&counter_labels, &histogram_labels);
321            }
322
323            result
324        }
325    })
326}
327
328/// Add autometrics instrumentation to an entire impl block
329fn instrument_impl_block(
330    args: &AutometricsArgs,
331    mut item: ItemImpl,
332    attributes_to_re_add: &str,
333) -> Result<TokenStream> {
334    let struct_name = Some(item.self_ty.to_token_stream().to_string());
335
336    // Replace all of the method items in place
337    item.items = item
338        .items
339        .into_iter()
340        .map(|item| match item {
341            ImplItem::Fn(mut method) => {
342                // Skip any methods that have the #[skip_autometrics] attribute
343                if method
344                    .attrs
345                    .iter()
346                    .any(|attr| attr.path().is_ident("skip_autometrics"))
347                {
348                    method
349                        .attrs
350                        .retain(|attr| !attr.path().is_ident("skip_autometrics"));
351                    return ImplItem::Fn(method);
352                }
353
354                let item_fn = ItemFn {
355                    attrs: method.attrs,
356                    vis: method.vis,
357                    sig: method.sig,
358                    block: Box::new(method.block),
359                };
360                let tokens = match instrument_function(args, item_fn, struct_name.as_deref()) {
361                    Ok(tokens) => tokens,
362                    Err(err) => err.to_compile_error(),
363                };
364                ImplItem::Verbatim(tokens)
365            }
366            _ => item,
367        })
368        .collect();
369
370    let ts = TokenStream::from_str(attributes_to_re_add)?;
371
372    Ok(quote! {
373        #ts
374        #item
375    })
376}
377
378/// Create Prometheus queries for the generated metric and
379/// package them up into a RustDoc string
380fn create_metrics_docs(prometheus_url: &str, function: &str, track_concurrency: bool) -> String {
381    let request_rate = request_rate_query("function", function);
382    let request_rate_url = make_prometheus_url(
383        prometheus_url,
384        &request_rate,
385        &format!(
386            "Rate of calls to the `{function}` function per second, averaged over 5 minute windows"
387        ),
388    );
389    let callee_request_rate = request_rate_query("caller_function", function);
390    let callee_request_rate_url = make_prometheus_url(prometheus_url, &callee_request_rate, &format!("Rate of calls to functions called by `{function}` per second, averaged over 5 minute windows"));
391
392    let error_ratio = &error_ratio_query("function", function);
393    let error_ratio_url = make_prometheus_url(prometheus_url, error_ratio, &format!("Percentage of calls to the `{function}` function that return errors, averaged over 5 minute windows"));
394    let callee_error_ratio = &error_ratio_query("caller_function", function);
395    let callee_error_ratio_url = make_prometheus_url(prometheus_url, callee_error_ratio, &format!("Percentage of calls to functions called by `{function}` that return errors, averaged over 5 minute windows"));
396
397    let latency = latency_query("function", function);
398    let latency_url = make_prometheus_url(
399        prometheus_url,
400        &latency,
401        &format!("95th and 99th percentile latencies (in seconds) for the `{function}` function"),
402    );
403
404    // Only include the concurrent calls query if the user has enabled it for this function
405    let concurrent_calls_doc = if track_concurrency {
406        let concurrent_calls = concurrent_calls_query("function", function);
407        let concurrent_calls_url = make_prometheus_url(
408            prometheus_url,
409            &concurrent_calls,
410            &format!("Concurrent calls to the `{function}` function"),
411        );
412        format!("\n- [Concurrent Calls]({concurrent_calls_url}")
413    } else {
414        String::new()
415    };
416
417    format!(
418        "\n\n---
419
420## Autometrics
421
422View the live metrics for the `{function}` function:
423- [Request Rate]({request_rate_url})
424- [Error Ratio]({error_ratio_url})
425- [Latency (95th and 99th percentiles)]({latency_url}){concurrent_calls_doc}
426
427Or, dig into the metrics of *functions called by* `{function}`:
428- [Request Rate]({callee_request_rate_url})
429- [Error Ratio]({callee_error_ratio_url})
430"
431    )
432}
433
434fn make_prometheus_url(url: &str, query: &str, comment: &str) -> String {
435    let mut url = url.to_string();
436    let comment_and_query = format!("# {comment}\n\n{query}");
437    let query = utf8_percent_encode(&comment_and_query, NON_ALPHANUMERIC).to_string();
438
439    if !url.ends_with('/') {
440        url.push('/');
441    }
442    url.push_str("graph?g0.expr=");
443    url.push_str(&query);
444    // Go straight to the graph tab
445    url.push_str("&g0.tab=0");
446    url
447}
448
449fn request_rate_query(label_key: &str, label_value: &str) -> String {
450    format!("sum by (function, module, service_name, commit, version) (rate({{__name__=~\"function_calls(_count)?(_total)?\",{label_key}=\"{label_value}\"}}[5m]) {ADD_BUILD_INFO_LABELS})")
451}
452
453fn error_ratio_query(label_key: &str, label_value: &str) -> String {
454    let request_rate = request_rate_query(label_key, label_value);
455    format!("(sum by (function, module, service_name, commit, version) (rate({{__name__=~\"function_calls(_count)?(_total)?\",{label_key}=\"{label_value}\",result=\"error\"}}[5m]) {ADD_BUILD_INFO_LABELS}))
456/
457({request_rate})",)
458}
459
460fn latency_query(label_key: &str, label_value: &str) -> String {
461    let latency = format!(
462        "sum by (le, function, module, service_name, commit, version) (rate({{__name__=~\"function_calls_duration(_seconds)?_bucket\",{label_key}=\"{label_value}\"}}[5m]) {ADD_BUILD_INFO_LABELS})"
463    );
464    format!(
465        "label_replace(histogram_quantile(0.99, {latency}), \"percentile_latency\", \"99\", \"\", \"\")
466or
467label_replace(histogram_quantile(0.95, {latency}), \"percentile_latency\", \"95\", \"\", \"\")"
468    )
469}
470
471fn concurrent_calls_query(label_key: &str, label_value: &str) -> String {
472    format!("sum by (function, module, service_name, commit, version) (function_calls_concurrent{{{label_key}=\"{label_value}\"}} {ADD_BUILD_INFO_LABELS})")
473}