forge_runtime/gateway/
metrics.rs

1//! HTTP metrics middleware for the gateway.
2
3use std::sync::Arc;
4use std::time::Instant;
5
6use axum::{extract::State, middleware::Next, response::Response};
7use forge_core::observability::{LogEntry, LogLevel, Metric, Span, SpanKind};
8
9use crate::observability::ObservabilityState;
10
11/// State for metrics middleware.
12#[derive(Clone)]
13pub struct MetricsState {
14    /// Observability state for recording metrics.
15    pub observability: ObservabilityState,
16}
17
18impl MetricsState {
19    /// Create a new metrics state.
20    pub fn new(observability: ObservabilityState) -> Self {
21        Self { observability }
22    }
23}
24
25/// Metrics middleware that records HTTP request metrics.
26///
27/// Records the following metrics:
28/// - `http_requests_total`: Total number of HTTP requests (counter)
29/// - `http_request_duration_seconds`: Request duration (gauge)
30/// - `http_errors_total`: Total number of error responses (counter)
31pub async fn metrics_middleware(
32    State(state): State<Arc<MetricsState>>,
33    req: axum::extract::Request,
34    next: Next,
35) -> Response {
36    let start = Instant::now();
37    let method = req.method().to_string();
38    let path = req.uri().path().to_string();
39
40    // Execute the request
41    let response = next.run(req).await;
42
43    let duration = start.elapsed();
44    let status = response.status();
45    let status_code = status.as_u16().to_string();
46
47    // Record metrics asynchronously
48    let obs = state.observability.clone();
49    let method_clone = method.clone();
50    let path_clone = path.clone();
51    let status_clone = status_code.clone();
52
53    tokio::spawn(async move {
54        // Record request count
55        let mut request_metric = Metric::counter("http_requests_total", 1.0);
56        request_metric
57            .labels
58            .insert("method".to_string(), method_clone.clone());
59        request_metric
60            .labels
61            .insert("path".to_string(), path_clone.clone());
62        request_metric
63            .labels
64            .insert("status".to_string(), status_clone.clone());
65        obs.record_metric(request_metric).await;
66
67        // Record request duration
68        let mut duration_metric =
69            Metric::gauge("http_request_duration_seconds", duration.as_secs_f64());
70        duration_metric
71            .labels
72            .insert("method".to_string(), method_clone.clone());
73        duration_metric
74            .labels
75            .insert("path".to_string(), path_clone.clone());
76        obs.record_metric(duration_metric).await;
77
78        // Record log entry for each request
79        let log_level = if status.is_server_error() {
80            LogLevel::Error
81        } else if status.is_client_error() {
82            LogLevel::Warn
83        } else {
84            LogLevel::Info
85        };
86        let mut log = LogEntry::new(
87            log_level,
88            format!(
89                "{} {} -> {} ({:.2}ms)",
90                method_clone,
91                path_clone,
92                status_clone,
93                duration.as_secs_f64() * 1000.0
94            ),
95        );
96        log.fields.insert(
97            "method".to_string(),
98            serde_json::Value::String(method_clone.clone()),
99        );
100        log.fields.insert(
101            "path".to_string(),
102            serde_json::Value::String(path_clone.clone()),
103        );
104        log.fields.insert(
105            "status".to_string(),
106            serde_json::Value::String(status_clone.clone()),
107        );
108        log.fields.insert(
109            "duration_ms".to_string(),
110            serde_json::Value::Number(
111                serde_json::Number::from_f64(duration.as_secs_f64() * 1000.0)
112                    .unwrap_or(serde_json::Number::from(0)),
113            ),
114        );
115        obs.record_log(log).await;
116
117        // Record trace span for each request
118        let mut span = Span::new(format!("{} {}", method_clone, path_clone));
119        span.kind = SpanKind::Server;
120        span.attributes.insert(
121            "http.method".to_string(),
122            serde_json::Value::String(method_clone.clone()),
123        );
124        span.attributes.insert(
125            "http.url".to_string(),
126            serde_json::Value::String(path_clone.clone()),
127        );
128        span.attributes.insert(
129            "http.status_code".to_string(),
130            serde_json::Value::String(status_clone.clone()),
131        );
132        if status.is_server_error() {
133            span.end_error("Server error");
134        } else {
135            span.end_ok();
136        }
137        obs.record_span(span).await;
138
139        // Record errors if status >= 400
140        if status.is_client_error() || status.is_server_error() {
141            let mut error_metric = Metric::counter("http_errors_total", 1.0);
142            error_metric
143                .labels
144                .insert("method".to_string(), method_clone);
145            error_metric.labels.insert("path".to_string(), path_clone);
146            error_metric
147                .labels
148                .insert("status".to_string(), status_clone);
149            obs.record_metric(error_metric).await;
150        }
151    });
152
153    response
154}
155
156#[cfg(test)]
157mod tests {
158    #[allow(unused_imports)]
159    use super::*;
160
161    #[test]
162    fn test_metrics_state_new() {
163        // Just verify the struct can be created
164        // Full test would require database pool
165    }
166}