axum_prometheus/
lib.rs

1//!A middleware to collect HTTP metrics for Axum applications.
2//!
3//! `axum-prometheus` relies on [`metrics.rs`](https://metrics.rs/) and its ecosystem to collect and export metrics - for instance for Prometheus, `metrics_exporter_prometheus` is used as a backend to interact with Prometheus.
4//!
5//! ## Metrics
6//!
7//! By default three HTTP metrics are tracked
8//! - `axum_http_requests_total` (labels: endpoint, method, status): the total number of HTTP requests handled (counter)
9//! - `axum_http_requests_duration_seconds` (labels: endpoint, method, status): the request duration for all HTTP requests handled (histogram)
10//! - `axum_http_requests_pending` (labels: endpoint, method): the number of currently in-flight requests (gauge)
11//!
12//! This crate also allows to track response body sizes as a histogram — see [`PrometheusMetricLayerBuilder::enable_response_body_size`].
13//!
14//! ### Renaming Metrics
15//!
16//! These metrics can be renamed by specifying environmental variables at compile time:
17//! - `AXUM_HTTP_REQUESTS_TOTAL`
18//! - `AXUM_HTTP_REQUESTS_DURATION_SECONDS`
19//! - `AXUM_HTTP_REQUESTS_PENDING`
20//! - `AXUM_HTTP_RESPONSE_BODY_SIZE` (if body size tracking is enabled)
21//!
22//! These environmental variables can be set in your `.cargo/config.toml` since Cargo 1.56:
23//! ```toml
24//! [env]
25//! AXUM_HTTP_REQUESTS_TOTAL = "my_app_requests_total"
26//! AXUM_HTTP_REQUESTS_DURATION_SECONDS = "my_app_requests_duration_seconds"
27//! AXUM_HTTP_REQUESTS_PENDING = "my_app_requests_pending"
28//! AXUM_HTTP_RESPONSE_BODY_SIZE = "my_app_response_body_size"
29//! ```
30//!
31//! ..or optionally use [`PrometheusMetricLayerBuilder::with_prefix`] function.
32//!
33//! ## Usage
34//!
35//! For more elaborate use-cases, see the builder-example that leverages [`PrometheusMetricLayerBuilder`].
36//!
37//! Add `axum-prometheus` to your `Cargo.toml`.
38//! ```not_rust
39//! [dependencies]
40//! axum-prometheus = "0.9.0"
41//! ```
42//!
43//! Then you instantiate the prometheus middleware:
44//! ```rust,no_run
45//! use std::{net::SocketAddr, time::Duration};
46//! use axum::{routing::get, Router};
47//! use axum_prometheus::PrometheusMetricLayer;
48//!
49//! #[tokio::main]
50//! async fn main() {
51//!     let (prometheus_layer, metric_handle) = PrometheusMetricLayer::pair();
52//!     let app = Router::new()
53//!         .route("/fast", get(|| async {}))
54//!         .route(
55//!             "/slow",
56//!             get(|| async {
57//!                 tokio::time::sleep(Duration::from_secs(1)).await;
58//!             }),
59//!         )
60//!         .route("/metrics", get(|| async move { metric_handle.render() }))
61//!         .layer(prometheus_layer);
62//!
63//!     let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
64//!         .await
65//!         .unwrap();
66//!     axum::serve(listener, app).await.unwrap()
67//! }
68//! ```
69//!
70//! Note that the `/metrics` endpoint is not automatically exposed, so you need to add that as a route manually.
71//! Calling the `/metrics` endpoint will expose your metrics:
72//! ```not_rust
73//! axum_http_requests_total{method="GET",endpoint="/metrics",status="200"} 5
74//! axum_http_requests_pending{method="GET",endpoint="/metrics"} 1
75//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.005"} 4
76//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.01"} 4
77//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.025"} 4
78//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.05"} 4
79//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.1"} 4
80//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.25"} 4
81//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="0.5"} 4
82//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="1"} 4
83//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="2.5"} 4
84//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="5"} 4
85//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="10"} 4
86//! axum_http_requests_duration_seconds_bucket{method="GET",status="200",endpoint="/metrics",le="+Inf"} 4
87//! axum_http_requests_duration_seconds_sum{method="GET",status="200",endpoint="/metrics"} 0.001997171
88//! axum_http_requests_duration_seconds_count{method="GET",status="200",endpoint="/metrics"} 4
89//! ```
90//!
91//! ## Prometheus push gateway feature
92//! This crate currently has no higher level API for the `push-gateway` feature. If you plan to use it, enable the
93//! `push-gateway` feature in `axum-prometheus`, use `BaseMetricLayer`, and setup your recorder manually, similar to
94//! the `base-metric-layer-example`.
95//!
96//! ## Using a different exporter than Prometheus
97//!
98//! This crate may be used with other exporters than Prometheus. First, disable the default features:
99//!
100//! ```toml
101//! axum-prometheus = { version = "0.9.0", default-features = false }
102//! ```
103//!
104//! Then implement the `MakeDefaultHandle` for the provider you'd like to use. For `StatsD`:
105//!
106//! ```rust,ignore
107//! use metrics_exporter_statsd::StatsdBuilder;
108//! use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
109//!
110//! // The custom StatsD exporter struct. It may take fields as well.
111//! struct Recorder { port: u16 }
112//!
113//! // In order to use this with `axum_prometheus`, we must implement `MakeDefaultHandle`.
114//! impl MakeDefaultHandle for Recorder {
115//!     // We don't need to return anything meaningful from here (unlike PrometheusHandle)
116//!     // Let's just return an empty tuple.
117//!     type Out = ();
118//!
119//!     fn make_default_handle(self) -> Self::Out {
120//!         // The regular setup for StatsD. Notice that `self` is passed in by value.
121//!         let recorder = StatsdBuilder::from("127.0.0.1", self.port)
122//!             .with_queue_size(5000)
123//!             .with_buffer_size(1024)
124//!             .build(Some("prefix"))
125//!             .expect("Could not create StatsdRecorder");
126//!
127//!         metrics::set_boxed_recorder(Box::new(recorder)).unwrap();
128//!     }
129//! }
130//!
131//! fn main() {
132//!     // Use `GenericMetricLayer` instead of `PrometheusMetricLayer`.
133//!     // Generally `GenericMetricLayer::pair_from` is what you're looking for.
134//!     // It lets you pass in a concrete initialized `Recorder`.
135//!     let (metric_layer, _handle) = GenericMetricLayer::pair_from(Recorder { port: 8125 });
136//! }
137//! ```
138//!
139//! It's also possible to use `GenericMetricLayer::pair`, however it's only callable if the recorder struct implements `Default` as well.
140//!
141//! ```rust,ignore
142//! use metrics_exporter_statsd::StatsdBuilder;
143//! use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
144//!
145//! #[derive(Default)]
146//! struct Recorder { port: u16 }
147//!
148//! impl MakeDefaultHandle for Recorder {
149//!    /* .. same as before .. */
150//! }
151//!
152//! fn main() {
153//!     // This will internally call `Recorder::make_default_handle(Recorder::default)`.
154//!     let (metric_layer, _handle) = GenericMetricLayer::<_, Recorder>::pair();
155//! }
156//! ```
157//!
158//! This crate is similar to (and takes inspiration from) [`actix-web-prom`](https://github.com/nlopes/actix-web-prom) and [`rocket_prometheus`](https://github.com/sd2k/rocket_prometheus),
159//! and also builds on top of davidpdrsn's [earlier work with LifeCycleHooks](https://github.com/tower-rs/tower-http/pull/96) in `tower-http`.
160//!
161//! [`PrometheusMetricLayerBuilder`]: crate::PrometheusMetricLayerBuilder
162
163#![allow(clippy::module_name_repetitions, clippy::unreadable_literal)]
164
165/// Identifies the gauge used for the requests pending metric. Defaults to
166/// `axum_http_requests_pending`, but can be changed by setting the `AXUM_HTTP_REQUESTS_PENDING`
167/// env at compile time.
168pub const AXUM_HTTP_REQUESTS_PENDING: &str = match option_env!("AXUM_HTTP_REQUESTS_PENDING") {
169    Some(n) => n,
170    None => "axum_http_requests_pending",
171};
172
173/// Identifies the histogram/summary used for request latency. Defaults to `axum_http_requests_duration_seconds`,
174/// but can be changed by setting the `AXUM_HTTP_REQUESTS_DURATION_SECONDS` env at compile time.
175pub const AXUM_HTTP_REQUESTS_DURATION_SECONDS: &str =
176    match option_env!("AXUM_HTTP_REQUESTS_DURATION_SECONDS") {
177        Some(n) => n,
178        None => "axum_http_requests_duration_seconds",
179    };
180
181/// Identifies the counter used for requests total. Defaults to `axum_http_requests_total`,
182/// but can be changed by setting the `AXUM_HTTP_REQUESTS_TOTAL` env at compile time.
183pub const AXUM_HTTP_REQUESTS_TOTAL: &str = match option_env!("AXUM_HTTP_REQUESTS_TOTAL") {
184    Some(n) => n,
185    None => "axum_http_requests_total",
186};
187
188/// Identifies the histogram/summary used for response body size. Defaults to `axum_http_response_body_size`,
189/// but can be changed by setting the `AXUM_HTTP_RESPONSE_BODY_SIZE` env at compile time.
190pub const AXUM_HTTP_RESPONSE_BODY_SIZE: &str = match option_env!("AXUM_HTTP_RESPONSE_BODY_SIZE") {
191    Some(n) => n,
192    None => "axum_http_response_body_size",
193};
194
195#[doc(hidden)]
196pub static PREFIXED_HTTP_REQUESTS_TOTAL: OnceLock<String> = OnceLock::new();
197#[doc(hidden)]
198pub static PREFIXED_HTTP_REQUESTS_DURATION_SECONDS: OnceLock<String> = OnceLock::new();
199#[doc(hidden)]
200pub static PREFIXED_HTTP_REQUESTS_PENDING: OnceLock<String> = OnceLock::new();
201#[doc(hidden)]
202pub static PREFIXED_HTTP_RESPONSE_BODY_SIZE: OnceLock<String> = OnceLock::new();
203
204use std::borrow::Cow;
205use std::collections::HashMap;
206use std::marker::PhantomData;
207use std::sync::atomic::AtomicBool;
208use std::sync::{Arc, OnceLock};
209use std::time::Duration;
210use std::time::Instant;
211
212mod builder;
213pub mod lifecycle;
214pub mod utils;
215use axum::extract::MatchedPath;
216pub use builder::EndpointLabel;
217pub use builder::MetricLayerBuilder;
218#[cfg(feature = "prometheus")]
219pub use builder::PrometheusMetricLayerBuilder;
220use builder::{LayerOnly, Paired};
221use lifecycle::layer::LifeCycleLayer;
222use lifecycle::OnBodyChunk;
223use lifecycle::{service::LifeCycle, Callbacks};
224use metrics::{counter, gauge, histogram, Gauge};
225use tower::Layer;
226use tower_http::classify::{ClassifiedResponse, SharedClassifier, StatusInRangeAsFailures};
227
228#[cfg(feature = "prometheus")]
229use metrics_exporter_prometheus::{Matcher, PrometheusBuilder, PrometheusHandle};
230
231pub use metrics;
232#[cfg(feature = "prometheus")]
233pub use metrics_exporter_prometheus;
234
235/// Use a prefix for the metrics instead of `axum`. This will use the following
236/// metric names:
237///  - `{prefix}_http_requests_total`
238///  - `{prefix}_http_requests_pending`
239///  - `{prefix}_http_requests_duration_seconds`
240///
241/// Note that this will take precedence over environment variables, and can only
242/// be called once. Attempts to call this a second time will panic.
243fn set_prefix(prefix: impl AsRef<str>) {
244    PREFIXED_HTTP_REQUESTS_TOTAL
245        .set(format!("{}_http_requests_total", prefix.as_ref()))
246        .expect("the prefix has already been set, and can only be set once.");
247    PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
248        .set(format!(
249            "{}_http_requests_duration_seconds",
250            prefix.as_ref()
251        ))
252        .expect("the prefix has already been set, and can only be set once.");
253    PREFIXED_HTTP_REQUESTS_PENDING
254        .set(format!("{}_http_requests_pending", prefix.as_ref()))
255        .expect("the prefix has already been set, and can only be set once.");
256    PREFIXED_HTTP_RESPONSE_BODY_SIZE
257        .set(format!("{}_http_response_body_size", prefix.as_ref()))
258        .expect("the prefix has already been set, and can only be set once.");
259}
260
261/// A marker struct that implements the [`lifecycle::Callbacks`] trait.
262#[derive(Clone, Default)]
263pub struct Traffic<'a> {
264    filter_mode: FilterMode,
265    ignore_patterns: matchit::Router<()>,
266    allow_patterns: matchit::Router<()>,
267    group_patterns: HashMap<&'a str, matchit::Router<()>>,
268    endpoint_label: EndpointLabel,
269}
270
271#[derive(Clone, Default)]
272enum FilterMode {
273    #[default]
274    Ignore,
275    AllowOnly,
276}
277
278impl<'a> Traffic<'a> {
279    pub(crate) fn new() -> Self {
280        Traffic::default()
281    }
282
283    pub(crate) fn with_ignore_pattern(&mut self, ignore_pattern: &'a str) {
284        if !matches!(self.filter_mode, FilterMode::Ignore) {
285            self.filter_mode = FilterMode::Ignore;
286            self.allow_patterns = matchit::Router::new();
287            self.ignore_patterns = matchit::Router::new();
288        }
289        self.ignore_patterns
290            .insert(ignore_pattern, ())
291            .expect("good route specs");
292    }
293
294    pub(crate) fn with_allow_pattern(&mut self, allow_pattern: &'a str) {
295        if !matches!(self.filter_mode, FilterMode::AllowOnly) {
296            self.filter_mode = FilterMode::AllowOnly;
297            self.ignore_patterns = matchit::Router::new();
298            self.allow_patterns = matchit::Router::new();
299        }
300        self.allow_patterns
301            .insert(allow_pattern, ())
302            .expect("good route specs");
303    }
304
305    pub(crate) fn with_ignore_patterns(&mut self, ignore_patterns: &'a [&'a str]) {
306        for pattern in ignore_patterns {
307            self.with_ignore_pattern(pattern);
308        }
309    }
310
311    pub(crate) fn with_allow_patterns(&mut self, allow_patterns: &'a [&'a str]) {
312        for pattern in allow_patterns {
313            self.with_allow_pattern(pattern);
314        }
315    }
316
317    pub(crate) fn with_group_patterns_as(&mut self, group_pattern: &'a str, patterns: &'a [&str]) {
318        self.group_patterns
319            .entry(group_pattern)
320            .and_modify(|router| {
321                for pattern in patterns {
322                    router.insert(*pattern, ()).expect("good route specs");
323                }
324            })
325            .or_insert_with(|| {
326                let mut inner_router = matchit::Router::new();
327                for pattern in patterns {
328                    inner_router.insert(*pattern, ()).expect("good route specs");
329                }
330                inner_router
331            });
332    }
333
334    pub(crate) fn ignores(&self, path: &str) -> bool {
335        match self.filter_mode {
336            FilterMode::Ignore => self.ignore_patterns.at(path).is_ok(),
337            FilterMode::AllowOnly => !self.allow_patterns.at(path).is_ok(),
338        }
339    }
340
341    pub(crate) fn apply_group_pattern(&self, path: &'a str) -> &'a str {
342        self.group_patterns
343            .iter()
344            .find_map(|(&group, router)| router.at(path).ok().and(Some(group)))
345            .unwrap_or(path)
346    }
347
348    pub(crate) fn with_endpoint_label_type(&mut self, endpoint_label: EndpointLabel) {
349        self.endpoint_label = endpoint_label;
350    }
351}
352
353/// Struct used for storing and calculating information about the current request.
354#[derive(Debug, Clone)]
355pub struct MetricsData {
356    pub endpoint: String,
357    pub start: Instant,
358    pub method: &'static str,
359    pub body_size: f64,
360    // FIXME: Unclear at the moment, maybe just a simple bool could suffice here?
361    pub(crate) exact_body_size_called: Arc<AtomicBool>,
362}
363
364#[doc(hidden)]
365pub struct Pending(Gauge);
366
367impl Drop for Pending {
368    fn drop(&mut self) {
369        self.0.decrement(1);
370    }
371}
372
373// The `Pending` struct is behind an Arc to make sure we only drop it once (since we're cloning this across the lifecycle).
374type DefaultCallbackData = Option<(MetricsData, Arc<Pending>)>;
375
376/// A marker struct that implements [`lifecycle::OnBodyChunk`], so it can be used to track response body sizes.
377#[derive(Clone)]
378pub struct BodySizeRecorder;
379
380impl<B> OnBodyChunk<B> for BodySizeRecorder
381where
382    B: bytes::Buf,
383{
384    type Data = DefaultCallbackData;
385
386    #[inline]
387    fn call(&mut self, body: &B, body_size: Option<u64>, data: &mut Self::Data) {
388        let Some((metrics_data, _pending_guard)) = data else {
389            return;
390        };
391        // If the exact body size is known ahead of time, we'll just call this whole thing once.
392        if let Some(exact_size) = body_size {
393            if !metrics_data
394                .exact_body_size_called
395                .swap(true, std::sync::atomic::Ordering::Relaxed)
396            {
397                // If the body size is enormous, we lose some precision. It shouldn't matter really.
398                metrics_data.body_size = exact_size as f64;
399                body_size_histogram(metrics_data);
400            }
401        } else {
402            // Otherwise, sum all the chunks.
403            metrics_data.body_size += body.remaining() as f64;
404            body_size_histogram(metrics_data);
405        }
406    }
407}
408
409impl<T, B> OnBodyChunk<B> for Option<T>
410where
411    T: OnBodyChunk<B>,
412    B: bytes::Buf,
413{
414    type Data = T::Data;
415
416    fn call(&mut self, body: &B, body_size: Option<u64>, data: &mut Self::Data) {
417        if let Some(this) = self {
418            T::call(this, body, body_size, data);
419        }
420    }
421}
422
423fn body_size_histogram(metrics_data: &MetricsData) {
424    let labels = &[
425        ("method", metrics_data.method.to_owned()),
426        ("endpoint", metrics_data.endpoint.clone()),
427    ];
428    let response_body_size = PREFIXED_HTTP_RESPONSE_BODY_SIZE
429        .get()
430        .map_or(AXUM_HTTP_RESPONSE_BODY_SIZE, |s| s.as_str());
431    metrics::histogram!(response_body_size, labels).record(metrics_data.body_size);
432}
433
434impl<'a, FailureClass> Callbacks<FailureClass> for Traffic<'a> {
435    type Data = DefaultCallbackData;
436
437    fn prepare<B>(&mut self, request: &http::Request<B>) -> Self::Data {
438        let now = std::time::Instant::now();
439        let exact_endpoint = request.uri().path();
440        if self.ignores(exact_endpoint) {
441            return None;
442        }
443        let endpoint = match self.endpoint_label {
444            EndpointLabel::Exact => Cow::from(exact_endpoint),
445            EndpointLabel::MatchedPath => Cow::from(
446                request
447                    .extensions()
448                    .get::<MatchedPath>()
449                    .map_or(exact_endpoint, MatchedPath::as_str),
450            ),
451            EndpointLabel::MatchedPathWithFallbackFn(fallback_fn) => {
452                if let Some(mp) = request
453                    .extensions()
454                    .get::<MatchedPath>()
455                    .map(MatchedPath::as_str)
456                {
457                    Cow::from(mp)
458                } else {
459                    Cow::from(fallback_fn(exact_endpoint))
460                }
461            }
462        };
463        let endpoint = self.apply_group_pattern(&endpoint).to_owned();
464        let method = utils::as_label(request.method());
465
466        let pending = gauge!(
467            utils::requests_pending_name(),
468            &[
469                ("method", method.to_owned()),
470                ("endpoint", endpoint.clone()),
471            ]
472        );
473        pending.increment(1);
474
475        Some((
476            MetricsData {
477                endpoint,
478                start: now,
479                method,
480                body_size: 0.0,
481                exact_body_size_called: Arc::new(AtomicBool::new(false)),
482            },
483            Arc::new(Pending(pending)),
484        ))
485    }
486
487    fn on_response<B>(
488        &mut self,
489        res: &http::Response<B>,
490        _cls: ClassifiedResponse<FailureClass, ()>,
491        data: &mut Self::Data,
492    ) {
493        if let Some((data, _pending_guard)) = data {
494            let duration_seconds = data.start.elapsed().as_secs_f64();
495
496            let labels = [
497                ("method", data.method.to_string()),
498                ("status", res.status().as_u16().to_string()),
499                ("endpoint", data.endpoint.to_string()),
500            ];
501
502            let requests_total = PREFIXED_HTTP_REQUESTS_TOTAL
503                .get()
504                .map_or(AXUM_HTTP_REQUESTS_TOTAL, |s| s.as_str());
505            counter!(requests_total, &labels).increment(1);
506
507            let requests_duration = PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
508                .get()
509                .map_or(AXUM_HTTP_REQUESTS_DURATION_SECONDS, |s| s.as_str());
510            histogram!(requests_duration, &labels).record(duration_seconds);
511        }
512    }
513}
514
515/// The tower middleware layer for recording HTTP metrics.
516///
517/// Unlike [`GenericMetricLayer`], this struct __does not__ know about the metrics exporter, or the recorder. It will only emit
518/// metrics via the `metrics` crate's macros. It's entirely up to the user to set the global metrics recorder/exporter before using this.
519///
520/// You may use this if `GenericMetricLayer`'s requirements are too strict for your use case.
521#[derive(Clone)]
522pub struct BaseMetricLayer<'a> {
523    pub(crate) inner_layer: LifeCycleLayer<
524        SharedClassifier<StatusInRangeAsFailures>,
525        Traffic<'a>,
526        Option<BodySizeRecorder>,
527    >,
528}
529
530impl<'a> BaseMetricLayer<'a> {
531    /// Construct a new `BaseMetricLayer`.
532    ///
533    /// # Example
534    /// ```
535    /// use axum::{routing::get, Router};
536    /// use axum_prometheus::{AXUM_HTTP_REQUESTS_DURATION_SECONDS, utils::SECONDS_DURATION_BUCKETS, BaseMetricLayer};
537    /// use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
538    /// use std::net::SocketAddr;
539    ///
540    /// #[tokio::main]
541    /// async fn main() {
542    ///    // Initialize the recorder as you like.
543    ///     let metric_handle = PrometheusBuilder::new()
544    ///        .set_buckets_for_metric(
545    ///            Matcher::Full(AXUM_HTTP_REQUESTS_DURATION_SECONDS.to_string()),
546    ///            SECONDS_DURATION_BUCKETS,
547    ///        )
548    ///        .unwrap()
549    ///        .install_recorder()
550    ///        .unwrap();
551    ///
552    ///     let app = Router::<()>::new()
553    ///       .route("/fast", get(|| async {}))
554    ///       .route(
555    ///           "/slow",
556    ///           get(|| async {
557    ///               tokio::time::sleep(std::time::Duration::from_secs(1)).await;
558    ///           }),
559    ///       )
560    ///       // Expose the metrics somehow to the outer world.
561    ///       .route("/metrics", get(|| async move { metric_handle.render() }))
562    ///       // Only need to add this layer at the end.
563    ///       .layer(BaseMetricLayer::new());
564    ///
565    ///    // Run the server as usual:
566    ///    // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
567    ///    //     .await
568    ///    //     .unwrap();
569    ///    // axum::serve(listener, app).await.unwrap()
570    /// }
571    /// ```
572    pub fn new() -> Self {
573        let make_classifier =
574            StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
575        let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new(), None);
576        Self { inner_layer }
577    }
578
579    /// Construct a new `BaseMetricLayer` with response body size tracking enabled.
580    pub fn with_response_body_size() -> Self {
581        let mut this = Self::new();
582        this.inner_layer.on_body_chunk(Some(BodySizeRecorder));
583        this
584    }
585}
586
587impl<'a> Default for BaseMetricLayer<'a> {
588    fn default() -> Self {
589        Self::new()
590    }
591}
592
593impl<'a, S> Layer<S> for BaseMetricLayer<'a> {
594    type Service = LifeCycle<
595        S,
596        SharedClassifier<StatusInRangeAsFailures>,
597        Traffic<'a>,
598        Option<BodySizeRecorder>,
599    >;
600
601    fn layer(&self, inner: S) -> Self::Service {
602        self.inner_layer.layer(inner)
603    }
604}
605
606/// The tower middleware layer for recording http metrics with different exporters.
607pub struct GenericMetricLayer<'a, T, M> {
608    pub(crate) inner_layer: LifeCycleLayer<
609        SharedClassifier<StatusInRangeAsFailures>,
610        Traffic<'a>,
611        Option<BodySizeRecorder>,
612    >,
613    _marker: PhantomData<(T, M)>,
614}
615
616// We don't require that `T` nor `M` is `Clone`, since none of them is actually contained in this type.
617impl<'a, T, M> std::clone::Clone for GenericMetricLayer<'a, T, M> {
618    fn clone(&self) -> Self {
619        GenericMetricLayer {
620            inner_layer: self.inner_layer.clone(),
621            _marker: self._marker,
622        }
623    }
624}
625
626impl<'a, T, M> GenericMetricLayer<'a, T, M>
627where
628    M: MakeDefaultHandle<Out = T>,
629{
630    /// Create a new tower middleware that can be used to track metrics.
631    ///
632    /// By default, this __will not__ "install" the exporter which sets it as the
633    /// global recorder for all `metrics` calls.
634    /// If you're using Prometheus, here you can use [`metrics_exporter_prometheus::PrometheusBuilder`]
635    /// to build your own customized metrics exporter.
636    ///
637    /// This middleware is using the following constants for identifying different HTTP metrics:
638    ///
639    /// - [`AXUM_HTTP_REQUESTS_PENDING`]
640    /// - [`AXUM_HTTP_REQUESTS_TOTAL`]
641    /// - [`AXUM_HTTP_REQUESTS_DURATION_SECONDS`].
642    ///
643    /// In terms of setup, the most important one is [`AXUM_HTTP_REQUESTS_DURATION_SECONDS`], which is a histogram metric
644    /// used for request latency. You may set customized buckets tailored for your used case here.
645    ///
646    /// # Example
647    /// ```
648    /// use axum::{routing::get, Router};
649    /// use axum_prometheus::{AXUM_HTTP_REQUESTS_DURATION_SECONDS, utils::SECONDS_DURATION_BUCKETS, PrometheusMetricLayer};
650    /// use metrics_exporter_prometheus::{Matcher, PrometheusBuilder};
651    /// use std::net::SocketAddr;
652    ///
653    /// #[tokio::main]
654    /// async fn main() {
655    ///     let metric_layer = PrometheusMetricLayer::new();
656    ///     // This is the default if you use `PrometheusMetricLayer::pair`.
657    ///     let metric_handle = PrometheusBuilder::new()
658    ///        .set_buckets_for_metric(
659    ///            Matcher::Full(AXUM_HTTP_REQUESTS_DURATION_SECONDS.to_string()),
660    ///            SECONDS_DURATION_BUCKETS,
661    ///        )
662    ///        .unwrap()
663    ///        .install_recorder()
664    ///        .unwrap();
665    ///
666    ///     let app = Router::<()>::new()
667    ///       .route("/fast", get(|| async {}))
668    ///       .route(
669    ///           "/slow",
670    ///           get(|| async {
671    ///               tokio::time::sleep(std::time::Duration::from_secs(1)).await;
672    ///           }),
673    ///       )
674    ///       .route("/metrics", get(|| async move { metric_handle.render() }))
675    ///       .layer(metric_layer);
676    ///
677    ///    // Run the server as usual:
678    ///    // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
679    ///    //     .await
680    ///    //     .unwrap();
681    ///    // axum::serve(listener, app).await.unwrap()
682    /// }
683    /// ```
684    pub fn new() -> Self {
685        let make_classifier =
686            StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
687        let inner_layer = LifeCycleLayer::new(make_classifier, Traffic::new(), None);
688        Self {
689            inner_layer,
690            _marker: PhantomData,
691        }
692    }
693
694    pub(crate) fn from_builder(builder: MetricLayerBuilder<'a, T, M, LayerOnly>) -> Self {
695        let make_classifier =
696            StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
697        let inner_layer = if builder.enable_body_size {
698            LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder))
699        } else {
700            LifeCycleLayer::new(make_classifier, builder.traffic, None)
701        };
702        Self {
703            inner_layer,
704            _marker: PhantomData,
705        }
706    }
707
708    /// Enable tracking response body sizes.
709    pub fn enable_response_body_size(&mut self) {
710        self.inner_layer.on_body_chunk(Some(BodySizeRecorder));
711    }
712
713    /// Crate a new tower middleware and a default exporter from the provided value of the passed in argument.
714    ///
715    /// This function is useful when additional data needs to be injected into `MakeDefaultHandle::make_default_handle`.
716    ///
717    /// # Example
718    ///
719    /// ```rust,no_run
720    /// use axum_prometheus::{GenericMetricLayer, MakeDefaultHandle};
721    ///
722    /// struct Recorder { host: String }
723    ///
724    /// impl MakeDefaultHandle for Recorder {
725    ///     type Out = ();
726    ///
727    ///     fn make_default_handle(self) -> Self::Out {
728    ///         // Perform the initialization. `self` is passed in by value.
729    ///         todo!();
730    ///     }
731    /// }
732    ///
733    /// fn main() {
734    ///     let (metric_layer, metric_handle) = GenericMetricLayer::pair_from(
735    ///         Recorder { host: "0.0.0.0".to_string() }
736    ///     );
737    /// }
738    /// ```
739    pub fn pair_from(m: M) -> (Self, T) {
740        (Self::new(), M::make_default_handle(m))
741    }
742}
743
744impl<'a, T, M> GenericMetricLayer<'a, T, M>
745where
746    M: MakeDefaultHandle<Out = T> + Default,
747{
748    pub(crate) fn pair_from_builder(builder: MetricLayerBuilder<'a, T, M, Paired>) -> (Self, T) {
749        let make_classifier =
750            StatusInRangeAsFailures::new_for_client_and_server_errors().into_make_classifier();
751        let inner_layer = if builder.enable_body_size {
752            LifeCycleLayer::new(make_classifier, builder.traffic, Some(BodySizeRecorder))
753        } else {
754            LifeCycleLayer::new(make_classifier, builder.traffic, None)
755        };
756
757        (
758            Self {
759                inner_layer,
760                _marker: PhantomData,
761            },
762            builder
763                .metric_handle
764                .unwrap_or_else(|| M::make_default_handle(M::default())),
765        )
766    }
767
768    /// Crate a new tower middleware and a default global Prometheus exporter with sensible defaults.
769    ///
770    /// If used with a custom exporter that's different from Prometheus, the exporter struct
771    /// must implement `MakeDefaultHandle + Default`.
772    ///
773    /// # Example
774    /// ```
775    /// use axum::{routing::get, Router};
776    /// use axum_prometheus::PrometheusMetricLayer;
777    /// use std::net::SocketAddr;
778    ///
779    /// #[tokio::main]
780    /// async fn main() {
781    ///     let (metric_layer, metric_handle) = PrometheusMetricLayer::pair();
782    ///
783    ///     let app = Router::<()>::new()
784    ///       .route("/fast", get(|| async {}))
785    ///       .route(
786    ///           "/slow",
787    ///           get(|| async {
788    ///               tokio::time::sleep(std::time::Duration::from_secs(1)).await;
789    ///           }),
790    ///       )
791    ///       .route("/metrics", get(|| async move { metric_handle.render() }))
792    ///       .layer(metric_layer);
793    ///
794    ///    // Run the server as usual:
795    ///    // let listener = tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 3000)))
796    ///    //     .await
797    ///    //     .unwrap();
798    ///    // axum::serve(listener, app).await.unwrap()
799    /// }
800    /// ```
801    pub fn pair() -> (Self, T) {
802        (Self::new(), M::make_default_handle(M::default()))
803    }
804}
805
806impl<'a, T, M> Default for GenericMetricLayer<'a, T, M>
807where
808    M: MakeDefaultHandle<Out = T>,
809{
810    fn default() -> Self {
811        Self::new()
812    }
813}
814
815impl<'a, S, T, M> Layer<S> for GenericMetricLayer<'a, T, M> {
816    type Service = LifeCycle<
817        S,
818        SharedClassifier<StatusInRangeAsFailures>,
819        Traffic<'a>,
820        Option<BodySizeRecorder>,
821    >;
822
823    fn layer(&self, inner: S) -> Self::Service {
824        self.inner_layer.layer(inner)
825    }
826}
827
828/// The trait that allows to use a metrics exporter in `GenericMetricLayer`.
829pub trait MakeDefaultHandle {
830    /// The type of the metrics handle to return from [`MetricLayerBuilder`].
831    type Out;
832
833    /// The function that defines how to initialize a metric exporter by default.
834    ///
835    /// # Example
836    ///
837    /// ```rust, no_run
838    /// use axum_prometheus::{MakeDefaultHandle, GenericMetricLayer};
839    ///
840    /// pub struct MyHandle(pub String);
841    ///
842    /// impl MakeDefaultHandle for MyHandle {
843    ///     type Out = ();
844    ///
845    ///     fn make_default_handle(self) -> Self::Out {
846    ///        // This is where you initialize and register everything you need.
847    ///        // Notice that self is passed in by value.
848    ///     }
849    /// }
850    /// ```
851    /// and then, to use it:
852    /// ```rust,ignore
853    /// // Initialize the struct, then use `pair_from`.
854    /// let my_handle = MyHandle(String::from("localhost"));
855    /// let (layer, handle) =  GenericMetricLayer::pair_from(my_handle);
856    ///
857    /// // Or optionally if your custom struct implements `Default` too, you may call `pair`.
858    /// // That's going to use `MyHandle::default()`.
859    /// let (layer, handle) =  GenericMetricLayer::<'_, _, MyHandle>::pair();
860    /// ```
861    fn make_default_handle(self) -> Self::Out;
862}
863
864/// The default handle for the Prometheus exporter.
865#[cfg(feature = "prometheus")]
866#[derive(Clone)]
867pub struct Handle(pub PrometheusHandle);
868
869#[cfg(feature = "prometheus")]
870impl Default for Handle {
871    fn default() -> Self {
872        let recorder = PrometheusBuilder::new()
873            .set_buckets_for_metric(
874                Matcher::Full(
875                    PREFIXED_HTTP_REQUESTS_DURATION_SECONDS
876                        .get()
877                        .map_or(AXUM_HTTP_REQUESTS_DURATION_SECONDS, |s| s.as_str())
878                        .to_string(),
879                ),
880                utils::SECONDS_DURATION_BUCKETS,
881            )
882            .unwrap()
883            .build_recorder();
884        let handle = recorder.handle();
885        let recorder_handle = handle.clone();
886        tokio::spawn(async move {
887            loop {
888                tokio::time::sleep(Duration::from_secs(5)).await;
889                recorder_handle.run_upkeep();
890            }
891        });
892        metrics::set_global_recorder(recorder).expect("Failed to set global recorder");
893        Self(handle)
894    }
895}
896
897#[cfg(feature = "prometheus")]
898impl MakeDefaultHandle for Handle {
899    type Out = PrometheusHandle;
900
901    fn make_default_handle(self) -> Self::Out {
902        self.0
903    }
904}
905
906#[cfg(feature = "prometheus")]
907/// The tower middleware layer for recording http metrics with Prometheus.
908pub type PrometheusMetricLayer<'a> = GenericMetricLayer<'a, PrometheusHandle, Handle>;