ipfrs_interface/
metrics_middleware.rs1use axum::{
10 body::Body,
11 extract::MatchedPath,
12 http::{Request, Response, StatusCode},
13 middleware::Next,
14 response::IntoResponse,
15};
16use http_body_util::BodyExt;
17use std::time::Instant;
18
19use crate::metrics::{
20 record_http_request, record_http_request_size, record_http_response_size,
21 HTTP_CONNECTIONS_ACTIVE, HTTP_REQUEST_DURATION_SECONDS,
22};
23
24pub async fn metrics_middleware(
26 req: Request<Body>,
27 next: Next,
28) -> Result<Response<Body>, StatusCode> {
29 let start = Instant::now();
30
31 let path = req
33 .extensions()
34 .get::<MatchedPath>()
35 .map(|p| p.as_str().to_string())
36 .unwrap_or_else(|| req.uri().path().to_string());
37 let method = req.method().to_string();
38
39 let (parts, body) = req.into_parts();
41 let body_bytes = body
42 .collect()
43 .await
44 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
45 .to_bytes();
46 let request_size = body_bytes.len();
47
48 let req = Request::from_parts(parts, Body::from(body_bytes));
50
51 HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
53
54 let response = next.run(req).await;
56
57 let status = response.status();
59 let (parts, body) = response.into_parts();
60 let body_bytes = body
61 .collect()
62 .await
63 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
64 .to_bytes();
65 let response_size = body_bytes.len();
66
67 let response = Response::from_parts(parts, Body::from(body_bytes));
69
70 let duration = start.elapsed().as_secs_f64();
72
73 HTTP_REQUEST_DURATION_SECONDS
74 .with_label_values(&[&path, &method])
75 .observe(duration);
76
77 record_http_request(&path, &method, status.as_u16());
78 record_http_request_size(&path, &method, request_size);
79 record_http_response_size(&path, &method, response_size);
80
81 HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
82
83 Ok(response)
84}
85
86pub async fn metrics_middleware_streaming(req: Request<Body>, next: Next) -> impl IntoResponse {
89 let start = Instant::now();
90
91 let path = req
93 .extensions()
94 .get::<MatchedPath>()
95 .map(|p| p.as_str().to_string())
96 .unwrap_or_else(|| req.uri().path().to_string());
97 let method = req.method().to_string();
98
99 HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
101
102 let response = next.run(req).await;
104
105 let status = response.status();
107
108 let duration = start.elapsed().as_secs_f64();
110
111 HTTP_REQUEST_DURATION_SECONDS
112 .with_label_values(&[&path, &method])
113 .observe(duration);
114
115 record_http_request(&path, &method, status.as_u16());
116
117 HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
118
119 response
120}
121
122#[cfg(test)]
123mod tests {
124 use super::*;
125 use axum::{routing::get, Router};
126 use tower::ServiceExt;
127
128 async fn test_handler() -> &'static str {
129 "Hello, World!"
130 }
131
132 #[tokio::test]
133 async fn test_metrics_middleware_streaming() {
134 let app = Router::new()
135 .route("/test", get(test_handler))
136 .layer(axum::middleware::from_fn(metrics_middleware_streaming));
137
138 let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
139
140 let response = app.oneshot(request).await.unwrap();
141
142 assert_eq!(response.status(), StatusCode::OK);
143
144 let metrics = crate::metrics::encode_metrics().unwrap();
146 assert!(metrics.contains("ipfrs_http_requests_total"));
147 assert!(metrics.contains("ipfrs_http_request_duration_seconds"));
148 }
149}