Skip to main content

crates_docs/metrics/
mod.rs

1//! Metrics module for Prometheus monitoring
2//!
3//! Provides metrics collection and export functionality for the MCP server.
4
5use prometheus_client::encoding::text::encode;
6use prometheus_client::metrics::counter::Counter;
7use prometheus_client::metrics::family::Family;
8use prometheus_client::metrics::gauge::Gauge;
9use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
10use prometheus_client::registry::Registry;
11use std::sync::atomic::AtomicU64;
12use std::sync::Arc;
13use std::time::Instant;
14
15/// Metrics labels for request tracking
16#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
17pub struct RequestLabels {
18    /// Tool name
19    pub tool: String,
20    /// Status: success or error
21    pub status: String,
22}
23
24/// Metrics labels for cache operations
25#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
26pub struct CacheLabels {
27    /// Operation type: get, set, hit, miss
28    pub operation: String,
29    /// Cache type: memory, redis
30    pub cache_type: String,
31}
32
33/// Metrics labels for HTTP requests
34#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
35pub struct HttpLabels {
36    /// HTTP method
37    pub method: String,
38    /// HTTP status code
39    pub status: String,
40    /// Target host
41    pub host: String,
42}
43
44/// Server metrics collection
45pub struct ServerMetrics {
46    /// Request counter
47    request_counter: Family<RequestLabels, Counter>,
48    /// Request duration histogram
49    request_duration: Family<RequestLabels, Histogram>,
50    /// Cache operation counter
51    cache_counter: Family<CacheLabels, Counter>,
52    /// Cache hits gauge
53    cache_hits: Gauge<u64, AtomicU64>,
54    /// Cache misses gauge
55    cache_misses: Gauge<u64, AtomicU64>,
56    /// Cache sets gauge
57    cache_sets: Gauge<u64, AtomicU64>,
58    /// Cache hit rate gauge
59    cache_hit_rate: Gauge<f64, AtomicU64>,
60    /// HTTP request counter
61    http_counter: Family<HttpLabels, Counter>,
62    /// HTTP request duration
63    http_duration: Family<HttpLabels, Histogram>,
64    /// Active connections gauge
65    active_connections: Gauge<u64, AtomicU64>,
66    /// Error counter
67    error_counter: Family<RequestLabels, Counter>,
68    /// Registry
69    registry: Arc<Registry>,
70}
71
72impl ServerMetrics {
73    /// Create a new metrics collection
74    #[must_use]
75    pub fn new() -> Self {
76        let mut registry = Registry::default();
77
78        // Request counter
79        let request_counter = Family::<RequestLabels, Counter>::default();
80        registry.register(
81            "mcp_requests_total",
82            "Total number of MCP tool requests",
83            request_counter.clone(),
84        );
85
86        // Request duration histogram (exponential buckets from 1ms to 30s)
87        let request_duration = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
88            Histogram::new(exponential_buckets(0.001, 2.0, 15))
89        });
90        registry.register(
91            "mcp_request_duration_seconds",
92            "MCP tool request duration in seconds",
93            request_duration.clone(),
94        );
95
96        // Cache operation counter
97        let cache_counter = Family::<CacheLabels, Counter>::default();
98        registry.register(
99            "mcp_cache_operations_total",
100            "Total number of cache operations",
101            cache_counter.clone(),
102        );
103
104        // Cache hits gauge (count of cache hits)
105        let cache_hits = Gauge::default();
106        registry.register(
107            "mcp_cache_hits",
108            "Number of cache hits (gauge)",
109            cache_hits.clone(),
110        );
111
112        // Cache misses gauge (count of cache misses)
113        let cache_misses = Gauge::default();
114        registry.register(
115            "mcp_cache_misses",
116            "Number of cache misses (gauge)",
117            cache_misses.clone(),
118        );
119
120        // Cache sets gauge (count of cache set operations)
121        let cache_sets = Gauge::default();
122        registry.register(
123            "mcp_cache_sets",
124            "Number of cache set operations (gauge)",
125            cache_sets.clone(),
126        );
127
128        // Cache hit rate gauge
129        let cache_hit_rate = Gauge::default();
130        registry.register(
131            "mcp_cache_hit_rate",
132            "Cache hit rate (0.0 to 1.0)",
133            cache_hit_rate.clone(),
134        );
135
136        // HTTP request counter
137        let http_counter = Family::<HttpLabels, Counter>::default();
138        registry.register(
139            "mcp_http_requests_total",
140            "Total number of HTTP requests",
141            http_counter.clone(),
142        );
143
144        // HTTP request duration
145        let http_duration = Family::<HttpLabels, Histogram>::new_with_constructor(|| {
146            Histogram::new(exponential_buckets(0.001, 2.0, 15))
147        });
148        registry.register(
149            "mcp_http_request_duration_seconds",
150            "HTTP request duration in seconds",
151            http_duration.clone(),
152        );
153
154        // Active connections gauge
155        let active_connections = Gauge::<u64, AtomicU64>::default();
156        registry.register(
157            "mcp_active_connections",
158            "Number of active connections",
159            active_connections.clone(),
160        );
161
162        // Error counter
163        let error_counter = Family::<RequestLabels, Counter>::default();
164        registry.register(
165            "mcp_errors_total",
166            "Total number of errors",
167            error_counter.clone(),
168        );
169
170        Self {
171            request_counter,
172            request_duration,
173            cache_counter,
174            cache_hits,
175            cache_misses,
176            cache_sets,
177            cache_hit_rate,
178            http_counter,
179            http_duration,
180            active_connections,
181            error_counter,
182            registry: Arc::new(registry),
183        }
184    }
185
186    /// Record a tool request
187    pub fn record_request(&self, tool: &str, success: bool, duration: std::time::Duration) {
188        let labels = RequestLabels {
189            tool: tool.to_string(),
190            status: if success {
191                "success".to_string()
192            } else {
193                "error".to_string()
194            },
195        };
196
197        self.request_counter.get_or_create(&labels).inc();
198        self.request_duration
199            .get_or_create(&labels)
200            .observe(duration.as_secs_f64());
201
202        if !success {
203            self.error_counter.get_or_create(&labels).inc();
204        }
205    }
206
207    /// Record a cache operation
208    pub fn record_cache_operation(&self, operation: &str, cache_type: &str) {
209        let labels = CacheLabels {
210            operation: operation.to_string(),
211            cache_type: cache_type.to_string(),
212        };
213        self.cache_counter.get_or_create(&labels).inc();
214    }
215
216    /// Record a cache hit
217    pub fn record_cache_hit(&self, cache_type: &str) {
218        self.record_cache_operation("hit", cache_type);
219    }
220
221    /// Record a cache miss
222    pub fn record_cache_miss(&self, cache_type: &str) {
223        self.record_cache_operation("miss", cache_type);
224    }
225
226    /// Update cache hit rate
227    #[allow(clippy::cast_precision_loss)]
228    pub fn update_cache_hit_rate(&self, hits: u64, misses: u64) {
229        let total = hits + misses;
230        if total > 0 {
231            let rate = hits as f64 / total as f64;
232            self.cache_hit_rate.set(rate);
233        }
234    }
235
236    /// Update cache statistics gauges from provided counts
237    ///
238    /// # Arguments
239    ///
240    /// * `hits` - Total cache hits
241    /// * `misses` - Total cache misses
242    /// * `sets` - Total cache set operations
243    ///
244    /// # Note
245    ///
246    /// This method also updates the cache hit rate automatically.
247    pub fn update_cache_stats(&self, hits: u64, misses: u64, sets: u64) {
248        self.cache_hits.set(hits);
249        self.cache_misses.set(misses);
250        self.cache_sets.set(sets);
251        self.update_cache_hit_rate(hits, misses);
252    }
253
254    /// Record an HTTP request
255    pub fn record_http_request(
256        &self,
257        method: &str,
258        status: u16,
259        host: &str,
260        duration: std::time::Duration,
261    ) {
262        let labels = HttpLabels {
263            method: method.to_string(),
264            status: status.to_string(),
265            host: host.to_string(),
266        };
267
268        self.http_counter.get_or_create(&labels).inc();
269        self.http_duration
270            .get_or_create(&labels)
271            .observe(duration.as_secs_f64());
272    }
273
274    /// Increment active connections
275    pub fn inc_active_connections(&self) {
276        self.active_connections.inc();
277    }
278
279    /// Decrement active connections
280    pub fn dec_active_connections(&self) {
281        self.active_connections.dec();
282    }
283
284    /// Export metrics as Prometheus text format
285    ///
286    /// # Errors
287    ///
288    /// Returns an error if encoding fails
289    pub fn export(&self) -> crate::error::Result<String> {
290        let mut output = String::new();
291        encode(&mut output, self.registry.as_ref())
292            .map_err(|e| crate::error::Error::Other(format!("Failed to encode metrics: {e}")))?;
293        Ok(output)
294    }
295
296    /// Get the registry
297    #[must_use]
298    pub fn registry(&self) -> &Arc<Registry> {
299        &self.registry
300    }
301}
302
303impl Default for ServerMetrics {
304    fn default() -> Self {
305        Self::new()
306    }
307}
308
309/// Request timer for tracking request duration
310pub struct RequestTimer {
311    start: Instant,
312    tool: String,
313    metrics: Option<Arc<ServerMetrics>>,
314}
315
316impl RequestTimer {
317    /// Create a new request timer
318    #[must_use]
319    pub fn new(tool: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
320        Self {
321            start: Instant::now(),
322            tool: tool.to_string(),
323            metrics,
324        }
325    }
326
327    /// Record successful completion
328    pub fn success(self) {
329        self.record(true);
330    }
331
332    /// Record failed completion
333    pub fn failure(self) {
334        self.record(false);
335    }
336
337    fn record(self, success: bool) {
338        if let Some(metrics) = self.metrics {
339            metrics.record_request(&self.tool, success, self.start.elapsed());
340        }
341    }
342}
343
344/// HTTP request timer
345pub struct HttpRequestTimer {
346    start: Instant,
347    method: String,
348    host: String,
349    metrics: Option<Arc<ServerMetrics>>,
350}
351
352impl HttpRequestTimer {
353    /// Create a new HTTP request timer
354    #[must_use]
355    pub fn new(method: &str, host: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
356        Self {
357            start: Instant::now(),
358            method: method.to_string(),
359            host: host.to_string(),
360            metrics,
361        }
362    }
363
364    /// Record request completion with status code
365    pub fn finish(self, status: u16) {
366        if let Some(metrics) = self.metrics {
367            metrics.record_http_request(&self.method, status, &self.host, self.start.elapsed());
368        }
369    }
370}
371
372use std::sync::OnceLock;
373
374/// Global metrics instance (optional, for simple use cases)
375static GLOBAL_METRICS: OnceLock<Arc<ServerMetrics>> = OnceLock::new();
376
377/// Initialize global metrics
378pub fn init_global_metrics() {
379    let _ = GLOBAL_METRICS.set(Arc::new(ServerMetrics::new()));
380}
381
382/// Get global metrics
383///
384/// # Panics
385///
386/// Panics if global metrics have not been initialized
387#[must_use]
388pub fn global_metrics() -> Arc<ServerMetrics> {
389    GLOBAL_METRICS
390        .get()
391        .cloned()
392        .expect("Global metrics not initialized")
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    #[test]
400    fn test_metrics_creation() {
401        let metrics = ServerMetrics::new();
402        let output = metrics.export();
403        assert!(output.is_ok());
404        assert!(!output.unwrap().is_empty());
405    }
406
407    #[test]
408    fn test_request_recording() {
409        let metrics = ServerMetrics::new();
410
411        // Record successful request
412        metrics.record_request("test_tool", true, std::time::Duration::from_millis(100));
413
414        // Record failed request
415        metrics.record_request("test_tool", false, std::time::Duration::from_millis(200));
416
417        let output = metrics.export().unwrap();
418        assert!(output.contains("mcp_requests_total"));
419        assert!(output.contains("test_tool"));
420    }
421
422    #[test]
423    fn test_cache_metrics() {
424        let metrics = ServerMetrics::new();
425
426        metrics.record_cache_hit("memory");
427        metrics.record_cache_miss("memory");
428        metrics.update_cache_hit_rate(1, 1);
429
430        let output = metrics.export().unwrap();
431        assert!(output.contains("mcp_cache_operations_total"));
432    }
433
434    #[test]
435    fn test_http_metrics() {
436        let metrics = ServerMetrics::new();
437
438        metrics.record_http_request("GET", 200, "docs.rs", std::time::Duration::from_millis(500));
439
440        let output = metrics.export().unwrap();
441        assert!(output.contains("mcp_http_requests_total"));
442    }
443
444    #[test]
445    fn test_request_timer() {
446        let metrics = Arc::new(ServerMetrics::new());
447        let timer = RequestTimer::new("test_tool", Some(metrics.clone()));
448        timer.success();
449
450        // Verify metrics were recorded
451        let output = metrics.export().unwrap();
452        assert!(output.contains("mcp_requests_total"));
453    }
454
455    #[test]
456    fn test_active_connections() {
457        let metrics = ServerMetrics::new();
458
459        metrics.inc_active_connections();
460        metrics.inc_active_connections();
461        metrics.dec_active_connections();
462
463        let output = metrics.export().unwrap();
464        assert!(output.contains("mcp_active_connections"));
465    }
466}