auth_framework/api/
metrics.rs

1//! API Metrics and Observability
2//!
3//! Provides comprehensive metrics collection for API endpoints
4
5use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
6use std::{
7    collections::HashMap,
8    sync::{Arc, Mutex},
9    time::{Duration, Instant},
10};
11
12/// Metrics collector for API endpoints
13#[derive(Debug, Clone)]
14pub struct ApiMetrics {
15    inner: Arc<Mutex<ApiMetricsInner>>,
16}
17
18#[derive(Debug)]
19struct ApiMetricsInner {
20    request_counts: HashMap<String, u64>,
21    response_times: HashMap<String, Vec<Duration>>,
22    error_counts: HashMap<String, u64>,
23    active_requests: u64,
24    start_time: Instant,
25}
26
27impl ApiMetrics {
28    pub fn new() -> Self {
29        Self {
30            inner: Arc::new(Mutex::new(ApiMetricsInner {
31                request_counts: HashMap::new(),
32                response_times: HashMap::new(),
33                error_counts: HashMap::new(),
34                active_requests: 0,
35                start_time: Instant::now(),
36            })),
37        }
38    }
39
40    pub fn record_request(&self, path: &str) {
41        let mut inner = self.inner.lock().unwrap();
42        *inner.request_counts.entry(path.to_string()).or_insert(0) += 1;
43        inner.active_requests += 1;
44    }
45
46    pub fn record_response(&self, path: &str, duration: Duration, status: StatusCode) {
47        let mut inner = self.inner.lock().unwrap();
48        inner
49            .response_times
50            .entry(path.to_string())
51            .or_default()
52            .push(duration);
53
54        if status.is_client_error() || status.is_server_error() {
55            *inner.error_counts.entry(path.to_string()).or_insert(0) += 1;
56        }
57
58        inner.active_requests = inner.active_requests.saturating_sub(1);
59    }
60
61    pub fn get_metrics(&self) -> MetricsSnapshot {
62        let inner = self.inner.lock().unwrap();
63        let mut endpoint_metrics = HashMap::new();
64
65        for (path, &count) in &inner.request_counts {
66            let response_times = inner.response_times.get(path).cloned().unwrap_or_default();
67            let error_count = inner.error_counts.get(path).copied().unwrap_or(0);
68
69            let avg_response_time = if !response_times.is_empty() {
70                response_times.iter().sum::<Duration>() / response_times.len() as u32
71            } else {
72                Duration::ZERO
73            };
74
75            let p95_response_time = calculate_percentile(&response_times, 95.0);
76            let p99_response_time = calculate_percentile(&response_times, 99.0);
77
78            endpoint_metrics.insert(
79                path.clone(),
80                EndpointMetrics {
81                    request_count: count,
82                    error_count,
83                    error_rate: if count > 0 {
84                        error_count as f64 / count as f64
85                    } else {
86                        0.0
87                    },
88                    avg_response_time,
89                    p95_response_time,
90                    p99_response_time,
91                },
92            );
93        }
94
95        MetricsSnapshot {
96            uptime: inner.start_time.elapsed(),
97            total_requests: inner.request_counts.values().sum(),
98            active_requests: inner.active_requests,
99            endpoint_metrics,
100        }
101    }
102
103    pub fn reset(&self) {
104        let mut inner = self.inner.lock().unwrap();
105        inner.request_counts.clear();
106        inner.response_times.clear();
107        inner.error_counts.clear();
108        inner.start_time = Instant::now();
109    }
110}
111
112impl Default for ApiMetrics {
113    fn default() -> Self {
114        Self::new()
115    }
116}
117
118#[derive(Debug, Clone)]
119pub struct MetricsSnapshot {
120    pub uptime: Duration,
121    pub total_requests: u64,
122    pub active_requests: u64,
123    pub endpoint_metrics: HashMap<String, EndpointMetrics>,
124}
125
126#[derive(Debug, Clone)]
127pub struct EndpointMetrics {
128    pub request_count: u64,
129    pub error_count: u64,
130    pub error_rate: f64,
131    pub avg_response_time: Duration,
132    pub p95_response_time: Duration,
133    pub p99_response_time: Duration,
134}
135
136/// Calculate percentile from a sorted list of durations
137fn calculate_percentile(durations: &[Duration], percentile: f64) -> Duration {
138    if durations.is_empty() {
139        return Duration::ZERO;
140    }
141
142    let mut sorted = durations.to_vec();
143    sorted.sort();
144
145    let index = ((percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
146    sorted.get(index).copied().unwrap_or(Duration::ZERO)
147}
148
149/// Middleware for collecting API metrics
150pub async fn metrics_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
151    let start_time = Instant::now();
152    let path = request.uri().path().to_string();
153
154    // Get metrics collector from extensions or create new one
155    let metrics = request
156        .extensions()
157        .get::<ApiMetrics>()
158        .cloned()
159        .unwrap_or_default();
160
161    metrics.record_request(&path);
162
163    let response = next.run(request).await;
164    let duration = start_time.elapsed();
165
166    metrics.record_response(&path, duration, response.status());
167
168    Ok(response)
169}
170
171/// Prometheus metrics format output
172impl MetricsSnapshot {
173    pub fn to_prometheus_format(&self) -> String {
174        let mut output = String::new();
175
176        // System metrics
177        output.push_str(&format!(
178            "# HELP auth_framework_uptime_seconds Total uptime in seconds\n\
179             # TYPE auth_framework_uptime_seconds counter\n\
180             auth_framework_uptime_seconds {}\n\n",
181            self.uptime.as_secs()
182        ));
183
184        output.push_str(&format!(
185            "# HELP auth_framework_requests_total Total number of requests\n\
186             # TYPE auth_framework_requests_total counter\n\
187             auth_framework_requests_total {}\n\n",
188            self.total_requests
189        ));
190
191        output.push_str(&format!(
192            "# HELP auth_framework_active_requests Current number of active requests\n\
193             # TYPE auth_framework_active_requests gauge\n\
194             auth_framework_active_requests {}\n\n",
195            self.active_requests
196        ));
197
198        // Endpoint metrics
199        for (endpoint, metrics) in &self.endpoint_metrics {
200            let _safe_endpoint = endpoint.replace(['/', '-'], "_");
201
202            output.push_str(&format!(
203                "auth_framework_endpoint_requests_total{{endpoint=\"{}\"}} {}\n",
204                endpoint, metrics.request_count
205            ));
206
207            output.push_str(&format!(
208                "auth_framework_endpoint_errors_total{{endpoint=\"{}\"}} {}\n",
209                endpoint, metrics.error_count
210            ));
211
212            output.push_str(&format!(
213                "auth_framework_endpoint_response_time_avg{{endpoint=\"{}\"}} {}\n",
214                endpoint,
215                metrics.avg_response_time.as_secs_f64()
216            ));
217
218            output.push_str(&format!(
219                "auth_framework_endpoint_response_time_p95{{endpoint=\"{}\"}} {}\n",
220                endpoint,
221                metrics.p95_response_time.as_secs_f64()
222            ));
223        }
224
225        output
226    }
227}
228
229#[cfg(test)]
230mod tests {
231    use super::*;
232
233    #[test]
234    fn test_metrics_collection() {
235        let metrics = ApiMetrics::new();
236
237        metrics.record_request("/api/login");
238        metrics.record_response("/api/login", Duration::from_millis(100), StatusCode::OK);
239
240        let snapshot = metrics.get_metrics();
241        assert_eq!(snapshot.total_requests, 1);
242        assert_eq!(snapshot.endpoint_metrics["/api/login"].request_count, 1);
243        assert_eq!(snapshot.endpoint_metrics["/api/login"].error_count, 0);
244    }
245
246    #[test]
247    fn test_error_tracking() {
248        let metrics = ApiMetrics::new();
249
250        metrics.record_request("/api/test");
251        metrics.record_response(
252            "/api/test",
253            Duration::from_millis(50),
254            StatusCode::BAD_REQUEST,
255        );
256
257        let snapshot = metrics.get_metrics();
258        assert_eq!(snapshot.endpoint_metrics["/api/test"].error_count, 1);
259        assert!(snapshot.endpoint_metrics["/api/test"].error_rate > 0.0);
260    }
261
262    #[test]
263    fn test_percentile_calculation() {
264        let durations = vec![
265            Duration::from_millis(10),
266            Duration::from_millis(20),
267            Duration::from_millis(30),
268            Duration::from_millis(40),
269            Duration::from_millis(100),
270        ];
271
272        let p95 = calculate_percentile(&durations, 95.0);
273        assert_eq!(p95, Duration::from_millis(100));
274    }
275
276    #[test]
277    fn test_prometheus_format() {
278        let metrics = ApiMetrics::new();
279        metrics.record_request("/api/test");
280        metrics.record_response("/api/test", Duration::from_millis(100), StatusCode::OK);
281
282        let snapshot = metrics.get_metrics();
283        let prometheus = snapshot.to_prometheus_format();
284
285        assert!(prometheus.contains("auth_framework_requests_total"));
286        assert!(prometheus.contains("auth_framework_active_requests"));
287        assert!(prometheus.contains("endpoint=\"/api/test\""));
288    }
289}
290
291