1use std::{sync::Arc, time::Instant};
2
3#[cfg(feature = "actix")]
4use actix_web::{
5 body::MessageBody,
6 dev::{Service, ServiceRequest, ServiceResponse, Transform},
7 http::{Method, StatusCode},
8 web::Data,
9 Error,
10};
11#[cfg(feature = "actix")]
12use actix_web_lab::middleware::{from_fn, Next};
13
14use prometheus::{
15 Encoder, HistogramOpts, HistogramVec, IntCounterVec, Opts, Registry, TextEncoder,
16};
17
18const DEFAULT_BUCKETS: [f64; 14] = [
19 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0, 10.0, 20.0, 30.0, 60.0,
20];
21
22pub struct HttpMetricsCollectorBuilder {
23 registry: Registry,
24 endpoint: Option<String>,
25 buckets: Vec<f64>,
26}
27
28impl HttpMetricsCollectorBuilder {
29 pub fn new() -> Self {
30 Self {
31 endpoint: None,
32 buckets: DEFAULT_BUCKETS.to_vec(),
33 registry: Registry::new(),
34 }
35 }
36
37 pub fn registry(mut self, registry: Registry) -> Self {
38 self.registry = registry;
39 self
40 }
41
42 pub fn buckets(mut self, buckets: &[f64]) -> Self {
43 self.buckets = buckets.to_vec();
44 self
45 }
46
47 pub fn endpoint(mut self, endpoint: &str) -> Self {
48 self.endpoint = Some(endpoint.to_string());
49 self
50 }
51
52 pub fn build(self) -> HttpMetricsCollector {
53 let http_requests_total_opts =
54 Opts::new("http_requests_total", "Total number of HTTP requests");
55
56 let label_names = ["method", "handler", "code"];
57
58 let http_requests_total =
59 IntCounterVec::new(http_requests_total_opts, &label_names).unwrap();
60
61 let http_requests_duration_seconds_opts = HistogramOpts::new(
62 "http_request_duration_seconds",
63 "HTTP request duration in seconds for all requests",
64 )
65 .buckets(self.buckets);
66
67 let http_requests_duration_seconds =
68 HistogramVec::new(http_requests_duration_seconds_opts, &label_names).unwrap();
69
70 self.registry
71 .register(Box::new(http_requests_total.clone()))
72 .unwrap();
73 self.registry
74 .register(Box::new(http_requests_duration_seconds.clone()))
75 .unwrap();
76
77 HttpMetricsCollector {
78 registry: self.registry,
79 http_request_duration_seconds: http_requests_duration_seconds,
80 http_requests_total,
81 endpoint: self.endpoint.unwrap_or("/metrics".to_string()),
82 }
83 }
84}
85
86impl Default for HttpMetricsCollectorBuilder {
87 fn default() -> Self {
88 Self::new()
89 }
90}
91
92pub struct HttpMetricsCollector {
93 registry: Registry,
94 http_requests_total: IntCounterVec,
95 http_request_duration_seconds: HistogramVec,
96 endpoint: String,
97}
98
99impl HttpMetricsCollector {
100 pub fn update_metrics(
101 &self,
102 method: &Method,
103 handler: &str,
104 code: StatusCode,
105 timestamp: Instant,
106 ) {
107 let label_values = [method.as_str(), handler, code.as_str()];
108
109 let elapsed = timestamp.elapsed();
110 let duration =
111 (elapsed.as_secs() as f64) + f64::from(elapsed.subsec_nanos()) / 1_000_000_000_f64;
112
113 self.http_request_duration_seconds
114 .with_label_values(&label_values)
115 .observe(duration);
116
117 self.http_requests_total
118 .with_label_values(&label_values)
119 .inc();
120 }
121
122 pub fn collect(&self) -> Result<String, String> {
123 let encoder = TextEncoder::new();
124 let mut buffer = vec![];
125
126 if let Err(err) = encoder.encode(&self.registry.gather(), &mut buffer) {
127 return Err(err.to_string());
128 }
129
130 match String::from_utf8(buffer) {
131 Ok(metrics) => Ok(metrics),
132 Err(_) => Err("Metrics corrupted".to_string()),
133 }
134 }
135
136 pub fn is_endpoint(&self, path: &str, method: &Method) -> bool {
137 path == self.endpoint && method == Method::GET
138 }
139}
140
141struct MetricLog {
142 collector: Arc<HttpMetricsCollector>,
143 handler: String,
144 method: Method,
145 code: StatusCode,
146 timestamp: Instant,
147}
148
149impl Drop for MetricLog {
150 fn drop(&mut self) {
151 self.collector
152 .update_metrics(&self.method, &self.handler, self.code, self.timestamp)
153 }
154}
155
156#[cfg(feature = "actix")]
157pub fn metrics<S, B>() -> impl Transform<
158 S,
159 ServiceRequest,
160 Response = ServiceResponse<impl MessageBody>,
161 Error = Error,
162 InitError = (),
163>
164where
165 S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
166 B: MessageBody + 'static,
167{
168 from_fn(move |req: ServiceRequest, next: Next<B>| {
169 let timestamp = Instant::now();
170
171 let method = req.method().clone();
172 let collector = req
173 .app_data::<Data<HttpMetricsCollector>>()
174 .unwrap()
175 .clone();
176
177 let handler = {
178 let path = req
179 .match_pattern()
180 .unwrap_or_else(|| req.path().to_string());
181
182 if req.resource_map().has_resource(&path) {
183 path
184 } else {
185 "*".to_string() }
187 };
188
189 async move {
190 let mut log = MetricLog {
191 collector: collector.clone().into_inner(),
192 method,
193 timestamp,
194 code: StatusCode::OK,
195 handler,
196 };
197
198 if collector.is_endpoint(req.path(), req.method()) {
199 Ok(req
200 .into_response(collector.collect().unwrap())
201 .map_into_right_body())
202 } else {
203 match next.call(req).await {
204 Ok(res) => {
205 let status = res.status();
206 log.code = status;
207 Ok(res.map_into_left_body())
208 }
209 Err(err) => {
210 let status = err.error_response().status();
211 log.code = status;
212 Err(err)
213 }
214 }
215 }
216 }
217 })
218}