Skip to main content

auth_framework/api/
metrics.rs

1//! API Metrics and Observability
2//!
3//! Provides comprehensive metrics collection for API endpoints
4
5use axum::{extract::Request, http::StatusCode, middleware::Next, response::Response};
6use std::{
7    collections::HashMap,
8    sync::{Arc, Mutex},
9    time::{Duration, Instant},
10};
11
12/// Thread-safe metrics collector for API endpoints.
13///
14/// Tracks per-path request counts, response times, error counts,
15/// and active request gauge.
16///
17/// # Example
18/// ```rust
19/// use auth_framework::api::metrics::ApiMetrics;
20/// use std::time::Duration;
21/// use axum::http::StatusCode;
22///
23/// let m = ApiMetrics::new();
24/// m.record_request("/login");
25/// m.record_response("/login", Duration::from_millis(5), StatusCode::OK);
26///
27/// let snap = m.get_metrics();
28/// assert_eq!(snap.total_requests, 1);
29/// ```
30#[derive(Debug, Clone)]
31pub struct ApiMetrics {
32    inner: Arc<Mutex<ApiMetricsInner>>,
33}
34
35#[derive(Debug)]
36struct ApiMetricsInner {
37    request_counts: HashMap<String, u64>,
38    response_times: HashMap<String, Vec<Duration>>,
39    error_counts: HashMap<String, u64>,
40    active_requests: u64,
41    start_time: Instant,
42}
43
44impl ApiMetrics {
45    /// Create a new, empty metrics collector.
46    ///
47    /// # Example
48    /// ```rust
49    /// use auth_framework::api::metrics::ApiMetrics;
50    /// let m = ApiMetrics::new();
51    /// assert_eq!(m.get_metrics().total_requests, 0);
52    /// ```
53    pub fn new() -> Self {
54        Self {
55            inner: Arc::new(Mutex::new(ApiMetricsInner {
56                request_counts: HashMap::new(),
57                response_times: HashMap::new(),
58                error_counts: HashMap::new(),
59                active_requests: 0,
60                start_time: Instant::now(),
61            })),
62        }
63    }
64
65    /// Record an incoming request for `path`.
66    ///
67    /// Increments the request counter and the active-requests gauge.
68    ///
69    /// # Example
70    /// ```rust
71    /// use auth_framework::api::metrics::ApiMetrics;
72    /// let m = ApiMetrics::new();
73    /// m.record_request("/health");
74    /// ```
75    pub fn record_request(&self, path: &str) {
76        let Ok(mut inner) = self.inner.lock() else {
77            return;
78        };
79        *inner.request_counts.entry(path.to_string()).or_insert(0) += 1;
80        inner.active_requests += 1;
81    }
82
83    /// Record a completed response for `path`.
84    ///
85    /// Stores the response `duration`, increments error counters for
86    /// 4xx/5xx status codes, and decrements the active-requests gauge.
87    ///
88    /// # Example
89    /// ```rust
90    /// use auth_framework::api::metrics::ApiMetrics;
91    /// use std::time::Duration;
92    /// use axum::http::StatusCode;
93    ///
94    /// let m = ApiMetrics::new();
95    /// m.record_request("/api");
96    /// m.record_response("/api", Duration::from_millis(12), StatusCode::OK);
97    /// ```
98    pub fn record_response(&self, path: &str, duration: Duration, status: StatusCode) {
99        let Ok(mut inner) = self.inner.lock() else {
100            return;
101        };
102        inner
103            .response_times
104            .entry(path.to_string())
105            .or_default()
106            .push(duration);
107
108        if status.is_client_error() || status.is_server_error() {
109            *inner.error_counts.entry(path.to_string()).or_insert(0) += 1;
110        }
111
112        inner.active_requests = inner.active_requests.saturating_sub(1);
113    }
114
115    /// Snapshot all collected metrics.
116    ///
117    /// Returns a [`MetricsSnapshot`] with uptime, totals, and per-endpoint
118    /// statistics including average, p95, and p99 response times.
119    ///
120    /// # Example
121    /// ```rust
122    /// use auth_framework::api::metrics::ApiMetrics;
123    ///
124    /// let m = ApiMetrics::new();
125    /// let snap = m.get_metrics();
126    /// assert_eq!(snap.total_requests, 0);
127    /// ```
128    pub fn get_metrics(&self) -> MetricsSnapshot {
129        let Ok(inner) = self.inner.lock() else {
130            return MetricsSnapshot {
131                uptime: Duration::ZERO,
132                total_requests: 0,
133                active_requests: 0,
134                endpoint_metrics: HashMap::new(),
135            };
136        };
137        let mut endpoint_metrics = HashMap::new();
138
139        for (path, &count) in &inner.request_counts {
140            let response_times = inner.response_times.get(path).cloned().unwrap_or_default();
141            let error_count = inner.error_counts.get(path).copied().unwrap_or(0);
142
143            let avg_response_time = if !response_times.is_empty() {
144                response_times.iter().sum::<Duration>() / response_times.len() as u32
145            } else {
146                Duration::ZERO
147            };
148
149            let p95_response_time = calculate_percentile(&response_times, 95.0);
150            let p99_response_time = calculate_percentile(&response_times, 99.0);
151
152            endpoint_metrics.insert(
153                path.clone(),
154                EndpointMetrics {
155                    request_count: count,
156                    error_count,
157                    error_rate: if count > 0 {
158                        error_count as f64 / count as f64
159                    } else {
160                        0.0
161                    },
162                    avg_response_time,
163                    p95_response_time,
164                    p99_response_time,
165                },
166            );
167        }
168
169        MetricsSnapshot {
170            uptime: inner.start_time.elapsed(),
171            total_requests: inner.request_counts.values().sum(),
172            active_requests: inner.active_requests,
173            endpoint_metrics,
174        }
175    }
176
177    /// Reset all counters and timers.
178    ///
179    /// Clears request counts, response times, error counts, and resets
180    /// the uptime clock.
181    ///
182    /// # Example
183    /// ```rust
184    /// use auth_framework::api::metrics::ApiMetrics;
185    /// use std::time::Duration;
186    /// use axum::http::StatusCode;
187    ///
188    /// let m = ApiMetrics::new();
189    /// m.record_request("/x");
190    /// m.record_response("/x", Duration::from_millis(1), StatusCode::OK);
191    /// m.reset();
192    /// assert_eq!(m.get_metrics().total_requests, 0);
193    /// ```
194    pub fn reset(&self) {
195        let Ok(mut inner) = self.inner.lock() else {
196            return;
197        };
198        inner.request_counts.clear();
199        inner.response_times.clear();
200        inner.error_counts.clear();
201        inner.start_time = Instant::now();
202    }
203}
204
205impl Default for ApiMetrics {
206    fn default() -> Self {
207        Self::new()
208    }
209}
210
211/// Point-in-time snapshot of API metrics.
212///
213/// Returned by [`ApiMetrics::get_metrics()`]. Use
214/// [`to_prometheus_format()`](MetricsSnapshot::to_prometheus_format)
215/// to export for Prometheus scraping.
216///
217/// # Example
218/// ```rust
219/// use auth_framework::api::metrics::ApiMetrics;
220///
221/// let snap = ApiMetrics::new().get_metrics();
222/// assert_eq!(snap.active_requests, 0);
223/// ```
224#[derive(Debug, Clone)]
225pub struct MetricsSnapshot {
226    pub uptime: Duration,
227    pub total_requests: u64,
228    pub active_requests: u64,
229    pub endpoint_metrics: HashMap<String, EndpointMetrics>,
230}
231
232/// Per-endpoint statistics within a [`MetricsSnapshot`].
233///
234/// # Example
235/// ```rust
236/// use auth_framework::api::metrics::{ApiMetrics, EndpointMetrics};
237/// use std::time::Duration;
238/// use axum::http::StatusCode;
239///
240/// let m = ApiMetrics::new();
241/// m.record_request("/test");
242/// m.record_response("/test", Duration::from_millis(50), StatusCode::OK);
243///
244/// let snap = m.get_metrics();
245/// let ep: &EndpointMetrics = &snap.endpoint_metrics["/test"];
246/// assert_eq!(ep.request_count, 1);
247/// ```
248#[derive(Debug, Clone)]
249pub struct EndpointMetrics {
250    pub request_count: u64,
251    pub error_count: u64,
252    pub error_rate: f64,
253    pub avg_response_time: Duration,
254    pub p95_response_time: Duration,
255    pub p99_response_time: Duration,
256}
257
258/// Calculate percentile from a sorted list of durations
259fn calculate_percentile(durations: &[Duration], percentile: f64) -> Duration {
260    if durations.is_empty() {
261        return Duration::ZERO;
262    }
263
264    let mut sorted = durations.to_vec();
265    sorted.sort();
266
267    let index = ((percentile / 100.0) * (sorted.len() - 1) as f64).round() as usize;
268    sorted.get(index).copied().unwrap_or(Duration::ZERO)
269}
270
271/// Axum middleware that records request and response metrics.
272///
273/// Attach via `axum::middleware::from_fn(metrics_middleware)`.
274///
275/// # Example
276/// ```rust,ignore
277/// use axum::{Router, middleware};
278/// use auth_framework::api::metrics::metrics_middleware;
279///
280/// let app = Router::new()
281///     .layer(middleware::from_fn(metrics_middleware));
282/// ```
283pub async fn metrics_middleware(request: Request, next: Next) -> Result<Response, StatusCode> {
284    let start_time = Instant::now();
285    let path = request.uri().path().to_string();
286
287    // Get metrics collector from extensions or create new one
288    let metrics = request
289        .extensions()
290        .get::<ApiMetrics>()
291        .cloned()
292        .unwrap_or_default();
293
294    metrics.record_request(&path);
295
296    let response = next.run(request).await;
297    let duration = start_time.elapsed();
298
299    metrics.record_response(&path, duration, response.status());
300
301    Ok(response)
302}
303
304impl MetricsSnapshot {
305    /// Render the snapshot in Prometheus text exposition format.
306    ///
307    /// # Example
308    /// ```rust
309    /// use auth_framework::api::metrics::ApiMetrics;
310    ///
311    /// let prom = ApiMetrics::new().get_metrics().to_prometheus_format();
312    /// assert!(prom.contains("auth_framework_uptime_seconds"));
313    /// ```
314    pub fn to_prometheus_format(&self) -> String {
315        let mut output = String::new();
316
317        // System metrics
318        output.push_str(&format!(
319            "# HELP auth_framework_uptime_seconds Total uptime in seconds\n\
320             # TYPE auth_framework_uptime_seconds counter\n\
321             auth_framework_uptime_seconds {}\n\n",
322            self.uptime.as_secs()
323        ));
324
325        output.push_str(&format!(
326            "# HELP auth_framework_requests_total Total number of requests\n\
327             # TYPE auth_framework_requests_total counter\n\
328             auth_framework_requests_total {}\n\n",
329            self.total_requests
330        ));
331
332        output.push_str(&format!(
333            "# HELP auth_framework_active_requests Current number of active requests\n\
334             # TYPE auth_framework_active_requests gauge\n\
335             auth_framework_active_requests {}\n\n",
336            self.active_requests
337        ));
338
339        // Endpoint metrics
340        for (endpoint, metrics) in &self.endpoint_metrics {
341            let _safe_endpoint = endpoint.replace(['/', '-'], "_");
342
343            output.push_str(&format!(
344                "auth_framework_endpoint_requests_total{{endpoint=\"{}\"}} {}\n",
345                endpoint, metrics.request_count
346            ));
347
348            output.push_str(&format!(
349                "auth_framework_endpoint_errors_total{{endpoint=\"{}\"}} {}\n",
350                endpoint, metrics.error_count
351            ));
352
353            output.push_str(&format!(
354                "auth_framework_endpoint_response_time_avg{{endpoint=\"{}\"}} {}\n",
355                endpoint,
356                metrics.avg_response_time.as_secs_f64()
357            ));
358
359            output.push_str(&format!(
360                "auth_framework_endpoint_response_time_p95{{endpoint=\"{}\"}} {}\n",
361                endpoint,
362                metrics.p95_response_time.as_secs_f64()
363            ));
364        }
365
366        output
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373
374    #[test]
375    fn test_metrics_collection() {
376        let metrics = ApiMetrics::new();
377
378        metrics.record_request("/api/login");
379        metrics.record_response("/api/login", Duration::from_millis(100), StatusCode::OK);
380
381        let snapshot = metrics.get_metrics();
382        assert_eq!(snapshot.total_requests, 1);
383        assert_eq!(snapshot.endpoint_metrics["/api/login"].request_count, 1);
384        assert_eq!(snapshot.endpoint_metrics["/api/login"].error_count, 0);
385    }
386
387    #[test]
388    fn test_error_tracking() {
389        let metrics = ApiMetrics::new();
390
391        metrics.record_request("/api/test");
392        metrics.record_response(
393            "/api/test",
394            Duration::from_millis(50),
395            StatusCode::BAD_REQUEST,
396        );
397
398        let snapshot = metrics.get_metrics();
399        assert_eq!(snapshot.endpoint_metrics["/api/test"].error_count, 1);
400        assert!(snapshot.endpoint_metrics["/api/test"].error_rate > 0.0);
401    }
402
403    #[test]
404    fn test_percentile_calculation() {
405        let durations = vec![
406            Duration::from_millis(10),
407            Duration::from_millis(20),
408            Duration::from_millis(30),
409            Duration::from_millis(40),
410            Duration::from_millis(100),
411        ];
412
413        let p95 = calculate_percentile(&durations, 95.0);
414        assert_eq!(p95, Duration::from_millis(100));
415    }
416
417    #[test]
418    fn test_prometheus_format() {
419        let metrics = ApiMetrics::new();
420        metrics.record_request("/api/test");
421        metrics.record_response("/api/test", Duration::from_millis(100), StatusCode::OK);
422
423        let snapshot = metrics.get_metrics();
424        let prometheus = snapshot.to_prometheus_format();
425
426        assert!(prometheus.contains("auth_framework_requests_total"));
427        assert!(prometheus.contains("auth_framework_active_requests"));
428        assert!(prometheus.contains("endpoint=\"/api/test\""));
429    }
430}