Skip to main content

mcp_proxy/
metrics.rs

1//! Prometheus metrics middleware for the proxy.
2//!
3//! Records per-request counters and duration histograms, labeled by
4//! MCP method and backend namespace.
5
6use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Instant;
11
12use metrics::{counter, histogram};
13use tower::Service;
14use tower_mcp::{RouterRequest, RouterResponse};
15
16/// Tower service that records request metrics.
17#[derive(Clone)]
18pub struct MetricsService<S> {
19    inner: S,
20}
21
22impl<S> MetricsService<S> {
23    /// Create a new metrics service wrapping `inner`.
24    pub fn new(inner: S) -> Self {
25        Self { inner }
26    }
27}
28
29impl<S> Service<RouterRequest> for MetricsService<S>
30where
31    S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
32        + Clone
33        + Send
34        + 'static,
35    S::Future: Send,
36{
37    type Response = RouterResponse;
38    type Error = Infallible;
39    type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
40
41    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
42        self.inner.poll_ready(cx)
43    }
44
45    fn call(&mut self, req: RouterRequest) -> Self::Future {
46        let method = req.inner.method_name().to_string();
47        let start = Instant::now();
48        let fut = self.inner.call(req);
49
50        Box::pin(async move {
51            let result = fut.await;
52            let duration = start.elapsed().as_secs_f64();
53
54            let status = match &result {
55                Ok(resp) => {
56                    if resp.inner.is_ok() {
57                        "ok"
58                    } else {
59                        "error"
60                    }
61                }
62                Err(_) => "error",
63            };
64
65            counter!("mcp_proxy_requests_total", "method" => method.clone(), "status" => status)
66                .increment(1);
67            histogram!(
68                "mcp_proxy_request_duration_seconds",
69                "method" => method,
70            )
71            .record(duration);
72
73            result
74        })
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use tower_mcp::protocol::McpRequest;
81
82    use super::MetricsService;
83    use crate::test_util::{MockService, call_service};
84
85    #[tokio::test]
86    async fn test_metrics_passes_through_request() {
87        let mock = MockService::with_tools(&["tool"]);
88        let mut svc = MetricsService::new(mock);
89
90        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
91        assert!(resp.inner.is_ok());
92    }
93
94    #[tokio::test]
95    async fn test_metrics_passes_through_tool_call() {
96        let mock = MockService::with_tools(&["tool"]);
97        let mut svc = MetricsService::new(mock);
98
99        let resp = call_service(
100            &mut svc,
101            McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
102                name: "tool".to_string(),
103                arguments: serde_json::json!({}),
104                meta: None,
105                task: None,
106            }),
107        )
108        .await;
109
110        assert!(resp.inner.is_ok());
111    }
112
113    #[tokio::test]
114    async fn test_metrics_records_error_responses() {
115        let mock = crate::test_util::ErrorMockService;
116        let mut svc = MetricsService::new(mock);
117
118        let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
119        // ErrorMockService returns a JSON-RPC error; metrics should record status="error"
120        assert!(resp.inner.is_err());
121    }
122
123    #[tokio::test]
124    async fn test_metrics_handles_ping() {
125        let mock = MockService::with_tools(&[]);
126        let mut svc = MetricsService::new(mock);
127
128        let resp = call_service(&mut svc, McpRequest::Ping).await;
129        assert!(resp.inner.is_ok());
130    }
131}