dcl_http_prom_metrics/
lib.rs

1use std::{sync::Arc, time::Instant};
2
3#[cfg(feature = "actix")]
4use actix_web::{
5    body::MessageBody,
6    dev::{Service, ServiceRequest, ServiceResponse, Transform},
7    http::{Method, StatusCode},
8    web::Data,
9    Error,
10};
11#[cfg(feature = "actix")]
12use actix_web_lab::middleware::{from_fn, Next};
13
14use prometheus::{
15    Encoder, HistogramOpts, HistogramVec, IntCounterVec, Opts, Registry, TextEncoder,
16};
17
18const DEFAULT_BUCKETS: [f64; 14] = [
19    0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0,
20];
21
22pub struct HttpMetricsCollectorBuilder {
23    registry: Registry,
24    endpoint: Option<String>,
25    buckets: Vec<f64>,
26}
27
28impl HttpMetricsCollectorBuilder {
29    pub fn new() -> Self {
30        Self {
31            endpoint: None,
32            buckets: DEFAULT_BUCKETS.to_vec(),
33            registry: Registry::new(),
34        }
35    }
36
37    pub fn registry(mut self, registry: Registry) -> Self {
38        self.registry = registry;
39        self
40    }
41
42    pub fn buckets(mut self, buckets: &[f64]) -> Self {
43        self.buckets = buckets.to_vec();
44        self
45    }
46
47    pub fn endpoint(mut self, endpoint: &str) -> Self {
48        self.endpoint = Some(endpoint.to_string());
49        self
50    }
51
52    pub fn build(self) -> HttpMetricsCollector {
53        let http_requests_total_opts =
54            Opts::new("http_requests_total", "Total number of HTTP requests");
55
56        let label_names = ["method", "handler", "code"];
57
58        let http_requests_total =
59            IntCounterVec::new(http_requests_total_opts, &label_names).unwrap();
60
61        let http_requests_duration_seconds_opts = HistogramOpts::new(
62            "http_request_duration_seconds",
63            "HTTP request duration in seconds for all requests",
64        )
65        .buckets(self.buckets);
66
67        let http_requests_duration_seconds =
68            HistogramVec::new(http_requests_duration_seconds_opts, &label_names).unwrap();
69
70        self.registry
71            .register(Box::new(http_requests_total.clone()))
72            .unwrap();
73        self.registry
74            .register(Box::new(http_requests_duration_seconds.clone()))
75            .unwrap();
76
77        HttpMetricsCollector {
78            registry: self.registry,
79            http_request_duration_seconds: http_requests_duration_seconds,
80            http_requests_total,
81            endpoint: self.endpoint.unwrap_or("/metrics".to_string()),
82        }
83    }
84}
85
86impl Default for HttpMetricsCollectorBuilder {
87    fn default() -> Self {
88        Self::new()
89    }
90}
91
92pub struct HttpMetricsCollector {
93    registry: Registry,
94    http_requests_total: IntCounterVec,
95    http_request_duration_seconds: HistogramVec,
96    endpoint: String,
97}
98
99impl HttpMetricsCollector {
100    pub fn update_metrics(
101        &self,
102        method: &Method,
103        handler: &str,
104        code: StatusCode,
105        timestamp: Instant,
106    ) {
107        let label_values = [method.as_str(), handler, code.as_str()];
108
109        let elapsed = timestamp.elapsed();
110        let duration =
111            (elapsed.as_secs() as f64) + f64::from(elapsed.subsec_nanos()) / 1_000_000_000_f64;
112
113        self.http_request_duration_seconds
114            .with_label_values(&label_values)
115            .observe(duration);
116
117        self.http_requests_total
118            .with_label_values(&label_values)
119            .inc();
120    }
121
122    pub fn collect(&self) -> Result<String, String> {
123        let encoder = TextEncoder::new();
124        let mut buffer = vec![];
125
126        if let Err(err) = encoder.encode(&self.registry.gather(), &mut buffer) {
127            return Err(err.to_string());
128        }
129
130        match String::from_utf8(buffer) {
131            Ok(metrics) => Ok(metrics),
132            Err(_) => Err("Metrics corrupted".to_string()),
133        }
134    }
135
136    pub fn is_endpoint(&self, path: &str, method: &Method) -> bool {
137        path == self.endpoint && method == Method::GET
138    }
139}
140
141struct MetricLog {
142    collector: Arc<HttpMetricsCollector>,
143    handler: String,
144    method: Method,
145    code: StatusCode,
146    timestamp: Instant,
147}
148
149impl Drop for MetricLog {
150    fn drop(&mut self) {
151        self.collector
152            .update_metrics(&self.method, &self.handler, self.code, self.timestamp)
153    }
154}
155
156#[cfg(feature = "actix")]
157pub fn metrics<S, B>() -> impl Transform<
158    S,
159    ServiceRequest,
160    Response = ServiceResponse<impl MessageBody>,
161    Error = Error,
162    InitError = (),
163>
164where
165    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
166    B: MessageBody + 'static,
167{
168    from_fn(move |req: ServiceRequest, next: Next<B>| {
169        let timestamp = Instant::now();
170
171        let method = req.method().clone();
172        let collector = req
173            .app_data::<Data<HttpMetricsCollector>>()
174            .unwrap()
175            .clone();
176
177        let handler = {
178            let path = req
179                .match_pattern()
180                .unwrap_or_else(|| req.path().to_string());
181
182            if req.resource_map().has_resource(&path) {
183                path
184            } else {
185                "*".to_string() // 404
186            }
187        };
188
189        async move {
190            let mut log = MetricLog {
191                collector: collector.clone().into_inner(),
192                method,
193                timestamp,
194                code: StatusCode::OK,
195                handler,
196            };
197
198            if collector.is_endpoint(req.path(), req.method()) {
199                Ok(req
200                    .into_response(collector.collect().unwrap())
201                    .map_into_right_body())
202            } else {
203                match next.call(req).await {
204                    Ok(res) => {
205                        let status = res.status();
206                        log.code = status;
207                        Ok(res.map_into_left_body())
208                    }
209                    Err(err) => {
210                        let status = err.error_response().status();
211                        log.code = status;
212                        Err(err)
213                    }
214                }
215            }
216        }
217    })
218}