mockforge_http/
metrics_middleware.rs

1//! HTTP metrics collection middleware
2//!
3//! Collects Prometheus metrics for all HTTP requests including:
4//! - Request counts by method and status
5//! - Request duration histograms
6//! - In-flight request tracking
7//! - Error counts
8//! - Pillar dimension for usage tracking
9
10use axum::{
11    extract::{MatchedPath, Request},
12    middleware::Next,
13    response::Response,
14};
15use mockforge_observability::get_global_registry;
16use std::time::Instant;
17use tracing::debug;
18
19/// Determine pillar from endpoint path
20///
21/// Analyzes the request path to determine which pillar(s) the request belongs to.
22/// This enables pillar-based usage tracking in telemetry.
23fn determine_pillar_from_path(path: &str) -> &'static str {
24    let path_lower = path.to_lowercase();
25
26    // Reality pillar patterns
27    if path_lower.contains("/reality")
28        || path_lower.contains("/personas")
29        || path_lower.contains("/chaos")
30        || path_lower.contains("/fidelity")
31        || path_lower.contains("/continuum")
32    {
33        return "reality";
34    }
35
36    // Contracts pillar patterns
37    if path_lower.contains("/contracts")
38        || path_lower.contains("/validation")
39        || path_lower.contains("/drift")
40        || path_lower.contains("/schema")
41        || path_lower.contains("/sync")
42    {
43        return "contracts";
44    }
45
46    // DevX pillar patterns
47    if path_lower.contains("/sdk")
48        || path_lower.contains("/playground")
49        || path_lower.contains("/plugins")
50        || path_lower.contains("/cli")
51        || path_lower.contains("/generator")
52    {
53        return "devx";
54    }
55
56    // Cloud pillar patterns
57    if path_lower.contains("/registry")
58        || path_lower.contains("/workspace")
59        || path_lower.contains("/org")
60        || path_lower.contains("/marketplace")
61        || path_lower.contains("/collab")
62    {
63        return "cloud";
64    }
65
66    // AI pillar patterns
67    if path_lower.contains("/ai")
68        || path_lower.contains("/mockai")
69        || path_lower.contains("/voice")
70        || path_lower.contains("/llm")
71        || path_lower.contains("/studio")
72    {
73        return "ai";
74    }
75
76    // Default to unknown if no pattern matches
77    "unknown"
78}
79
80/// Metrics collection middleware for HTTP requests
81///
82/// This middleware should be applied to all HTTP routes to collect comprehensive
83/// metrics for Prometheus. It tracks:
84/// - Total request counts (by method and status code)
85/// - Request duration (as histograms for percentile calculations)
86/// - In-flight requests
87/// - Error rates
88pub async fn collect_http_metrics(
89    matched_path: Option<MatchedPath>,
90    req: Request,
91    next: Next,
92) -> Response {
93    let start_time = Instant::now();
94    let method = req.method().to_string();
95    let uri_path = req.uri().path().to_string();
96    let path = matched_path.as_ref().map(|mp| mp.as_str().to_string()).unwrap_or(uri_path);
97
98    // Get metrics registry
99    let registry = get_global_registry();
100
101    // Track in-flight requests
102    registry.increment_in_flight("http");
103    debug!(
104        method = %method,
105        path = %path,
106        "Starting HTTP request metrics collection"
107    );
108
109    // Process the request
110    let response = next.run(req).await;
111
112    // Decrement in-flight requests
113    registry.decrement_in_flight("http");
114
115    // Calculate metrics
116    let duration = start_time.elapsed();
117    let duration_seconds = duration.as_secs_f64();
118    let status_code = response.status().as_u16();
119
120    // Determine pillar from path
121    let pillar = determine_pillar_from_path(&path);
122
123    // Record metrics with pillar information
124    registry.record_http_request_with_pillar(&method, status_code, duration_seconds, pillar);
125
126    // Record errors separately with pillar
127    if status_code >= 400 {
128        let error_type = if status_code >= 500 {
129            "server_error"
130        } else {
131            "client_error"
132        };
133        registry.record_error_with_pillar("http", error_type, pillar);
134    }
135
136    debug!(
137        method = %method,
138        path = %path,
139        status = status_code,
140        duration_ms = duration.as_millis(),
141        pillar = pillar,
142        "HTTP request metrics recorded with pillar dimension"
143    );
144
145    response
146}
147
148#[cfg(test)]
149mod tests {
150    use super::*;
151    use axum::{
152        body::Body,
153        http::{Request, StatusCode},
154        middleware,
155        response::IntoResponse,
156        Router,
157    };
158    use tower::ServiceExt;
159
160    async fn test_handler() -> impl IntoResponse {
161        (StatusCode::OK, "test response")
162    }
163
164    #[tokio::test]
165    async fn test_metrics_middleware_records_success() {
166        use axum::Router;
167        let app = Router::new()
168            .route("/test", axum::routing::get(test_handler))
169            .layer(middleware::from_fn(collect_http_metrics));
170
171        let request = Request::builder().uri("/test").body(Body::empty()).unwrap();
172
173        let response = app.oneshot(request).await.unwrap();
174        assert_eq!(response.status(), StatusCode::OK);
175    }
176
177    #[tokio::test]
178    async fn test_metrics_middleware_records_errors() {
179        async fn error_handler() -> impl IntoResponse {
180            (StatusCode::INTERNAL_SERVER_ERROR, "error")
181        }
182
183        use axum::Router;
184        let app = Router::new()
185            .route("/error", axum::routing::get(error_handler))
186            .layer(middleware::from_fn(collect_http_metrics));
187
188        let request = Request::builder().uri("/error").body(Body::empty()).unwrap();
189
190        let response = app.oneshot(request).await.unwrap();
191        assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR);
192    }
193}