use prometheus_client::encoding::text::encode;
use prometheus_client::metrics::counter::Counter;
use prometheus_client::metrics::family::Family;
use prometheus_client::metrics::gauge::Gauge;
use prometheus_client::metrics::histogram::{exponential_buckets, Histogram};
use prometheus_client::registry::Registry;
use std::sync::atomic::AtomicU64;
use std::sync::Arc;
use std::time::Instant;
#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
pub struct RequestLabels {
pub tool: String,
pub status: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
pub struct CacheLabels {
pub operation: String,
pub cache_type: String,
}
#[derive(Clone, Debug, Hash, PartialEq, Eq, prometheus_client::encoding::EncodeLabelSet)]
pub struct HttpLabels {
pub method: String,
pub status: String,
pub host: String,
}
pub struct ServerMetrics {
request_counter: Family<RequestLabels, Counter>,
request_duration: Family<RequestLabels, Histogram>,
cache_counter: Family<CacheLabels, Counter>,
cache_hits: Gauge<u64, AtomicU64>,
cache_misses: Gauge<u64, AtomicU64>,
cache_sets: Gauge<u64, AtomicU64>,
cache_hit_rate: Gauge<f64, AtomicU64>,
http_counter: Family<HttpLabels, Counter>,
http_duration: Family<HttpLabels, Histogram>,
active_connections: Gauge<u64, AtomicU64>,
error_counter: Family<RequestLabels, Counter>,
registry: Arc<Registry>,
}
impl ServerMetrics {
#[must_use]
pub fn new() -> Self {
let mut registry = Registry::default();
let request_counter = Family::<RequestLabels, Counter>::default();
registry.register(
"mcp_requests_total",
"Total number of MCP tool requests",
request_counter.clone(),
);
let request_duration = Family::<RequestLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.001, 2.0, 15))
});
registry.register(
"mcp_request_duration_seconds",
"MCP tool request duration in seconds",
request_duration.clone(),
);
let cache_counter = Family::<CacheLabels, Counter>::default();
registry.register(
"mcp_cache_operations_total",
"Total number of cache operations",
cache_counter.clone(),
);
let cache_hits = Gauge::default();
registry.register(
"mcp_cache_hits",
"Number of cache hits (gauge)",
cache_hits.clone(),
);
let cache_misses = Gauge::default();
registry.register(
"mcp_cache_misses",
"Number of cache misses (gauge)",
cache_misses.clone(),
);
let cache_sets = Gauge::default();
registry.register(
"mcp_cache_sets",
"Number of cache set operations (gauge)",
cache_sets.clone(),
);
let cache_hit_rate = Gauge::default();
registry.register(
"mcp_cache_hit_rate",
"Cache hit rate (0.0 to 1.0)",
cache_hit_rate.clone(),
);
let http_counter = Family::<HttpLabels, Counter>::default();
registry.register(
"mcp_http_requests_total",
"Total number of HTTP requests",
http_counter.clone(),
);
let http_duration = Family::<HttpLabels, Histogram>::new_with_constructor(|| {
Histogram::new(exponential_buckets(0.001, 2.0, 15))
});
registry.register(
"mcp_http_request_duration_seconds",
"HTTP request duration in seconds",
http_duration.clone(),
);
let active_connections = Gauge::<u64, AtomicU64>::default();
registry.register(
"mcp_active_connections",
"Number of active connections",
active_connections.clone(),
);
let error_counter = Family::<RequestLabels, Counter>::default();
registry.register(
"mcp_errors_total",
"Total number of errors",
error_counter.clone(),
);
Self {
request_counter,
request_duration,
cache_counter,
cache_hits,
cache_misses,
cache_sets,
cache_hit_rate,
http_counter,
http_duration,
active_connections,
error_counter,
registry: Arc::new(registry),
}
}
pub fn record_request(&self, tool: &str, success: bool, duration: std::time::Duration) {
let labels = RequestLabels {
tool: tool.to_string(),
status: if success {
"success".to_string()
} else {
"error".to_string()
},
};
self.request_counter.get_or_create(&labels).inc();
self.request_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
if !success {
self.error_counter.get_or_create(&labels).inc();
}
}
pub fn record_cache_operation(&self, operation: &str, cache_type: &str) {
let labels = CacheLabels {
operation: operation.to_string(),
cache_type: cache_type.to_string(),
};
self.cache_counter.get_or_create(&labels).inc();
}
pub fn record_cache_hit(&self, cache_type: &str) {
self.record_cache_operation("hit", cache_type);
}
pub fn record_cache_miss(&self, cache_type: &str) {
self.record_cache_operation("miss", cache_type);
}
#[allow(clippy::cast_precision_loss)]
pub fn update_cache_hit_rate(&self, hits: u64, misses: u64) {
let total = hits + misses;
if total > 0 {
let rate = hits as f64 / total as f64;
self.cache_hit_rate.set(rate);
}
}
pub fn update_cache_stats(&self, hits: u64, misses: u64, sets: u64) {
self.cache_hits.set(hits);
self.cache_misses.set(misses);
self.cache_sets.set(sets);
self.update_cache_hit_rate(hits, misses);
}
pub fn record_http_request(
&self,
method: &str,
status: u16,
host: &str,
duration: std::time::Duration,
) {
let labels = HttpLabels {
method: method.to_string(),
status: status.to_string(),
host: host.to_string(),
};
self.http_counter.get_or_create(&labels).inc();
self.http_duration
.get_or_create(&labels)
.observe(duration.as_secs_f64());
}
pub fn inc_active_connections(&self) {
self.active_connections.inc();
}
pub fn dec_active_connections(&self) {
self.active_connections.dec();
}
pub fn export(&self) -> crate::error::Result<String> {
let mut output = String::new();
encode(&mut output, self.registry.as_ref())
.map_err(|e| crate::error::Error::Other(format!("Failed to encode metrics: {e}")))?;
Ok(output)
}
#[must_use]
pub fn registry(&self) -> &Arc<Registry> {
&self.registry
}
}
impl Default for ServerMetrics {
fn default() -> Self {
Self::new()
}
}
pub struct RequestTimer {
start: Instant,
tool: String,
metrics: Option<Arc<ServerMetrics>>,
}
impl RequestTimer {
#[must_use]
pub fn new(tool: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
Self {
start: Instant::now(),
tool: tool.to_string(),
metrics,
}
}
pub fn success(self) {
self.record(true);
}
pub fn failure(self) {
self.record(false);
}
fn record(self, success: bool) {
if let Some(metrics) = self.metrics {
metrics.record_request(&self.tool, success, self.start.elapsed());
}
}
}
pub struct HttpRequestTimer {
start: Instant,
method: String,
host: String,
metrics: Option<Arc<ServerMetrics>>,
}
impl HttpRequestTimer {
#[must_use]
pub fn new(method: &str, host: &str, metrics: Option<Arc<ServerMetrics>>) -> Self {
Self {
start: Instant::now(),
method: method.to_string(),
host: host.to_string(),
metrics,
}
}
pub fn finish(self, status: u16) {
if let Some(metrics) = self.metrics {
metrics.record_http_request(&self.method, status, &self.host, self.start.elapsed());
}
}
}
use std::sync::OnceLock;
static GLOBAL_METRICS: OnceLock<Arc<ServerMetrics>> = OnceLock::new();
pub fn init_global_metrics() {
let _ = GLOBAL_METRICS.set(Arc::new(ServerMetrics::new()));
}
#[must_use]
pub fn global_metrics() -> Arc<ServerMetrics> {
GLOBAL_METRICS
.get()
.cloned()
.expect("Global metrics not initialized")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_metrics_creation() {
let metrics = ServerMetrics::new();
let output = metrics.export();
assert!(output.is_ok());
assert!(!output.unwrap().is_empty());
}
#[test]
fn test_request_recording() {
let metrics = ServerMetrics::new();
metrics.record_request("test_tool", true, std::time::Duration::from_millis(100));
metrics.record_request("test_tool", false, std::time::Duration::from_millis(200));
let output = metrics.export().unwrap();
assert!(output.contains("mcp_requests_total"));
assert!(output.contains("test_tool"));
}
#[test]
fn test_cache_metrics() {
let metrics = ServerMetrics::new();
metrics.record_cache_hit("memory");
metrics.record_cache_miss("memory");
metrics.update_cache_hit_rate(1, 1);
let output = metrics.export().unwrap();
assert!(output.contains("mcp_cache_operations_total"));
}
#[test]
fn test_http_metrics() {
let metrics = ServerMetrics::new();
metrics.record_http_request("GET", 200, "docs.rs", std::time::Duration::from_millis(500));
let output = metrics.export().unwrap();
assert!(output.contains("mcp_http_requests_total"));
}
#[test]
fn test_request_timer() {
let metrics = Arc::new(ServerMetrics::new());
let timer = RequestTimer::new("test_tool", Some(metrics.clone()));
timer.success();
let output = metrics.export().unwrap();
assert!(output.contains("mcp_requests_total"));
}
#[test]
fn test_active_connections() {
let metrics = ServerMetrics::new();
metrics.inc_active_connections();
metrics.inc_active_connections();
metrics.dec_active_connections();
let output = metrics.export().unwrap();
assert!(output.contains("mcp_active_connections"));
}
}