1use crate::parse::{AutometricsArgs, Item};
2use percent_encoding::{utf8_percent_encode, NON_ALPHANUMERIC};
3use proc_macro2::TokenStream;
4use quote::{quote, ToTokens};
5use regex::Regex;
6use std::env;
7use std::str::FromStr;
8use syn::{
9 parse_macro_input, GenericArgument, ImplItem, ItemFn, ItemImpl, PathArguments, Result,
10 ReturnType, Type,
11};
12
13mod parse;
14mod result_labels;
15
16const ADD_BUILD_INFO_LABELS: &str =
17 "* on (instance, job) group_left(version, commit) last_over_time(build_info[1s])";
18
19const DEFAULT_PROMETHEUS_URL: &str = "http://localhost:9090";
20
21#[proc_macro_attribute]
22pub fn autometrics(
23 args: proc_macro::TokenStream,
24 item: proc_macro::TokenStream,
25) -> proc_macro::TokenStream {
26 let args = parse_macro_input!(args as AutometricsArgs);
27
28 let async_trait = check_async_trait(&item);
29 let item = parse_macro_input!(item as Item);
30
31 let result = match item {
32 Item::Function(item) => instrument_function(&args, item, args.struct_name.as_deref()),
33 Item::Impl(item) => instrument_impl_block(&args, item, &async_trait),
34 };
35
36 let output = match result {
37 Ok(output) => output,
38 Err(err) => err.into_compile_error(),
39 };
40
41 output.into()
42}
43
44fn check_async_trait(input: &proc_macro::TokenStream) -> String {
46 let regex = Regex::new(r#"#\[[^\]]*async_trait\]"#)
47 .expect("The regex is hardcoded and thus guaranteed to be successfully parseable");
48
49 let original = input.to_string();
50 let attributes: Vec<_> = regex.find_iter(&original).map(|m| m.as_str()).collect();
51
52 attributes.join("\n")
53}
54
55#[proc_macro_derive(ResultLabels, attributes(label))]
56pub fn result_labels(input: proc_macro::TokenStream) -> proc_macro::TokenStream {
57 let input = parse_macro_input!(input as syn::DeriveInput);
58 result_labels::expand(input)
59 .unwrap_or_else(syn::Error::into_compile_error)
60 .into()
61}
62
63fn instrument_function(
65 args: &AutometricsArgs,
66 item: ItemFn,
67 struct_name: Option<&str>,
68) -> Result<TokenStream> {
69 let sig = item.sig;
70 let block = item.block;
71 let vis = item.vis;
72 let attrs = item.attrs;
73
74 let function_name = match struct_name {
76 Some(struct_name) => format!("{}::{}", struct_name, sig.ident),
77 None => sig.ident.to_string(),
78 };
79
80 let prometheus_url =
82 env::var("PROMETHEUS_URL").unwrap_or_else(|_| DEFAULT_PROMETHEUS_URL.to_string());
83
84 let metrics_docs = if env::var("AUTOMETRICS_DISABLE_DOCS").is_ok() {
86 String::new()
87 } else {
88 create_metrics_docs(&prometheus_url, &function_name, args.track_concurrency)
89 };
90
91 let return_type = match sig.output {
116 ReturnType::Default => quote! { : () },
117 ReturnType::Type(_, ref t) => match t.as_ref() {
118 Type::ImplTrait(_) => quote! {},
119 Type::Path(path) => {
120 let mut ts = vec![];
121 let mut first = true;
122
123 for segment in &path.path.segments {
124 let ident = &segment.ident;
125 let args = &segment.arguments;
126
127 let suffix = match args {
141 PathArguments::AngleBracketed(brackets) => {
142 let mut ts = vec![];
143
144 for args in &brackets.args {
145 ts.push(match args {
146 GenericArgument::Type(Type::ImplTrait(_)) => {
147 quote! { _ }
148 }
149 generic_arg => quote! { #generic_arg },
150 });
151 }
152
153 quote! { ::<#(#ts),*> }
154 }
155 _ => quote! {},
156 };
157
158 if !first {
162 ts.push(quote! { :: });
163 } else {
164 first = false;
165 }
166
167 ts.push(quote! { #ident });
168 ts.push(quote! { #suffix });
169 }
170
171 quote! { : #(#ts)* }
172 }
173 _ => quote! { : #t },
174 },
175 };
176
177 let caller_info = quote! {
180 use autometrics::__private::{CALLER, CallerInfo};
181 let caller = CallerInfo {
182 caller_function: #function_name,
183 caller_module: module_path!(),
184 };
185 };
186
187 let call_function = if sig.asyncness.is_some() {
189 quote! {
190 {
191 #caller_info
192 CALLER.scope(caller, async move {
193 #block
194 }).await
195 }
196 }
197 } else {
198 quote! {
199 {
200 #caller_info
201 CALLER.sync_scope(caller, move || {
202 #block
203 })
204 }
205 }
206 };
207
208 let objective = if let Some(objective) = &args.objective {
209 quote! { Some(#objective) }
210 } else {
211 quote! { None }
212 };
213
214 let counter_labels = if args.ok_if.is_some() || args.error_if.is_some() {
215 let result_label = if let Some(ok_if) = &args.ok_if {
217 quote! { if #ok_if (&result) { "ok" } else { "error" } }
218 } else if let Some(error_if) = &args.error_if {
219 quote! { if #error_if (&result) { "error" } else { "ok" } }
220 } else {
221 unreachable!()
222 };
223 quote! {
224 {
225 use autometrics::__private::{CALLER, CounterLabels, GetStaticStrFromIntoStaticStr, GetStaticStr};
226 let result_label = #result_label;
227 let value_type = (&result).__autometrics_static_str();
229 let caller = CALLER.get();
230 CounterLabels::new(
231 #function_name,
232 module_path!(),
233 caller.caller_function,
234 caller.caller_module,
235 Some((result_label, value_type)),
236 #objective,
237 )
238 }
239 }
240 } else {
241 quote! {
242 {
243 use autometrics::__private::{CALLER, CounterLabels, GetLabels};
244 let result_labels = autometrics::get_result_labels_for_value!(&result);
245 let caller = CALLER.get();
246 CounterLabels::new(
247 #function_name,
248 module_path!(),
249 caller.caller_function,
250 caller.caller_module,
251 result_labels,
252 #objective,
253 )
254 }
255 }
256 };
257
258 let gauge_labels = if args.track_concurrency {
259 quote! { {
260 use autometrics::__private::GaugeLabels;
261 Some(&GaugeLabels::new(
262 #function_name,
263 module_path!(),
264 )) }
265 }
266 } else {
267 quote! { None }
268 };
269
270 let collect_function_descriptions = if cfg!(debug_assertions) {
274 quote! {
275 {
276 use autometrics::__private::{linkme::distributed_slice, FUNCTION_DESCRIPTIONS, FunctionDescription};
277 #[distributed_slice(FUNCTION_DESCRIPTIONS)]
278 #[linkme(crate = autometrics::__private::linkme)]
280 static FUNCTION_DESCRIPTION: FunctionDescription = FunctionDescription {
281 name: #function_name,
282 module: module_path!(),
283 objective: #objective,
284 };
285 }
286 }
287 } else {
288 quote! {}
289 };
290
291 Ok(quote! {
292 #(#attrs)*
293
294 #[doc=#metrics_docs]
296
297 #vis #sig {
298 #collect_function_descriptions
299
300 let __autometrics_tracker = {
301 use autometrics::__private::{AutometricsTracker, BuildInfoLabels, TrackMetrics};
302 AutometricsTracker::set_build_info(&BuildInfoLabels::new(
303 option_env!("AUTOMETRICS_VERSION").or(option_env!("CARGO_PKG_VERSION")).unwrap_or_default(),
304 option_env!("AUTOMETRICS_COMMIT").or(option_env!("VERGEN_GIT_SHA")).unwrap_or_default(),
305 option_env!("AUTOMETRICS_BRANCH").or(option_env!("VERGEN_GIT_BRANCH")).unwrap_or_default(),
306 ));
307 AutometricsTracker::start(#gauge_labels)
308 };
309
310 let result #return_type = #call_function;
311
312 {
313 use autometrics::__private::{HistogramLabels, TrackMetrics};
314 let counter_labels = #counter_labels;
315 let histogram_labels = HistogramLabels::new(
316 #function_name,
317 module_path!(),
318 #objective,
319 );
320 __autometrics_tracker.finish(&counter_labels, &histogram_labels);
321 }
322
323 result
324 }
325 })
326}
327
328fn instrument_impl_block(
330 args: &AutometricsArgs,
331 mut item: ItemImpl,
332 attributes_to_re_add: &str,
333) -> Result<TokenStream> {
334 let struct_name = Some(item.self_ty.to_token_stream().to_string());
335
336 item.items = item
338 .items
339 .into_iter()
340 .map(|item| match item {
341 ImplItem::Fn(mut method) => {
342 if method
344 .attrs
345 .iter()
346 .any(|attr| attr.path().is_ident("skip_autometrics"))
347 {
348 method
349 .attrs
350 .retain(|attr| !attr.path().is_ident("skip_autometrics"));
351 return ImplItem::Fn(method);
352 }
353
354 let item_fn = ItemFn {
355 attrs: method.attrs,
356 vis: method.vis,
357 sig: method.sig,
358 block: Box::new(method.block),
359 };
360 let tokens = match instrument_function(args, item_fn, struct_name.as_deref()) {
361 Ok(tokens) => tokens,
362 Err(err) => err.to_compile_error(),
363 };
364 ImplItem::Verbatim(tokens)
365 }
366 _ => item,
367 })
368 .collect();
369
370 let ts = TokenStream::from_str(attributes_to_re_add)?;
371
372 Ok(quote! {
373 #ts
374 #item
375 })
376}
377
378fn create_metrics_docs(prometheus_url: &str, function: &str, track_concurrency: bool) -> String {
381 let request_rate = request_rate_query("function", function);
382 let request_rate_url = make_prometheus_url(
383 prometheus_url,
384 &request_rate,
385 &format!(
386 "Rate of calls to the `{function}` function per second, averaged over 5 minute windows"
387 ),
388 );
389 let callee_request_rate = request_rate_query("caller_function", function);
390 let callee_request_rate_url = make_prometheus_url(prometheus_url, &callee_request_rate, &format!("Rate of calls to functions called by `{function}` per second, averaged over 5 minute windows"));
391
392 let error_ratio = &error_ratio_query("function", function);
393 let error_ratio_url = make_prometheus_url(prometheus_url, error_ratio, &format!("Percentage of calls to the `{function}` function that return errors, averaged over 5 minute windows"));
394 let callee_error_ratio = &error_ratio_query("caller_function", function);
395 let callee_error_ratio_url = make_prometheus_url(prometheus_url, callee_error_ratio, &format!("Percentage of calls to functions called by `{function}` that return errors, averaged over 5 minute windows"));
396
397 let latency = latency_query("function", function);
398 let latency_url = make_prometheus_url(
399 prometheus_url,
400 &latency,
401 &format!("95th and 99th percentile latencies (in seconds) for the `{function}` function"),
402 );
403
404 let concurrent_calls_doc = if track_concurrency {
406 let concurrent_calls = concurrent_calls_query("function", function);
407 let concurrent_calls_url = make_prometheus_url(
408 prometheus_url,
409 &concurrent_calls,
410 &format!("Concurrent calls to the `{function}` function"),
411 );
412 format!("\n- [Concurrent Calls]({concurrent_calls_url}")
413 } else {
414 String::new()
415 };
416
417 format!(
418 "\n\n---
419
420## Autometrics
421
422View the live metrics for the `{function}` function:
423- [Request Rate]({request_rate_url})
424- [Error Ratio]({error_ratio_url})
425- [Latency (95th and 99th percentiles)]({latency_url}){concurrent_calls_doc}
426
427Or, dig into the metrics of *functions called by* `{function}`:
428- [Request Rate]({callee_request_rate_url})
429- [Error Ratio]({callee_error_ratio_url})
430"
431 )
432}
433
434fn make_prometheus_url(url: &str, query: &str, comment: &str) -> String {
435 let mut url = url.to_string();
436 let comment_and_query = format!("# {comment}\n\n{query}");
437 let query = utf8_percent_encode(&comment_and_query, NON_ALPHANUMERIC).to_string();
438
439 if !url.ends_with('/') {
440 url.push('/');
441 }
442 url.push_str("graph?g0.expr=");
443 url.push_str(&query);
444 url.push_str("&g0.tab=0");
446 url
447}
448
449fn request_rate_query(label_key: &str, label_value: &str) -> String {
450 format!("sum by (function, module, service_name, commit, version) (rate({{__name__=~\"function_calls(_count)?(_total)?\",{label_key}=\"{label_value}\"}}[5m]) {ADD_BUILD_INFO_LABELS})")
451}
452
453fn error_ratio_query(label_key: &str, label_value: &str) -> String {
454 let request_rate = request_rate_query(label_key, label_value);
455 format!("(sum by (function, module, service_name, commit, version) (rate({{__name__=~\"function_calls(_count)?(_total)?\",{label_key}=\"{label_value}\",result=\"error\"}}[5m]) {ADD_BUILD_INFO_LABELS}))
456/
457({request_rate})",)
458}
459
460fn latency_query(label_key: &str, label_value: &str) -> String {
461 let latency = format!(
462 "sum by (le, function, module, service_name, commit, version) (rate({{__name__=~\"function_calls_duration(_seconds)?_bucket\",{label_key}=\"{label_value}\"}}[5m]) {ADD_BUILD_INFO_LABELS})"
463 );
464 format!(
465 "label_replace(histogram_quantile(0.99, {latency}), \"percentile_latency\", \"99\", \"\", \"\")
466or
467label_replace(histogram_quantile(0.95, {latency}), \"percentile_latency\", \"95\", \"\", \"\")"
468 )
469}
470
471fn concurrent_calls_query(label_key: &str, label_value: &str) -> String {
472 format!("sum by (function, module, service_name, commit, version) (function_calls_concurrent{{{label_key}=\"{label_value}\"}} {ADD_BUILD_INFO_LABELS})")
473}