1use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
6use std::{
7 collections::HashMap,
8 sync::{Arc, Mutex},
9 time::{Duration, Instant},
10};
11
12#[derive(Debug, Clone)]
14pub struct ApiMetrics {
15 inner: Arc<Mutex<ApiMetricsInner>>,
16}
17
18#[derive(Debug)]
19struct ApiMetricsInner {
20 request_counts: HashMap<String, u64>,
21 response_times: HashMap<String, Vec<Duration>>,
22 error_counts: HashMap<String, u64>,
23 active_requests: u64,
24 start_time: Instant,
25}
26
27impl ApiMetrics {
28 pub fn new() -> Self {
29 Self {
30 inner: Arc::new(Mutex::new(ApiMetricsInner {
31 request_counts: HashMap::new(),
32 response_times: HashMap::new(),
33 error_counts: HashMap::new(),
34 active_requests: 0,
35 start_time: Instant::now(),
36 })),
37 }
38 }
39
40 pub fn record_request(&self, path: &str) {
41 let mut inner = self.inner.lock().unwrap();
42 *inner.request_counts.entry(path.to_string()).or_insert(0) += 1;
43 inner.active_requests += 1;
44 }
45
46 pub fn record_response(&self, path: &str, duration: Duration, status: StatusCode) {
47 let mut inner = self.inner.lock().unwrap();
48 inner
49 .response_times
50 .entry(path.to_string())
51 .or_default()
52 .push(duration);
53
54 if status.is_client_error() || status.is_server_error() {
55 *inner.error_counts.entry(path.to_string()).or_insert(0) += 1;
56 }
57
58 inner.active_requests = inner.active_requests.saturating_sub(1);
59 }
60
61 pub fn get_metrics(&self) -> MetricsSnapshot {
62 let inner = self.inner.lock().unwrap();
63 let mut endpoint_metrics = HashMap::new();
64
65 for (path, &count) in &inner.request_counts {
66 let response_times = inner.response_times.get(path).cloned().unwrap_or_default();
67 let error_count = inner.error_counts.get(path).copied().unwrap_or(0);
68
69 let avg_response_time = if !response_times.is_empty() {
70 response_times.iter().sum::<Duration>() / response_times.len() as u32
71 } else {
72 Duration::ZERO
73 };
74
75 let p95_response_time = calculate_percentile(&response_times, 95.0);
76 let p99_response_time = calculate_percentile(&response_times, 99.0);
77
78 endpoint_metrics.insert(
79 path.clone(),
80 EndpointMetrics {
81 request_count: count,
82 error_count,
83 error_rate: if count > 0 {
84 error_count as f64 / count as f64
85 } else {
86 0.0
87 },
88 avg_response_time,
89 p95_response_time,
90 p99_response_time,
91 },
92 );
93 }
94
95 MetricsSnapshot {
96 uptime: inner.start_time.elapsed(),
97 total_requests: inner.request_counts.values().sum(),
98 active_requests: inner.active_requests,
99 endpoint_metrics,
100 }
101 }
102
103 pub fn reset(&self) {
104 let mut inner = self.inner.lock().unwrap();
105 inner.request_counts.clear();
106 inner.response_times.clear();
107 inner.error_counts.clear();
108 inner.start_time = Instant::now();
109 }
110}
111
112impl Default for ApiMetrics {
113 fn default() -> Self {
114 Self::new()
115 }
116}
117
118#[derive(Debug, Clone)]
119pub struct MetricsSnapshot {
120 pub uptime: Duration,
121 pub total_requests: u64,
122 pub active_requests: u64,
123 pub endpoint_metrics: HashMap<String, EndpointMetrics>,
124}
125
126#[derive(Debug, Clone)]
127pub struct EndpointMetrics {
128 pub request_count: u64,
129 pub error_count: u64,
130 pub error_rate: f64,
131 pub avg_response_time: Duration,
132 pub p95_response_time: Duration,
133 pub p99_response_time: Duration,
134}
135
136fn calculate_percentile(durations: &[Duration], percentile: f64) -> Duration {
138 if durations.is_empty() {
139 return Duration::ZERO;
140 }
141
142 let mut sorted = durations.to_vec();
143 sorted.sort();
144
145 let index = ((percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
146 sorted.get(index).copied().unwrap_or(Duration::ZERO)
147}
148
149pub async fn metrics_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
151 let start_time = Instant::now();
152 let path = request.uri().path().to_string();
153
154 let metrics = request
156 .extensions()
157 .get::<ApiMetrics>()
158 .cloned()
159 .unwrap_or_default();
160
161 metrics.record_request(&path);
162
163 let response = next.run(request).await;
164 let duration = start_time.elapsed();
165
166 metrics.record_response(&path, duration, response.status());
167
168 Ok(response)
169}
170
171impl MetricsSnapshot {
173 pub fn to_prometheus_format(&self) -> String {
174 let mut output = String::new();
175
176 output.push_str(&format!(
178 "# HELP auth_framework_uptime_seconds Total uptime in seconds\n\
179 # TYPE auth_framework_uptime_seconds counter\n\
180 auth_framework_uptime_seconds {}\n\n",
181 self.uptime.as_secs()
182 ));
183
184 output.push_str(&format!(
185 "# HELP auth_framework_requests_total Total number of requests\n\
186 # TYPE auth_framework_requests_total counter\n\
187 auth_framework_requests_total {}\n\n",
188 self.total_requests
189 ));
190
191 output.push_str(&format!(
192 "# HELP auth_framework_active_requests Current number of active requests\n\
193 # TYPE auth_framework_active_requests gauge\n\
194 auth_framework_active_requests {}\n\n",
195 self.active_requests
196 ));
197
198 for (endpoint, metrics) in &self.endpoint_metrics {
200 let _safe_endpoint = endpoint.replace(['/', '-'], "_");
201
202 output.push_str(&format!(
203 "auth_framework_endpoint_requests_total{{endpoint=\"{}\"}} {}\n",
204 endpoint, metrics.request_count
205 ));
206
207 output.push_str(&format!(
208 "auth_framework_endpoint_errors_total{{endpoint=\"{}\"}} {}\n",
209 endpoint, metrics.error_count
210 ));
211
212 output.push_str(&format!(
213 "auth_framework_endpoint_response_time_avg{{endpoint=\"{}\"}} {}\n",
214 endpoint,
215 metrics.avg_response_time.as_secs_f64()
216 ));
217
218 output.push_str(&format!(
219 "auth_framework_endpoint_response_time_p95{{endpoint=\"{}\"}} {}\n",
220 endpoint,
221 metrics.p95_response_time.as_secs_f64()
222 ));
223 }
224
225 output
226 }
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232
233 #[test]
234 fn test_metrics_collection() {
235 let metrics = ApiMetrics::new();
236
237 metrics.record_request("/api/login");
238 metrics.record_response("/api/login", Duration::from_millis(100), StatusCode::OK);
239
240 let snapshot = metrics.get_metrics();
241 assert_eq!(snapshot.total_requests, 1);
242 assert_eq!(snapshot.endpoint_metrics["/api/login"].request_count, 1);
243 assert_eq!(snapshot.endpoint_metrics["/api/login"].error_count, 0);
244 }
245
246 #[test]
247 fn test_error_tracking() {
248 let metrics = ApiMetrics::new();
249
250 metrics.record_request("/api/test");
251 metrics.record_response(
252 "/api/test",
253 Duration::from_millis(50),
254 StatusCode::BAD_REQUEST,
255 );
256
257 let snapshot = metrics.get_metrics();
258 assert_eq!(snapshot.endpoint_metrics["/api/test"].error_count, 1);
259 assert!(snapshot.endpoint_metrics["/api/test"].error_rate > 0.0);
260 }
261
262 #[test]
263 fn test_percentile_calculation() {
264 let durations = vec![
265 Duration::from_millis(10),
266 Duration::from_millis(20),
267 Duration::from_millis(30),
268 Duration::from_millis(40),
269 Duration::from_millis(100),
270 ];
271
272 let p95 = calculate_percentile(&durations, 95.0);
273 assert_eq!(p95, Duration::from_millis(100));
274 }
275
276 #[test]
277 fn test_prometheus_format() {
278 let metrics = ApiMetrics::new();
279 metrics.record_request("/api/test");
280 metrics.record_response("/api/test", Duration::from_millis(100), StatusCode::OK);
281
282 let snapshot = metrics.get_metrics();
283 let prometheus = snapshot.to_prometheus_format();
284
285 assert!(prometheus.contains("auth_framework_requests_total"));
286 assert!(prometheus.contains("auth_framework_active_requests"));
287 assert!(prometheus.contains("endpoint=\"/api/test\""));
288 }
289}
290
291