mockforge_http/
metrics_middleware.rs

1//! HTTP metrics collection middleware
2//!
3//! Collects Prometheus metrics for all HTTP requests including:
4//! - Request counts by method and status
5//! - Request duration histograms
6//! - In-flight request tracking
7//! - Error counts
8
9use axum::{
10    extract::{MatchedPath, Request},
11    middleware::Next,
12    response::Response,
13};
14use mockforge_observability::get_global_registry;
15use std::time::Instant;
16use tracing::debug;
17
18/// Metrics collection middleware for HTTP requests
19///
20/// This middleware should be applied to all HTTP routes to collect comprehensive
21/// metrics for Prometheus. It tracks:
22/// - Total request counts (by method and status code)
23/// - Request duration (as histograms for percentile calculations)
24/// - In-flight requests
25/// - Error rates
26pub async fn collect_http_metrics(
27    matched_path: Option<MatchedPath>,
28    req: Request,
29    next: Next,
30) -> Response {
31    let start_time = Instant::now();
32    let method = req.method().to_string();
33    let uri_path = req.uri().path().to_string();
34    let path = matched_path.as_ref().map(|mp| mp.as_str().to_string()).unwrap_or(uri_path);
35
36    // Get metrics registry
37    let registry = get_global_registry();
38
39    // Track in-flight requests
40    registry.increment_in_flight("http");
41    debug!(
42        method = %method,
43        path = %path,
44        "Starting HTTP request metrics collection"
45    );
46
47    // Process the request
48    let response = next.run(req).await;
49
50    // Decrement in-flight requests
51    registry.decrement_in_flight("http");
52
53    // Calculate metrics
54    let duration = start_time.elapsed();
55    let duration_seconds = duration.as_secs_f64();
56    let status_code = response.status().as_u16();
57
58    // Record metrics with path information
59    registry.record_http_request_with_path(&path, &method, status_code, duration_seconds);
60
61    // Record errors separately
62    if status_code >= 400 {
63        let error_type = if status_code >= 500 {
64            "server_error"
65        } else {
66            "client_error"
67        };
68        registry.record_error("http", error_type);
69    }
70
71    debug!(
72        method = %method,
73        path = %path,
74        status = status_code,
75        duration_ms = duration.as_millis(),
76        "HTTP request metrics recorded (including path-based metrics)"
77    );
78
79    response
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use axum::{
86        body::Body,
87        http::{Request, StatusCode},
88        middleware,
89        response::IntoResponse,
90        Router,
91    };
92    use tower::ServiceExt;
93
94    async fn test_handler() -> impl IntoResponse {
95        (StatusCode::OK, "test response")
96    }
97
98    #[tokio::test]
99    async fn test_metrics_middleware_records_success() {
100        let app = Router::new()
101            .route("/test", axum::routing::get(test_handler))
102            .layer(middleware::from_fn(collect_http_metrics));
103
104        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
105
106        let response = app.oneshot(request).await.unwrap();
107        assert_eq!(response.status(), StatusCode::OK);
108    }
109
110    #[tokio::test]
111    async fn test_metrics_middleware_records_errors() {
112        async fn error_handler() -> impl IntoResponse {
113            (StatusCode::INTERNAL_SERVER_ERROR, "error")
114        }
115
116        let app = Router::new()
117            .route("/error", axum::routing::get(error_handler))
118            .layer(middleware::from_fn(collect_http_metrics));
119
120        let request = Request::builder().uri("/error").body(Body::empty()).unwrap();
121
122        let response = app.oneshot(request).await.unwrap();
123        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
124    }
125}