use std::sync::Arc;
use prometheus::{
Encoder, HistogramOpts, HistogramVec, IntCounterVec, Registry, TextEncoder, opts,
};
use crate::error::McpxError;
const HTTP_DURATION_BUCKETS: &[f64] = &[
0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
];
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct McpMetrics {
pub registry: Registry,
pub http_requests_total: IntCounterVec,
pub http_request_duration_seconds: HistogramVec,
}
impl McpMetrics {
pub fn new() -> Result<Self, McpxError> {
let registry = Registry::new();
let http_requests_total = IntCounterVec::new(
opts!("rmcp_server_kit_http_requests_total", "Total HTTP requests"),
&["method", "path", "status"],
)
.map_err(|e| McpxError::Metrics(e.to_string()))?;
registry
.register(Box::new(http_requests_total.clone()))
.map_err(|e| McpxError::Metrics(e.to_string()))?;
let http_request_duration_seconds = HistogramVec::new(
HistogramOpts::new(
"rmcp_server_kit_http_request_duration_seconds",
"HTTP request duration in seconds",
)
.buckets(HTTP_DURATION_BUCKETS.to_vec()),
&["method", "path"],
)
.map_err(|e| McpxError::Metrics(e.to_string()))?;
registry
.register(Box::new(http_request_duration_seconds.clone()))
.map_err(|e| McpxError::Metrics(e.to_string()))?;
Ok(Self {
registry,
http_requests_total,
http_request_duration_seconds,
})
}
#[must_use]
pub fn encode(&self) -> String {
let encoder = TextEncoder::new();
let metric_families = self.registry.gather();
let mut buf = Vec::new();
if let Err(e) = encoder.encode(&metric_families, &mut buf) {
tracing::warn!(error = %e, "prometheus encode failed");
return String::new();
}
String::from_utf8(buf).unwrap_or_default()
}
}
pub async fn serve_metrics(bind: String, metrics: Arc<McpMetrics>) -> Result<(), McpxError> {
let app = axum::Router::new().route(
"/metrics",
axum::routing::get(move || {
let m = Arc::clone(&metrics);
async move { m.encode() }
}),
);
let listener = tokio::net::TcpListener::bind(&bind)
.await
.map_err(|e| McpxError::Startup(format!("metrics bind {bind}: {e}")))?;
tracing::info!("metrics endpoint listening on http://{bind}/metrics");
axum::serve(listener, app)
.await
.map_err(|e| McpxError::Startup(format!("metrics serve: {e}")))?;
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(
clippy::unwrap_used,
clippy::expect_used,
clippy::panic,
clippy::indexing_slicing,
clippy::unwrap_in_result,
clippy::print_stdout,
clippy::print_stderr
)]
use super::*;
#[test]
fn new_creates_registry_with_counters() {
let m = McpMetrics::new().unwrap();
m.http_requests_total
.with_label_values(&["GET", "/test", "200"])
.inc();
m.http_request_duration_seconds
.with_label_values(&["GET", "/test"])
.observe(0.1);
assert_eq!(m.registry.gather().len(), 2);
}
#[test]
fn encode_empty_registry() {
let m = McpMetrics::new().unwrap();
let output = m.encode();
assert!(output.is_empty() || output.contains("rmcp_server_kit_"));
}
#[test]
fn counter_increment_shows_in_encode() {
let m = McpMetrics::new().unwrap();
m.http_requests_total
.with_label_values(&["GET", "/healthz", "200"])
.inc();
let output = m.encode();
assert!(output.contains("rmcp_server_kit_http_requests_total"));
assert!(output.contains("method=\"GET\""));
assert!(output.contains("path=\"/healthz\""));
assert!(output.contains("status=\"200\""));
assert!(output.contains(" 1")); }
#[test]
fn histogram_observe_shows_in_encode() {
let m = McpMetrics::new().unwrap();
m.http_request_duration_seconds
.with_label_values(&["POST", "/mcp"])
.observe(0.042);
let output = m.encode();
assert!(output.contains("rmcp_server_kit_http_request_duration_seconds"));
assert!(output.contains("method=\"POST\""));
assert!(output.contains("path=\"/mcp\""));
}
#[test]
fn multiple_increments_accumulate() {
let m = McpMetrics::new().unwrap();
let counter = m
.http_requests_total
.with_label_values(&["POST", "/mcp", "200"]);
counter.inc();
counter.inc();
counter.inc();
let output = m.encode();
assert!(output.contains(" 3")); }
#[test]
fn clone_shares_registry() {
let m = McpMetrics::new().unwrap();
let m2 = m.clone();
m.http_requests_total
.with_label_values(&["GET", "/test", "200"])
.inc();
let output = m2.encode();
assert!(output.contains(" 1"));
}
}