1use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
6use std::{
7 collections::HashMap,
8 sync::{Arc, Mutex},
9 time::{Duration, Instant},
10};
11
12#[derive(Debug, Clone)]
31pub struct ApiMetrics {
32 inner: Arc<Mutex<ApiMetricsInner>>,
33}
34
35#[derive(Debug)]
36struct ApiMetricsInner {
37 request_counts: HashMap<String, u64>,
38 response_times: HashMap<String, Vec<Duration>>,
39 error_counts: HashMap<String, u64>,
40 active_requests: u64,
41 start_time: Instant,
42}
43
44impl ApiMetrics {
45 pub fn new() -> Self {
54 Self {
55 inner: Arc::new(Mutex::new(ApiMetricsInner {
56 request_counts: HashMap::new(),
57 response_times: HashMap::new(),
58 error_counts: HashMap::new(),
59 active_requests: 0,
60 start_time: Instant::now(),
61 })),
62 }
63 }
64
65 pub fn record_request(&self, path: &str) {
76 let Ok(mut inner) = self.inner.lock() else {
77 return;
78 };
79 *inner.request_counts.entry(path.to_string()).or_insert(0) += 1;
80 inner.active_requests += 1;
81 }
82
83 pub fn record_response(&self, path: &str, duration: Duration, status: StatusCode) {
99 let Ok(mut inner) = self.inner.lock() else {
100 return;
101 };
102 inner
103 .response_times
104 .entry(path.to_string())
105 .or_default()
106 .push(duration);
107
108 if status.is_client_error() || status.is_server_error() {
109 *inner.error_counts.entry(path.to_string()).or_insert(0) += 1;
110 }
111
112 inner.active_requests = inner.active_requests.saturating_sub(1);
113 }
114
115 pub fn get_metrics(&self) -> MetricsSnapshot {
129 let Ok(inner) = self.inner.lock() else {
130 return MetricsSnapshot {
131 uptime: Duration::ZERO,
132 total_requests: 0,
133 active_requests: 0,
134 endpoint_metrics: HashMap::new(),
135 };
136 };
137 let mut endpoint_metrics = HashMap::new();
138
139 for (path, &count) in &inner.request_counts {
140 let response_times = inner.response_times.get(path).cloned().unwrap_or_default();
141 let error_count = inner.error_counts.get(path).copied().unwrap_or(0);
142
143 let avg_response_time = if !response_times.is_empty() {
144 response_times.iter().sum::<Duration>() / response_times.len() as u32
145 } else {
146 Duration::ZERO
147 };
148
149 let p95_response_time = calculate_percentile(&response_times, 95.0);
150 let p99_response_time = calculate_percentile(&response_times, 99.0);
151
152 endpoint_metrics.insert(
153 path.clone(),
154 EndpointMetrics {
155 request_count: count,
156 error_count,
157 error_rate: if count > 0 {
158 error_count as f64 / count as f64
159 } else {
160 0.0
161 },
162 avg_response_time,
163 p95_response_time,
164 p99_response_time,
165 },
166 );
167 }
168
169 MetricsSnapshot {
170 uptime: inner.start_time.elapsed(),
171 total_requests: inner.request_counts.values().sum(),
172 active_requests: inner.active_requests,
173 endpoint_metrics,
174 }
175 }
176
177 pub fn reset(&self) {
195 let Ok(mut inner) = self.inner.lock() else {
196 return;
197 };
198 inner.request_counts.clear();
199 inner.response_times.clear();
200 inner.error_counts.clear();
201 inner.start_time = Instant::now();
202 }
203}
204
205impl Default for ApiMetrics {
206 fn default() -> Self {
207 Self::new()
208 }
209}
210
211#[derive(Debug, Clone)]
225pub struct MetricsSnapshot {
226 pub uptime: Duration,
227 pub total_requests: u64,
228 pub active_requests: u64,
229 pub endpoint_metrics: HashMap<String, EndpointMetrics>,
230}
231
232#[derive(Debug, Clone)]
249pub struct EndpointMetrics {
250 pub request_count: u64,
251 pub error_count: u64,
252 pub error_rate: f64,
253 pub avg_response_time: Duration,
254 pub p95_response_time: Duration,
255 pub p99_response_time: Duration,
256}
257
258fn calculate_percentile(durations: &[Duration], percentile: f64) -> Duration {
260 if durations.is_empty() {
261 return Duration::ZERO;
262 }
263
264 let mut sorted = durations.to_vec();
265 sorted.sort();
266
267 let index = ((percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
268 sorted.get(index).copied().unwrap_or(Duration::ZERO)
269}
270
271pub async fn metrics_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
284 let start_time = Instant::now();
285 let path = request.uri().path().to_string();
286
287 let metrics = request
289 .extensions()
290 .get::<ApiMetrics>()
291 .cloned()
292 .unwrap_or_default();
293
294 metrics.record_request(&path);
295
296 let response = next.run(request).await;
297 let duration = start_time.elapsed();
298
299 metrics.record_response(&path, duration, response.status());
300
301 Ok(response)
302}
303
304impl MetricsSnapshot {
305 pub fn to_prometheus_format(&self) -> String {
315 let mut output = String::new();
316
317 output.push_str(&format!(
319 "# HELP auth_framework_uptime_seconds Total uptime in seconds\n\
320 # TYPE auth_framework_uptime_seconds counter\n\
321 auth_framework_uptime_seconds {}\n\n",
322 self.uptime.as_secs()
323 ));
324
325 output.push_str(&format!(
326 "# HELP auth_framework_requests_total Total number of requests\n\
327 # TYPE auth_framework_requests_total counter\n\
328 auth_framework_requests_total {}\n\n",
329 self.total_requests
330 ));
331
332 output.push_str(&format!(
333 "# HELP auth_framework_active_requests Current number of active requests\n\
334 # TYPE auth_framework_active_requests gauge\n\
335 auth_framework_active_requests {}\n\n",
336 self.active_requests
337 ));
338
339 for (endpoint, metrics) in &self.endpoint_metrics {
341 let _safe_endpoint = endpoint.replace(['/', '-'], "_");
342
343 output.push_str(&format!(
344 "auth_framework_endpoint_requests_total{{endpoint=\"{}\"}} {}\n",
345 endpoint, metrics.request_count
346 ));
347
348 output.push_str(&format!(
349 "auth_framework_endpoint_errors_total{{endpoint=\"{}\"}} {}\n",
350 endpoint, metrics.error_count
351 ));
352
353 output.push_str(&format!(
354 "auth_framework_endpoint_response_time_avg{{endpoint=\"{}\"}} {}\n",
355 endpoint,
356 metrics.avg_response_time.as_secs_f64()
357 ));
358
359 output.push_str(&format!(
360 "auth_framework_endpoint_response_time_p95{{endpoint=\"{}\"}} {}\n",
361 endpoint,
362 metrics.p95_response_time.as_secs_f64()
363 ));
364 }
365
366 output
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373
374 #[test]
375 fn test_metrics_collection() {
376 let metrics = ApiMetrics::new();
377
378 metrics.record_request("/api/login");
379 metrics.record_response("/api/login", Duration::from_millis(100), StatusCode::OK);
380
381 let snapshot = metrics.get_metrics();
382 assert_eq!(snapshot.total_requests, 1);
383 assert_eq!(snapshot.endpoint_metrics["/api/login"].request_count, 1);
384 assert_eq!(snapshot.endpoint_metrics["/api/login"].error_count, 0);
385 }
386
387 #[test]
388 fn test_error_tracking() {
389 let metrics = ApiMetrics::new();
390
391 metrics.record_request("/api/test");
392 metrics.record_response(
393 "/api/test",
394 Duration::from_millis(50),
395 StatusCode::BAD_REQUEST,
396 );
397
398 let snapshot = metrics.get_metrics();
399 assert_eq!(snapshot.endpoint_metrics["/api/test"].error_count, 1);
400 assert!(snapshot.endpoint_metrics["/api/test"].error_rate > 0.0);
401 }
402
403 #[test]
404 fn test_percentile_calculation() {
405 let durations = vec![
406 Duration::from_millis(10),
407 Duration::from_millis(20),
408 Duration::from_millis(30),
409 Duration::from_millis(40),
410 Duration::from_millis(100),
411 ];
412
413 let p95 = calculate_percentile(&durations, 95.0);
414 assert_eq!(p95, Duration::from_millis(100));
415 }
416
417 #[test]
418 fn test_prometheus_format() {
419 let metrics = ApiMetrics::new();
420 metrics.record_request("/api/test");
421 metrics.record_response("/api/test", Duration::from_millis(100), StatusCode::OK);
422
423 let snapshot = metrics.get_metrics();
424 let prometheus = snapshot.to_prometheus_format();
425
426 assert!(prometheus.contains("auth_framework_requests_total"));
427 assert!(prometheus.contains("auth_framework_active_requests"));
428 assert!(prometheus.contains("endpoint=\"/api/test\""));
429 }
430}