1use prometheus_client::encoding::text::encode;
6use prometheus_client::metrics::counter::Counter;
7use prometheus_client::metrics::family::Family;
8use prometheus_client::metrics::gauge::Gauge;
9use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
10use prometheus_client::registry::Registry;
11use std::sync::atomic::AtomicU64;
12use std::sync::Arc;
13use std::time::Instant;
14
15#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
17pub struct RequestLabels {
18 pub tool: String,
20 pub status: String,
22}
23
24#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
26pub struct CacheLabels {
27 pub operation: String,
29 pub cache_type: String,
31}
32
33#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
35pub struct HttpLabels {
36 pub method: String,
38 pub status: String,
40 pub host: String,
42}
43
44pub struct ServerMetrics {
46 request_counter: Family<RequestLabels, Counter>,
48 request_duration: Family<RequestLabels, Histogram>,
50 cache_counter: Family<CacheLabels, Counter>,
52 cache_hits: Gauge<u64, AtomicU64>,
54 cache_misses: Gauge<u64, AtomicU64>,
56 cache_sets: Gauge<u64, AtomicU64>,
58 cache_hit_rate: Gauge<f64, AtomicU64>,
60 http_counter: Family<HttpLabels, Counter>,
62 http_duration: Family<HttpLabels, Histogram>,
64 active_connections: Gauge<u64, AtomicU64>,
66 error_counter: Family<RequestLabels, Counter>,
68 registry: Arc<Registry>,
70}
71
72impl ServerMetrics {
73 #[must_use]
75 pub fn new() -> Self {
76 let mut registry = Registry::default();
77
78 let request_counter = Family::<RequestLabels, Counter>::default();
80 registry.register(
81 "mcp_requests_total",
82 "Total number of MCP tool requests",
83 request_counter.clone(),
84 );
85
86 let request_duration = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
88 Histogram::new(exponential_buckets(0.001, 2.0, 15))
89 });
90 registry.register(
91 "mcp_request_duration_seconds",
92 "MCP tool request duration in seconds",
93 request_duration.clone(),
94 );
95
96 let cache_counter = Family::<CacheLabels, Counter>::default();
98 registry.register(
99 "mcp_cache_operations_total",
100 "Total number of cache operations",
101 cache_counter.clone(),
102 );
103
104 let cache_hits = Gauge::default();
106 registry.register(
107 "mcp_cache_hits",
108 "Number of cache hits (gauge)",
109 cache_hits.clone(),
110 );
111
112 let cache_misses = Gauge::default();
114 registry.register(
115 "mcp_cache_misses",
116 "Number of cache misses (gauge)",
117 cache_misses.clone(),
118 );
119
120 let cache_sets = Gauge::default();
122 registry.register(
123 "mcp_cache_sets",
124 "Number of cache set operations (gauge)",
125 cache_sets.clone(),
126 );
127
128 let cache_hit_rate = Gauge::default();
130 registry.register(
131 "mcp_cache_hit_rate",
132 "Cache hit rate (0.0 to 1.0)",
133 cache_hit_rate.clone(),
134 );
135
136 let http_counter = Family::<HttpLabels, Counter>::default();
138 registry.register(
139 "mcp_http_requests_total",
140 "Total number of HTTP requests",
141 http_counter.clone(),
142 );
143
144 let http_duration = Family::<HttpLabels, Histogram>::new_with_constructor(|| {
146 Histogram::new(exponential_buckets(0.001, 2.0, 15))
147 });
148 registry.register(
149 "mcp_http_request_duration_seconds",
150 "HTTP request duration in seconds",
151 http_duration.clone(),
152 );
153
154 let active_connections = Gauge::<u64, AtomicU64>::default();
156 registry.register(
157 "mcp_active_connections",
158 "Number of active connections",
159 active_connections.clone(),
160 );
161
162 let error_counter = Family::<RequestLabels, Counter>::default();
164 registry.register(
165 "mcp_errors_total",
166 "Total number of errors",
167 error_counter.clone(),
168 );
169
170 Self {
171 request_counter,
172 request_duration,
173 cache_counter,
174 cache_hits,
175 cache_misses,
176 cache_sets,
177 cache_hit_rate,
178 http_counter,
179 http_duration,
180 active_connections,
181 error_counter,
182 registry: Arc::new(registry),
183 }
184 }
185
186 pub fn record_request(&self, tool: &str, success: bool, duration: std::time::Duration) {
188 let labels = RequestLabels {
189 tool: tool.to_string(),
190 status: if success {
191 "success".to_string()
192 } else {
193 "error".to_string()
194 },
195 };
196
197 self.request_counter.get_or_create(&labels).inc();
198 self.request_duration
199 .get_or_create(&labels)
200 .observe(duration.as_secs_f64());
201
202 if !success {
203 self.error_counter.get_or_create(&labels).inc();
204 }
205 }
206
207 pub fn record_cache_operation(&self, operation: &str, cache_type: &str) {
209 let labels = CacheLabels {
210 operation: operation.to_string(),
211 cache_type: cache_type.to_string(),
212 };
213 self.cache_counter.get_or_create(&labels).inc();
214 }
215
216 pub fn record_cache_hit(&self, cache_type: &str) {
218 self.record_cache_operation("hit", cache_type);
219 }
220
221 pub fn record_cache_miss(&self, cache_type: &str) {
223 self.record_cache_operation("miss", cache_type);
224 }
225
226 #[allow(clippy::cast_precision_loss)]
228 pub fn update_cache_hit_rate(&self, hits: u64, misses: u64) {
229 let total = hits + misses;
230 if total > 0 {
231 let rate = hits as f64 / total as f64;
232 self.cache_hit_rate.set(rate);
233 }
234 }
235
236 pub fn update_cache_stats(&self, hits: u64, misses: u64, sets: u64) {
248 self.cache_hits.set(hits);
249 self.cache_misses.set(misses);
250 self.cache_sets.set(sets);
251 self.update_cache_hit_rate(hits, misses);
252 }
253
254 pub fn record_http_request(
256 &self,
257 method: &str,
258 status: u16,
259 host: &str,
260 duration: std::time::Duration,
261 ) {
262 let labels = HttpLabels {
263 method: method.to_string(),
264 status: status.to_string(),
265 host: host.to_string(),
266 };
267
268 self.http_counter.get_or_create(&labels).inc();
269 self.http_duration
270 .get_or_create(&labels)
271 .observe(duration.as_secs_f64());
272 }
273
274 pub fn inc_active_connections(&self) {
276 self.active_connections.inc();
277 }
278
279 pub fn dec_active_connections(&self) {
281 self.active_connections.dec();
282 }
283
284 pub fn export(&self) -> crate::error::Result<String> {
290 let mut output = String::new();
291 encode(&mut output, self.registry.as_ref())
292 .map_err(|e| crate::error::Error::Other(format!("Failed to encode metrics: {e}")))?;
293 Ok(output)
294 }
295
296 #[must_use]
298 pub fn registry(&self) -> &Arc<Registry> {
299 &self.registry
300 }
301}
302
303impl Default for ServerMetrics {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309pub struct RequestTimer {
311 start: Instant,
312 tool: String,
313 metrics: Option<Arc<ServerMetrics>>,
314}
315
316impl RequestTimer {
317 #[must_use]
319 pub fn new(tool: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
320 Self {
321 start: Instant::now(),
322 tool: tool.to_string(),
323 metrics,
324 }
325 }
326
327 pub fn success(self) {
329 self.record(true);
330 }
331
332 pub fn failure(self) {
334 self.record(false);
335 }
336
337 fn record(self, success: bool) {
338 if let Some(metrics) = self.metrics {
339 metrics.record_request(&self.tool, success, self.start.elapsed());
340 }
341 }
342}
343
344pub struct HttpRequestTimer {
346 start: Instant,
347 method: String,
348 host: String,
349 metrics: Option<Arc<ServerMetrics>>,
350}
351
352impl HttpRequestTimer {
353 #[must_use]
355 pub fn new(method: &str, host: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
356 Self {
357 start: Instant::now(),
358 method: method.to_string(),
359 host: host.to_string(),
360 metrics,
361 }
362 }
363
364 pub fn finish(self, status: u16) {
366 if let Some(metrics) = self.metrics {
367 metrics.record_http_request(&self.method, status, &self.host, self.start.elapsed());
368 }
369 }
370}
371
372use std::sync::OnceLock;
373
374static GLOBAL_METRICS: OnceLock<Arc<ServerMetrics>> = OnceLock::new();
376
377pub fn init_global_metrics() {
379 let _ = GLOBAL_METRICS.set(Arc::new(ServerMetrics::new()));
380}
381
382#[must_use]
388pub fn global_metrics() -> Arc<ServerMetrics> {
389 GLOBAL_METRICS
390 .get()
391 .cloned()
392 .expect("Global metrics not initialized")
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
400 fn test_metrics_creation() {
401 let metrics = ServerMetrics::new();
402 let output = metrics.export();
403 assert!(output.is_ok());
404 assert!(!output.unwrap().is_empty());
405 }
406
407 #[test]
408 fn test_request_recording() {
409 let metrics = ServerMetrics::new();
410
411 metrics.record_request("test_tool", true, std::time::Duration::from_millis(100));
413
414 metrics.record_request("test_tool", false, std::time::Duration::from_millis(200));
416
417 let output = metrics.export().unwrap();
418 assert!(output.contains("mcp_requests_total"));
419 assert!(output.contains("test_tool"));
420 }
421
422 #[test]
423 fn test_cache_metrics() {
424 let metrics = ServerMetrics::new();
425
426 metrics.record_cache_hit("memory");
427 metrics.record_cache_miss("memory");
428 metrics.update_cache_hit_rate(1, 1);
429
430 let output = metrics.export().unwrap();
431 assert!(output.contains("mcp_cache_operations_total"));
432 }
433
434 #[test]
435 fn test_http_metrics() {
436 let metrics = ServerMetrics::new();
437
438 metrics.record_http_request("GET", 200, "docs.rs", std::time::Duration::from_millis(500));
439
440 let output = metrics.export().unwrap();
441 assert!(output.contains("mcp_http_requests_total"));
442 }
443
444 #[test]
445 fn test_request_timer() {
446 let metrics = Arc::new(ServerMetrics::new());
447 let timer = RequestTimer::new("test_tool", Some(metrics.clone()));
448 timer.success();
449
450 let output = metrics.export().unwrap();
452 assert!(output.contains("mcp_requests_total"));
453 }
454
455 #[test]
456 fn test_active_connections() {
457 let metrics = ServerMetrics::new();
458
459 metrics.inc_active_connections();
460 metrics.inc_active_connections();
461 metrics.dec_active_connections();
462
463 let output = metrics.export().unwrap();
464 assert!(output.contains("mcp_active_connections"));
465 }
466}