1use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Instant;
11
12use metrics::{counter, histogram};
13use tower::Service;
14use tower_mcp::{RouterRequest, RouterResponse};
15
16#[derive(Clone)]
18pub struct MetricsService<S> {
19 inner: S,
20}
21
22impl<S> MetricsService<S> {
23 pub fn new(inner: S) -> Self {
25 Self { inner }
26 }
27}
28
29impl<S> Service<RouterRequest> for MetricsService<S>
30where
31 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
32 + Clone
33 + Send
34 + 'static,
35 S::Future: Send,
36{
37 type Response = RouterResponse;
38 type Error = Infallible;
39 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
40
41 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
42 self.inner.poll_ready(cx)
43 }
44
45 fn call(&mut self, req: RouterRequest) -> Self::Future {
46 let method = req.inner.method_name().to_string();
47 let start = Instant::now();
48 let fut = self.inner.call(req);
49
50 Box::pin(async move {
51 let result = fut.await;
52 let duration = start.elapsed().as_secs_f64();
53
54 let status = match &result {
55 Ok(resp) => {
56 if resp.inner.is_ok() {
57 "ok"
58 } else {
59 "error"
60 }
61 }
62 Err(_) => "error",
63 };
64
65 counter!("mcp_proxy_requests_total", "method" => method.clone(), "status" => status)
66 .increment(1);
67 histogram!(
68 "mcp_proxy_request_duration_seconds",
69 "method" => method,
70 )
71 .record(duration);
72
73 result
74 })
75 }
76}
77
78#[cfg(test)]
79mod tests {
80 use tower_mcp::protocol::McpRequest;
81
82 use super::MetricsService;
83 use crate::test_util::{MockService, call_service};
84
85 #[tokio::test]
86 async fn test_metrics_passes_through_request() {
87 let mock = MockService::with_tools(&["tool"]);
88 let mut svc = MetricsService::new(mock);
89
90 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
91 assert!(resp.inner.is_ok());
92 }
93
94 #[tokio::test]
95 async fn test_metrics_passes_through_tool_call() {
96 let mock = MockService::with_tools(&["tool"]);
97 let mut svc = MetricsService::new(mock);
98
99 let resp = call_service(
100 &mut svc,
101 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
102 name: "tool".to_string(),
103 arguments: serde_json::json!({}),
104 meta: None,
105 task: None,
106 }),
107 )
108 .await;
109
110 assert!(resp.inner.is_ok());
111 }
112
113 #[tokio::test]
114 async fn test_metrics_records_error_responses() {
115 let mock = crate::test_util::ErrorMockService;
116 let mut svc = MetricsService::new(mock);
117
118 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
119 assert!(resp.inner.is_err());
121 }
122
123 #[tokio::test]
124 async fn test_metrics_handles_ping() {
125 let mock = MockService::with_tools(&[]);
126 let mut svc = MetricsService::new(mock);
127
128 let resp = call_service(&mut svc, McpRequest::Ping).await;
129 assert!(resp.inner.is_ok());
130 }
131}