Skip to main content

crates_docs/metrics/
mod.rs

1//! Metrics module for Prometheus monitoring
2//!
3//! Provides metrics collection and export functionality for the MCP server.
4
5use prometheus_client::encoding::text::encode;
6use prometheus_client::metrics::counter::Counter;
7use prometheus_client::metrics::family::Family;
8use prometheus_client::metrics::gauge::Gauge;
9use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
10use prometheus_client::registry::Registry;
11use std::sync::atomic::AtomicU64;
12use std::sync::Arc;
13use std::time::Instant;
14
15/// Metrics labels for request tracking
16#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
17pub struct RequestLabels {
18    /// Tool name
19    pub tool: String,
20    /// Status: success or error
21    pub status: String,
22}
23
24/// Metrics labels for cache operations
25#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
26pub struct CacheLabels {
27    /// Operation type: get, set, hit, miss
28    pub operation: String,
29    /// Cache type: memory, redis
30    pub cache_type: String,
31}
32
33/// Metrics labels for HTTP requests
34#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
35pub struct HttpLabels {
36    /// HTTP method
37    pub method: String,
38    /// HTTP status code
39    pub status: String,
40    /// Target host
41    pub host: String,
42}
43
44/// Server metrics collection
45pub struct ServerMetrics {
46    /// Request counter
47    request_counter: Family<RequestLabels, Counter>,
48    /// Request duration histogram
49    request_duration: Family<RequestLabels, Histogram>,
50    /// Cache operation counter
51    cache_counter: Family<CacheLabels, Counter>,
52    /// Cache hit rate gauge
53    cache_hit_rate: Gauge<f64, AtomicU64>,
54    /// HTTP request counter
55    http_counter: Family<HttpLabels, Counter>,
56    /// HTTP request duration
57    http_duration: Family<HttpLabels, Histogram>,
58    /// Active connections gauge
59    active_connections: Gauge<u64, AtomicU64>,
60    /// Error counter
61    error_counter: Family<RequestLabels, Counter>,
62    /// Registry
63    registry: Arc<Registry>,
64}
65
66impl ServerMetrics {
67    /// Create a new metrics collection
68    #[must_use]
69    pub fn new() -> Self {
70        let mut registry = Registry::default();
71
72        // Request counter
73        let request_counter = Family::<RequestLabels, Counter>::default();
74        registry.register(
75            "mcp_requests_total",
76            "Total number of MCP tool requests",
77            request_counter.clone(),
78        );
79
80        // Request duration histogram (exponential buckets from 1ms to 30s)
81        let request_duration = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
82            Histogram::new(exponential_buckets(0.001, 2.0, 15))
83        });
84        registry.register(
85            "mcp_request_duration_seconds",
86            "MCP tool request duration in seconds",
87            request_duration.clone(),
88        );
89
90        // Cache operation counter
91        let cache_counter = Family::<CacheLabels, Counter>::default();
92        registry.register(
93            "mcp_cache_operations_total",
94            "Total number of cache operations",
95            cache_counter.clone(),
96        );
97
98        // Cache hit rate gauge
99        let cache_hit_rate = Gauge::default();
100        registry.register(
101            "mcp_cache_hit_rate",
102            "Cache hit rate (0.0 to 1.0)",
103            cache_hit_rate.clone(),
104        );
105
106        // HTTP request counter
107        let http_counter = Family::<HttpLabels, Counter>::default();
108        registry.register(
109            "mcp_http_requests_total",
110            "Total number of HTTP requests",
111            http_counter.clone(),
112        );
113
114        // HTTP request duration
115        let http_duration = Family::<HttpLabels, Histogram>::new_with_constructor(|| {
116            Histogram::new(exponential_buckets(0.001, 2.0, 15))
117        });
118        registry.register(
119            "mcp_http_request_duration_seconds",
120            "HTTP request duration in seconds",
121            http_duration.clone(),
122        );
123
124        // Active connections gauge
125        let active_connections = Gauge::<u64, AtomicU64>::default();
126        registry.register(
127            "mcp_active_connections",
128            "Number of active connections",
129            active_connections.clone(),
130        );
131
132        // Error counter
133        let error_counter = Family::<RequestLabels, Counter>::default();
134        registry.register(
135            "mcp_errors_total",
136            "Total number of errors",
137            error_counter.clone(),
138        );
139
140        Self {
141            request_counter,
142            request_duration,
143            cache_counter,
144            cache_hit_rate,
145            http_counter,
146            http_duration,
147            active_connections,
148            error_counter,
149            registry: Arc::new(registry),
150        }
151    }
152
153    /// Record a tool request
154    pub fn record_request(&self, tool: &str, success: bool, duration: std::time::Duration) {
155        let labels = RequestLabels {
156            tool: tool.to_string(),
157            status: if success {
158                "success".to_string()
159            } else {
160                "error".to_string()
161            },
162        };
163
164        self.request_counter.get_or_create(&labels).inc();
165        self.request_duration
166            .get_or_create(&labels)
167            .observe(duration.as_secs_f64());
168
169        if !success {
170            self.error_counter.get_or_create(&labels).inc();
171        }
172    }
173
174    /// Record a cache operation
175    pub fn record_cache_operation(&self, operation: &str, cache_type: &str) {
176        let labels = CacheLabels {
177            operation: operation.to_string(),
178            cache_type: cache_type.to_string(),
179        };
180        self.cache_counter.get_or_create(&labels).inc();
181    }
182
183    /// Record a cache hit
184    pub fn record_cache_hit(&self, cache_type: &str) {
185        self.record_cache_operation("hit", cache_type);
186    }
187
188    /// Record a cache miss
189    pub fn record_cache_miss(&self, cache_type: &str) {
190        self.record_cache_operation("miss", cache_type);
191    }
192
193    /// Update cache hit rate
194    #[allow(clippy::cast_precision_loss)]
195    pub fn update_cache_hit_rate(&self, hits: u64, misses: u64) {
196        let total = hits + misses;
197        if total > 0 {
198            let rate = hits as f64 / total as f64;
199            self.cache_hit_rate.set(rate);
200        }
201    }
202
203    /// Record an HTTP request
204    pub fn record_http_request(
205        &self,
206        method: &str,
207        status: u16,
208        host: &str,
209        duration: std::time::Duration,
210    ) {
211        let labels = HttpLabels {
212            method: method.to_string(),
213            status: status.to_string(),
214            host: host.to_string(),
215        };
216
217        self.http_counter.get_or_create(&labels).inc();
218        self.http_duration
219            .get_or_create(&labels)
220            .observe(duration.as_secs_f64());
221    }
222
223    /// Increment active connections
224    pub fn inc_active_connections(&self) {
225        self.active_connections.inc();
226    }
227
228    /// Decrement active connections
229    pub fn dec_active_connections(&self) {
230        self.active_connections.dec();
231    }
232
233    /// Export metrics as Prometheus text format
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if encoding fails
238    pub fn export(&self) -> crate::error::Result<String> {
239        let mut output = String::new();
240        encode(&mut output, self.registry.as_ref())
241            .map_err(|e| crate::error::Error::Other(format!("Failed to encode metrics: {e}")))?;
242        Ok(output)
243    }
244
245    /// Get the registry
246    #[must_use]
247    pub fn registry(&self) -> &Arc<Registry> {
248        &self.registry
249    }
250}
251
252impl Default for ServerMetrics {
253    fn default() -> Self {
254        Self::new()
255    }
256}
257
258/// Request timer for tracking request duration
259pub struct RequestTimer {
260    start: Instant,
261    tool: String,
262    metrics: Option<Arc<ServerMetrics>>,
263}
264
265impl RequestTimer {
266    /// Create a new request timer
267    #[must_use]
268    pub fn new(tool: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
269        Self {
270            start: Instant::now(),
271            tool: tool.to_string(),
272            metrics,
273        }
274    }
275
276    /// Record successful completion
277    pub fn success(self) {
278        self.record(true);
279    }
280
281    /// Record failed completion
282    pub fn failure(self) {
283        self.record(false);
284    }
285
286    fn record(self, success: bool) {
287        if let Some(metrics) = self.metrics {
288            metrics.record_request(&self.tool, success, self.start.elapsed());
289        }
290    }
291}
292
293/// HTTP request timer
294pub struct HttpRequestTimer {
295    start: Instant,
296    method: String,
297    host: String,
298    metrics: Option<Arc<ServerMetrics>>,
299}
300
301impl HttpRequestTimer {
302    /// Create a new HTTP request timer
303    #[must_use]
304    pub fn new(method: &str, host: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
305        Self {
306            start: Instant::now(),
307            method: method.to_string(),
308            host: host.to_string(),
309            metrics,
310        }
311    }
312
313    /// Record request completion with status code
314    pub fn finish(self, status: u16) {
315        if let Some(metrics) = self.metrics {
316            metrics.record_http_request(&self.method, status, &self.host, self.start.elapsed());
317        }
318    }
319}
320
321use std::sync::OnceLock;
322
323/// Global metrics instance (optional, for simple use cases)
324static GLOBAL_METRICS: OnceLock<Arc<ServerMetrics>> = OnceLock::new();
325
326/// Initialize global metrics
327pub fn init_global_metrics() {
328    let _ = GLOBAL_METRICS.set(Arc::new(ServerMetrics::new()));
329}
330
331/// Get global metrics
332///
333/// # Panics
334///
335/// Panics if global metrics have not been initialized
336#[must_use]
337pub fn global_metrics() -> Arc<ServerMetrics> {
338    GLOBAL_METRICS
339        .get()
340        .cloned()
341        .expect("Global metrics not initialized")
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_metrics_creation() {
350        let metrics = ServerMetrics::new();
351        let output = metrics.export();
352        assert!(output.is_ok());
353        assert!(!output.unwrap().is_empty());
354    }
355
356    #[test]
357    fn test_request_recording() {
358        let metrics = ServerMetrics::new();
359
360        // Record successful request
361        metrics.record_request("test_tool", true, std::time::Duration::from_millis(100));
362
363        // Record failed request
364        metrics.record_request("test_tool", false, std::time::Duration::from_millis(200));
365
366        let output = metrics.export().unwrap();
367        assert!(output.contains("mcp_requests_total"));
368        assert!(output.contains("test_tool"));
369    }
370
371    #[test]
372    fn test_cache_metrics() {
373        let metrics = ServerMetrics::new();
374
375        metrics.record_cache_hit("memory");
376        metrics.record_cache_miss("memory");
377        metrics.update_cache_hit_rate(1, 1);
378
379        let output = metrics.export().unwrap();
380        assert!(output.contains("mcp_cache_operations_total"));
381    }
382
383    #[test]
384    fn test_http_metrics() {
385        let metrics = ServerMetrics::new();
386
387        metrics.record_http_request("GET", 200, "docs.rs", std::time::Duration::from_millis(500));
388
389        let output = metrics.export().unwrap();
390        assert!(output.contains("mcp_http_requests_total"));
391    }
392
393    #[test]
394    fn test_request_timer() {
395        let metrics = Arc::new(ServerMetrics::new());
396        let timer = RequestTimer::new("test_tool", Some(metrics.clone()));
397        timer.success();
398
399        // Verify metrics were recorded
400        let output = metrics.export().unwrap();
401        assert!(output.contains("mcp_requests_total"));
402    }
403
404    #[test]
405    fn test_active_connections() {
406        let metrics = ServerMetrics::new();
407
408        metrics.inc_active_connections();
409        metrics.inc_active_connections();
410        metrics.dec_active_connections();
411
412        let output = metrics.export().unwrap();
413        assert!(output.contains("mcp_active_connections"));
414    }
415}