1use std::convert::Infallible;
7use std::future::Future;
8use std::pin::Pin;
9use std::task::{Context, Poll};
10use std::time::Instant;
11
12use metrics::{counter, histogram};
13use tower::{Layer, Service};
14use tower_mcp::{RouterRequest, RouterResponse};
15
16#[derive(Clone, Default)]
18pub struct MetricsLayer;
19
20impl MetricsLayer {
21 pub fn new() -> Self {
23 Self
24 }
25}
26
27impl<S> Layer<S> for MetricsLayer {
28 type Service = MetricsService<S>;
29
30 fn layer(&self, inner: S) -> Self::Service {
31 MetricsService::new(inner)
32 }
33}
34
35#[derive(Clone)]
37pub struct MetricsService<S> {
38 inner: S,
39}
40
41impl<S> MetricsService<S> {
42 pub fn new(inner: S) -> Self {
44 Self { inner }
45 }
46}
47
48impl<S> Service<RouterRequest> for MetricsService<S>
49where
50 S: Service<RouterRequest, Response = RouterResponse, Error = Infallible>
51 + Clone
52 + Send
53 + 'static,
54 S::Future: Send,
55{
56 type Response = RouterResponse;
57 type Error = Infallible;
58 type Future = Pin<Box<dyn Future<Output = Result<RouterResponse, Infallible>> + Send>>;
59
60 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.inner.poll_ready(cx)
62 }
63
64 fn call(&mut self, req: RouterRequest) -> Self::Future {
65 let method = req.inner.method_name().to_string();
66 let start = Instant::now();
67 let fut = self.inner.call(req);
68
69 Box::pin(async move {
70 let result = fut.await;
71 let duration = start.elapsed().as_secs_f64();
72
73 let status = match &result {
74 Ok(resp) => {
75 if resp.inner.is_ok() {
76 "ok"
77 } else {
78 "error"
79 }
80 }
81 Err(_) => "error",
82 };
83
84 counter!("mcp_proxy_requests_total", "method" => method.clone(), "status" => status)
85 .increment(1);
86 histogram!(
87 "mcp_proxy_request_duration_seconds",
88 "method" => method,
89 )
90 .record(duration);
91
92 result
93 })
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use tower_mcp::protocol::McpRequest;
100
101 use super::MetricsService;
102 use crate::test_util::{MockService, call_service};
103
104 #[tokio::test]
105 async fn test_metrics_passes_through_request() {
106 let mock = MockService::with_tools(&["tool"]);
107 let mut svc = MetricsService::new(mock);
108
109 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
110 assert!(resp.inner.is_ok());
111 }
112
113 #[tokio::test]
114 async fn test_metrics_passes_through_tool_call() {
115 let mock = MockService::with_tools(&["tool"]);
116 let mut svc = MetricsService::new(mock);
117
118 let resp = call_service(
119 &mut svc,
120 McpRequest::CallTool(tower_mcp::protocol::CallToolParams {
121 name: "tool".to_string(),
122 arguments: serde_json::json!({}),
123 meta: None,
124 task: None,
125 }),
126 )
127 .await;
128
129 assert!(resp.inner.is_ok());
130 }
131
132 #[tokio::test]
133 async fn test_metrics_records_error_responses() {
134 let mock = crate::test_util::ErrorMockService;
135 let mut svc = MetricsService::new(mock);
136
137 let resp = call_service(&mut svc, McpRequest::ListTools(Default::default())).await;
138 assert!(resp.inner.is_err());
140 }
141
142 #[tokio::test]
143 async fn test_metrics_handles_ping() {
144 let mock = MockService::with_tools(&[]);
145 let mut svc = MetricsService::new(mock);
146
147 let resp = call_service(&mut svc, McpRequest::Ping).await;
148 assert!(resp.inner.is_ok());
149 }
150}