actix_web_prom/
lib.rs

1/*!
2Prometheus instrumentation for [actix-web](https://github.com/actix/actix-web). This middleware is heavily influenced by the work in [sd2k/rocket_prometheus](https://github.com/sd2k/rocket_prometheus). We track the same default metrics and allow for adding user defined metrics.
3
4By default two metrics are tracked (this assumes the namespace `actix_web_prom`):
5
6  - `actix_web_prom_http_requests_total` (labels: endpoint, method, status): the total number
7    of HTTP requests handled by the actix HttpServer.
8
9  - `actix_web_prom_http_requests_duration_seconds` (labels: endpoint, method, status): the
10    request duration for all HTTP requests handled by the actix HttpServer.
11
12
13# Usage
14
15First add `actix-web-prom` to your `Cargo.toml`:
16
17```toml
18[dependencies]
19actix-web-prom = "0.10.0"
20```
21
22You then instantiate the prometheus middleware and pass it to `.wrap()`:
23
24```rust
25use std::collections::HashMap;
26
27use actix_web::{web, App, HttpResponse, HttpServer};
28use actix_web_prom::{PrometheusMetrics, PrometheusMetricsBuilder};
29
30async fn health() -> HttpResponse {
31    HttpResponse::Ok().finish()
32}
33
34#[actix_web::main]
35async fn main() -> std::io::Result<()> {
36    let mut labels = HashMap::new();
37    labels.insert("label1".to_string(), "value1".to_string());
38    let prometheus = PrometheusMetricsBuilder::new("api")
39        .endpoint("/metrics")
40        .const_labels(labels)
41        .build()
42        .unwrap();
43
44# if false {
45        HttpServer::new(move || {
46            App::new()
47                .wrap(prometheus.clone())
48                .service(web::resource("/health").to(health))
49        })
50        .bind("127.0.0.1:8080")?
51        .run()
52        .await?;
53# }
54    Ok(())
55}
56```
57
58Using the above as an example, a few things are worth mentioning:
59 - `api` is the metrics namespace
60 - `/metrics` will be auto exposed (GET requests only) with Content-Type header `content-type: text/plain; version=0.0.4; charset=utf-8`
61 - `Some(labels)` is used to add fixed labels to the metrics; `None` can be passed instead
62   if no additional labels are necessary.
63
64
65A call to the /metrics endpoint will expose your metrics:
66
67```shell
68$ curl http://localhost:8080/metrics
69# HELP api_http_requests_duration_seconds HTTP request duration in seconds for all requests
70# TYPE api_http_requests_duration_seconds histogram
71api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.005"} 1
72api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.01"} 1
73api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.025"} 1
74api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.05"} 1
75api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.1"} 1
76api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.25"} 1
77api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="0.5"} 1
78api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="1"} 1
79api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="2.5"} 1
80api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="5"} 1
81api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="10"} 1
82api_http_requests_duration_seconds_bucket{endpoint="/metrics",label1="value1",method="GET",status="200",le="+Inf"} 1
83api_http_requests_duration_seconds_sum{endpoint="/metrics",label1="value1",method="GET",status="200"} 0.00003
84api_http_requests_duration_seconds_count{endpoint="/metrics",label1="value1",method="GET",status="200"} 1
85# HELP api_http_requests_total Total number of HTTP requests
86# TYPE api_http_requests_total counter
87api_http_requests_total{endpoint="/metrics",label1="value1",method="GET",status="200"} 1
88```
89
90## Features
91If you enable `process` feature of this crate, default process metrics will also be collected.
92[Default process metrics](https://prometheus.io/docs/instrumenting/writing_clientlibs/#process-metrics)
93
94```shell
95# HELP process_cpu_seconds_total Total user and system CPU time spent in seconds.
96# TYPE process_cpu_seconds_total counter
97process_cpu_seconds_total 0.22
98# HELP process_max_fds Maximum number of open file descriptors.
99# TYPE process_max_fds gauge
100process_max_fds 1048576
101# HELP process_open_fds Number of open file descriptors.
102# TYPE process_open_fds gauge
103process_open_fds 78
104# HELP process_resident_memory_bytes Resident memory size in bytes.
105# TYPE process_resident_memory_bytes gauge
106process_resident_memory_bytes 17526784
107# HELP process_start_time_seconds Start time of the process since unix epoch in seconds.
108# TYPE process_start_time_seconds gauge
109process_start_time_seconds 1628105774.92
110# HELP process_virtual_memory_bytes Virtual memory size in bytes.
111# TYPE process_virtual_memory_bytes gauge
112process_virtual_memory_bytes 1893163008
113```
114
115## Custom metrics
116
117You instantiate `PrometheusMetrics` and then use its `.registry` to register your custom
118metric (in this case, we use a `IntCounterVec`).
119
120Then you can pass this counter through `.data()` to have it available within the resource
121responder.
122
123```rust
124use actix_web::{web, App, HttpResponse, HttpServer};
125use actix_web_prom::{PrometheusMetrics, PrometheusMetricsBuilder};
126use prometheus::{opts, IntCounterVec};
127
128async fn health(counter: web::Data<IntCounterVec>) -> HttpResponse {
129    counter.with_label_values(&["endpoint", "method", "status"]).inc();
130    HttpResponse::Ok().finish()
131}
132
133#[actix_web::main]
134async fn main() -> std::io::Result<()> {
135    let prometheus = PrometheusMetricsBuilder::new("api")
136        .endpoint("/metrics")
137        .build()
138        .unwrap();
139
140    let counter_opts = opts!("counter", "some random counter").namespace("api");
141    let counter = IntCounterVec::new(counter_opts, &["endpoint", "method", "status"]).unwrap();
142    prometheus
143        .registry
144        .register(Box::new(counter.clone()))
145        .unwrap();
146
147# if false {
148        HttpServer::new(move || {
149            App::new()
150                .wrap(prometheus.clone())
151                .data(counter.clone())
152                .service(web::resource("/health").to(health))
153        })
154        .bind("127.0.0.1:8080")?
155        .run()
156        .await?;
157# }
158    Ok(())
159}
160```
161
162## Custom `Registry`
163
164Some apps might have more than one `actix_web::HttpServer`.
165If that's the case, you might want to use your own registry:
166
167```rust
168use actix_web::{web, App, HttpResponse, HttpServer};
169use actix_web_prom::{PrometheusMetrics, PrometheusMetricsBuilder};
170use actix_web::rt::System;
171use prometheus::Registry;
172use std::thread;
173
174async fn public_handler() -> HttpResponse {
175    HttpResponse::Ok().body("Everyone can see it!")
176}
177
178async fn private_handler() -> HttpResponse {
179    HttpResponse::Ok().body("This can be hidden behind a firewall")
180}
181
182fn main() -> std::io::Result<()> {
183    let shared_registry = Registry::new();
184
185    let private_metrics = PrometheusMetricsBuilder::new("private_api")
186        .registry(shared_registry.clone())
187        .endpoint("/metrics")
188        .build()
189        // It is safe to unwrap when __no other app has the same namespace__
190        .unwrap();
191
192    let public_metrics = PrometheusMetricsBuilder::new("public_api")
193        .registry(shared_registry.clone())
194        // Metrics should not be available from the outside
195        // so no endpoint is registered
196        .build()
197        .unwrap();
198
199# if false {
200    let private_thread = thread::spawn(move || {
201        let mut sys = System::new();
202        let srv = HttpServer::new(move || {
203            App::new()
204                .wrap(private_metrics.clone())
205                .service(web::resource("/test").to(private_handler))
206        })
207        .bind("127.0.0.1:8081")
208        .unwrap()
209        .run();
210        sys.block_on(srv).unwrap();
211    });
212
213    let public_thread = thread::spawn(|| {
214        let mut sys = System::new();
215        let srv = HttpServer::new(move || {
216            App::new()
217                .wrap(public_metrics.clone())
218                .service(web::resource("/test").to(public_handler))
219        })
220        .bind("127.0.0.1:8082")
221        .unwrap()
222        .run();
223        sys.block_on(srv).unwrap();
224    });
225
226    private_thread.join().unwrap();
227    public_thread.join().unwrap();
228# }
229    Ok(())
230}
231
232```
233
234## Configurable routes pattern cardinality
235
236Let's say you have on your app a route to fetch posts by language and by slug `GET /posts/{language}/{slug}`.
237By default, actix-web-prom will provide metrics for the whole route with the label `endpoint` set to the pattern `/posts/{language}/{slug}`.
238This is great but you cannot differentiate metrics across languages (as there is only a limited set of them).
239Actix-web-prom can be configured to allow for more cardinality on some route params.
240
241For that you need to add a middleware to pass some [extensions data](https://blog.adamchalmers.com/what-are-extensions/), specifically the `MetricsConfig` struct that contains the list of params you want to keep cardinality on.
242
243```rust
244use actix_web::{dev::Service, web, HttpMessage, HttpResponse};
245use actix_web_prom::MetricsConfig;
246
247async fn handler() -> HttpResponse {
248    HttpResponse::Ok().finish()
249}
250
251web::resource("/posts/{language}/{slug}")
252    .wrap_fn(|req, srv| {
253        req.extensions_mut().insert::<MetricsConfig>(
254            MetricsConfig { cardinality_keep_params: vec!["language".to_string()] }
255        );
256        srv.call(req)
257    })
258    .route(web::get().to(handler));
259```
260
261See the full example `with_cardinality_on_params.rs`.
262
263## Configurable metric names
264
265If you want to rename the default metrics, you can use `ActixMetricsConfiguration` to do so.
266
267```rust
268use actix_web_prom::{PrometheusMetricsBuilder, ActixMetricsConfiguration};
269
270PrometheusMetricsBuilder::new("api")
271    .endpoint("/metrics")
272    .metrics_configuration(
273        ActixMetricsConfiguration::default()
274        .http_requests_duration_seconds_name("my_http_request_duration"),
275    )
276    .build()
277    .unwrap();
278```
279
280See full example `configuring_default_metrics.rs`.
281
282## Masking unknown paths
283
284This is useful to avoid producting lots and lots of useless metrics due to bots on the internet.
285
286What this does is transform a path that will never be found (404) into *one single
287metric*. So, if you want metrics about every single path that is hit, even if it doesn't
288exist, avoid this section altogether.
289
290```rust,no_run
291use actix_web_prom::PrometheusMetricsBuilder;
292
293PrometheusMetricsBuilder::new("api")
294    .endpoint("/metrics")
295    .mask_unmatched_patterns("UNKNOWN")
296    .build()
297    .unwrap();
298```
299
300The above will convert all `/<nonexistent-path>` into `UNKNOWN`:
301
302```text
303http_requests_duration_seconds_sum{endpoint="/favicon.ico",method="GET",status="400"} 0.000424898
304```
305
306becomes
307
308```text
309http_requests_duration_seconds_sum{endpoint="UNKNOWN",method="GET",status="400"} 0.000424898
310```
311*/
312#![deny(missing_docs)]
313
314use log::warn;
315use std::collections::{HashMap, HashSet};
316use std::future::{ready, Future, Ready};
317use std::marker::PhantomData;
318use std::pin::Pin;
319use std::sync::Arc;
320use std::task::{Context, Poll};
321use std::time::Instant;
322
323use actix_web::{
324    body::{BodySize, EitherBody, MessageBody},
325    dev::{self, Service, ServiceRequest, ServiceResponse, Transform},
326    http::{
327        header::{HeaderValue, CONTENT_TYPE},
328        Method, StatusCode, Version,
329    },
330    web::Bytes,
331    Error, HttpMessage,
332};
333use futures_core::ready;
334use pin_project_lite::pin_project;
335use prometheus::{
336    Encoder, HistogramOpts, HistogramVec, IntCounterVec, Opts, Registry, TextEncoder,
337};
338
339use regex::RegexSet;
340use strfmt::strfmt;
341
342/// MetricsConfig define middleware and config struct to change the behaviour of the metrics
343/// struct to define some particularities
344#[derive(Debug, Clone)]
345pub struct MetricsConfig {
346    /// list of params where the cardinality matters
347    pub cardinality_keep_params: Vec<String>,
348}
349
350#[derive(Debug)]
351/// Builder to create new PrometheusMetrics struct.HistogramVec
352///
353/// It allows setting optional parameters like registry, buckets, etc.
354pub struct PrometheusMetricsBuilder {
355    namespace: String,
356    endpoint: Option<String>,
357    const_labels: HashMap<String, String>,
358    registry: Registry,
359    buckets: Vec<f64>,
360    exclude: HashSet<String>,
361    exclude_regex: RegexSet,
362    exclude_status: HashSet<StatusCode>,
363    unmatched_patterns_mask: Option<String>,
364    metrics_configuration: ActixMetricsConfiguration,
365}
366
367impl PrometheusMetricsBuilder {
368    /// Create new `PrometheusMetricsBuilder`
369    ///
370    /// namespace example: "actix"
371    pub fn new(namespace: &str) -> Self {
372        Self {
373            namespace: namespace.into(),
374            endpoint: None,
375            const_labels: HashMap::new(),
376            registry: Registry::new(),
377            buckets: prometheus::DEFAULT_BUCKETS.to_vec(),
378            exclude: HashSet::new(),
379            exclude_regex: RegexSet::empty(),
380            exclude_status: HashSet::new(),
381            unmatched_patterns_mask: None,
382            metrics_configuration: ActixMetricsConfiguration::default(),
383        }
384    }
385
386    /// Set actix web endpoint
387    ///
388    /// Example: "/metrics"
389    pub fn endpoint(mut self, value: &str) -> Self {
390        self.endpoint = Some(value.into());
391        self
392    }
393
394    /// Set histogram buckets
395    pub fn buckets(mut self, value: &[f64]) -> Self {
396        self.buckets = value.to_vec();
397        self
398    }
399
400    /// Set labels to add on every metrics
401    pub fn const_labels(mut self, value: HashMap<String, String>) -> Self {
402        self.const_labels = value;
403        self
404    }
405
406    /// Set registry
407    ///
408    /// By default one is set and is internal to `PrometheusMetrics`
409    pub fn registry(mut self, value: Registry) -> Self {
410        self.registry = value;
411        self
412    }
413
414    /// Ignore and do not record metrics for specified path.
415    pub fn exclude<T: Into<String>>(mut self, path: T) -> Self {
416        self.exclude.insert(path.into());
417        self
418    }
419
420    /// Ignore and do not record metrics for paths matching the regex.
421    pub fn exclude_regex<T: Into<String>>(mut self, path: T) -> Self {
422        let mut patterns = self.exclude_regex.patterns().to_vec();
423        patterns.push(path.into());
424        self.exclude_regex = RegexSet::new(patterns).unwrap();
425        self
426    }
427
428    /// Ignore and do not record metrics for paths returning the status code.
429    pub fn exclude_status<T: Into<StatusCode>>(mut self, status: T) -> Self {
430        self.exclude_status.insert(status.into());
431        self
432    }
433
434    /// Replaces the request path with the supplied mask if no actix-web handler is matched
435    pub fn mask_unmatched_patterns<T: Into<String>>(mut self, mask: T) -> Self {
436        self.unmatched_patterns_mask = Some(mask.into());
437        self
438    }
439
440    /// Set metrics configuration
441    pub fn metrics_configuration(mut self, value: ActixMetricsConfiguration) -> Self {
442        self.metrics_configuration = value;
443        self
444    }
445
446    /// Instantiate `PrometheusMetrics` struct
447    pub fn build(self) -> Result<PrometheusMetrics, Box<dyn std::error::Error + Send + Sync>> {
448        let labels_vec = self.metrics_configuration.labels.clone().to_vec();
449        let labels = &labels_vec.iter().map(|s| s.as_str()).collect::<Vec<&str>>();
450
451        let http_requests_total_opts = Opts::new(
452            self.metrics_configuration
453                .http_requests_total_name
454                .to_owned(),
455            "Total number of HTTP requests",
456        )
457        .namespace(&self.namespace)
458        .const_labels(self.const_labels.clone());
459
460        let http_requests_total = IntCounterVec::new(http_requests_total_opts, labels)?;
461
462        let http_requests_duration_seconds_opts = HistogramOpts::new(
463            self.metrics_configuration
464                .http_requests_duration_seconds_name
465                .to_owned(),
466            "HTTP request duration in seconds for all requests",
467        )
468        .namespace(&self.namespace)
469        .buckets(self.buckets.to_vec())
470        .const_labels(self.const_labels.clone());
471
472        let http_requests_duration_seconds =
473            HistogramVec::new(http_requests_duration_seconds_opts, labels)?;
474
475        self.registry
476            .register(Box::new(http_requests_total.clone()))?;
477        self.registry
478            .register(Box::new(http_requests_duration_seconds.clone()))?;
479
480        Ok(PrometheusMetrics {
481            http_requests_total,
482            http_requests_duration_seconds,
483            registry: self.registry,
484            namespace: self.namespace,
485            endpoint: self.endpoint,
486            const_labels: self.const_labels,
487            exclude: self.exclude,
488            exclude_regex: self.exclude_regex,
489            exclude_status: self.exclude_status,
490            enable_http_version_label: self.metrics_configuration.labels.version.is_some(),
491            unmatched_patterns_mask: self.unmatched_patterns_mask,
492        })
493    }
494}
495
496#[derive(Debug, Clone)]
497///Configurations for the labels used in metrics
498pub struct LabelsConfiguration {
499    endpoint: String,
500    method: String,
501    status: String,
502    version: Option<String>,
503}
504
505impl Default for LabelsConfiguration {
506    fn default() -> Self {
507        Self {
508            endpoint: String::from("endpoint"),
509            method: String::from("method"),
510            status: String::from("status"),
511            version: None,
512        }
513    }
514}
515
516impl LabelsConfiguration {
517    fn to_vec(&self) -> Vec<String> {
518        let mut labels = vec![
519            self.endpoint.clone(),
520            self.method.clone(),
521            self.status.clone(),
522        ];
523        if let Some(version) = self.version.clone() {
524            labels.push(version);
525        }
526        labels
527    }
528
529    /// set http method label
530    pub fn method(mut self, name: &str) -> Self {
531        self.method = name.to_owned();
532        self
533    }
534
535    /// set http endpoint label
536    pub fn endpoint(mut self, name: &str) -> Self {
537        self.endpoint = name.to_owned();
538        self
539    }
540
541    /// set http status label
542    pub fn status(mut self, name: &str) -> Self {
543        self.status = name.to_owned();
544        self
545    }
546
547    /// set http version label
548    pub fn version(mut self, name: &str) -> Self {
549        self.version = Some(name.to_owned());
550        self
551    }
552}
553
554#[derive(Debug, Clone)]
555/// Configuration for the collected metrics
556///
557/// Stores individual metric configuration objects
558pub struct ActixMetricsConfiguration {
559    http_requests_total_name: String,
560    http_requests_duration_seconds_name: String,
561    labels: LabelsConfiguration,
562}
563
564impl Default for ActixMetricsConfiguration {
565    fn default() -> Self {
566        Self {
567            http_requests_total_name: String::from("http_requests_total"),
568            http_requests_duration_seconds_name: String::from("http_requests_duration_seconds"),
569            labels: LabelsConfiguration::default(),
570        }
571    }
572}
573
574impl ActixMetricsConfiguration {
575    /// Set the labels collected for the metrics
576    pub fn labels(mut self, labels: LabelsConfiguration) -> Self {
577        self.labels = labels;
578        self
579    }
580
581    /// Set name for `http_requests_total` metric
582    pub fn http_requests_total_name(mut self, name: &str) -> Self {
583        self.http_requests_total_name = name.to_owned();
584        self
585    }
586
587    /// Set name for `http_requests_duration_seconds` metric
588    pub fn http_requests_duration_seconds_name(mut self, name: &str) -> Self {
589        self.http_requests_duration_seconds_name = name.to_owned();
590        self
591    }
592}
593
594#[derive(Clone)]
595#[must_use = "must be set up as middleware for actix-web"]
596/// By default two metrics are tracked (this assumes the namespace `actix_web_prom`):
597///
598///   - `actix_web_prom_http_requests_total` (labels: endpoint, method, status): the total
599///     number of HTTP requests handled by the actix `HttpServer`.
600///
601///   - `actix_web_prom_http_requests_duration_seconds` (labels: endpoint, method,
602///     status): the request duration for all HTTP requests handled by the actix
603///     `HttpServer`.
604pub struct PrometheusMetrics {
605    pub(crate) http_requests_total: IntCounterVec,
606    pub(crate) http_requests_duration_seconds: HistogramVec,
607
608    /// exposed registry for custom prometheus metrics
609    pub registry: Registry,
610    #[allow(dead_code)]
611    pub(crate) namespace: String,
612    pub(crate) endpoint: Option<String>,
613    #[allow(dead_code)]
614    pub(crate) const_labels: HashMap<String, String>,
615
616    pub(crate) exclude: HashSet<String>,
617    pub(crate) exclude_regex: RegexSet,
618    pub(crate) exclude_status: HashSet<StatusCode>,
619    pub(crate) enable_http_version_label: bool,
620    pub(crate) unmatched_patterns_mask: Option<String>,
621}
622
623impl PrometheusMetrics {
624    fn metrics(&self) -> String {
625        let mut buffer = vec![];
626        TextEncoder::new()
627            .encode(&self.registry.gather(), &mut buffer)
628            .unwrap();
629
630        #[cfg(feature = "process")]
631        {
632            let mut process_metrics = vec![];
633            TextEncoder::new()
634                .encode(&prometheus::gather(), &mut process_metrics)
635                .unwrap();
636
637            buffer.extend_from_slice(&process_metrics);
638        }
639
640        String::from_utf8(buffer).unwrap()
641    }
642
643    fn matches(&self, path: &str, method: &Method) -> bool {
644        if self.endpoint.is_some() {
645            self.endpoint.as_ref().unwrap() == path && method == Method::GET
646        } else {
647            false
648        }
649    }
650
651    fn update_metrics(
652        &self,
653        http_version: Version,
654        mixed_pattern: &str,
655        fallback_pattern: &str,
656        method: &Method,
657        status: StatusCode,
658        clock: Instant,
659        was_path_matched: bool,
660    ) {
661        if self.exclude.contains(mixed_pattern)
662            || self.exclude_regex.is_match(mixed_pattern)
663            || self.exclude_status.contains(&status)
664        {
665            return;
666        }
667
668        // do not record mixed patterns that were considered invalid by the server
669        let final_pattern = if fallback_pattern != mixed_pattern && (status == 404 || status == 405)
670        {
671            fallback_pattern
672        } else {
673            mixed_pattern
674        };
675
676        let final_pattern = if was_path_matched {
677            final_pattern
678        } else if let Some(mask) = &self.unmatched_patterns_mask {
679            mask
680        } else {
681            final_pattern
682        };
683
684        let label_values = [
685            final_pattern,
686            method.as_str(),
687            status.as_str(),
688            Self::http_version_label(http_version),
689        ];
690        let label_values = if self.enable_http_version_label {
691            &label_values[..]
692        } else {
693            &label_values[..3]
694        };
695
696        let elapsed = clock.elapsed();
697        let duration =
698            (elapsed.as_secs() as f64) + f64::from(elapsed.subsec_nanos()) / 1_000_000_000_f64;
699        self.http_requests_duration_seconds
700            .with_label_values(label_values)
701            .observe(duration);
702
703        self.http_requests_total
704            .with_label_values(label_values)
705            .inc();
706    }
707
708    fn http_version_label(version: Version) -> &'static str {
709        match version {
710            v if v == Version::HTTP_09 => "HTTP/0.9",
711            v if v == Version::HTTP_10 => "HTTP/1.0",
712            v if v == Version::HTTP_11 => "HTTP/1.1",
713            v if v == Version::HTTP_2 => "HTTP/2.0",
714            v if v == Version::HTTP_3 => "HTTP/3.0",
715            _ => "<unrecognized>",
716        }
717    }
718}
719
720impl<S, B> Transform<S, ServiceRequest> for PrometheusMetrics
721where
722    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
723{
724    type Response = ServiceResponse<EitherBody<StreamLog<B>, StreamLog<String>>>;
725    type Error = Error;
726    type InitError = ();
727    type Transform = PrometheusMetricsMiddleware<S>;
728    type Future = Ready<Result<Self::Transform, Self::InitError>>;
729
730    fn new_transform(&self, service: S) -> Self::Future {
731        ready(Ok(PrometheusMetricsMiddleware {
732            service,
733            inner: Arc::new(self.clone()),
734        }))
735    }
736}
737
738pin_project! {
739    #[doc(hidden)]
740    pub struct LoggerResponse<S>
741        where
742        S: Service<ServiceRequest>,
743    {
744        #[pin]
745        fut: S::Future,
746        time: Instant,
747        inner: Arc<PrometheusMetrics>,
748        _t: PhantomData<()>,
749    }
750}
751
752impl<S, B> Future for LoggerResponse<S>
753where
754    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
755{
756    type Output = Result<ServiceResponse<EitherBody<StreamLog<B>, StreamLog<String>>>, Error>;
757
758    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
759        let this = self.project();
760
761        let res = match ready!(this.fut.poll(cx)) {
762            Ok(res) => res,
763            Err(e) => return Poll::Ready(Err(e)),
764        };
765
766        let time = *this.time;
767        let req = res.request();
768        let method = req.method().clone();
769        let version = req.version();
770        let was_path_matched = req.match_pattern().is_some();
771
772        // get metrics config for this specific route
773        // piece of code to allow for more cardinality
774        let params_keep_path_cardinality = match req.extensions_mut().get::<MetricsConfig>() {
775            Some(config) => config.cardinality_keep_params.clone(),
776            None => vec![],
777        };
778
779        let full_pattern = req.match_pattern();
780        let path = req.path().to_string();
781        let fallback_pattern = full_pattern.clone().unwrap_or(path.clone());
782
783        // mixed_pattern is the final path used as label value in metrics
784        let mixed_pattern = match full_pattern {
785            None => path.clone(),
786            Some(full_pattern) => {
787                let mut params: HashMap<String, String> = HashMap::new();
788
789                for (key, val) in req.match_info().iter() {
790                    if params_keep_path_cardinality.contains(&key.to_string()) {
791                        params.insert(key.to_string(), val.to_string());
792                        continue;
793                    }
794                    params.insert(key.to_string(), format!("{{{key}}}"));
795                }
796
797                if let Ok(mixed_cardinality_pattern) = strfmt(&full_pattern, &params) {
798                    mixed_cardinality_pattern
799                } else {
800                    warn!(
801                        "Cannot build mixed cardinality pattern {full_pattern}, with params {params:?}"
802                    );
803                    full_pattern
804                }
805            }
806        };
807
808        let inner = this.inner.clone();
809
810        Poll::Ready(Ok(res.map_body(move |head, body| {
811            // We short circuit the response status and body to serve the endpoint
812            // automagically. This way the user does not need to set the middleware *AND*
813            // an endpoint to serve middleware results. The user is only required to set
814            // the middleware and tell us what the endpoint should be.
815            if inner.matches(&path, &method) {
816                head.status = StatusCode::OK;
817                head.headers.insert(
818                    CONTENT_TYPE,
819                    HeaderValue::from_static("text/plain; version=0.0.4; charset=utf-8"),
820                );
821
822                EitherBody::right(StreamLog {
823                    body: inner.metrics(),
824                    size: 0,
825                    clock: time,
826                    inner,
827                    status: head.status,
828                    mixed_pattern,
829                    fallback_pattern,
830                    method,
831                    version,
832                    was_path_matched: true,
833                })
834            } else {
835                EitherBody::left(StreamLog {
836                    body,
837                    size: 0,
838                    clock: time,
839                    inner,
840                    status: head.status,
841                    mixed_pattern,
842                    fallback_pattern,
843                    method,
844                    version,
845                    was_path_matched,
846                })
847            }
848        })))
849    }
850}
851
852#[doc(hidden)]
853/// Middleware service for PrometheusMetrics
854pub struct PrometheusMetricsMiddleware<S> {
855    service: S,
856    inner: Arc<PrometheusMetrics>,
857}
858
859impl<S, B> Service<ServiceRequest> for PrometheusMetricsMiddleware<S>
860where
861    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
862{
863    type Response = ServiceResponse<EitherBody<StreamLog<B>, StreamLog<String>>>;
864    type Error = S::Error;
865    type Future = LoggerResponse<S>;
866
867    dev::forward_ready!(service);
868
869    fn call(&self, req: ServiceRequest) -> Self::Future {
870        LoggerResponse {
871            fut: self.service.call(req),
872            time: Instant::now(),
873            inner: self.inner.clone(),
874            _t: PhantomData,
875        }
876    }
877}
878
879pin_project! {
880    #[doc(hidden)]
881    pub struct StreamLog<B> {
882        #[pin]
883        body: B,
884        size: usize,
885        clock: Instant,
886        inner: Arc<PrometheusMetrics>,
887        status: StatusCode,
888        // a route pattern with some params not-filled and some params filled in by user-defined
889        mixed_pattern: String,
890        fallback_pattern: String,
891        method: Method,
892        version: Version,
893        was_path_matched: bool
894    }
895
896
897    impl<B> PinnedDrop for StreamLog<B> {
898        fn drop(this: Pin<&mut Self>) {
899            // update the metrics for this request at the very end of responding
900            this.inner
901                .update_metrics(this.version, &this.mixed_pattern, &this.fallback_pattern, &this.method, this.status, this.clock, this.was_path_matched);
902        }
903    }
904}
905
906impl<B: MessageBody> MessageBody for StreamLog<B> {
907    type Error = B::Error;
908
909    fn size(&self) -> BodySize {
910        self.body.size()
911    }
912
913    fn poll_next(
914        self: Pin<&mut Self>,
915        cx: &mut Context<'_>,
916    ) -> Poll<Option<Result<Bytes, Self::Error>>> {
917        let this = self.project();
918        match ready!(this.body.poll_next(cx)) {
919            Some(Ok(chunk)) => {
920                *this.size += chunk.len();
921                Poll::Ready(Some(Ok(chunk)))
922            }
923            Some(Err(err)) => Poll::Ready(Some(Err(err))),
924            None => Poll::Ready(None),
925        }
926    }
927}
928
929#[cfg(test)]
930mod tests {
931    use super::*;
932    use actix_web::dev::Service;
933    use actix_web::test::{call_and_read_body, call_service, init_service, read_body, TestRequest};
934    use actix_web::{web, App, HttpMessage, HttpResponse, Resource, Scope};
935
936    use prometheus::{Counter, Opts};
937
938    #[actix_web::test]
939    async fn middleware_basic() {
940        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
941            .endpoint("/metrics")
942            .build()
943            .unwrap();
944
945        let app = init_service(
946            App::new()
947                .wrap(prometheus)
948                .service(web::resource("/health_check").to(HttpResponse::Ok)),
949        )
950        .await;
951
952        let res = call_service(&app, TestRequest::with_uri("/health_check").to_request()).await;
953        assert!(res.status().is_success());
954        assert_eq!(read_body(res).await, "");
955
956        let res = call_service(&app, TestRequest::with_uri("/metrics").to_request()).await;
957        assert_eq!(
958            res.headers().get(CONTENT_TYPE).unwrap(),
959            "text/plain; version=0.0.4; charset=utf-8"
960        );
961        let body = String::from_utf8(read_body(res).await.to_vec()).unwrap();
962        assert!(&body.contains(
963            &String::from_utf8(web::Bytes::from(
964                "# HELP actix_web_prom_http_requests_duration_seconds HTTP request duration in seconds for all requests
965# TYPE actix_web_prom_http_requests_duration_seconds histogram
966actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"/health_check\",method=\"GET\",status=\"200\",le=\"0.005\"} 1
967"
968        ).to_vec()).unwrap()));
969        assert!(body.contains(
970            &String::from_utf8(
971                web::Bytes::from(
972                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
973# TYPE actix_web_prom_http_requests_total counter
974actix_web_prom_http_requests_total{endpoint=\"/health_check\",method=\"GET\",status=\"200\"} 1
975"
976                )
977                .to_vec()
978            )
979            .unwrap()
980        ));
981    }
982
983    #[actix_web::test]
984    async fn middleware_http_version() {
985        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
986            .endpoint("/metrics")
987            .metrics_configuration(
988                ActixMetricsConfiguration::default()
989                    .labels(LabelsConfiguration::default().version("version")),
990            )
991            .build()
992            .unwrap();
993
994        let app = init_service(
995            App::new()
996                .wrap(prometheus)
997                .service(web::resource("/health_check").to(HttpResponse::Ok)),
998        )
999        .await;
1000
1001        let test_cases = HashMap::from([
1002            (Version::HTTP_09, 1),
1003            (Version::HTTP_10, 2),
1004            (Version::HTTP_11, 5),
1005            (Version::HTTP_2, 7),
1006            (Version::HTTP_3, 11),
1007        ]);
1008
1009        for (http_version, repeats) in test_cases.iter() {
1010            for _ in 0..*repeats {
1011                let res = call_service(
1012                    &app,
1013                    TestRequest::with_uri("/health_check")
1014                        .version(*http_version)
1015                        .to_request(),
1016                )
1017                .await;
1018                assert!(res.status().is_success());
1019                assert_eq!(read_body(res).await, "");
1020            }
1021        }
1022
1023        let res = call_service(&app, TestRequest::with_uri("/metrics").to_request()).await;
1024        assert_eq!(
1025            res.headers().get(CONTENT_TYPE).unwrap(),
1026            "text/plain; version=0.0.4; charset=utf-8"
1027        );
1028        let body = String::from_utf8(read_body(res).await.to_vec()).unwrap();
1029        println!("Body: {}", body);
1030        for (http_version, repeats) in test_cases {
1031            assert!(&body.contains(
1032                &String::from_utf8(web::Bytes::from(
1033                    format!(
1034                        "actix_web_prom_http_requests_duration_seconds_bucket{{endpoint=\"/health_check\",method=\"GET\",status=\"200\",version=\"{}\",le=\"0.005\"}} {}
1035", PrometheusMetrics::http_version_label(http_version), repeats)
1036            ).to_vec()).unwrap()));
1037
1038            assert!(&body.contains(
1039                &String::from_utf8(web::Bytes::from(
1040                    format!(
1041                        "actix_web_prom_http_requests_total{{endpoint=\"/health_check\",method=\"GET\",status=\"200\",version=\"{}\"}} {}
1042", PrometheusMetrics::http_version_label(http_version), repeats)
1043            ).to_vec()).unwrap()));
1044        }
1045    }
1046
1047    #[actix_web::test]
1048    async fn middleware_scope() {
1049        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1050            .endpoint("/internal/metrics")
1051            .build()
1052            .unwrap();
1053
1054        let app = init_service(
1055            App::new().service(
1056                web::scope("/internal")
1057                    .wrap(prometheus)
1058                    .service(web::resource("/health_check").to(HttpResponse::Ok)),
1059            ),
1060        )
1061        .await;
1062
1063        let res = call_service(
1064            &app,
1065            TestRequest::with_uri("/internal/health_check").to_request(),
1066        )
1067        .await;
1068        assert!(res.status().is_success());
1069        assert_eq!(read_body(res).await, "");
1070
1071        let res = call_service(
1072            &app,
1073            TestRequest::with_uri("/internal/metrics").to_request(),
1074        )
1075        .await;
1076        assert_eq!(
1077            res.headers().get(CONTENT_TYPE).unwrap(),
1078            "text/plain; version=0.0.4; charset=utf-8"
1079        );
1080        let body = String::from_utf8(read_body(res).await.to_vec()).unwrap();
1081        assert!(&body.contains(
1082            &String::from_utf8(web::Bytes::from(
1083                "# HELP actix_web_prom_http_requests_duration_seconds HTTP request duration in seconds for all requests
1084# TYPE actix_web_prom_http_requests_duration_seconds histogram
1085actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"/internal/health_check\",method=\"GET\",status=\"200\",le=\"0.005\"} 1
1086"
1087        ).to_vec()).unwrap()));
1088        assert!(body.contains(
1089            &String::from_utf8(
1090                web::Bytes::from(
1091                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
1092# TYPE actix_web_prom_http_requests_total counter
1093actix_web_prom_http_requests_total{endpoint=\"/internal/health_check\",method=\"GET\",status=\"200\"} 1
1094"
1095                )
1096                .to_vec()
1097            )
1098            .unwrap()
1099        ));
1100    }
1101
1102    #[actix_web::test]
1103    async fn middleware_match_pattern() {
1104        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1105            .endpoint("/metrics")
1106            .build()
1107            .unwrap();
1108
1109        let app = init_service(
1110            App::new()
1111                .wrap(prometheus)
1112                .service(web::resource("/resource/{id}").to(HttpResponse::Ok)),
1113        )
1114        .await;
1115
1116        let res = call_service(&app, TestRequest::with_uri("/resource/123").to_request()).await;
1117        assert!(res.status().is_success());
1118        assert_eq!(read_body(res).await, "");
1119
1120        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1121        let body = String::from_utf8(res.to_vec()).unwrap();
1122        assert!(&body.contains(
1123            &String::from_utf8(web::Bytes::from(
1124                "# HELP actix_web_prom_http_requests_duration_seconds HTTP request duration in seconds for all requests
1125# TYPE actix_web_prom_http_requests_duration_seconds histogram
1126actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"/resource/{id}\",method=\"GET\",status=\"200\",le=\"0.005\"} 1
1127"
1128        ).to_vec()).unwrap()));
1129        assert!(body.contains(
1130            &String::from_utf8(
1131                web::Bytes::from(
1132                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
1133# TYPE actix_web_prom_http_requests_total counter
1134actix_web_prom_http_requests_total{endpoint=\"/resource/{id}\",method=\"GET\",status=\"200\"} 1
1135"
1136                )
1137                .to_vec()
1138            )
1139            .unwrap()
1140        ));
1141    }
1142
1143    #[actix_web::test]
1144    async fn middleware_with_mask_unmatched_pattern() {
1145        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1146            .endpoint("/metrics")
1147            .mask_unmatched_patterns("UNKNOWN")
1148            .build()
1149            .unwrap();
1150
1151        let app = init_service(
1152            App::new()
1153                .wrap(prometheus)
1154                .service(web::resource("/resource/{id}").to(HttpResponse::Ok)),
1155        )
1156        .await;
1157
1158        let res = call_service(&app, TestRequest::with_uri("/not-real").to_request()).await;
1159        assert!(res.status().is_client_error());
1160        assert_eq!(read_body(res).await, "");
1161
1162        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1163        let body = String::from_utf8(res.to_vec()).unwrap();
1164
1165        assert!(&body.contains(
1166            &String::from_utf8(web::Bytes::from(
1167                "actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"UNKNOWN\",method=\"GET\",status=\"404\",le=\"0.005\"} 1"
1168        ).to_vec()).unwrap()));
1169        assert!(body.contains(
1170            &String::from_utf8(
1171                web::Bytes::from(
1172                    "actix_web_prom_http_requests_total{endpoint=\"UNKNOWN\",method=\"GET\",status=\"404\"} 1"
1173                )
1174                .to_vec()
1175            )
1176            .unwrap()
1177        ));
1178    }
1179
1180    #[actix_web::test]
1181    async fn middleware_with_mixed_params_cardinality() {
1182        // we want to keep metrics label on the "cheap param" but not on the "expensive" param
1183        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1184            .endpoint("/metrics")
1185            .build()
1186            .unwrap();
1187
1188        let app = init_service(
1189            App::new().wrap(prometheus).service(
1190                web::resource("/resource/{cheap}/{expensive}")
1191                    .wrap_fn(|req, srv| {
1192                        req.extensions_mut().insert::<MetricsConfig>(MetricsConfig {
1193                            cardinality_keep_params: vec!["cheap".to_string()],
1194                        });
1195                        srv.call(req)
1196                    })
1197                    .to(|path: web::Path<(String, String)>| async {
1198                        let (cheap, _expensive) = path.into_inner();
1199                        if !["foo", "bar"].map(|x| x.to_string()).contains(&cheap) {
1200                            return HttpResponse::NotFound().finish();
1201                        }
1202                        HttpResponse::Ok().finish()
1203                    }),
1204            ),
1205        )
1206        .await;
1207
1208        // first probe to check basic facts
1209        let res = call_service(
1210            &app,
1211            TestRequest::with_uri("/resource/foo/12345").to_request(),
1212        )
1213        .await;
1214        assert!(res.status().is_success());
1215        assert_eq!(read_body(res).await, "");
1216
1217        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1218        let body = String::from_utf8(res.to_vec()).unwrap();
1219        println!("Body: {}", body);
1220        assert!(&body.contains(
1221            &String::from_utf8(web::Bytes::from(
1222                "actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"/resource/foo/{expensive}\",method=\"GET\",status=\"200\",le=\"0.005\"} 1"
1223        ).to_vec()).unwrap()));
1224        assert!(body.contains(
1225            &String::from_utf8(
1226                web::Bytes::from(
1227                    "actix_web_prom_http_requests_total{endpoint=\"/resource/foo/{expensive}\",method=\"GET\",status=\"200\"} 1"
1228                )
1229                .to_vec()
1230            )
1231            .unwrap()
1232        ));
1233
1234        // second probe to test 404 behavior
1235        let res = call_service(
1236            &app,
1237            TestRequest::with_uri("/resource/invalid/92945").to_request(),
1238        )
1239        .await;
1240        assert!(res.status() == 404);
1241        assert_eq!(read_body(res).await, "");
1242
1243        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1244        let body = String::from_utf8(res.to_vec()).unwrap();
1245        println!("Body: {}", body);
1246        assert!(body.contains(
1247            &String::from_utf8(
1248                web::Bytes::from(
1249                    "actix_web_prom_http_requests_total{endpoint=\"/resource/{cheap}/{expensive}\",method=\"GET\",status=\"404\"} 1"
1250                )
1251                .to_vec()
1252            )
1253            .unwrap()
1254        ));
1255    }
1256
1257    #[actix_web::test]
1258    async fn middleware_metrics_exposed_with_conflicting_pattern() {
1259        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1260            .endpoint("/metrics")
1261            .build()
1262            .unwrap();
1263
1264        let app = init_service(
1265            App::new()
1266                .wrap(prometheus)
1267                .service(web::resource("/{path}").to(HttpResponse::Ok)),
1268        )
1269        .await;
1270
1271        let res = call_service(&app, TestRequest::with_uri("/something").to_request()).await;
1272        assert!(res.status().is_success());
1273        assert_eq!(read_body(res).await, "");
1274
1275        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1276        let body = String::from_utf8(res.to_vec()).unwrap();
1277        assert!(&body.contains(
1278            &String::from_utf8(web::Bytes::from(
1279                "# HELP actix_web_prom_http_requests_duration_seconds HTTP request duration in seconds for all requests"
1280        ).to_vec()).unwrap()));
1281    }
1282
1283    #[actix_web::test]
1284    async fn middleware_basic_failure() {
1285        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1286            .endpoint("/prometheus")
1287            .build()
1288            .unwrap();
1289
1290        let app = init_service(
1291            App::new()
1292                .wrap(prometheus)
1293                .service(web::resource("/health_check").to(HttpResponse::Ok)),
1294        )
1295        .await;
1296
1297        call_service(&app, TestRequest::with_uri("/health_checkz").to_request()).await;
1298        let res = call_and_read_body(&app, TestRequest::with_uri("/prometheus").to_request()).await;
1299        assert!(String::from_utf8(res.to_vec()).unwrap().contains(
1300            &String::from_utf8(
1301                web::Bytes::from(
1302                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
1303# TYPE actix_web_prom_http_requests_total counter
1304actix_web_prom_http_requests_total{endpoint=\"/health_checkz\",method=\"GET\",status=\"404\"} 1
1305"
1306                )
1307                .to_vec()
1308            )
1309            .unwrap()
1310        ));
1311    }
1312
1313    #[actix_web::test]
1314    async fn middleware_custom_counter() {
1315        let counter_opts = Opts::new("counter", "some random counter").namespace("actix_web_prom");
1316        let counter = IntCounterVec::new(counter_opts, &["endpoint", "method", "status"]).unwrap();
1317
1318        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1319            .endpoint("/metrics")
1320            .build()
1321            .unwrap();
1322
1323        prometheus
1324            .registry
1325            .register(Box::new(counter.clone()))
1326            .unwrap();
1327
1328        let app = init_service(
1329            App::new()
1330                .wrap(prometheus)
1331                .service(web::resource("/health_check").to(HttpResponse::Ok)),
1332        )
1333        .await;
1334
1335        // Verify that 'counter' does not appear in the output before we use it
1336        call_service(&app, TestRequest::with_uri("/health_check").to_request()).await;
1337        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1338        assert!(!String::from_utf8(res.to_vec()).unwrap().contains(
1339            &String::from_utf8(
1340                web::Bytes::from(
1341                    "# HELP actix_web_prom_counter some random counter
1342# TYPE actix_web_prom_counter counter
1343actix_web_prom_counter{endpoint=\"endpoint\",method=\"method\",status=\"status\"} 1
1344"
1345                )
1346                .to_vec()
1347            )
1348            .unwrap()
1349        ));
1350
1351        // Verify that 'counter' appears after we use it
1352        counter
1353            .with_label_values(&["endpoint", "method", "status"])
1354            .inc();
1355        counter
1356            .with_label_values(&["endpoint", "method", "status"])
1357            .inc();
1358        call_service(&app, TestRequest::with_uri("/metrics").to_request()).await;
1359        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1360        assert!(String::from_utf8(res.to_vec()).unwrap().contains(
1361            &String::from_utf8(
1362                web::Bytes::from(
1363                    "# HELP actix_web_prom_counter some random counter
1364# TYPE actix_web_prom_counter counter
1365actix_web_prom_counter{endpoint=\"endpoint\",method=\"method\",status=\"status\"} 2
1366"
1367                )
1368                .to_vec()
1369            )
1370            .unwrap()
1371        ));
1372    }
1373
1374    #[actix_web::test]
1375    async fn middleware_none_endpoint() {
1376        // Init PrometheusMetrics with none URL
1377        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1378            .build()
1379            .unwrap();
1380
1381        let app = init_service(App::new().wrap(prometheus.clone()).service(
1382            web::resource("/metrics").to(|| async { HttpResponse::Ok().body("not prometheus") }),
1383        ))
1384        .await;
1385
1386        let response =
1387            call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1388
1389        // Assert app works
1390        assert_eq!(
1391            String::from_utf8(response.to_vec()).unwrap(),
1392            "not prometheus"
1393        );
1394
1395        // Assert counter counts
1396        let mut buffer = Vec::new();
1397        let encoder = TextEncoder::new();
1398        let metric_families = prometheus.registry.gather();
1399        encoder.encode(&metric_families, &mut buffer).unwrap();
1400        let output = String::from_utf8(buffer).unwrap();
1401
1402        assert!(output.contains(
1403            "actix_web_prom_http_requests_total{endpoint=\"/metrics\",method=\"GET\",status=\"200\"} 1"
1404        ));
1405    }
1406
1407    #[actix_web::test]
1408    async fn middleware_custom_registry_works() {
1409        // Init Prometheus Registry
1410        let registry = Registry::new();
1411
1412        let counter_opts = Opts::new("test_counter", "test counter help");
1413        let counter = Counter::with_opts(counter_opts).unwrap();
1414        registry.register(Box::new(counter.clone())).unwrap();
1415
1416        counter.inc_by(10_f64);
1417
1418        // Init PrometheusMetrics
1419        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1420            .registry(registry)
1421            .endpoint("/metrics")
1422            .build()
1423            .unwrap();
1424
1425        let app = init_service(
1426            App::new()
1427                .wrap(prometheus.clone())
1428                .service(web::resource("/test").to(|| async { HttpResponse::Ok().finish() })),
1429        )
1430        .await;
1431
1432        // all http counters are 0 because this is the first http request,
1433        // so we should get only 10 on test counter
1434        let response =
1435            call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1436        let body = String::from_utf8(response.to_vec()).unwrap();
1437
1438        let ten_test_counter =
1439            "# HELP test_counter test counter help\n# TYPE test_counter counter\ntest_counter 10\n";
1440        assert!(body.contains(ten_test_counter));
1441
1442        // all http counters are 1 because this is the second http request,
1443        // plus 10 on test counter
1444        let response =
1445            call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1446        let response_string = String::from_utf8(response.to_vec()).unwrap();
1447
1448        let one_http_counters = "# HELP actix_web_prom_http_requests_total Total number of HTTP requests\n# TYPE actix_web_prom_http_requests_total counter\nactix_web_prom_http_requests_total{endpoint=\"/metrics\",method=\"GET\",status=\"200\"} 1";
1449
1450        assert!(response_string.contains(ten_test_counter));
1451        assert!(response_string.contains(one_http_counters));
1452    }
1453
1454    #[actix_web::test]
1455    async fn middleware_const_labels() {
1456        let mut labels = HashMap::new();
1457        labels.insert("label1".to_string(), "value1".to_string());
1458        labels.insert("label2".to_string(), "value2".to_string());
1459        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1460            .endpoint("/metrics")
1461            .const_labels(labels)
1462            .build()
1463            .unwrap();
1464
1465        let app = init_service(
1466            App::new()
1467                .wrap(prometheus)
1468                .service(web::resource("/health_check").to(HttpResponse::Ok)),
1469        )
1470        .await;
1471
1472        let res = call_service(&app, TestRequest::with_uri("/health_check").to_request()).await;
1473        assert!(res.status().is_success());
1474        assert_eq!(read_body(res).await, "");
1475
1476        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1477        let body = String::from_utf8(res.to_vec()).unwrap();
1478        assert!(&body.contains(
1479            &String::from_utf8(web::Bytes::from(
1480                "# HELP actix_web_prom_http_requests_duration_seconds HTTP request duration in seconds for all requests
1481# TYPE actix_web_prom_http_requests_duration_seconds histogram
1482actix_web_prom_http_requests_duration_seconds_bucket{endpoint=\"/health_check\",label1=\"value1\",label2=\"value2\",method=\"GET\",status=\"200\",le=\"0.005\"} 1
1483"
1484        ).to_vec()).unwrap()));
1485        assert!(body.contains(
1486            &String::from_utf8(
1487                web::Bytes::from(
1488                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
1489# TYPE actix_web_prom_http_requests_total counter
1490actix_web_prom_http_requests_total{endpoint=\"/health_check\",label1=\"value1\",label2=\"value2\",method=\"GET\",status=\"200\"} 1
1491"
1492                )
1493                .to_vec()
1494            )
1495            .unwrap()
1496        ));
1497    }
1498
1499    #[actix_web::test]
1500    async fn middleware_metrics_configuration() {
1501        let metrics_config = ActixMetricsConfiguration::default()
1502            .http_requests_duration_seconds_name("my_http_request_duration")
1503            .http_requests_total_name("my_http_requests_total");
1504
1505        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1506            .endpoint("/metrics")
1507            .metrics_configuration(metrics_config)
1508            .build()
1509            .unwrap();
1510
1511        let app = init_service(
1512            App::new()
1513                .wrap(prometheus)
1514                .service(web::resource("/health_check").to(HttpResponse::Ok)),
1515        )
1516        .await;
1517
1518        let res = call_service(&app, TestRequest::with_uri("/health_check").to_request()).await;
1519        assert!(res.status().is_success());
1520        assert_eq!(read_body(res).await, "");
1521
1522        let res = call_and_read_body(&app, TestRequest::with_uri("/metrics").to_request()).await;
1523        let body = String::from_utf8(res.to_vec()).unwrap();
1524        assert!(&body.contains(
1525            &String::from_utf8(web::Bytes::from(
1526                "# HELP actix_web_prom_my_http_request_duration HTTP request duration in seconds for all requests
1527# TYPE actix_web_prom_my_http_request_duration histogram
1528actix_web_prom_my_http_request_duration_bucket{endpoint=\"/health_check\",method=\"GET\",status=\"200\",le=\"0.005\"} 1
1529"
1530        ).to_vec()).unwrap()));
1531        assert!(body.contains(
1532            &String::from_utf8(
1533                web::Bytes::from(
1534                    "# HELP actix_web_prom_my_http_requests_total Total number of HTTP requests
1535# TYPE actix_web_prom_my_http_requests_total counter
1536actix_web_prom_my_http_requests_total{endpoint=\"/health_check\",method=\"GET\",status=\"200\"} 1
1537"
1538                )
1539                .to_vec()
1540            )
1541            .unwrap()
1542        ));
1543    }
1544
1545    #[test]
1546    fn compat_with_non_boxed_middleware() {
1547        let _app = App::new()
1548            .wrap(PrometheusMetricsBuilder::new("").build().unwrap())
1549            .wrap(actix_web::middleware::Logger::default())
1550            .route("", web::to(|| async { "" }));
1551
1552        let _app = App::new()
1553            .wrap(actix_web::middleware::Logger::default())
1554            .wrap(PrometheusMetricsBuilder::new("").build().unwrap())
1555            .route("", web::to(|| async { "" }));
1556
1557        let _scope = Scope::new("")
1558            .wrap(PrometheusMetricsBuilder::new("").build().unwrap())
1559            .route("", web::to(|| async { "" }));
1560
1561        let _resource = Resource::new("")
1562            .wrap(PrometheusMetricsBuilder::new("").build().unwrap())
1563            .route(web::to(|| async { "" }));
1564    }
1565
1566    #[actix_web::test]
1567    async fn middleware_excludes() {
1568        let prometheus = PrometheusMetricsBuilder::new("actix_web_prom")
1569            .endpoint("/metrics")
1570            .exclude("/ping")
1571            .exclude_regex("/readyz/.*")
1572            .exclude_status(StatusCode::NOT_FOUND)
1573            .build()
1574            .unwrap();
1575
1576        let app = init_service(
1577            App::new()
1578                .wrap(prometheus)
1579                .service(web::resource("/health_check").to(HttpResponse::Ok))
1580                .service(web::resource("/ping").to(HttpResponse::Ok))
1581                .service(web::resource("/readyz/{subsystem}").to(HttpResponse::Ok)),
1582        )
1583        .await;
1584
1585        let res = call_service(&app, TestRequest::with_uri("/health_check").to_request()).await;
1586        assert!(res.status().is_success());
1587        assert_eq!(read_body(res).await, "");
1588
1589        let res = call_service(&app, TestRequest::with_uri("/ping").to_request()).await;
1590        assert!(res.status().is_success());
1591        assert_eq!(read_body(res).await, "");
1592
1593        let res = call_service(&app, TestRequest::with_uri("/readyz/database").to_request()).await;
1594        assert!(res.status().is_success());
1595        assert_eq!(read_body(res).await, "");
1596
1597        let res = call_service(&app, TestRequest::with_uri("/notfound").to_request()).await;
1598        assert!(res.status().is_client_error());
1599        assert_eq!(read_body(res).await, "");
1600
1601        let res = call_service(&app, TestRequest::with_uri("/metrics").to_request()).await;
1602        assert_eq!(
1603            res.headers().get(CONTENT_TYPE).unwrap(),
1604            "text/plain; version=0.0.4; charset=utf-8"
1605        );
1606        let body = String::from_utf8(read_body(res).await.to_vec()).unwrap();
1607        assert!(&body.contains(
1608            &String::from_utf8(
1609                web::Bytes::from(
1610                    "# HELP actix_web_prom_http_requests_total Total number of HTTP requests
1611# TYPE actix_web_prom_http_requests_total counter
1612actix_web_prom_http_requests_total{endpoint=\"/health_check\",method=\"GET\",status=\"200\"} 1
1613"
1614                )
1615                .to_vec()
1616            )
1617            .unwrap()
1618        ));
1619
1620        assert!(!&body.contains("endpoint=\"/ping\""));
1621        assert!(!&body.contains("endpoint=\"/readyz"));
1622        assert!(!body.contains("endpoint=\"/notfound"));
1623    }
1624}