Skip to main content

ferro_rs/metrics/
mod.rs

1//! Request metrics collection for performance monitoring
2//!
3//! Collects request counts, response times, and error rates per route.
4//! Metrics are stored in-memory and exposed via `/_ferro/metrics`.
5
6use serde::Serialize;
7use std::collections::HashMap;
8use std::sync::{OnceLock, RwLock};
9use std::time::Duration;
10
11/// Safety cap on the number of unique route entries in the metrics store.
12/// Normalization handles the common case (unmatched routes → "UNMATCHED"),
13/// this cap prevents any other unbounded growth scenario.
14const MAX_ROUTE_ENTRIES: usize = 1000;
15
16/// Request metrics for a single route
17#[derive(Debug, Clone, Serialize)]
18pub struct RouteMetrics {
19    /// Route pattern (e.g., "/users/{id}")
20    pub route: String,
21    /// HTTP method
22    pub method: String,
23    /// Total request count
24    pub count: u64,
25    /// Total duration in milliseconds (for average calculation)
26    pub total_duration_ms: u64,
27    /// Number of error responses (4xx and 5xx)
28    pub error_count: u64,
29    /// Minimum response time in ms
30    pub min_duration_ms: u64,
31    /// Maximum response time in ms
32    pub max_duration_ms: u64,
33}
34
35impl RouteMetrics {
36    fn new(route: String, method: String) -> Self {
37        Self {
38            route,
39            method,
40            count: 0,
41            total_duration_ms: 0,
42            error_count: 0,
43            min_duration_ms: u64::MAX,
44            max_duration_ms: 0,
45        }
46    }
47
48    /// Calculate average response time in milliseconds
49    pub fn avg_duration_ms(&self) -> f64 {
50        if self.count == 0 {
51            0.0
52        } else {
53            self.total_duration_ms as f64 / self.count as f64
54        }
55    }
56}
57
58/// Aggregated metrics response
59#[derive(Debug, Serialize)]
60pub struct MetricsSnapshot {
61    /// Per-route metrics
62    pub routes: Vec<RouteMetricsView>,
63    /// Total requests across all routes
64    pub total_requests: u64,
65    /// Total errors across all routes
66    pub total_errors: u64,
67    /// Uptime since metrics collection started (seconds)
68    pub uptime_seconds: u64,
69}
70
71/// View of route metrics for serialization (includes computed avg)
72#[derive(Debug, Serialize)]
73pub struct RouteMetricsView {
74    /// Route pattern (e.g., "/users/{id}").
75    pub route: String,
76    /// HTTP method (e.g., "GET", "POST").
77    pub method: String,
78    /// Total number of requests for this route.
79    pub count: u64,
80    /// Average response time in milliseconds.
81    pub avg_duration_ms: f64,
82    /// Minimum response time in ms, or `None` if no requests recorded.
83    pub min_duration_ms: Option<u64>,
84    /// Maximum response time in milliseconds.
85    pub max_duration_ms: u64,
86    /// Number of error responses (4xx and 5xx).
87    pub error_count: u64,
88    /// Ratio of error responses to total requests (0.0–1.0).
89    pub error_rate: f64,
90}
91
92/// Global metrics storage
93static METRICS: OnceLock<RwLock<MetricsStore>> = OnceLock::new();
94
95struct MetricsStore {
96    routes: HashMap<String, RouteMetrics>,
97    start_time: std::time::Instant,
98}
99
100impl MetricsStore {
101    fn new() -> Self {
102        Self {
103            routes: HashMap::new(),
104            start_time: std::time::Instant::now(),
105        }
106    }
107}
108
109fn get_store() -> &'static RwLock<MetricsStore> {
110    METRICS.get_or_init(|| RwLock::new(MetricsStore::new()))
111}
112
113/// Generate a unique key for route metrics
114fn route_key(method: &str, route: &str) -> String {
115    format!("{method}:{route}")
116}
117
118/// Record a request completion
119///
120/// # Arguments
121/// * `route` - Route pattern (e.g., "/users/{id}")
122/// * `method` - HTTP method
123/// * `duration` - Request duration
124/// * `is_error` - Whether response was an error (4xx or 5xx)
125pub fn record_request(route: &str, method: &str, duration: Duration, is_error: bool) {
126    let key = route_key(method, route);
127    let duration_ms = duration.as_millis() as u64;
128
129    if let Ok(mut store) = get_store().write() {
130        // Safety cap: if at capacity and this is a new key, skip recording
131        if store.routes.len() >= MAX_ROUTE_ENTRIES && !store.routes.contains_key(&key) {
132            return;
133        }
134
135        let metrics = store
136            .routes
137            .entry(key)
138            .or_insert_with(|| RouteMetrics::new(route.to_string(), method.to_string()));
139
140        metrics.count += 1;
141        metrics.total_duration_ms += duration_ms;
142
143        if duration_ms < metrics.min_duration_ms {
144            metrics.min_duration_ms = duration_ms;
145        }
146        if duration_ms > metrics.max_duration_ms {
147            metrics.max_duration_ms = duration_ms;
148        }
149
150        if is_error {
151            metrics.error_count += 1;
152        }
153    }
154}
155
156/// Get current metrics snapshot
157pub fn get_metrics() -> MetricsSnapshot {
158    let store = get_store().read().unwrap();
159
160    let mut total_requests = 0u64;
161    let mut total_errors = 0u64;
162
163    let routes: Vec<RouteMetricsView> = store
164        .routes
165        .values()
166        .map(|m| {
167            total_requests += m.count;
168            total_errors += m.error_count;
169
170            RouteMetricsView {
171                route: m.route.clone(),
172                method: m.method.clone(),
173                count: m.count,
174                avg_duration_ms: m.avg_duration_ms(),
175                min_duration_ms: if m.min_duration_ms == u64::MAX {
176                    None
177                } else {
178                    Some(m.min_duration_ms)
179                },
180                max_duration_ms: m.max_duration_ms,
181                error_count: m.error_count,
182                error_rate: if m.count == 0 {
183                    0.0
184                } else {
185                    m.error_count as f64 / m.count as f64
186                },
187            }
188        })
189        .collect();
190
191    MetricsSnapshot {
192        routes,
193        total_requests,
194        total_errors,
195        uptime_seconds: store.start_time.elapsed().as_secs(),
196    }
197}
198
199/// Reset all metrics (useful for testing)
200pub fn reset_metrics() {
201    if let Ok(mut store) = get_store().write() {
202        store.routes.clear();
203        store.start_time = std::time::Instant::now();
204    }
205}
206
207/// Check if metrics collection is enabled
208pub fn is_enabled() -> bool {
209    std::env::var("FERRO_COLLECT_METRICS")
210        .map(|v| v == "true" || v == "1")
211        .unwrap_or(true) // Enabled by default in development
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217    use serial_test::serial;
218    use std::time::Duration;
219
220    fn setup() {
221        reset_metrics();
222    }
223
224    #[test]
225    #[serial]
226    fn test_record_request_increments_count() {
227        setup();
228
229        record_request("/users", "GET", Duration::from_millis(10), false);
230        record_request("/users", "GET", Duration::from_millis(20), false);
231
232        let snapshot = get_metrics();
233        let route = snapshot
234            .routes
235            .iter()
236            .find(|r| r.route == "/users")
237            .unwrap();
238
239        assert_eq!(route.count, 2);
240        assert_eq!(snapshot.total_requests, 2);
241    }
242
243    #[test]
244    #[serial]
245    fn test_record_request_tracks_duration() {
246        setup();
247
248        record_request("/api/test", "POST", Duration::from_millis(10), false);
249        record_request("/api/test", "POST", Duration::from_millis(30), false);
250        record_request("/api/test", "POST", Duration::from_millis(20), false);
251
252        let snapshot = get_metrics();
253        let route = snapshot
254            .routes
255            .iter()
256            .find(|r| r.route == "/api/test")
257            .unwrap();
258
259        assert_eq!(route.min_duration_ms, Some(10));
260        assert_eq!(route.max_duration_ms, 30);
261        assert!((route.avg_duration_ms - 20.0).abs() < 0.01);
262    }
263
264    #[test]
265    #[serial]
266    fn test_record_request_counts_errors() {
267        setup();
268
269        record_request("/error", "GET", Duration::from_millis(5), false);
270        record_request("/error", "GET", Duration::from_millis(5), true);
271        record_request("/error", "GET", Duration::from_millis(5), true);
272
273        let snapshot = get_metrics();
274        let route = snapshot
275            .routes
276            .iter()
277            .find(|r| r.route == "/error")
278            .unwrap();
279
280        assert_eq!(route.count, 3);
281        assert_eq!(route.error_count, 2);
282        assert!((route.error_rate - 2.0 / 3.0).abs() < 0.01);
283        assert_eq!(snapshot.total_errors, 2);
284    }
285
286    #[test]
287    #[serial]
288    fn test_different_methods_tracked_separately() {
289        setup();
290
291        record_request("/resource", "GET", Duration::from_millis(10), false);
292        record_request("/resource", "POST", Duration::from_millis(20), false);
293        record_request("/resource", "GET", Duration::from_millis(15), false);
294
295        let snapshot = get_metrics();
296
297        let get_route = snapshot
298            .routes
299            .iter()
300            .find(|r| r.route == "/resource" && r.method == "GET")
301            .unwrap();
302        let post_route = snapshot
303            .routes
304            .iter()
305            .find(|r| r.route == "/resource" && r.method == "POST")
306            .unwrap();
307
308        assert_eq!(get_route.count, 2);
309        assert_eq!(post_route.count, 1);
310    }
311
312    #[test]
313    fn test_route_metrics_avg_duration_zero_count() {
314        let metrics = RouteMetrics::new("/test".to_string(), "GET".to_string());
315        assert_eq!(metrics.avg_duration_ms(), 0.0);
316    }
317
318    #[test]
319    #[serial]
320    fn test_min_duration_none_when_no_requests() {
321        setup();
322
323        // Record to a different route
324        record_request("/other", "GET", Duration::from_millis(10), false);
325
326        let snapshot = get_metrics();
327
328        // Find a route that exists
329        let route = snapshot
330            .routes
331            .iter()
332            .find(|r| r.route == "/other")
333            .unwrap();
334        assert_eq!(route.min_duration_ms, Some(10));
335    }
336
337    #[test]
338    #[serial]
339    fn test_reset_metrics_clears_data() {
340        setup();
341
342        record_request("/clear-test", "GET", Duration::from_millis(10), false);
343
344        let snapshot = get_metrics();
345        assert!(!snapshot.routes.is_empty());
346
347        reset_metrics();
348
349        let snapshot = get_metrics();
350        assert!(snapshot.routes.is_empty());
351        assert_eq!(snapshot.total_requests, 0);
352    }
353
354    #[test]
355    #[serial]
356    fn test_uptime_tracking() {
357        setup();
358
359        let snapshot1 = get_metrics();
360        std::thread::sleep(Duration::from_millis(50));
361        let snapshot2 = get_metrics();
362
363        // Uptime should have increased
364        assert!(snapshot2.uptime_seconds >= snapshot1.uptime_seconds);
365    }
366
367    #[test]
368    fn test_route_key_format() {
369        assert_eq!(route_key("GET", "/users"), "GET:/users");
370        assert_eq!(route_key("POST", "/api/v1/items"), "POST:/api/v1/items");
371    }
372
373    #[test]
374    #[serial]
375    fn test_unmatched_routes_use_fixed_bucket() {
376        setup();
377
378        // Simulate what the middleware now does: all unmatched routes → "UNMATCHED"
379        record_request("UNMATCHED", "GET", Duration::from_millis(5), true);
380        record_request("UNMATCHED", "GET", Duration::from_millis(10), true);
381        record_request("UNMATCHED", "POST", Duration::from_millis(8), true);
382
383        let snapshot = get_metrics();
384
385        // GET:UNMATCHED and POST:UNMATCHED — only 2 entries, not N unique paths
386        let unmatched_routes: Vec<_> = snapshot
387            .routes
388            .iter()
389            .filter(|r| r.route == "UNMATCHED")
390            .collect();
391        assert_eq!(unmatched_routes.len(), 2);
392
393        let get_unmatched = unmatched_routes.iter().find(|r| r.method == "GET").unwrap();
394        assert_eq!(get_unmatched.count, 2);
395    }
396
397    #[test]
398    #[serial]
399    fn test_entry_cap_prevents_unbounded_growth() {
400        setup();
401
402        // Fill to MAX_ROUTE_ENTRIES with unique routes
403        for i in 0..MAX_ROUTE_ENTRIES + 100 {
404            let route = format!("/route/{i}");
405            record_request(&route, "GET", Duration::from_millis(1), false);
406        }
407
408        let snapshot = get_metrics();
409        assert!(snapshot.routes.len() <= MAX_ROUTE_ENTRIES);
410    }
411
412    #[test]
413    #[serial]
414    fn test_existing_entries_updated_after_cap() {
415        setup();
416
417        // Fill exactly to MAX_ROUTE_ENTRIES
418        for i in 0..MAX_ROUTE_ENTRIES {
419            let route = format!("/route/{i}");
420            record_request(&route, "GET", Duration::from_millis(1), false);
421        }
422
423        let snapshot = get_metrics();
424        assert_eq!(snapshot.routes.len(), MAX_ROUTE_ENTRIES);
425
426        // Record another request for an existing route — should still be tracked
427        record_request("/route/0", "GET", Duration::from_millis(50), false);
428
429        let snapshot = get_metrics();
430        let route = snapshot
431            .routes
432            .iter()
433            .find(|r| r.route == "/route/0" && r.method == "GET")
434            .unwrap();
435        assert_eq!(route.count, 2); // Original + new request
436    }
437}