ipfrs_interface/
metrics_middleware.rs

1//! Prometheus metrics middleware for Axum
2//!
3//! This middleware automatically tracks HTTP request metrics including:
4//! - Request count by endpoint and status
5//! - Request duration
6//! - Request/response sizes
7//! - Active connections
8
9use axum::{
10    body::Body,
11    extract::MatchedPath,
12    http::{Request, Response, StatusCode},
13    middleware::Next,
14    response::IntoResponse,
15};
16use http_body_util::BodyExt;
17use std::time::Instant;
18
19use crate::metrics::{
20    record_http_request, record_http_request_size, record_http_response_size,
21    HTTP_CONNECTIONS_ACTIVE, HTTP_REQUEST_DURATION_SECONDS,
22};
23
24/// Middleware that records Prometheus metrics for HTTP requests
25pub async fn metrics_middleware(
26    req: Request<Body>,
27    next: Next,
28) -> Result<Response<Body>, StatusCode> {
29    let start = Instant::now();
30
31    // Extract path and method
32    let path = req
33        .extensions()
34        .get::<MatchedPath>()
35        .map(|p| p.as_str().to_string())
36        .unwrap_or_else(|| req.uri().path().to_string());
37    let method = req.method().to_string();
38
39    // Get request size
40    let (parts, body) = req.into_parts();
41    let body_bytes = body
42        .collect()
43        .await
44        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
45        .to_bytes();
46    let request_size = body_bytes.len();
47
48    // Reconstruct request
49    let req = Request::from_parts(parts, Body::from(body_bytes));
50
51    // Increment active connections
52    HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
53
54    // Call the next middleware/handler
55    let response = next.run(req).await;
56
57    // Get response status and size
58    let status = response.status();
59    let (parts, body) = response.into_parts();
60    let body_bytes = body
61        .collect()
62        .await
63        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
64        .to_bytes();
65    let response_size = body_bytes.len();
66
67    // Reconstruct response
68    let response = Response::from_parts(parts, Body::from(body_bytes));
69
70    // Record metrics
71    let duration = start.elapsed().as_secs_f64();
72
73    HTTP_REQUEST_DURATION_SECONDS
74        .with_label_values(&[&path, &method])
75        .observe(duration);
76
77    record_http_request(&path, &method, status.as_u16());
78    record_http_request_size(&path, &method, request_size);
79    record_http_response_size(&path, &method, response_size);
80
81    HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
82
83    Ok(response)
84}
85
86/// Lightweight metrics middleware that doesn't buffer bodies
87/// Use this for streaming endpoints to avoid memory issues
88pub async fn metrics_middleware_streaming(req: Request<Body>, next: Next) -> impl IntoResponse {
89    let start = Instant::now();
90
91    // Extract path and method
92    let path = req
93        .extensions()
94        .get::<MatchedPath>()
95        .map(|p| p.as_str().to_string())
96        .unwrap_or_else(|| req.uri().path().to_string());
97    let method = req.method().to_string();
98
99    // Increment active connections
100    HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).inc();
101
102    // Call the next middleware/handler
103    let response = next.run(req).await;
104
105    // Get response status
106    let status = response.status();
107
108    // Record metrics (without body sizes for streaming)
109    let duration = start.elapsed().as_secs_f64();
110
111    HTTP_REQUEST_DURATION_SECONDS
112        .with_label_values(&[&path, &method])
113        .observe(duration);
114
115    record_http_request(&path, &method, status.as_u16());
116
117    HTTP_CONNECTIONS_ACTIVE.with_label_values(&[&path]).dec();
118
119    response
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125    use axum::{routing::get, Router};
126    use tower::ServiceExt;
127
128    async fn test_handler() -> &'static str {
129        "Hello, World!"
130    }
131
132    #[tokio::test]
133    async fn test_metrics_middleware_streaming() {
134        let app = Router::new()
135            .route("/test", get(test_handler))
136            .layer(axum::middleware::from_fn(metrics_middleware_streaming));
137
138        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
139
140        let response = app.oneshot(request).await.unwrap();
141
142        assert_eq!(response.status(), StatusCode::OK);
143
144        // Verify metrics were recorded
145        let metrics = crate::metrics::encode_metrics().unwrap();
146        assert!(metrics.contains("ipfrs_http_requests_total"));
147        assert!(metrics.contains("ipfrs_http_request_duration_seconds"));
148    }
149}