1use prometheus_client::encoding::text::encode;
6use prometheus_client::metrics::counter::Counter;
7use prometheus_client::metrics::family::Family;
8use prometheus_client::metrics::gauge::Gauge;
9use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
10use prometheus_client::registry::Registry;
11use std::sync::atomic::AtomicU64;
12use std::sync::Arc;
13use std::time::Instant;
14
15#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
17pub struct RequestLabels {
18 pub tool: String,
20 pub status: String,
22}
23
24#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
26pub struct CacheLabels {
27 pub operation: String,
29 pub cache_type: String,
31}
32
33#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
35pub struct HttpLabels {
36 pub method: String,
38 pub status: String,
40 pub host: String,
42}
43
44pub struct ServerMetrics {
46 request_counter: Family<RequestLabels, Counter>,
48 request_duration: Family<RequestLabels, Histogram>,
50 cache_counter: Family<CacheLabels, Counter>,
52 cache_hit_rate: Gauge<f64, AtomicU64>,
54 http_counter: Family<HttpLabels, Counter>,
56 http_duration: Family<HttpLabels, Histogram>,
58 active_connections: Gauge<u64, AtomicU64>,
60 error_counter: Family<RequestLabels, Counter>,
62 registry: Arc<Registry>,
64}
65
66impl ServerMetrics {
67 #[must_use]
69 pub fn new() -> Self {
70 let mut registry = Registry::default();
71
72 let request_counter = Family::<RequestLabels, Counter>::default();
74 registry.register(
75 "mcp_requests_total",
76 "Total number of MCP tool requests",
77 request_counter.clone(),
78 );
79
80 let request_duration = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
82 Histogram::new(exponential_buckets(0.001, 2.0, 15))
83 });
84 registry.register(
85 "mcp_request_duration_seconds",
86 "MCP tool request duration in seconds",
87 request_duration.clone(),
88 );
89
90 let cache_counter = Family::<CacheLabels, Counter>::default();
92 registry.register(
93 "mcp_cache_operations_total",
94 "Total number of cache operations",
95 cache_counter.clone(),
96 );
97
98 let cache_hit_rate = Gauge::default();
100 registry.register(
101 "mcp_cache_hit_rate",
102 "Cache hit rate (0.0 to 1.0)",
103 cache_hit_rate.clone(),
104 );
105
106 let http_counter = Family::<HttpLabels, Counter>::default();
108 registry.register(
109 "mcp_http_requests_total",
110 "Total number of HTTP requests",
111 http_counter.clone(),
112 );
113
114 let http_duration = Family::<HttpLabels, Histogram>::new_with_constructor(|| {
116 Histogram::new(exponential_buckets(0.001, 2.0, 15))
117 });
118 registry.register(
119 "mcp_http_request_duration_seconds",
120 "HTTP request duration in seconds",
121 http_duration.clone(),
122 );
123
124 let active_connections = Gauge::<u64, AtomicU64>::default();
126 registry.register(
127 "mcp_active_connections",
128 "Number of active connections",
129 active_connections.clone(),
130 );
131
132 let error_counter = Family::<RequestLabels, Counter>::default();
134 registry.register(
135 "mcp_errors_total",
136 "Total number of errors",
137 error_counter.clone(),
138 );
139
140 Self {
141 request_counter,
142 request_duration,
143 cache_counter,
144 cache_hit_rate,
145 http_counter,
146 http_duration,
147 active_connections,
148 error_counter,
149 registry: Arc::new(registry),
150 }
151 }
152
153 pub fn record_request(&self, tool: &str, success: bool, duration: std::time::Duration) {
155 let labels = RequestLabels {
156 tool: tool.to_string(),
157 status: if success {
158 "success".to_string()
159 } else {
160 "error".to_string()
161 },
162 };
163
164 self.request_counter.get_or_create(&labels).inc();
165 self.request_duration
166 .get_or_create(&labels)
167 .observe(duration.as_secs_f64());
168
169 if !success {
170 self.error_counter.get_or_create(&labels).inc();
171 }
172 }
173
174 pub fn record_cache_operation(&self, operation: &str, cache_type: &str) {
176 let labels = CacheLabels {
177 operation: operation.to_string(),
178 cache_type: cache_type.to_string(),
179 };
180 self.cache_counter.get_or_create(&labels).inc();
181 }
182
183 pub fn record_cache_hit(&self, cache_type: &str) {
185 self.record_cache_operation("hit", cache_type);
186 }
187
188 pub fn record_cache_miss(&self, cache_type: &str) {
190 self.record_cache_operation("miss", cache_type);
191 }
192
193 #[allow(clippy::cast_precision_loss)]
195 pub fn update_cache_hit_rate(&self, hits: u64, misses: u64) {
196 let total = hits + misses;
197 if total > 0 {
198 let rate = hits as f64 / total as f64;
199 self.cache_hit_rate.set(rate);
200 }
201 }
202
203 pub fn record_http_request(
205 &self,
206 method: &str,
207 status: u16,
208 host: &str,
209 duration: std::time::Duration,
210 ) {
211 let labels = HttpLabels {
212 method: method.to_string(),
213 status: status.to_string(),
214 host: host.to_string(),
215 };
216
217 self.http_counter.get_or_create(&labels).inc();
218 self.http_duration
219 .get_or_create(&labels)
220 .observe(duration.as_secs_f64());
221 }
222
223 pub fn inc_active_connections(&self) {
225 self.active_connections.inc();
226 }
227
228 pub fn dec_active_connections(&self) {
230 self.active_connections.dec();
231 }
232
233 pub fn export(&self) -> crate::error::Result<String> {
239 let mut output = String::new();
240 encode(&mut output, self.registry.as_ref())
241 .map_err(|e| crate::error::Error::Other(format!("Failed to encode metrics: {e}")))?;
242 Ok(output)
243 }
244
245 #[must_use]
247 pub fn registry(&self) -> &Arc<Registry> {
248 &self.registry
249 }
250}
251
252impl Default for ServerMetrics {
253 fn default() -> Self {
254 Self::new()
255 }
256}
257
258pub struct RequestTimer {
260 start: Instant,
261 tool: String,
262 metrics: Option<Arc<ServerMetrics>>,
263}
264
265impl RequestTimer {
266 #[must_use]
268 pub fn new(tool: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
269 Self {
270 start: Instant::now(),
271 tool: tool.to_string(),
272 metrics,
273 }
274 }
275
276 pub fn success(self) {
278 self.record(true);
279 }
280
281 pub fn failure(self) {
283 self.record(false);
284 }
285
286 fn record(self, success: bool) {
287 if let Some(metrics) = self.metrics {
288 metrics.record_request(&self.tool, success, self.start.elapsed());
289 }
290 }
291}
292
293pub struct HttpRequestTimer {
295 start: Instant,
296 method: String,
297 host: String,
298 metrics: Option<Arc<ServerMetrics>>,
299}
300
301impl HttpRequestTimer {
302 #[must_use]
304 pub fn new(method: &str, host: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
305 Self {
306 start: Instant::now(),
307 method: method.to_string(),
308 host: host.to_string(),
309 metrics,
310 }
311 }
312
313 pub fn finish(self, status: u16) {
315 if let Some(metrics) = self.metrics {
316 metrics.record_http_request(&self.method, status, &self.host, self.start.elapsed());
317 }
318 }
319}
320
321use std::sync::OnceLock;
322
323static GLOBAL_METRICS: OnceLock<Arc<ServerMetrics>> = OnceLock::new();
325
326pub fn init_global_metrics() {
328 let _ = GLOBAL_METRICS.set(Arc::new(ServerMetrics::new()));
329}
330
331#[must_use]
337pub fn global_metrics() -> Arc<ServerMetrics> {
338 GLOBAL_METRICS
339 .get()
340 .cloned()
341 .expect("Global metrics not initialized")
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_metrics_creation() {
350 let metrics = ServerMetrics::new();
351 let output = metrics.export();
352 assert!(output.is_ok());
353 assert!(!output.unwrap().is_empty());
354 }
355
356 #[test]
357 fn test_request_recording() {
358 let metrics = ServerMetrics::new();
359
360 metrics.record_request("test_tool", true, std::time::Duration::from_millis(100));
362
363 metrics.record_request("test_tool", false, std::time::Duration::from_millis(200));
365
366 let output = metrics.export().unwrap();
367 assert!(output.contains("mcp_requests_total"));
368 assert!(output.contains("test_tool"));
369 }
370
371 #[test]
372 fn test_cache_metrics() {
373 let metrics = ServerMetrics::new();
374
375 metrics.record_cache_hit("memory");
376 metrics.record_cache_miss("memory");
377 metrics.update_cache_hit_rate(1, 1);
378
379 let output = metrics.export().unwrap();
380 assert!(output.contains("mcp_cache_operations_total"));
381 }
382
383 #[test]
384 fn test_http_metrics() {
385 let metrics = ServerMetrics::new();
386
387 metrics.record_http_request("GET", 200, "docs.rs", std::time::Duration::from_millis(500));
388
389 let output = metrics.export().unwrap();
390 assert!(output.contains("mcp_http_requests_total"));
391 }
392
393 #[test]
394 fn test_request_timer() {
395 let metrics = Arc::new(ServerMetrics::new());
396 let timer = RequestTimer::new("test_tool", Some(metrics.clone()));
397 timer.success();
398
399 let output = metrics.export().unwrap();
401 assert!(output.contains("mcp_requests_total"));
402 }
403
404 #[test]
405 fn test_active_connections() {
406 let metrics = ServerMetrics::new();
407
408 metrics.inc_active_connections();
409 metrics.inc_active_connections();
410 metrics.dec_active_connections();
411
412 let output = metrics.export().unwrap();
413 assert!(output.contains("mcp_active_connections"));
414 }
415}