Skip to main content

mcp_proxy/
metrics.rs

1//! Prometheus metrics middleware for the proxy.
2//!
3//! Records per-request counters and duration histograms, labeled by
4//! MCP method and backend namespace.
5
6use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Instant;
11
12use metrics::{counter, histogram};
13use tower::{Layer, Service};
14use tower_mcp::{RouterRequest, RouterResponse};
15
16/// Tower layer that produces a [`MetricsService`].
17#[derive(Clone, Default)]
18pub struct MetricsLayer;
19
20impl MetricsLayer {
21    /// Create a new metrics layer.
22    pub fn new() -> Self {
23        Self
24    }
25}
26
27impl<S> Layer<S> for MetricsLayer {
28    type Service = MetricsService<S>;
29
30    fn layer(&self, inner: S) -> Self::Service {
31        MetricsService::new(inner)
32    }
33}
34
35/// Tower service that records request metrics.
36#[derive(Clone)]
37pub struct MetricsService<S> {
38    inner: S,
39}
40
41impl<S> MetricsService<S> {
42    /// Create a new metrics service wrapping `inner`.
43    pub fn new(inner: S) -> Self {
44        Self { inner }
45    }
46}
47
48impl<S> Service<RouterRequest> for MetricsService<S>
49where
50    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
51        + Clone
52        + Send
53        + 'static,
54    S::Future: Send,
55{
56    type Response = RouterResponse;
57    type Error = Infallible;
58    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
59
60    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        self.inner.poll_ready(cx)
62    }
63
64    fn call(&mut self, req: RouterRequest) -> Self::Future {
65        let method = req.inner.method_name().to_string();
66        let start = Instant::now();
67        let fut = self.inner.call(req);
68
69        Box::pin(async move {
70            let result = fut.await;
71            let duration = start.elapsed().as_secs_f64();
72
73            let status = match &result {
74                Ok(resp) => {
75                    if resp.inner.is_ok() {
76                        "ok"
77                    } else {
78                        "error"
79                    }
80                }
81                Err(_) => "error",
82            };
83
84            counter!("mcp_proxy_requests_total", "method" => method.clone(), "status" => status)
85                .increment(1);
86            histogram!(
87                "mcp_proxy_request_duration_seconds",
88                "method" => method,
89            )
90            .record(duration);
91
92            result
93        })
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use tower_mcp::protocol::McpRequest;
100
101    use super::MetricsService;
102    use crate::test_util::{MockService, call_service};
103
104    #[tokio::test]
105    async fn test_metrics_passes_through_request() {
106        let mock = MockService::with_tools(&["tool"]);
107        let mut svc = MetricsService::new(mock);
108
109        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
110        assert!(resp.inner.is_ok());
111    }
112
113    #[tokio::test]
114    async fn test_metrics_passes_through_tool_call() {
115        let mock = MockService::with_tools(&["tool"]);
116        let mut svc = MetricsService::new(mock);
117
118        let resp = call_service(
119            &mut svc,
120            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
121                name: "tool".to_string(),
122                arguments: serde_json::json!({}),
123                meta: None,
124                task: None,
125            }),
126        )
127        .await;
128
129        assert!(resp.inner.is_ok());
130    }
131
132    #[tokio::test]
133    async fn test_metrics_records_error_responses() {
134        let mock = crate::test_util::ErrorMockService;
135        let mut svc = MetricsService::new(mock);
136
137        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
138        // ErrorMockService returns a JSON-RPC error; metrics should record status="error"
139        assert!(resp.inner.is_err());
140    }
141
142    #[tokio::test]
143    async fn test_metrics_handles_ping() {
144        let mock = MockService::with_tools(&[]);
145        let mut svc = MetricsService::new(mock);
146
147        let resp = call_service(&mut svc, McpRequest::Ping).await;
148        assert!(resp.inner.is_ok());
149    }
150}