forge_runtime/gateway/
metrics.rs1use std::sync::Arc;
4use std::time::Instant;
5
6use axum::{extract::State, middleware::Next, response::Response};
7use forge_core::observability::{LogEntry, LogLevel, Metric, Span, SpanKind};
8
9use crate::observability::ObservabilityState;
10
11#[derive(Clone)]
13pub struct MetricsState {
14 pub observability: ObservabilityState,
16}
17
18impl MetricsState {
19 pub fn new(observability: ObservabilityState) -> Self {
21 Self { observability }
22 }
23}
24
25pub async fn metrics_middleware(
32 State(state): State<Arc<MetricsState>>,
33 req: axum::extract::Request,
34 next: Next,
35) -> Response {
36 let start = Instant::now();
37 let method = req.method().to_string();
38 let path = req.uri().path().to_string();
39
40 let response = next.run(req).await;
42
43 let duration = start.elapsed();
44 let status = response.status();
45 let status_code = status.as_u16().to_string();
46
47 let obs = state.observability.clone();
49 let method_clone = method.clone();
50 let path_clone = path.clone();
51 let status_clone = status_code.clone();
52
53 tokio::spawn(async move {
54 let mut request_metric = Metric::counter("http_requests_total", 1.0);
56 request_metric
57 .labels
58 .insert("method".to_string(), method_clone.clone());
59 request_metric
60 .labels
61 .insert("path".to_string(), path_clone.clone());
62 request_metric
63 .labels
64 .insert("status".to_string(), status_clone.clone());
65 obs.record_metric(request_metric).await;
66
67 let mut duration_metric =
69 Metric::gauge("http_request_duration_seconds", duration.as_secs_f64());
70 duration_metric
71 .labels
72 .insert("method".to_string(), method_clone.clone());
73 duration_metric
74 .labels
75 .insert("path".to_string(), path_clone.clone());
76 obs.record_metric(duration_metric).await;
77
78 let log_level = if status.is_server_error() {
80 LogLevel::Error
81 } else if status.is_client_error() {
82 LogLevel::Warn
83 } else {
84 LogLevel::Info
85 };
86 let mut log = LogEntry::new(
87 log_level,
88 format!(
89 "{} {} -> {} ({:.2}ms)",
90 method_clone,
91 path_clone,
92 status_clone,
93 duration.as_secs_f64() * 1000.0
94 ),
95 );
96 log.fields.insert(
97 "method".to_string(),
98 serde_json::Value::String(method_clone.clone()),
99 );
100 log.fields.insert(
101 "path".to_string(),
102 serde_json::Value::String(path_clone.clone()),
103 );
104 log.fields.insert(
105 "status".to_string(),
106 serde_json::Value::String(status_clone.clone()),
107 );
108 log.fields.insert(
109 "duration_ms".to_string(),
110 serde_json::Value::Number(
111 serde_json::Number::from_f64(duration.as_secs_f64() * 1000.0)
112 .unwrap_or(serde_json::Number::from(0)),
113 ),
114 );
115 obs.record_log(log).await;
116
117 let mut span = Span::new(format!("{} {}", method_clone, path_clone));
119 span.kind = SpanKind::Server;
120 span.attributes.insert(
121 "http.method".to_string(),
122 serde_json::Value::String(method_clone.clone()),
123 );
124 span.attributes.insert(
125 "http.url".to_string(),
126 serde_json::Value::String(path_clone.clone()),
127 );
128 span.attributes.insert(
129 "http.status_code".to_string(),
130 serde_json::Value::String(status_clone.clone()),
131 );
132 if status.is_server_error() {
133 span.end_error("Server error");
134 } else {
135 span.end_ok();
136 }
137 obs.record_span(span).await;
138
139 if status.is_client_error() || status.is_server_error() {
141 let mut error_metric = Metric::counter("http_errors_total", 1.0);
142 error_metric
143 .labels
144 .insert("method".to_string(), method_clone);
145 error_metric.labels.insert("path".to_string(), path_clone);
146 error_metric
147 .labels
148 .insert("status".to_string(), status_clone);
149 obs.record_metric(error_metric).await;
150 }
151 });
152
153 response
154}
155
156#[cfg(test)]
157mod tests {
158 #[allow(unused_imports)]
159 use super::*;
160
161 #[test]
162 fn test_metrics_state_new() {
163 }
166}