Skip to main content

rmcp_server_kit/
metrics.rs

1//! Prometheus metrics for MCP servers.
2//!
3//! Provides a shared [`crate::metrics::McpMetrics`] registry with standard HTTP counters.
4//! The transport layer exposes these via a `/metrics` endpoint on a
5//! dedicated listener when `metrics_enabled` is true.
6//!
7//! # Public surface and the `prometheus` crate
8//!
9//! [`crate::metrics::McpMetrics::registry`] and the `IntCounterVec` / `HistogramVec` fields are
10//! intentionally exposed so downstream crates can register additional custom
11//! collectors against the same registry. This re-exports the [`prometheus`]
12//! crate types as part of `rmcp-server-kit`'s public API; pin the same major version to
13//! avoid type-identity mismatches when registering custom metrics.
14
15use std::sync::Arc;
16
17use prometheus::{
18    Encoder, HistogramOpts, HistogramVec, IntCounterVec, Registry, TextEncoder, opts,
19};
20
21use crate::error::McpxError;
22
23/// Default Prometheus histogram buckets for HTTP request latency
24/// (seconds). Tuned for low-latency service work: sub-millisecond
25/// through five seconds, covering health-check fast paths up to slow
26/// outbound dependencies. Operators that need different buckets can
27/// register their own histogram against
28/// [`McpMetrics::registry`].
29const HTTP_DURATION_BUCKETS: &[f64] = &[
30    0.001, 0.005, 0.01, 0.025, 0.05, 0.1, 0.25, 0.5, 1.0, 2.5, 5.0,
31];
32
33/// Collected Prometheus metrics for an MCP server.
34#[derive(Clone, Debug)]
35#[non_exhaustive]
36pub struct McpMetrics {
37    /// Prometheus registry holding all counters and histograms.
38    pub registry: Registry,
39    /// Total HTTP requests by method, path, and status code.
40    pub http_requests_total: IntCounterVec,
41    /// HTTP request duration in seconds by method and path.
42    pub http_request_duration_seconds: HistogramVec,
43}
44
45impl McpMetrics {
46    /// Create a new metrics registry with default MCP counters.
47    ///
48    /// # Errors
49    ///
50    /// Returns [`McpxError::Metrics`] if counter registration fails (should
51    /// not happen unless duplicate registrations occur).
52    pub fn new() -> Result<Self, McpxError> {
53        let registry = Registry::new();
54
55        let http_requests_total = IntCounterVec::new(
56            opts!("rmcp_server_kit_http_requests_total", "Total HTTP requests"),
57            &["method", "path", "status"],
58        )
59        .map_err(|e| McpxError::Metrics(e.to_string()))?;
60        registry
61            .register(Box::new(http_requests_total.clone()))
62            .map_err(|e| McpxError::Metrics(e.to_string()))?;
63
64        let http_request_duration_seconds = HistogramVec::new(
65            HistogramOpts::new(
66                "rmcp_server_kit_http_request_duration_seconds",
67                "HTTP request duration in seconds",
68            )
69            .buckets(HTTP_DURATION_BUCKETS.to_vec()),
70            &["method", "path"],
71        )
72        .map_err(|e| McpxError::Metrics(e.to_string()))?;
73        registry
74            .register(Box::new(http_request_duration_seconds.clone()))
75            .map_err(|e| McpxError::Metrics(e.to_string()))?;
76
77        Ok(Self {
78            registry,
79            http_requests_total,
80            http_request_duration_seconds,
81        })
82    }
83
84    /// Encode all collected metrics as Prometheus text format.
85    #[must_use]
86    pub fn encode(&self) -> String {
87        let encoder = TextEncoder::new();
88        let metric_families = self.registry.gather();
89        let mut buf = Vec::new();
90        if let Err(e) = encoder.encode(&metric_families, &mut buf) {
91            tracing::warn!(error = %e, "prometheus encode failed");
92            return String::new();
93        }
94        // TextEncoder always produces valid UTF-8; fall back to empty on
95        // the near-impossible chance it doesn't.
96        String::from_utf8(buf).unwrap_or_default()
97    }
98}
99
100/// Spawn a dedicated HTTP listener that serves Prometheus metrics on `/metrics`.
101///
102/// The listener exits and releases the bound port when `shutdown` is
103/// cancelled, keeping the metrics endpoint tied to the parent server's
104/// graceful-shutdown lifecycle (M7).
105///
106/// # Errors
107///
108/// Returns [`McpxError::Startup`] if the TCP listener cannot bind or the
109/// underlying axum server fails.
110pub async fn serve_metrics(
111    bind: String,
112    metrics: Arc<McpMetrics>,
113    shutdown: tokio_util::sync::CancellationToken,
114) -> Result<(), McpxError> {
115    let app = axum::Router::new().route(
116        "/metrics",
117        axum::routing::get(move || {
118            let m = Arc::clone(&metrics);
119            async move { m.encode() }
120        }),
121    );
122
123    let listener = tokio::net::TcpListener::bind(&bind)
124        .await
125        .map_err(|e| McpxError::Startup(format!("metrics bind {bind}: {e}")))?;
126    tracing::info!("metrics endpoint listening on http://{bind}/metrics");
127    axum::serve(listener, app)
128        .with_graceful_shutdown(async move { shutdown.cancelled().await })
129        .await
130        .map_err(|e| McpxError::Startup(format!("metrics serve: {e}")))?;
131    Ok(())
132}
133
134#[cfg(test)]
135mod tests {
136    #![allow(
137        clippy::unwrap_used,
138        clippy::expect_used,
139        clippy::panic,
140        clippy::indexing_slicing,
141        clippy::unwrap_in_result,
142        clippy::print_stdout,
143        clippy::print_stderr
144    )]
145    use super::*;
146
147    #[test]
148    fn new_creates_registry_with_counters() {
149        let m = McpMetrics::new().unwrap();
150        // Incrementing a counter should make it appear in gather output.
151        m.http_requests_total
152            .with_label_values(&["GET", "/test", "200"])
153            .inc();
154        m.http_request_duration_seconds
155            .with_label_values(&["GET", "/test"])
156            .observe(0.1);
157        assert_eq!(m.registry.gather().len(), 2);
158    }
159
160    #[test]
161    fn encode_empty_registry() {
162        let m = McpMetrics::new().unwrap();
163        let output = m.encode();
164        // Empty counters/histograms produce no samples but the output is valid.
165        assert!(output.is_empty() || output.contains("rmcp_server_kit_"));
166    }
167
168    #[test]
169    fn counter_increment_shows_in_encode() {
170        let m = McpMetrics::new().unwrap();
171        m.http_requests_total
172            .with_label_values(&["GET", "/healthz", "200"])
173            .inc();
174        let output = m.encode();
175        assert!(output.contains("rmcp_server_kit_http_requests_total"));
176        assert!(output.contains("method=\"GET\""));
177        assert!(output.contains("path=\"/healthz\""));
178        assert!(output.contains("status=\"200\""));
179        assert!(output.contains(" 1")); // count = 1
180    }
181
182    #[test]
183    fn histogram_observe_shows_in_encode() {
184        let m = McpMetrics::new().unwrap();
185        m.http_request_duration_seconds
186            .with_label_values(&["POST", "/mcp"])
187            .observe(0.042);
188        let output = m.encode();
189        assert!(output.contains("rmcp_server_kit_http_request_duration_seconds"));
190        assert!(output.contains("method=\"POST\""));
191        assert!(output.contains("path=\"/mcp\""));
192    }
193
194    #[test]
195    fn multiple_increments_accumulate() {
196        let m = McpMetrics::new().unwrap();
197        let counter = m
198            .http_requests_total
199            .with_label_values(&["POST", "/mcp", "200"]);
200        counter.inc();
201        counter.inc();
202        counter.inc();
203        let output = m.encode();
204        assert!(output.contains(" 3")); // count = 3
205    }
206
207    #[test]
208    fn clone_shares_registry() {
209        let m = McpMetrics::new().unwrap();
210        let m2 = m.clone();
211        m.http_requests_total
212            .with_label_values(&["GET", "/test", "200"])
213            .inc();
214        // The clone should see the same counter value.
215        let output = m2.encode();
216        assert!(output.contains(" 1"));
217    }
218
219    // M7 regression: cancelling the shutdown token must release the
220    // metrics listener's bound port so a subsequent bind to the same
221    // address succeeds. Prior to M7 the metrics endpoint ran without
222    // graceful_shutdown wiring and would leak the port until process
223    // exit.
224    #[tokio::test]
225    async fn serve_metrics_releases_port_on_shutdown() {
226        // Pick an ephemeral port, then drop the probe so serve_metrics
227        // can claim it.
228        let probe = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
229        let addr = probe.local_addr().unwrap();
230        drop(probe);
231
232        let metrics = Arc::new(McpMetrics::new().unwrap());
233        let shutdown = tokio_util::sync::CancellationToken::new();
234        let handle = tokio::spawn(serve_metrics(
235            addr.to_string(),
236            Arc::clone(&metrics),
237            shutdown.clone(),
238        ));
239
240        // Wait until the listener is actually accepting connections.
241        let deadline = std::time::Instant::now() + std::time::Duration::from_secs(2);
242        loop {
243            if tokio::net::TcpStream::connect(addr).await.is_ok() {
244                break;
245            }
246            assert!(
247                std::time::Instant::now() < deadline,
248                "metrics listener never accepted on {addr}"
249            );
250            tokio::time::sleep(std::time::Duration::from_millis(20)).await;
251        }
252
253        // Cancel and await graceful shutdown.
254        shutdown.cancel();
255        let join = tokio::time::timeout(std::time::Duration::from_secs(5), handle)
256            .await
257            .expect("serve_metrics did not return within timeout");
258        join.expect("join error")
259            .expect("serve_metrics returned Err");
260
261        // Port must be immediately rebindable.
262        let rebind = tokio::net::TcpListener::bind(addr)
263            .await
264            .expect("port not released after shutdown");
265        drop(rebind);
266    }
267}