axum_prometheus/lib.rs
1//!A middleware to collect HTTP metrics for Axum applications.
2//!
3//! `axum-prometheus` relies on [`metrics.rs`](https://metrics.rs/) and its ecosystem to collect and export metrics - for instance for Prometheus, `metrics_exporter_prometheus` is used as a backend to interact with Prometheus.
4//!
5//! ## Metrics
6//!
7//! By default three HTTP metrics are tracked
8//! - `axum_http_requests_total` (labels: endpoint, method, status): the total number of HTTP requests handled (counter)
9//! - `axum_http_requests_duration_seconds` (labels: endpoint, method, status): the request duration for all HTTP requests handled (histogram)
10//! - `axum_http_requests_pending` (labels: endpoint, method): the number of currently in-flight requests (gauge)
11//!
12//! This crate also allows to track response body sizes as a histogram — see [`PrometheusMetricLayerBuilder::enable_response_body_size`].
13//!
14//! ### Renaming Metrics
15//!
16//! These metrics can be renamed by specifying environmental variables at compile time:
17//! - `AXUM_HTTP_REQUESTS_TOTAL`
18//! - `AXUM_HTTP_REQUESTS_DURATION_SECONDS`
19//! - `AXUM_HTTP_REQUESTS_PENDING`
20//! - `AXUM_HTTP_RESPONSE_BODY_SIZE` (if body size tracking is enabled)
21//!
22//! These environmental variables can be set in your `.cargo/config.toml` since Cargo 1.56:
23//! ```toml
24//! [env]
25//! AXUM_HTTP_REQUESTS_TOTAL = "my_app_requests_total"
26//! AXUM_HTTP_REQUESTS_DURATION_SECONDS = "my_app_requests_duration_seconds"
27//! AXUM_HTTP_REQUESTS_PENDING = "my_app_requests_pending"
28//! AXUM_HTTP_RESPONSE_BODY_SIZE = "my_app_response_body_size"
29//! ```
30//!
31//! ..or optionally use [`PrometheusMetricLayerBuilder::with_prefix`] function.
32//!
33//! ## Usage
34//!
35//! For more elaborate use-cases, see the builder-example that leverages [`PrometheusMetricLayerBuilder`].
36//!
37//! Add `axum-prometheus` to your `Cargo.toml`.
38//! ```not_rust
39//! [dependencies]
40//! axum-prometheus = "0.9.0"
41//! ```
42//!
43//! Then you instantiate the prometheus middleware:
44//! ```rust,no_run
45//! use std::{net::SocketAddr, time::Duration};
46//! use axum::{routing::get, Router};
47//! use axum_prometheus::PrometheusMetricLayer;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//! let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
52//! let app = Router::new()
53//! .route("/fast", get(|| async {}))
54//! .route(
55//! "/slow",
56//! get(|| async {
57//! tokio::time::sleep(Duration::from_secs(1)).await;
58//! }),
59//! )
60//! .route("/metrics", get(|| async move { metric_handle.render() }))
61//! .layer(prometheus_layer);
62//!
63//! let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
64//! .await
65//! .unwrap();
66//! axum::serve(listener, app).await.unwrap()
67//! }
68//! ```
69//!
70//! Note that the `/metrics` endpoint is not automatically exposed, so you need to add that as a route manually.
71//! Calling the `/metrics` endpoint will expose your metrics:
72//! ```not_rust
73//! axum_http_requests_total{method="GET",endpoint="/metrics",status="200"} 5
74//! axum_http_requests_pending{method="GET",endpoint="/metrics"} 1
75//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.005"} 4
76//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.01"} 4
77//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.025"} 4
78//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.05"} 4
79//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.1"} 4
80//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.25"} 4
81//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.5"} 4
82//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="1"} 4
83//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="2.5"} 4
84//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="5"} 4
85//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="10"} 4
86//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="+Inf"} 4
87//! axum_http_requests_duration_seconds_sum{method="GET",status="200",endpoint="/metrics"} 0.001997171
88//! axum_http_requests_duration_seconds_count{method="GET",status="200",endpoint="/metrics"} 4
89//! ```
90//!
91//! ## Prometheus push gateway feature
92//! This crate currently has no higher level API for the `push-gateway` feature. If you plan to use it, enable the
93//! `push-gateway` feature in `axum-prometheus`, use `BaseMetricLayer`, and setup your recorder manually, similar to
94//! the `base-metric-layer-example`.
95//!
96//! ## Using a different exporter than Prometheus
97//!
98//! This crate may be used with other exporters than Prometheus. First, disable the default features:
99//!
100//! ```toml
101//! axum-prometheus = { version = "0.9.0", default-features = false }
102//! ```
103//!
104//! Then implement the `MakeDefaultHandle` for the provider you'd like to use. For `StatsD`:
105//!
106//! ```rust,ignore
107//! use metrics_exporter_statsd::StatsdBuilder;
108//! use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
109//!
110//! // The custom StatsD exporter struct. It may take fields as well.
111//! struct Recorder { port: u16 }
112//!
113//! // In order to use this with `axum_prometheus`, we must implement `MakeDefaultHandle`.
114//! impl MakeDefaultHandle for Recorder {
115//! // We don't need to return anything meaningful from here (unlike PrometheusHandle)
116//! // Let's just return an empty tuple.
117//! type Out = ();
118//!
119//! fn make_default_handle(self) -> Self::Out {
120//! // The regular setup for StatsD. Notice that `self` is passed in by value.
121//! let recorder = StatsdBuilder::from("127.0.0.1", self.port)
122//! .with_queue_size(5000)
123//! .with_buffer_size(1024)
124//! .build(Some("prefix"))
125//! .expect("Could not create StatsdRecorder");
126//!
127//! metrics::set_boxed_recorder(Box::new(recorder)).unwrap();
128//! }
129//! }
130//!
131//! fn main() {
132//! // Use `GenericMetricLayer` instead of `PrometheusMetricLayer`.
133//! // Generally `GenericMetricLayer::pair_from` is what you're looking for.
134//! // It lets you pass in a concrete initialized `Recorder`.
135//! let (metric_layer, _handle) = GenericMetricLayer::pair_from(Recorder { port: 8125 });
136//! }
137//! ```
138//!
139//! It's also possible to use `GenericMetricLayer::pair`, however it's only callable if the recorder struct implements `Default` as well.
140//!
141//! ```rust,ignore
142//! use metrics_exporter_statsd::StatsdBuilder;
143//! use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
144//!
145//! #[derive(Default)]
146//! struct Recorder { port: u16 }
147//!
148//! impl MakeDefaultHandle for Recorder {
149//! /* .. same as before .. */
150//! }
151//!
152//! fn main() {
153//! // This will internally call `Recorder::make_default_handle(Recorder::default)`.
154//! let (metric_layer, _handle) = GenericMetricLayer::<_, Recorder>::pair();
155//! }
156//! ```
157//!
158//! This crate is similar to (and takes inspiration from) [`actix-web-prom`](https://github.com/nlopes/actix-web-prom) and [`rocket_prometheus`](https://github.com/sd2k/rocket_prometheus),
159//! and also builds on top of davidpdrsn's [earlier work with LifeCycleHooks](https://github.com/tower-rs/tower-http/pull/96) in `tower-http`.
160//!
161//! [`PrometheusMetricLayerBuilder`]: crate::PrometheusMetricLayerBuilder
162
163#![allow(clippy::module_name_repetitions, clippy::unreadable_literal)]
164
165/// Identifies the gauge used for the requests pending metric. Defaults to
166/// `axum_http_requests_pending`, but can be changed by setting the `AXUM_HTTP_REQUESTS_PENDING`
167/// env at compile time.
168pub const AXUM_HTTP_REQUESTS_PENDING: &str = match option_env!("AXUM_HTTP_REQUESTS_PENDING") {
169 Some(n) => n,
170 None => "axum_http_requests_pending",
171};
172
173/// Identifies the histogram/summary used for request latency. Defaults to `axum_http_requests_duration_seconds`,
174/// but can be changed by setting the `AXUM_HTTP_REQUESTS_DURATION_SECONDS` env at compile time.
175pub const AXUM_HTTP_REQUESTS_DURATION_SECONDS: &str =
176 match option_env!("AXUM_HTTP_REQUESTS_DURATION_SECONDS") {
177 Some(n) => n,
178 None => "axum_http_requests_duration_seconds",
179 };
180
181/// Identifies the counter used for requests total. Defaults to `axum_http_requests_total`,
182/// but can be changed by setting the `AXUM_HTTP_REQUESTS_TOTAL` env at compile time.
183pub const AXUM_HTTP_REQUESTS_TOTAL: &str = match option_env!("AXUM_HTTP_REQUESTS_TOTAL") {
184 Some(n) => n,
185 None => "axum_http_requests_total",
186};
187
188/// Identifies the histogram/summary used for response body size. Defaults to `axum_http_response_body_size`,
189/// but can be changed by setting the `AXUM_HTTP_RESPONSE_BODY_SIZE` env at compile time.
190pub const AXUM_HTTP_RESPONSE_BODY_SIZE: &str = match option_env!("AXUM_HTTP_RESPONSE_BODY_SIZE") {
191 Some(n) => n,
192 None => "axum_http_response_body_size",
193};
194
195#[doc(hidden)]
196pub static PREFIXED_HTTP_REQUESTS_TOTAL: OnceLock<String> = OnceLock::new();
197#[doc(hidden)]
198pub static PREFIXED_HTTP_REQUESTS_DURATION_SECONDS: OnceLock<String> = OnceLock::new();
199#[doc(hidden)]
200pub static PREFIXED_HTTP_REQUESTS_PENDING: OnceLock<String> = OnceLock::new();
201#[doc(hidden)]
202pub static PREFIXED_HTTP_RESPONSE_BODY_SIZE: OnceLock<String> = OnceLock::new();
203
204use std::borrow::Cow;
205use std::collections::HashMap;
206use std::marker::PhantomData;
207use std::sync::atomic::AtomicBool;
208use std::sync::{Arc, OnceLock};
209use std::time::Duration;
210use std::time::Instant;
211
212mod builder;
213pub mod lifecycle;
214pub mod utils;
215use axum::extract::MatchedPath;
216pub use builder::EndpointLabel;
217pub use builder::MetricLayerBuilder;
218#[cfg(feature = "prometheus")]
219pub use builder::PrometheusMetricLayerBuilder;
220use builder::{LayerOnly, Paired};
221use lifecycle::layer::LifeCycleLayer;
222use lifecycle::OnBodyChunk;
223use lifecycle::{service::LifeCycle, Callbacks};
224use metrics::{counter, gauge, histogram, Gauge};
225use tower::Layer;
226use tower_http::classify::{ClassifiedResponse, SharedClassifier, StatusInRangeAsFailures};
227
228#[cfg(feature = "prometheus")]
229use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
230
231pub use metrics;
232#[cfg(feature = "prometheus")]
233pub use metrics_exporter_prometheus;
234
235/// Use a prefix for the metrics instead of `axum`. This will use the following
236/// metric names:
237/// - `{prefix}_http_requests_total`
238/// - `{prefix}_http_requests_pending`
239/// - `{prefix}_http_requests_duration_seconds`
240///
241/// Note that this will take precedence over environment variables, and can only
242/// be called once. Attempts to call this a second time will panic.
243fn set_prefix(prefix: impl AsRef<str>) {
244 PREFIXED_HTTP_REQUESTS_TOTAL
245 .set(format!("{}_http_requests_total", prefix.as_ref()))
246 .expect("the prefix has already been set, and can only be set once.");
247 PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
248 .set(format!(
249 "{}_http_requests_duration_seconds",
250 prefix.as_ref()
251 ))
252 .expect("the prefix has already been set, and can only be set once.");
253 PREFIXED_HTTP_REQUESTS_PENDING
254 .set(format!("{}_http_requests_pending", prefix.as_ref()))
255 .expect("the prefix has already been set, and can only be set once.");
256 PREFIXED_HTTP_RESPONSE_BODY_SIZE
257 .set(format!("{}_http_response_body_size", prefix.as_ref()))
258 .expect("the prefix has already been set, and can only be set once.");
259}
260
261/// A marker struct that implements the [`lifecycle::Callbacks`] trait.
262#[derive(Clone, Default)]
263pub struct Traffic<'a> {
264 filter_mode: FilterMode,
265 ignore_patterns: matchit::Router<()>,
266 allow_patterns: matchit::Router<()>,
267 group_patterns: HashMap<&'a str, matchit::Router<()>>,
268 endpoint_label: EndpointLabel,
269}
270
271#[derive(Clone, Default)]
272enum FilterMode {
273 #[default]
274 Ignore,
275 AllowOnly,
276}
277
278impl<'a> Traffic<'a> {
279 pub(crate) fn new() -> Self {
280 Traffic::default()
281 }
282
283 pub(crate) fn with_ignore_pattern(&mut self, ignore_pattern: &'a str) {
284 if !matches!(self.filter_mode, FilterMode::Ignore) {
285 self.filter_mode = FilterMode::Ignore;
286 self.allow_patterns = matchit::Router::new();
287 self.ignore_patterns = matchit::Router::new();
288 }
289 self.ignore_patterns
290 .insert(ignore_pattern, ())
291 .expect("good route specs");
292 }
293
294 pub(crate) fn with_allow_pattern(&mut self, allow_pattern: &'a str) {
295 if !matches!(self.filter_mode, FilterMode::AllowOnly) {
296 self.filter_mode = FilterMode::AllowOnly;
297 self.ignore_patterns = matchit::Router::new();
298 self.allow_patterns = matchit::Router::new();
299 }
300 self.allow_patterns
301 .insert(allow_pattern, ())
302 .expect("good route specs");
303 }
304
305 pub(crate) fn with_ignore_patterns(&mut self, ignore_patterns: &'a [&'a str]) {
306 for pattern in ignore_patterns {
307 self.with_ignore_pattern(pattern);
308 }
309 }
310
311 pub(crate) fn with_allow_patterns(&mut self, allow_patterns: &'a [&'a str]) {
312 for pattern in allow_patterns {
313 self.with_allow_pattern(pattern);
314 }
315 }
316
317 pub(crate) fn with_group_patterns_as(&mut self, group_pattern: &'a str, patterns: &'a [&str]) {
318 self.group_patterns
319 .entry(group_pattern)
320 .and_modify(|router| {
321 for pattern in patterns {
322 router.insert(*pattern, ()).expect("good route specs");
323 }
324 })
325 .or_insert_with(|| {
326 let mut inner_router = matchit::Router::new();
327 for pattern in patterns {
328 inner_router.insert(*pattern, ()).expect("good route specs");
329 }
330 inner_router
331 });
332 }
333
334 pub(crate) fn ignores(&self, path: &str) -> bool {
335 match self.filter_mode {
336 FilterMode::Ignore => self.ignore_patterns.at(path).is_ok(),
337 FilterMode::AllowOnly => !self.allow_patterns.at(path).is_ok(),
338 }
339 }
340
341 pub(crate) fn apply_group_pattern(&self, path: &'a str) -> &'a str {
342 self.group_patterns
343 .iter()
344 .find_map(|(&group, router)| router.at(path).ok().and(Some(group)))
345 .unwrap_or(path)
346 }
347
348 pub(crate) fn with_endpoint_label_type(&mut self, endpoint_label: EndpointLabel) {
349 self.endpoint_label = endpoint_label;
350 }
351}
352
353/// Struct used for storing and calculating information about the current request.
354#[derive(Debug, Clone)]
355pub struct MetricsData {
356 pub endpoint: String,
357 pub start: Instant,
358 pub method: &'static str,
359 pub body_size: f64,
360 // FIXME: Unclear at the moment, maybe just a simple bool could suffice here?
361 pub(crate) exact_body_size_called: Arc<AtomicBool>,
362}
363
364#[doc(hidden)]
365pub struct Pending(Gauge);
366
367impl Drop for Pending {
368 fn drop(&mut self) {
369 self.0.decrement(1);
370 }
371}
372
373// The `Pending` struct is behind an Arc to make sure we only drop it once (since we're cloning this across the lifecycle).
374type DefaultCallbackData = Option<(MetricsData, Arc<Pending>)>;
375
376/// A marker struct that implements [`lifecycle::OnBodyChunk`], so it can be used to track response body sizes.
377#[derive(Clone)]
378pub struct BodySizeRecorder;
379
380impl<B> OnBodyChunk<B> for BodySizeRecorder
381where
382 B: bytes::Buf,
383{
384 type Data = DefaultCallbackData;
385
386 #[inline]
387 fn call(&mut self, body: &B, body_size: Option<u64>, data: &mut Self::Data) {
388 let Some((metrics_data, _pending_guard)) = data else {
389 return;
390 };
391 // If the exact body size is known ahead of time, we'll just call this whole thing once.
392 if let Some(exact_size) = body_size {
393 if !metrics_data
394 .exact_body_size_called
395 .swap(true, std::sync::atomic::Ordering::Relaxed)
396 {
397 // If the body size is enormous, we lose some precision. It shouldn't matter really.
398 metrics_data.body_size = exact_size as f64;
399 body_size_histogram(metrics_data);
400 }
401 } else {
402 // Otherwise, sum all the chunks.
403 metrics_data.body_size += body.remaining() as f64;
404 body_size_histogram(metrics_data);
405 }
406 }
407}
408
409impl<T, B> OnBodyChunk<B> for Option<T>
410where
411 T: OnBodyChunk<B>,
412 B: bytes::Buf,
413{
414 type Data = T::Data;
415
416 fn call(&mut self, body: &B, body_size: Option<u64>, data: &mut Self::Data) {
417 if let Some(this) = self {
418 T::call(this, body, body_size, data);
419 }
420 }
421}
422
423fn body_size_histogram(metrics_data: &MetricsData) {
424 let labels = &[
425 ("method", metrics_data.method.to_owned()),
426 ("endpoint", metrics_data.endpoint.clone()),
427 ];
428 let response_body_size = PREFIXED_HTTP_RESPONSE_BODY_SIZE
429 .get()
430 .map_or(AXUM_HTTP_RESPONSE_BODY_SIZE, |s| s.as_str());
431 metrics::histogram!(response_body_size, labels).record(metrics_data.body_size);
432}
433
434impl<'a, FailureClass> Callbacks<FailureClass> for Traffic<'a> {
435 type Data = DefaultCallbackData;
436
437 fn prepare<B>(&mut self, request: &http::Request<B>) -> Self::Data {
438 let now = std::time::Instant::now();
439 let exact_endpoint = request.uri().path();
440 if self.ignores(exact_endpoint) {
441 return None;
442 }
443 let endpoint = match self.endpoint_label {
444 EndpointLabel::Exact => Cow::from(exact_endpoint),
445 EndpointLabel::MatchedPath => Cow::from(
446 request
447 .extensions()
448 .get::<MatchedPath>()
449 .map_or(exact_endpoint, MatchedPath::as_str),
450 ),
451 EndpointLabel::MatchedPathWithFallbackFn(fallback_fn) => {
452 if let Some(mp) = request
453 .extensions()
454 .get::<MatchedPath>()
455 .map(MatchedPath::as_str)
456 {
457 Cow::from(mp)
458 } else {
459 Cow::from(fallback_fn(exact_endpoint))
460 }
461 }
462 };
463 let endpoint = self.apply_group_pattern(&endpoint).to_owned();
464 let method = utils::as_label(request.method());
465
466 let pending = gauge!(
467 utils::requests_pending_name(),
468 &[
469 ("method", method.to_owned()),
470 ("endpoint", endpoint.clone()),
471 ]
472 );
473 pending.increment(1);
474
475 Some((
476 MetricsData {
477 endpoint,
478 start: now,
479 method,
480 body_size: 0.0,
481 exact_body_size_called: Arc::new(AtomicBool::new(false)),
482 },
483 Arc::new(Pending(pending)),
484 ))
485 }
486
487 fn on_response<B>(
488 &mut self,
489 res: &http::Response<B>,
490 _cls: ClassifiedResponse<FailureClass, ()>,
491 data: &mut Self::Data,
492 ) {
493 if let Some((data, _pending_guard)) = data {
494 let duration_seconds = data.start.elapsed().as_secs_f64();
495
496 let labels = [
497 ("method", data.method.to_string()),
498 ("status", res.status().as_u16().to_string()),
499 ("endpoint", data.endpoint.to_string()),
500 ];
501
502 let requests_total = PREFIXED_HTTP_REQUESTS_TOTAL
503 .get()
504 .map_or(AXUM_HTTP_REQUESTS_TOTAL, |s| s.as_str());
505 counter!(requests_total, &labels).increment(1);
506
507 let requests_duration = PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
508 .get()
509 .map_or(AXUM_HTTP_REQUESTS_DURATION_SECONDS, |s| s.as_str());
510 histogram!(requests_duration, &labels).record(duration_seconds);
511 }
512 }
513}
514
515/// The tower middleware layer for recording HTTP metrics.
516///
517/// Unlike [`GenericMetricLayer`], this struct __does not__ know about the metrics exporter, or the recorder. It will only emit
518/// metrics via the `metrics` crate's macros. It's entirely up to the user to set the global metrics recorder/exporter before using this.
519///
520/// You may use this if `GenericMetricLayer`'s requirements are too strict for your use case.
521#[derive(Clone)]
522pub struct BaseMetricLayer<'a> {
523 pub(crate) inner_layer: LifeCycleLayer<
524 SharedClassifier<StatusInRangeAsFailures>,
525 Traffic<'a>,
526 Option<BodySizeRecorder>,
527 >,
528}
529
530impl<'a> BaseMetricLayer<'a> {
531 /// Construct a new `BaseMetricLayer`.
532 ///
533 /// # Example
534 /// ```
535 /// use axum::{routing::get, Router};
536 /// use axum_prometheus::{AXUM_HTTP_REQUESTS_DURATION_SECONDS, utils::SECONDS_DURATION_BUCKETS, BaseMetricLayer};
537 /// use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
538 /// use std::net::SocketAddr;
539 ///
540 /// #[tokio::main]
541 /// async fn main() {
542 /// // Initialize the recorder as you like.
543 /// let metric_handle = PrometheusBuilder::new()
544 /// .set_buckets_for_metric(
545 /// Matcher::Full(AXUM_HTTP_REQUESTS_DURATION_SECONDS.to_string()),
546 /// SECONDS_DURATION_BUCKETS,
547 /// )
548 /// .unwrap()
549 /// .install_recorder()
550 /// .unwrap();
551 ///
552 /// let app = Router::<()>::new()
553 /// .route("/fast", get(|| async {}))
554 /// .route(
555 /// "/slow",
556 /// get(|| async {
557 /// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
558 /// }),
559 /// )
560 /// // Expose the metrics somehow to the outer world.
561 /// .route("/metrics", get(|| async move { metric_handle.render() }))
562 /// // Only need to add this layer at the end.
563 /// .layer(BaseMetricLayer::new());
564 ///
565 /// // Run the server as usual:
566 /// // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
567 /// // .await
568 /// // .unwrap();
569 /// // axum::serve(listener, app).await.unwrap()
570 /// }
571 /// ```
572 pub fn new() -> Self {
573 let make_classifier =
574 StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
575 let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new(), None);
576 Self { inner_layer }
577 }
578
579 /// Construct a new `BaseMetricLayer` with response body size tracking enabled.
580 pub fn with_response_body_size() -> Self {
581 let mut this = Self::new();
582 this.inner_layer.on_body_chunk(Some(BodySizeRecorder));
583 this
584 }
585}
586
587impl<'a> Default for BaseMetricLayer<'a> {
588 fn default() -> Self {
589 Self::new()
590 }
591}
592
593impl<'a, S> Layer<S> for BaseMetricLayer<'a> {
594 type Service = LifeCycle<
595 S,
596 SharedClassifier<StatusInRangeAsFailures>,
597 Traffic<'a>,
598 Option<BodySizeRecorder>,
599 >;
600
601 fn layer(&self, inner: S) -> Self::Service {
602 self.inner_layer.layer(inner)
603 }
604}
605
606/// The tower middleware layer for recording http metrics with different exporters.
607pub struct GenericMetricLayer<'a, T, M> {
608 pub(crate) inner_layer: LifeCycleLayer<
609 SharedClassifier<StatusInRangeAsFailures>,
610 Traffic<'a>,
611 Option<BodySizeRecorder>,
612 >,
613 _marker: PhantomData<(T, M)>,
614}
615
616// We don't require that `T` nor `M` is `Clone`, since none of them is actually contained in this type.
617impl<'a, T, M> std::clone::Clone for GenericMetricLayer<'a, T, M> {
618 fn clone(&self) -> Self {
619 GenericMetricLayer {
620 inner_layer: self.inner_layer.clone(),
621 _marker: self._marker,
622 }
623 }
624}
625
626impl<'a, T, M> GenericMetricLayer<'a, T, M>
627where
628 M: MakeDefaultHandle<Out = T>,
629{
630 /// Create a new tower middleware that can be used to track metrics.
631 ///
632 /// By default, this __will not__ "install" the exporter which sets it as the
633 /// global recorder for all `metrics` calls.
634 /// If you're using Prometheus, here you can use [`metrics_exporter_prometheus::PrometheusBuilder`]
635 /// to build your own customized metrics exporter.
636 ///
637 /// This middleware is using the following constants for identifying different HTTP metrics:
638 ///
639 /// - [`AXUM_HTTP_REQUESTS_PENDING`]
640 /// - [`AXUM_HTTP_REQUESTS_TOTAL`]
641 /// - [`AXUM_HTTP_REQUESTS_DURATION_SECONDS`].
642 ///
643 /// In terms of setup, the most important one is [`AXUM_HTTP_REQUESTS_DURATION_SECONDS`], which is a histogram metric
644 /// used for request latency. You may set customized buckets tailored for your used case here.
645 ///
646 /// # Example
647 /// ```
648 /// use axum::{routing::get, Router};
649 /// use axum_prometheus::{AXUM_HTTP_REQUESTS_DURATION_SECONDS, utils::SECONDS_DURATION_BUCKETS, PrometheusMetricLayer};
650 /// use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
651 /// use std::net::SocketAddr;
652 ///
653 /// #[tokio::main]
654 /// async fn main() {
655 /// let metric_layer = PrometheusMetricLayer::new();
656 /// // This is the default if you use `PrometheusMetricLayer::pair`.
657 /// let metric_handle = PrometheusBuilder::new()
658 /// .set_buckets_for_metric(
659 /// Matcher::Full(AXUM_HTTP_REQUESTS_DURATION_SECONDS.to_string()),
660 /// SECONDS_DURATION_BUCKETS,
661 /// )
662 /// .unwrap()
663 /// .install_recorder()
664 /// .unwrap();
665 ///
666 /// let app = Router::<()>::new()
667 /// .route("/fast", get(|| async {}))
668 /// .route(
669 /// "/slow",
670 /// get(|| async {
671 /// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
672 /// }),
673 /// )
674 /// .route("/metrics", get(|| async move { metric_handle.render() }))
675 /// .layer(metric_layer);
676 ///
677 /// // Run the server as usual:
678 /// // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
679 /// // .await
680 /// // .unwrap();
681 /// // axum::serve(listener, app).await.unwrap()
682 /// }
683 /// ```
684 pub fn new() -> Self {
685 let make_classifier =
686 StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
687 let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new(), None);
688 Self {
689 inner_layer,
690 _marker: PhantomData,
691 }
692 }
693
694 pub(crate) fn from_builder(builder: MetricLayerBuilder<'a, T, M, LayerOnly>) -> Self {
695 let make_classifier =
696 StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
697 let inner_layer = if builder.enable_body_size {
698 LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder))
699 } else {
700 LifeCycleLayer::new(make_classifier, builder.traffic, None)
701 };
702 Self {
703 inner_layer,
704 _marker: PhantomData,
705 }
706 }
707
708 /// Enable tracking response body sizes.
709 pub fn enable_response_body_size(&mut self) {
710 self.inner_layer.on_body_chunk(Some(BodySizeRecorder));
711 }
712
713 /// Crate a new tower middleware and a default exporter from the provided value of the passed in argument.
714 ///
715 /// This function is useful when additional data needs to be injected into `MakeDefaultHandle::make_default_handle`.
716 ///
717 /// # Example
718 ///
719 /// ```rust,no_run
720 /// use axum_prometheus::{GenericMetricLayer, MakeDefaultHandle};
721 ///
722 /// struct Recorder { host: String }
723 ///
724 /// impl MakeDefaultHandle for Recorder {
725 /// type Out = ();
726 ///
727 /// fn make_default_handle(self) -> Self::Out {
728 /// // Perform the initialization. `self` is passed in by value.
729 /// todo!();
730 /// }
731 /// }
732 ///
733 /// fn main() {
734 /// let (metric_layer, metric_handle) = GenericMetricLayer::pair_from(
735 /// Recorder { host: "0.0.0.0".to_string() }
736 /// );
737 /// }
738 /// ```
739 pub fn pair_from(m: M) -> (Self, T) {
740 (Self::new(), M::make_default_handle(m))
741 }
742}
743
744impl<'a, T, M> GenericMetricLayer<'a, T, M>
745where
746 M: MakeDefaultHandle<Out = T> + Default,
747{
748 pub(crate) fn pair_from_builder(builder: MetricLayerBuilder<'a, T, M, Paired>) -> (Self, T) {
749 let make_classifier =
750 StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
751 let inner_layer = if builder.enable_body_size {
752 LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder))
753 } else {
754 LifeCycleLayer::new(make_classifier, builder.traffic, None)
755 };
756
757 (
758 Self {
759 inner_layer,
760 _marker: PhantomData,
761 },
762 builder
763 .metric_handle
764 .unwrap_or_else(|| M::make_default_handle(M::default())),
765 )
766 }
767
768 /// Crate a new tower middleware and a default global Prometheus exporter with sensible defaults.
769 ///
770 /// If used with a custom exporter that's different from Prometheus, the exporter struct
771 /// must implement `MakeDefaultHandle + Default`.
772 ///
773 /// # Example
774 /// ```
775 /// use axum::{routing::get, Router};
776 /// use axum_prometheus::PrometheusMetricLayer;
777 /// use std::net::SocketAddr;
778 ///
779 /// #[tokio::main]
780 /// async fn main() {
781 /// let (metric_layer, metric_handle) = PrometheusMetricLayer::pair();
782 ///
783 /// let app = Router::<()>::new()
784 /// .route("/fast", get(|| async {}))
785 /// .route(
786 /// "/slow",
787 /// get(|| async {
788 /// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
789 /// }),
790 /// )
791 /// .route("/metrics", get(|| async move { metric_handle.render() }))
792 /// .layer(metric_layer);
793 ///
794 /// // Run the server as usual:
795 /// // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
796 /// // .await
797 /// // .unwrap();
798 /// // axum::serve(listener, app).await.unwrap()
799 /// }
800 /// ```
801 pub fn pair() -> (Self, T) {
802 (Self::new(), M::make_default_handle(M::default()))
803 }
804}
805
806impl<'a, T, M> Default for GenericMetricLayer<'a, T, M>
807where
808 M: MakeDefaultHandle<Out = T>,
809{
810 fn default() -> Self {
811 Self::new()
812 }
813}
814
815impl<'a, S, T, M> Layer<S> for GenericMetricLayer<'a, T, M> {
816 type Service = LifeCycle<
817 S,
818 SharedClassifier<StatusInRangeAsFailures>,
819 Traffic<'a>,
820 Option<BodySizeRecorder>,
821 >;
822
823 fn layer(&self, inner: S) -> Self::Service {
824 self.inner_layer.layer(inner)
825 }
826}
827
828/// The trait that allows to use a metrics exporter in `GenericMetricLayer`.
829pub trait MakeDefaultHandle {
830 /// The type of the metrics handle to return from [`MetricLayerBuilder`].
831 type Out;
832
833 /// The function that defines how to initialize a metric exporter by default.
834 ///
835 /// # Example
836 ///
837 /// ```rust, no_run
838 /// use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
839 ///
840 /// pub struct MyHandle(pub String);
841 ///
842 /// impl MakeDefaultHandle for MyHandle {
843 /// type Out = ();
844 ///
845 /// fn make_default_handle(self) -> Self::Out {
846 /// // This is where you initialize and register everything you need.
847 /// // Notice that self is passed in by value.
848 /// }
849 /// }
850 /// ```
851 /// and then, to use it:
852 /// ```rust,ignore
853 /// // Initialize the struct, then use `pair_from`.
854 /// let my_handle = MyHandle(String::from("localhost"));
855 /// let (layer, handle) = GenericMetricLayer::pair_from(my_handle);
856 ///
857 /// // Or optionally if your custom struct implements `Default` too, you may call `pair`.
858 /// // That's going to use `MyHandle::default()`.
859 /// let (layer, handle) = GenericMetricLayer::<'_, _, MyHandle>::pair();
860 /// ```
861 fn make_default_handle(self) -> Self::Out;
862}
863
864/// The default handle for the Prometheus exporter.
865#[cfg(feature = "prometheus")]
866#[derive(Clone)]
867pub struct Handle(pub PrometheusHandle);
868
869#[cfg(feature = "prometheus")]
870impl Default for Handle {
871 fn default() -> Self {
872 let recorder = PrometheusBuilder::new()
873 .set_buckets_for_metric(
874 Matcher::Full(
875 PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
876 .get()
877 .map_or(AXUM_HTTP_REQUESTS_DURATION_SECONDS, |s| s.as_str())
878 .to_string(),
879 ),
880 utils::SECONDS_DURATION_BUCKETS,
881 )
882 .unwrap()
883 .build_recorder();
884 let handle = recorder.handle();
885 let recorder_handle = handle.clone();
886 tokio::spawn(async move {
887 loop {
888 tokio::time::sleep(Duration::from_secs(5)).await;
889 recorder_handle.run_upkeep();
890 }
891 });
892 metrics::set_global_recorder(recorder).expect("Failed to set global recorder");
893 Self(handle)
894 }
895}
896
897#[cfg(feature = "prometheus")]
898impl MakeDefaultHandle for Handle {
899 type Out = PrometheusHandle;
900
901 fn make_default_handle(self) -> Self::Out {
902 self.0
903 }
904}
905
906#[cfg(feature = "prometheus")]
907/// The tower middleware layer for recording http metrics with Prometheus.
908pub type PrometheusMetricLayer<'a> = GenericMetricLayer<'a, PrometheusHandle, Handle>;